From 3871329f03833c9169e43159043165d7e4c6f845 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 16 Jul 2021 14:40:14 +0800 Subject: [PATCH] Fix --- convert.go | 21 +++++- integrations/types_test.go | 3 + scan.go | 50 +++++++------- schemas/type.go | 4 +- session.go | 135 +++++++++++++++++-------------------- 5 files changed, 110 insertions(+), 103 deletions(-) diff --git a/convert.go b/convert.go index 5bf6ac78..b1322e06 100644 --- a/convert.go +++ b/convert.go @@ -226,7 +226,16 @@ func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time. } return convert.String2Time(string(t), dbLoc, uiLoc) case *sql.NullTime: - tm := t.Time + if !t.Valid { + return nil, nil + } + z, _ := t.Time.Zone() + if len(z) == 0 || t.Time.Year() == 0 || t.Time.Location().String() != dbLoc.String() { + tm := time.Date(t.Time.Year(), t.Time.Month(), t.Time.Day(), t.Time.Hour(), + t.Time.Minute(), t.Time.Second(), t.Time.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.Time.In(uiLoc) return &tm, nil case *time.Time: z, _ := t.Zone() @@ -243,6 +252,9 @@ func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time. case int64: tm := time.Unix(t, 0).In(uiLoc) return &tm, nil + case *sql.NullInt64: + tm := time.Unix(t.Int64, 0).In(uiLoc) + return &tm, nil } return nil, fmt.Errorf("unsupported value %#v as time", src) } @@ -329,6 +341,9 @@ func asBytes(src interface{}) ([]byte, bool) { case []byte: return t, true case *sql.NullString: + if !t.Valid { + return nil, true + } return []byte(t.String), true case *sql.RawBytes: return *t, true @@ -763,6 +778,10 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { r := vv.Convert(schemas.NullInt64Type) return r.Interface().(sql.NullInt64).Int64, nil } + if vv.Type().ConvertibleTo(schemas.NullStringType) { + r := vv.Convert(schemas.NullStringType) + return r.Interface().(sql.NullString).String, nil + } } return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) } diff --git a/integrations/types_test.go b/integrations/types_test.go index f192c1ff..98989496 100644 --- a/integrations/types_test.go +++ b/integrations/types_test.go @@ -109,6 +109,7 @@ func TestGetBytes(t *testing.T) { type ConvString string func (s *ConvString) FromDB(data []byte) error { + fmt.Println("3333", string(data)) *s = ConvString("prefix---" + string(data)) return nil } @@ -127,6 +128,7 @@ func (s *ConvConfig) FromDB(data []byte) error { s = nil return nil } + fmt.Println("11111", string(data)) return json.DefaultJSONHandler.Unmarshal(data, s) } @@ -140,6 +142,7 @@ func (s *ConvConfig) ToDB() ([]byte, error) { type SliceType []*ConvConfig func (s *SliceType) FromDB(data []byte) error { + fmt.Println("2222", string(data)) return json.DefaultJSONHandler.Unmarshal(data, s) } diff --git a/scan.go b/scan.go index 2fedd415..a92c729e 100644 --- a/scan.go +++ b/scan.go @@ -191,37 +191,36 @@ func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnTy return scanResults, nil } +func (engine *Engine) genScanResult(tp *sql.ColumnType, v interface{}) (interface{}, bool, error) { + switch t := v.(type) { + case sql.Scanner: + return t, false, nil + case convert.Conversion: + return &sql.RawBytes{}, true, nil + case *big.Float: + return &sql.NullString{}, true, nil + default: + var useNullable = true + if engine.driver.Features().SupportNullable { + nullable, ok := tp.Nullable() + useNullable = ok && nullable + } + if useNullable { + return genScanResultsByBeanNullable(v) + } + return genScanResultsByBean(v) + } +} + // scan is a wrap of driver.Scan but will automatically change the input values according requirements func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error { var scanResults = make([]interface{}, 0, len(types)) var replaces = make([]bool, 0, len(types)) var err error for _, v := range vv { - var replaced bool - var scanResult interface{} - switch t := v.(type) { - case sql.Scanner: - scanResult = t - case convert.Conversion: - scanResult = &sql.RawBytes{} - replaced = true - case *big.Float: - scanResult = &sql.NullString{} - replaced = true - default: - var useNullable = true - if engine.driver.Features().SupportNullable { - nullable, ok := types[0].Nullable() - useNullable = ok && nullable - } - if useNullable { - scanResult, replaced, err = genScanResultsByBeanNullable(v) - } else { - scanResult, replaced, err = genScanResultsByBean(v) - } - if err != nil { - return err - } + scanResult, replaced, err := engine.genScanResult(types[0], v) + if err != nil { + return err } scanResults = append(scanResults, scanResult) @@ -242,7 +241,6 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column } } } - return nil } diff --git a/schemas/type.go b/schemas/type.go index 7dff9cf6..da1e51d8 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -250,6 +250,7 @@ var ( timeDefault time.Time bigFloatDefault big.Float nullInt64Default sql.NullInt64 + nullStringDefault sql.NullString ) // enumerates all types @@ -280,7 +281,8 @@ var ( TimeType = reflect.TypeOf(timeDefault) BigFloatType = reflect.TypeOf(bigFloatDefault) - NullInt64Type = reflect.TypeOf(nullInt64Default) + NullInt64Type = reflect.TypeOf(nullInt64Default) + NullStringType = reflect.TypeOf(nullStringDefault) ) // enumerates all types diff --git a/session.go b/session.go index 7889b447..81cae5d2 100644 --- a/session.go +++ b/session.go @@ -429,7 +429,7 @@ func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fiel return nil, err } } - if err := rows.Scan(scanResults...); err != nil { + if err := session.engine.scan(rows, fields, types, scanResults...); err != nil { return nil, err } @@ -439,36 +439,31 @@ func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fiel } func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, scanResult interface{}) error { - var bs []byte - switch t := scanResult.(type) { - case string: - bs = []byte(t) - case []byte: - bs = t - case *sql.NullString: - bs = []byte(t.String) - default: + bs, ok := asBytes(scanResult) + if !ok { return fmt.Errorf("unsupported database data type: %#v", scanResult) } + if len(bs) == 0 { + return nil + } - if len(bs) > 0 { - if fieldType.Kind() == reflect.String { - fieldValue.SetString(string(bs)) - return nil + if fieldType.Kind() == reflect.String { + fieldValue.SetString(string(bs)) + return nil + } + + if fieldValue.CanAddr() { + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + if err != nil { + return err } - if fieldValue.CanAddr() { - err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return err - } - } else { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) + } else { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) + if err != nil { + return err } + fieldValue.Set(x.Elem()) } return nil } @@ -497,6 +492,9 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec if !ok { return fmt.Errorf("cannot convert %#v as bytes", scanResult) } + if len(data) == 0 { + return nil + } return structConvert.FromDB(data) } } @@ -510,14 +508,14 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec if !ok { return fmt.Errorf("cannot convert %#v as bytes", scanResult) } - if data != nil { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - return fieldValue.Interface().(convert.Conversion).FromDB(data) - } - return structConvert.FromDB(data) + if data == nil { + return nil } - return nil + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + return fieldValue.Interface().(convert.Conversion).FromDB(data) + } + return structConvert.FromDB(data) } rawValueType := reflect.TypeOf(rawValue.Interface()) @@ -539,59 +537,43 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec if err := session.convertBeanField(col, &e, scanResult); err != nil { return err } - if fieldValue.IsNil() { + if fieldValue.IsNil() && !e.Addr().IsNil() { fieldValue.Set(e.Addr()) } return nil case reflect.Complex64, reflect.Complex128: return session.setJSON(fieldValue, fieldType, scanResult) case reflect.Map: - switch rawValueType.Kind() { - case reflect.String: - return session.setJSON(fieldValue, fieldType, scanResult) - case reflect.Slice: + switch scanResult.(type) { + case string, []byte, *sql.NullString, *sql.RawBytes: return session.setJSON(fieldValue, fieldType, scanResult) default: - return fmt.Errorf("unsupported %v -> %T", scanResult, fieldType) + return fmt.Errorf("unsupported %#v -> %T map", scanResult, fieldType) } case reflect.Slice, reflect.Array: - switch rawValueType.Kind() { - case reflect.String: - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - return nil - case reflect.Slice, reflect.Array: - switch rawValueType.Elem().Kind() { - case reflect.Uint8: - if fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - for i := 0; i < fieldValue.Len(); i++ { - if i < vv.Len() { - fieldValue.Index(i).Set(vv.Index(i)) - } - } - } else { - for i := 0; i < vv.Len(); i++ { - fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) - } - } - return nil - } - if col.SQLType.IsText() { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - return nil - } - } + bs, ok := asBytes(scanResult) + if !ok { + return fmt.Errorf("unsupported %#v -> %T slice,array", scanResult, fieldType) } + if bs == nil { + return nil + } + + if fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + for i := 0; i < fieldValue.Len(); i++ { + if i < len(bs) { + fieldValue.Index(i).Set(reflect.ValueOf(bs[i])) + } + } + } else { + for i := 0; i < vv.Len(); i++ { + fieldValue.Set(reflect.Append(*fieldValue, reflect.ValueOf(bs[i]))) + } + } + return nil + } + return session.setJSON(fieldValue, fieldType, scanResult) case reflect.Struct: if fieldType.ConvertibleTo(schemas.BigFloatType) { v, err := asBigFloat(scanResult) @@ -612,6 +594,9 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec if err != nil { return err } + if tm == nil { + return nil + } fieldValue.Set(reflect.ValueOf(*tm).Convert(fieldType)) return nil } else if session.statement.UseCascade {