This commit is contained in:
Lunny Xiao 2021-07-16 14:40:14 +08:00
parent de08656fa3
commit 3871329f03
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
5 changed files with 110 additions and 103 deletions

View File

@ -226,7 +226,16 @@ func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.
}
return convert.String2Time(string(t), dbLoc, uiLoc)
case *sql.NullTime:
tm := t.Time
if !t.Valid {
return nil, nil
}
z, _ := t.Time.Zone()
if len(z) == 0 || t.Time.Year() == 0 || t.Time.Location().String() != dbLoc.String() {
tm := time.Date(t.Time.Year(), t.Time.Month(), t.Time.Day(), t.Time.Hour(),
t.Time.Minute(), t.Time.Second(), t.Time.Nanosecond(), dbLoc).In(uiLoc)
return &tm, nil
}
tm := t.Time.In(uiLoc)
return &tm, nil
case *time.Time:
z, _ := t.Zone()
@ -243,6 +252,9 @@ func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.
case int64:
tm := time.Unix(t, 0).In(uiLoc)
return &tm, nil
case *sql.NullInt64:
tm := time.Unix(t.Int64, 0).In(uiLoc)
return &tm, nil
}
return nil, fmt.Errorf("unsupported value %#v as time", src)
}
@ -329,6 +341,9 @@ func asBytes(src interface{}) ([]byte, bool) {
case []byte:
return t, true
case *sql.NullString:
if !t.Valid {
return nil, true
}
return []byte(t.String), true
case *sql.RawBytes:
return *t, true
@ -763,6 +778,10 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
r := vv.Convert(schemas.NullInt64Type)
return r.Interface().(sql.NullInt64).Int64, nil
}
if vv.Type().ConvertibleTo(schemas.NullStringType) {
r := vv.Convert(schemas.NullStringType)
return r.Interface().(sql.NullString).String, nil
}
}
return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv)
}

View File

@ -109,6 +109,7 @@ func TestGetBytes(t *testing.T) {
type ConvString string
func (s *ConvString) FromDB(data []byte) error {
fmt.Println("3333", string(data))
*s = ConvString("prefix---" + string(data))
return nil
}
@ -127,6 +128,7 @@ func (s *ConvConfig) FromDB(data []byte) error {
s = nil
return nil
}
fmt.Println("11111", string(data))
return json.DefaultJSONHandler.Unmarshal(data, s)
}
@ -140,6 +142,7 @@ func (s *ConvConfig) ToDB() ([]byte, error) {
type SliceType []*ConvConfig
func (s *SliceType) FromDB(data []byte) error {
fmt.Println("2222", string(data))
return json.DefaultJSONHandler.Unmarshal(data, s)
}

46
scan.go
View File

@ -191,38 +191,37 @@ func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnTy
return scanResults, nil
}
func (engine *Engine) genScanResult(tp *sql.ColumnType, v interface{}) (interface{}, bool, error) {
switch t := v.(type) {
case sql.Scanner:
return t, false, nil
case convert.Conversion:
return &sql.RawBytes{}, true, nil
case *big.Float:
return &sql.NullString{}, true, nil
default:
var useNullable = true
if engine.driver.Features().SupportNullable {
nullable, ok := tp.Nullable()
useNullable = ok && nullable
}
if useNullable {
return genScanResultsByBeanNullable(v)
}
return genScanResultsByBean(v)
}
}
// scan is a wrap of driver.Scan but will automatically change the input values according requirements
func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error {
var scanResults = make([]interface{}, 0, len(types))
var replaces = make([]bool, 0, len(types))
var err error
for _, v := range vv {
var replaced bool
var scanResult interface{}
switch t := v.(type) {
case sql.Scanner:
scanResult = t
case convert.Conversion:
scanResult = &sql.RawBytes{}
replaced = true
case *big.Float:
scanResult = &sql.NullString{}
replaced = true
default:
var useNullable = true
if engine.driver.Features().SupportNullable {
nullable, ok := types[0].Nullable()
useNullable = ok && nullable
}
if useNullable {
scanResult, replaced, err = genScanResultsByBeanNullable(v)
} else {
scanResult, replaced, err = genScanResultsByBean(v)
}
scanResult, replaced, err := engine.genScanResult(types[0], v)
if err != nil {
return err
}
}
scanResults = append(scanResults, scanResult)
replaces = append(replaces, replaced)
@ -242,7 +241,6 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
}
}
}
return nil
}

View File

@ -250,6 +250,7 @@ var (
timeDefault time.Time
bigFloatDefault big.Float
nullInt64Default sql.NullInt64
nullStringDefault sql.NullString
)
// enumerates all types
@ -281,6 +282,7 @@ var (
BigFloatType = reflect.TypeOf(bigFloatDefault)
NullInt64Type = reflect.TypeOf(nullInt64Default)
NullStringType = reflect.TypeOf(nullStringDefault)
)
// enumerates all types

View File

@ -429,7 +429,7 @@ func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fiel
return nil, err
}
}
if err := rows.Scan(scanResults...); err != nil {
if err := session.engine.scan(rows, fields, types, scanResults...); err != nil {
return nil, err
}
@ -439,23 +439,19 @@ func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fiel
}
func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, scanResult interface{}) error {
var bs []byte
switch t := scanResult.(type) {
case string:
bs = []byte(t)
case []byte:
bs = t
case *sql.NullString:
bs = []byte(t.String)
default:
bs, ok := asBytes(scanResult)
if !ok {
return fmt.Errorf("unsupported database data type: %#v", scanResult)
}
if len(bs) == 0 {
return nil
}
if len(bs) > 0 {
if fieldType.Kind() == reflect.String {
fieldValue.SetString(string(bs))
return nil
}
if fieldValue.CanAddr() {
err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
@ -469,7 +465,6 @@ func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Typ
}
fieldValue.Set(x.Elem())
}
}
return nil
}
@ -497,6 +492,9 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
if !ok {
return fmt.Errorf("cannot convert %#v as bytes", scanResult)
}
if len(data) == 0 {
return nil
}
return structConvert.FromDB(data)
}
}
@ -510,15 +508,15 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
if !ok {
return fmt.Errorf("cannot convert %#v as bytes", scanResult)
}
if data != nil {
if data == nil {
return nil
}
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
return fieldValue.Interface().(convert.Conversion).FromDB(data)
}
return structConvert.FromDB(data)
}
return nil
}
rawValueType := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface())
@ -539,59 +537,43 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
if err := session.convertBeanField(col, &e, scanResult); err != nil {
return err
}
if fieldValue.IsNil() {
if fieldValue.IsNil() && !e.Addr().IsNil() {
fieldValue.Set(e.Addr())
}
return nil
case reflect.Complex64, reflect.Complex128:
return session.setJSON(fieldValue, fieldType, scanResult)
case reflect.Map:
switch rawValueType.Kind() {
case reflect.String:
return session.setJSON(fieldValue, fieldType, scanResult)
case reflect.Slice:
switch scanResult.(type) {
case string, []byte, *sql.NullString, *sql.RawBytes:
return session.setJSON(fieldValue, fieldType, scanResult)
default:
return fmt.Errorf("unsupported %v -> %T", scanResult, fieldType)
return fmt.Errorf("unsupported %#v -> %T map", scanResult, fieldType)
}
case reflect.Slice, reflect.Array:
switch rawValueType.Kind() {
case reflect.String:
x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil {
return err
bs, ok := asBytes(scanResult)
if !ok {
return fmt.Errorf("unsupported %#v -> %T slice,array", scanResult, fieldType)
}
fieldValue.Set(x.Elem())
if bs == nil {
return nil
case reflect.Slice, reflect.Array:
switch rawValueType.Elem().Kind() {
case reflect.Uint8:
}
if fieldType.Elem().Kind() == reflect.Uint8 {
if fieldValue.Len() > 0 {
for i := 0; i < fieldValue.Len(); i++ {
if i < vv.Len() {
fieldValue.Index(i).Set(vv.Index(i))
if i < len(bs) {
fieldValue.Index(i).Set(reflect.ValueOf(bs[i]))
}
}
} else {
for i := 0; i < vv.Len(); i++ {
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i)))
fieldValue.Set(reflect.Append(*fieldValue, reflect.ValueOf(bs[i])))
}
}
return nil
}
if col.SQLType.IsText() {
x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface())
if err != nil {
return err
}
fieldValue.Set(x.Elem())
return nil
}
}
}
return session.setJSON(fieldValue, fieldType, scanResult)
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.BigFloatType) {
v, err := asBigFloat(scanResult)
@ -612,6 +594,9 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
if err != nil {
return err
}
if tm == nil {
return nil
}
fieldValue.Set(reflect.ValueOf(*tm).Convert(fieldType))
return nil
} else if session.statement.UseCascade {