From 27ff0fd87332e1c8221f7e95a42e6a18ae28eeec Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 16 Jul 2021 13:21:55 +0800 Subject: [PATCH] refactor more --- convert.go | 80 +++++++++++++++++++++++++++++----- processors.go | 4 +- schemas/type.go | 4 ++ session.go | 113 +++++++++++++++++------------------------------- session_raw.go | 6 +-- 5 files changed, 118 insertions(+), 89 deletions(-) diff --git a/convert.go b/convert.go index 116dd783..5bf6ac78 100644 --- a/convert.go +++ b/convert.go @@ -15,6 +15,7 @@ import ( "time" "xorm.io/xorm/convert" + "xorm.io/xorm/schemas" ) var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error @@ -192,6 +193,8 @@ func asFloat64(src interface{}) (float64, error) { return float64(v.Int32), nil case *sql.NullInt64: return float64(v.Int64), nil + case *sql.NullFloat64: + return v.Float64, nil } rv := reflect.ValueOf(src) @@ -208,6 +211,42 @@ func asFloat64(src interface{}) (float64, error) { return 0, fmt.Errorf("unsupported value %T as int64", src) } +func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.Time, error) { + switch t := src.(type) { + case string: + return convert.String2Time(t, dbLoc, uiLoc) + case *sql.NullString: + if !t.Valid { + return nil, nil + } + return convert.String2Time(t.String, dbLoc, uiLoc) + case []uint8: + if t == nil { + return nil, nil + } + return convert.String2Time(string(t), dbLoc, uiLoc) + case *sql.NullTime: + tm := t.Time + return &tm, nil + case *time.Time: + z, _ := t.Zone() + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { + tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.In(uiLoc) + return &tm, nil + case int: + tm := time.Unix(int64(t), 0).In(uiLoc) + return &tm, nil + case int64: + tm := time.Unix(t, 0).In(uiLoc) + return &tm, nil + } + return nil, fmt.Errorf("unsupported value %#v as time", src) +} + func asBigFloat(src interface{}) (*big.Float, error) { res := big.NewFloat(0) switch v := src.(type) { @@ -285,23 +324,33 @@ func asBigFloat(src interface{}) (*big.Float, error) { return nil, fmt.Errorf("unsupported value %T as big.Float", src) } -func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { +func asBytes(src interface{}) ([]byte, bool) { + switch t := src.(type) { + case []byte: + return t, true + case *sql.NullString: + return []byte(t.String), true + case *sql.RawBytes: + return *t, true + } + + rv := reflect.ValueOf(src) switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return strconv.AppendInt(buf, rv.Int(), 10), true + return strconv.AppendInt(nil, rv.Int(), 10), true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return strconv.AppendUint(buf, rv.Uint(), 10), true + return strconv.AppendUint(nil, rv.Uint(), 10), true case reflect.Float32: - return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 32), true case reflect.Float64: - return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 64), true case reflect.Bool: - return strconv.AppendBool(buf, rv.Bool()), true + return strconv.AppendBool(nil, rv.Bool()), true case reflect.String: s := rv.String() - return append(buf, s...), true + return []byte(s), true } - return + return nil, false } // convertAssign copies to dest the value in src, converting it if possible. @@ -559,8 +608,7 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve return nil } case *[]byte: - sv = reflect.ValueOf(src) - if b, ok := asBytes(nil, sv); ok { + if b, ok := asBytes(src); ok { *d = b return nil } @@ -678,6 +726,8 @@ func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, conver func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { switch tp.Kind() { + case reflect.Ptr: + return asKind(vv.Elem(), tp.Elem()) case reflect.Int64: return vv.Int(), nil case reflect.Int: @@ -708,7 +758,11 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { } return v, nil } - + case reflect.Struct: + if vv.Type().ConvertibleTo(schemas.NullInt64Type) { + r := vv.Convert(schemas.NullInt64Type) + return r.Interface().(sql.NullInt64).Int64, nil + } } return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) } @@ -743,6 +797,10 @@ func asBool(src interface{}) (bool, error) { return strconv.ParseBool(string(v)) case string: return strconv.ParseBool(v) + case *sql.NullInt64: + return v.Int64 > 0, nil + case *sql.NullInt32: + return v.Int32 > 0, nil default: return false, fmt.Errorf("unknow type %T as bool", src) } diff --git a/processors.go b/processors.go index 8697e302..b17ef648 100644 --- a/processors.go +++ b/processors.go @@ -94,7 +94,7 @@ func executeBeforeClosures(session *Session, bean interface{}) { func executeBeforeSet(bean interface{}, fields []string, scanResults []interface{}) { if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet { for ii, key := range fields { - b.BeforeSet(key, Cell(scanResults[ii].(*interface{}))) + b.BeforeSet(key, Cell(scanResults[ii])) } } } @@ -102,7 +102,7 @@ func executeBeforeSet(bean interface{}, fields []string, scanResults []interface func executeAfterSet(bean interface{}, fields []string, scanResults []interface{}) { if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { for ii, key := range fields { - b.AfterSet(key, Cell(scanResults[ii].(*interface{}))) + b.AfterSet(key, Cell(scanResults[ii])) } } } diff --git a/schemas/type.go b/schemas/type.go index 62e66c2e..7dff9cf6 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -5,6 +5,7 @@ package schemas import ( + "database/sql" "math/big" "reflect" "sort" @@ -248,6 +249,7 @@ var ( uintDefault uint timeDefault time.Time bigFloatDefault big.Float + nullInt64Default sql.NullInt64 ) // enumerates all types @@ -277,6 +279,8 @@ var ( TimeType = reflect.TypeOf(timeDefault) BigFloatType = reflect.TypeOf(bigFloatDefault) + + NullInt64Type = reflect.TypeOf(nullInt64Default) ) // enumerates all types diff --git a/session.go b/session.go index b3e7d0ab..7889b447 100644 --- a/session.go +++ b/session.go @@ -16,7 +16,6 @@ import ( "io" "reflect" "strings" - "time" "xorm.io/xorm/contexts" "xorm.io/xorm/convert" @@ -387,7 +386,7 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s } // Cell cell is a result of one column field -type Cell *interface{} +type Cell interface{} func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, newElemFunc func([]string) reflect.Value, @@ -439,14 +438,17 @@ func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fiel return scanResults, nil } -func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, vv reflect.Value, rawValueType reflect.Type) error { +func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, scanResult interface{}) error { var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(schemas.BytesType) { - bs = vv.Bytes() - } else { - return fmt.Errorf("unsupported database data type: %v", rawValueType.Kind()) + switch t := scanResult.(type) { + case string: + bs = []byte(t) + case []byte: + bs = t + case *sql.NullString: + bs = []byte(t.String) + default: + return fmt.Errorf("unsupported database data type: %#v", scanResult) } if len(bs) > 0 { @@ -487,26 +489,33 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } if fieldValue.CanAddr() { + if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + return scanner.Scan(scanResult) + } if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - data, err := value2Bytes(&rawValue) - if err != nil { - return err + data, ok := asBytes(scanResult) + if !ok { + return fmt.Errorf("cannot convert %#v as bytes", scanResult) } - if err := structConvert.FromDB(data); err != nil { - return err - } - return nil + return structConvert.FromDB(data) } } - if _, ok := fieldValue.Interface().(convert.Conversion); ok { - if data, err := value2Bytes(&rawValue); err == nil { + if scanner, ok := fieldValue.Interface().(sql.Scanner); ok { + return scanner.Scan(scanResult) + } + + if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { + data, ok := asBytes(scanResult) + if !ok { + return fmt.Errorf("cannot convert %#v as bytes", scanResult) + } + if data != nil { if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + return fieldValue.Interface().(convert.Conversion).FromDB(data) } - fieldValue.Interface().(convert.Conversion).FromDB(data) - } else { - return err + return structConvert.FromDB(data) } return nil } @@ -516,7 +525,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec fieldType := fieldValue.Type() if col.IsJSON { - return session.setJSON(fieldValue, fieldType, vv, rawValueType) + return session.setJSON(fieldValue, fieldType, scanResult) } switch fieldType.Kind() { @@ -535,13 +544,13 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } return nil case reflect.Complex64, reflect.Complex128: - return session.setJSON(fieldValue, fieldType, vv, rawValueType) + return session.setJSON(fieldValue, fieldType, scanResult) case reflect.Map: switch rawValueType.Kind() { case reflect.String: - return session.setJSON(fieldValue, fieldType, vv, rawValueType) + return session.setJSON(fieldValue, fieldType, scanResult) case reflect.Slice: - return session.setJSON(fieldValue, fieldType, vv, rawValueType) + return session.setJSON(fieldValue, fieldType, scanResult) default: return fmt.Errorf("unsupported %v -> %T", scanResult, fieldType) } @@ -556,7 +565,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec fieldValue.Set(x.Elem()) return nil case reflect.Slice, reflect.Array: - fmt.Printf("======%T\n", scanResult) switch rawValueType.Elem().Kind() { case reflect.Uint8: if fieldType.Elem().Kind() == reflect.Uint8 { @@ -600,53 +608,12 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec dbTZ = col.TimeZone } - if rawValueType == schemas.TimeType { - t := vv.Convert(schemas.TimeType).Interface().(time.Time) - - z, _ := t.Zone() - // set new location if database don't save timezone or give an incorrect timezone - if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location - session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", col.Name, t, z, *t.Location()) - t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), - t.Minute(), t.Second(), t.Nanosecond(), dbTZ) - } - - t = t.In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type || - rawValueType == schemas.Int32Type { - t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } else { - if d, ok := vv.Interface().([]uint8); ok { - t, err := session.byte2Time(col, d) - if err != nil { - session.engine.logger.Errorf("byte2Time error: %v", err) - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } - - } else if d, ok := vv.Interface().(string); ok { - t, err := session.str2Time(col, d) - if err != nil { - session.engine.logger.Errorf("byte2Time error: %v", err) - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } - } else { - return fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) - } + tm, err := asTime(scanResult, dbTZ, session.engine.TZLocation) + if err != nil { + return err } - } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err := nulVal.Scan(vv.Interface()) - if err == nil { - return nil - } - session.engine.logger.Errorf("sql.Sanner error: %v", err) + fieldValue.Set(reflect.ValueOf(*tm).Convert(fieldType)) + return nil } else if session.statement.UseCascade { table, err := session.engine.tagParser.ParseWithCache(*fieldValue) if err != nil { @@ -679,7 +646,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } return nil } - return session.setJSON(fieldValue, fieldType, vv, rawValueType) + return session.setJSON(fieldValue, fieldType, scanResult) } // switch fieldType.Kind() return convertAssignV(fieldValue.Addr(), scanResult, session.engine.DatabaseTZ, session.engine.TZLocation) diff --git a/session_raw.go b/session_raw.go index bf32c6ed..7e9ff52b 100644 --- a/session_raw.go +++ b/session_raw.go @@ -96,14 +96,14 @@ func value2String(rawValue *reflect.Value) (str string, err error) { str = "0" } default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + err = fmt.Errorf("Unsupported struct type %v as array", vv.Type().Name()) } // time type case reflect.Struct: if aa.ConvertibleTo(schemas.TimeType) { str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) } else { - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + err = fmt.Errorf("Unsupported struct type %v as struct", vv.Type().Name()) } case reflect.Bool: str = strconv.FormatBool(vv.Bool()) @@ -117,7 +117,7 @@ func value2String(rawValue *reflect.Value) (str string, err error) { case reflect.Chan, reflect.Func, reflect.Interface: */ default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + err = fmt.Errorf("Unsupported struct type %v as %v", vv.Type().Name(), aa.Kind()) } return }