From d3593cd8debf8213160719a7ad8b9c30226b8569 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 27 Jun 2021 17:18:55 +0800 Subject: [PATCH] Improve code --- integrations/types_null_test.go | 6 ++++-- scan.go | 20 ++++++++++++++------ session.go | 30 +++++++++++++++++++++++------- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/integrations/types_null_test.go b/integrations/types_null_test.go index 98bd86b9..a3fffa9e 100644 --- a/integrations/types_null_test.go +++ b/integrations/types_null_test.go @@ -7,7 +7,6 @@ package integrations import ( "database/sql" "database/sql/driver" - "errors" "fmt" "strconv" "strings" @@ -26,6 +25,9 @@ type NullType struct { CustomStruct CustomStruct `xorm:"varchar(64) null"` } +var _ sql.Scanner = &CustomStruct{} +var _ driver.Valuer = &CustomStruct{} + type CustomStruct struct { Year int Month int @@ -50,7 +52,7 @@ func (m *CustomStruct) Scan(value interface{}) error { return nil } - return errors.New("scan data not fit []byte") + return fmt.Errorf("scan data type %#v not fit []byte", value) } func (m CustomStruct) Value() (driver.Value, error) { diff --git a/scan.go b/scan.go index c50712db..64021c0d 100644 --- a/scan.go +++ b/scan.go @@ -156,13 +156,20 @@ func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[ return result, nil } -func genScanResult(driver dialects.Driver, fieldType reflect.Type, columnType *sql.ColumnType) (interface{}, error) { - if fieldType.Implements(scannerType) || fieldType.Implements(conversionType) { - return &sql.RawBytes{}, nil +func genScanResult(driver dialects.Driver, fieldValue reflect.Value, columnType *sql.ColumnType) (interface{}, error) { + fieldType := fieldValue.Type() + if fieldValue.Type().Implements(scannerType) || fieldValue.Type().Implements(conversionType) { + return fieldValue.Interface(), nil + } + if fieldValue.CanAddr() && fieldValue.Type().Kind() != reflect.Ptr { + rType := reflect.PtrTo(fieldType) + if rType.Implements(scannerType) || rType.Implements(conversionType) { + return fieldValue.Addr().Interface(), nil + } } switch fieldType.Kind() { case reflect.Ptr: - return genScanResult(driver, fieldType.Elem(), columnType) + return genScanResult(driver, fieldValue.Elem(), columnType) case reflect.Array, reflect.Slice: return &sql.RawBytes{}, nil default: @@ -183,7 +190,7 @@ func genScanResults(driver dialects.Driver, types []*sql.ColumnType) ([]interfac return scanResults, nil } -func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fields []string, table *schemas.Table) ([]interface{}, error) { +func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fields []string, values []reflect.Value, table *schemas.Table) ([]interface{}, error) { var scanResults = make([]interface{}, 0, len(types)) for i, tp := range types { col := table.GetColumn(fields[i]) @@ -192,7 +199,8 @@ func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fi scanResults = append(scanResults, &EmptyScanner{}) continue } - scanResult, err := genScanResult(driver, col.Type, tp) + fmt.Println("=========,,,,,,", col.Name) + scanResult, err := genScanResult(driver, values[i], tp) if err != nil { return nil, err } diff --git a/session.go b/session.go index 791eb0bb..695761bf 100644 --- a/session.go +++ b/session.go @@ -9,6 +9,7 @@ import ( "crypto/rand" "crypto/sha256" "database/sql" + "database/sql/driver" "encoding/hex" "errors" "fmt" @@ -423,7 +424,12 @@ func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fiel closure(bean) } - scanResults, err := genScanResultsWithTable(session.engine.driver, types, fields, table) + values, err := getValues(bean, fields) + if err != nil { + return nil, err + } + + scanResults, err := genScanResultsWithTable(session.engine.driver, types, fields, values, table) if err != nil { return nil, err } @@ -495,12 +501,12 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri } } - if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { - if scanner, ok := fieldValue.Interface().(sql.Scanner); ok { - fmt.Println("===========111111111111") - return scanner.Scan(src) - } + if scanner, ok := fieldValue.Interface().(sql.Scanner); ok { + fmt.Println("===========111111111111") + return scanner.Scan(src) + } + if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { switch t := src.(type) { case *sql.RawBytes: if fieldValue.IsNil() { @@ -526,6 +532,16 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri return nil } + if fieldValue.Type().Implements(valuerType) { + fmt.Println("--------333333--3--33-3") + return nil + } + + if _, ok := fieldValue.Interface().(driver.Valuer); ok { + fmt.Println("22222222222") + return nil + } + rawValueType := reflect.TypeOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface()) fieldType := fieldValue.Type() @@ -957,7 +973,7 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri } // switch fieldType.Kind() if !hasAssigned { - return fmt.Errorf("unsupported convertion from %#v to %#v", src, fieldValue.Interface()) + return fmt.Errorf("unsupported convertion from %#v to %#v on %s", src, fieldValue.Interface(), columnName) } return nil