Improve code

This commit is contained in:
Lunny Xiao 2021-06-27 17:18:55 +08:00
parent f22f863fc7
commit d3593cd8de
3 changed files with 41 additions and 15 deletions

View File

@ -7,7 +7,6 @@ package integrations
import ( import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@ -26,6 +25,9 @@ type NullType struct {
CustomStruct CustomStruct `xorm:"varchar(64) null"` CustomStruct CustomStruct `xorm:"varchar(64) null"`
} }
var _ sql.Scanner = &CustomStruct{}
var _ driver.Valuer = &CustomStruct{}
type CustomStruct struct { type CustomStruct struct {
Year int Year int
Month int Month int
@ -50,7 +52,7 @@ func (m *CustomStruct) Scan(value interface{}) error {
return nil return nil
} }
return errors.New("scan data not fit []byte") return fmt.Errorf("scan data type %#v not fit []byte", value)
} }
func (m CustomStruct) Value() (driver.Value, error) { func (m CustomStruct) Value() (driver.Value, error) {

20
scan.go
View File

@ -156,13 +156,20 @@ func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[
return result, nil return result, nil
} }
func genScanResult(driver dialects.Driver, fieldType reflect.Type, columnType *sql.ColumnType) (interface{}, error) { func genScanResult(driver dialects.Driver, fieldValue reflect.Value, columnType *sql.ColumnType) (interface{}, error) {
if fieldType.Implements(scannerType) || fieldType.Implements(conversionType) { fieldType := fieldValue.Type()
return &sql.RawBytes{}, nil if fieldValue.Type().Implements(scannerType) || fieldValue.Type().Implements(conversionType) {
return fieldValue.Interface(), nil
}
if fieldValue.CanAddr() && fieldValue.Type().Kind() != reflect.Ptr {
rType := reflect.PtrTo(fieldType)
if rType.Implements(scannerType) || rType.Implements(conversionType) {
return fieldValue.Addr().Interface(), nil
}
} }
switch fieldType.Kind() { switch fieldType.Kind() {
case reflect.Ptr: case reflect.Ptr:
return genScanResult(driver, fieldType.Elem(), columnType) return genScanResult(driver, fieldValue.Elem(), columnType)
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
return &sql.RawBytes{}, nil return &sql.RawBytes{}, nil
default: default:
@ -183,7 +190,7 @@ func genScanResults(driver dialects.Driver, types []*sql.ColumnType) ([]interfac
return scanResults, nil return scanResults, nil
} }
func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fields []string, table *schemas.Table) ([]interface{}, error) { func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fields []string, values []reflect.Value, table *schemas.Table) ([]interface{}, error) {
var scanResults = make([]interface{}, 0, len(types)) var scanResults = make([]interface{}, 0, len(types))
for i, tp := range types { for i, tp := range types {
col := table.GetColumn(fields[i]) col := table.GetColumn(fields[i])
@ -192,7 +199,8 @@ func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fi
scanResults = append(scanResults, &EmptyScanner{}) scanResults = append(scanResults, &EmptyScanner{})
continue continue
} }
scanResult, err := genScanResult(driver, col.Type, tp) fmt.Println("=========,,,,,,", col.Name)
scanResult, err := genScanResult(driver, values[i], tp)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -9,6 +9,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"database/sql" "database/sql"
"database/sql/driver"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
@ -423,7 +424,12 @@ func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fiel
closure(bean) closure(bean)
} }
scanResults, err := genScanResultsWithTable(session.engine.driver, types, fields, table) values, err := getValues(bean, fields)
if err != nil {
return nil, err
}
scanResults, err := genScanResultsWithTable(session.engine.driver, types, fields, values, table)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -495,12 +501,12 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
} }
} }
if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { if scanner, ok := fieldValue.Interface().(sql.Scanner); ok {
if scanner, ok := fieldValue.Interface().(sql.Scanner); ok { fmt.Println("===========111111111111")
fmt.Println("===========111111111111") return scanner.Scan(src)
return scanner.Scan(src) }
}
if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok {
switch t := src.(type) { switch t := src.(type) {
case *sql.RawBytes: case *sql.RawBytes:
if fieldValue.IsNil() { if fieldValue.IsNil() {
@ -526,6 +532,16 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
return nil return nil
} }
if fieldValue.Type().Implements(valuerType) {
fmt.Println("--------333333--3--33-3")
return nil
}
if _, ok := fieldValue.Interface().(driver.Valuer); ok {
fmt.Println("22222222222")
return nil
}
rawValueType := reflect.TypeOf(rawValue.Interface()) rawValueType := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface())
fieldType := fieldValue.Type() fieldType := fieldValue.Type()
@ -957,7 +973,7 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
} // switch fieldType.Kind() } // switch fieldType.Kind()
if !hasAssigned { if !hasAssigned {
return fmt.Errorf("unsupported convertion from %#v to %#v", src, fieldValue.Interface()) return fmt.Errorf("unsupported convertion from %#v to %#v on %s", src, fieldValue.Interface(), columnName)
} }
return nil return nil