Improve code
This commit is contained in:
parent
f22f863fc7
commit
d3593cd8de
|
@ -7,7 +7,6 @@ package integrations
|
|||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -26,6 +25,9 @@ type NullType struct {
|
|||
CustomStruct CustomStruct `xorm:"varchar(64) null"`
|
||||
}
|
||||
|
||||
var _ sql.Scanner = &CustomStruct{}
|
||||
var _ driver.Valuer = &CustomStruct{}
|
||||
|
||||
type CustomStruct struct {
|
||||
Year int
|
||||
Month int
|
||||
|
@ -50,7 +52,7 @@ func (m *CustomStruct) Scan(value interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
return errors.New("scan data not fit []byte")
|
||||
return fmt.Errorf("scan data type %#v not fit []byte", value)
|
||||
}
|
||||
|
||||
func (m CustomStruct) Value() (driver.Value, error) {
|
||||
|
|
20
scan.go
20
scan.go
|
@ -156,13 +156,20 @@ func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func genScanResult(driver dialects.Driver, fieldType reflect.Type, columnType *sql.ColumnType) (interface{}, error) {
|
||||
if fieldType.Implements(scannerType) || fieldType.Implements(conversionType) {
|
||||
return &sql.RawBytes{}, nil
|
||||
func genScanResult(driver dialects.Driver, fieldValue reflect.Value, columnType *sql.ColumnType) (interface{}, error) {
|
||||
fieldType := fieldValue.Type()
|
||||
if fieldValue.Type().Implements(scannerType) || fieldValue.Type().Implements(conversionType) {
|
||||
return fieldValue.Interface(), nil
|
||||
}
|
||||
if fieldValue.CanAddr() && fieldValue.Type().Kind() != reflect.Ptr {
|
||||
rType := reflect.PtrTo(fieldType)
|
||||
if rType.Implements(scannerType) || rType.Implements(conversionType) {
|
||||
return fieldValue.Addr().Interface(), nil
|
||||
}
|
||||
}
|
||||
switch fieldType.Kind() {
|
||||
case reflect.Ptr:
|
||||
return genScanResult(driver, fieldType.Elem(), columnType)
|
||||
return genScanResult(driver, fieldValue.Elem(), columnType)
|
||||
case reflect.Array, reflect.Slice:
|
||||
return &sql.RawBytes{}, nil
|
||||
default:
|
||||
|
@ -183,7 +190,7 @@ func genScanResults(driver dialects.Driver, types []*sql.ColumnType) ([]interfac
|
|||
return scanResults, nil
|
||||
}
|
||||
|
||||
func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fields []string, table *schemas.Table) ([]interface{}, error) {
|
||||
func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fields []string, values []reflect.Value, table *schemas.Table) ([]interface{}, error) {
|
||||
var scanResults = make([]interface{}, 0, len(types))
|
||||
for i, tp := range types {
|
||||
col := table.GetColumn(fields[i])
|
||||
|
@ -192,7 +199,8 @@ func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fi
|
|||
scanResults = append(scanResults, &EmptyScanner{})
|
||||
continue
|
||||
}
|
||||
scanResult, err := genScanResult(driver, col.Type, tp)
|
||||
fmt.Println("=========,,,,,,", col.Name)
|
||||
scanResult, err := genScanResult(driver, values[i], tp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
22
session.go
22
session.go
|
@ -9,6 +9,7 @@ import (
|
|||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -423,7 +424,12 @@ func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fiel
|
|||
closure(bean)
|
||||
}
|
||||
|
||||
scanResults, err := genScanResultsWithTable(session.engine.driver, types, fields, table)
|
||||
values, err := getValues(bean, fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
scanResults, err := genScanResultsWithTable(session.engine.driver, types, fields, values, table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -495,12 +501,12 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
|
|||
}
|
||||
}
|
||||
|
||||
if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok {
|
||||
if scanner, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||
fmt.Println("===========111111111111")
|
||||
return scanner.Scan(src)
|
||||
}
|
||||
|
||||
if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok {
|
||||
switch t := src.(type) {
|
||||
case *sql.RawBytes:
|
||||
if fieldValue.IsNil() {
|
||||
|
@ -526,6 +532,16 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
|
|||
return nil
|
||||
}
|
||||
|
||||
if fieldValue.Type().Implements(valuerType) {
|
||||
fmt.Println("--------333333--3--33-3")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, ok := fieldValue.Interface().(driver.Valuer); ok {
|
||||
fmt.Println("22222222222")
|
||||
return nil
|
||||
}
|
||||
|
||||
rawValueType := reflect.TypeOf(rawValue.Interface())
|
||||
vv := reflect.ValueOf(rawValue.Interface())
|
||||
fieldType := fieldValue.Type()
|
||||
|
@ -957,7 +973,7 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
|
|||
} // switch fieldType.Kind()
|
||||
|
||||
if !hasAssigned {
|
||||
return fmt.Errorf("unsupported convertion from %#v to %#v", src, fieldValue.Interface())
|
||||
return fmt.Errorf("unsupported convertion from %#v to %#v on %s", src, fieldValue.Interface(), columnName)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
Loading…
Reference in New Issue