Improve code

This commit is contained in:
Lunny Xiao 2021-06-27 17:18:55 +08:00
parent f22f863fc7
commit d3593cd8de
3 changed files with 41 additions and 15 deletions

View File

@ -7,7 +7,6 @@ package integrations
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"strconv"
"strings"
@ -26,6 +25,9 @@ type NullType struct {
CustomStruct CustomStruct `xorm:"varchar(64) null"`
}
var _ sql.Scanner = &CustomStruct{}
var _ driver.Valuer = &CustomStruct{}
type CustomStruct struct {
Year int
Month int
@ -50,7 +52,7 @@ func (m *CustomStruct) Scan(value interface{}) error {
return nil
}
return errors.New("scan data not fit []byte")
return fmt.Errorf("scan data type %#v not fit []byte", value)
}
func (m CustomStruct) Value() (driver.Value, error) {

20
scan.go
View File

@ -156,13 +156,20 @@ func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[
return result, nil
}
func genScanResult(driver dialects.Driver, fieldType reflect.Type, columnType *sql.ColumnType) (interface{}, error) {
if fieldType.Implements(scannerType) || fieldType.Implements(conversionType) {
return &sql.RawBytes{}, nil
func genScanResult(driver dialects.Driver, fieldValue reflect.Value, columnType *sql.ColumnType) (interface{}, error) {
fieldType := fieldValue.Type()
if fieldValue.Type().Implements(scannerType) || fieldValue.Type().Implements(conversionType) {
return fieldValue.Interface(), nil
}
if fieldValue.CanAddr() && fieldValue.Type().Kind() != reflect.Ptr {
rType := reflect.PtrTo(fieldType)
if rType.Implements(scannerType) || rType.Implements(conversionType) {
return fieldValue.Addr().Interface(), nil
}
}
switch fieldType.Kind() {
case reflect.Ptr:
return genScanResult(driver, fieldType.Elem(), columnType)
return genScanResult(driver, fieldValue.Elem(), columnType)
case reflect.Array, reflect.Slice:
return &sql.RawBytes{}, nil
default:
@ -183,7 +190,7 @@ func genScanResults(driver dialects.Driver, types []*sql.ColumnType) ([]interfac
return scanResults, nil
}
func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fields []string, table *schemas.Table) ([]interface{}, error) {
func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fields []string, values []reflect.Value, table *schemas.Table) ([]interface{}, error) {
var scanResults = make([]interface{}, 0, len(types))
for i, tp := range types {
col := table.GetColumn(fields[i])
@ -192,7 +199,8 @@ func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fi
scanResults = append(scanResults, &EmptyScanner{})
continue
}
scanResult, err := genScanResult(driver, col.Type, tp)
fmt.Println("=========,,,,,,", col.Name)
scanResult, err := genScanResult(driver, values[i], tp)
if err != nil {
return nil, err
}

View File

@ -9,6 +9,7 @@ import (
"crypto/rand"
"crypto/sha256"
"database/sql"
"database/sql/driver"
"encoding/hex"
"errors"
"fmt"
@ -423,7 +424,12 @@ func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fiel
closure(bean)
}
scanResults, err := genScanResultsWithTable(session.engine.driver, types, fields, table)
values, err := getValues(bean, fields)
if err != nil {
return nil, err
}
scanResults, err := genScanResultsWithTable(session.engine.driver, types, fields, values, table)
if err != nil {
return nil, err
}
@ -495,12 +501,12 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
}
}
if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok {
if scanner, ok := fieldValue.Interface().(sql.Scanner); ok {
fmt.Println("===========111111111111")
return scanner.Scan(src)
}
if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok {
switch t := src.(type) {
case *sql.RawBytes:
if fieldValue.IsNil() {
@ -526,6 +532,16 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
return nil
}
if fieldValue.Type().Implements(valuerType) {
fmt.Println("--------333333--3--33-3")
return nil
}
if _, ok := fieldValue.Interface().(driver.Valuer); ok {
fmt.Println("22222222222")
return nil
}
rawValueType := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface())
fieldType := fieldValue.Type()
@ -957,7 +973,7 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
} // switch fieldType.Kind()
if !hasAssigned {
return fmt.Errorf("unsupported convertion from %#v to %#v", src, fieldValue.Interface())
return fmt.Errorf("unsupported convertion from %#v to %#v on %s", src, fieldValue.Interface(), columnName)
}
return nil