From a7b647930996d0f7399eb6b7f7bfe03cc6a9aad5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 18 Jul 2021 08:59:05 +0800 Subject: [PATCH] refactor more --- convert.go | 59 +++++++++++++++++++++ convert/time.go | 7 +++ rows.go | 6 ++- scan.go | 9 +++- session.go | 127 ++++++++++----------------------------------- session_convert.go | 4 -- session_find.go | 7 ++- session_get.go | 2 +- 8 files changed, 114 insertions(+), 107 deletions(-) diff --git a/convert.go b/convert.go index 7d30ec62..bab2c67b 100644 --- a/convert.go +++ b/convert.go @@ -318,6 +318,65 @@ func asBytes(src interface{}) ([]byte, bool) { return nil, false } +func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.Time, error) { + switch t := src.(type) { + case string: + return convert.String2Time(t, dbLoc, uiLoc) + case *sql.NullString: + if !t.Valid { + return nil, nil + } + return convert.String2Time(t.String, dbLoc, uiLoc) + case []uint8: + if t == nil { + return nil, nil + } + fmt.Printf("====== %#v,,%v,,%v\n", string(t), dbLoc.String(), uiLoc.String()) + return convert.String2Time(string(t), dbLoc, uiLoc) + case *sql.NullTime: + if !t.Valid { + return nil, nil + } + z, _ := t.Time.Zone() + if len(z) == 0 || t.Time.Year() == 0 || t.Time.Location().String() != dbLoc.String() { + tm := time.Date(t.Time.Year(), t.Time.Month(), t.Time.Day(), t.Time.Hour(), + t.Time.Minute(), t.Time.Second(), t.Time.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.Time.In(uiLoc) + return &tm, nil + case *time.Time: + z, _ := t.Zone() + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { + tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.In(uiLoc) + return &tm, nil + case time.Time: + z, _ := t.Zone() + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { + tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.In(uiLoc) + return &tm, nil + case int: + tm := time.Unix(int64(t), 0).In(uiLoc) + return &tm, nil + case int64: + tm := time.Unix(t, 0).In(uiLoc) + return &tm, nil + case *sql.NullInt64: + tm := time.Unix(t.Int64, 0).In(uiLoc) + return &tm, nil + + } + return nil, fmt.Errorf("unsupported value %#v as time", src) +} + // convertAssign copies to dest the value in src, converting it if possible. // An error is returned if the copy would result in loss of information. // dest should be a pointer type. diff --git a/convert/time.go b/convert/time.go index a9936b75..5a3e5246 100644 --- a/convert/time.go +++ b/convert/time.go @@ -6,6 +6,7 @@ package convert import ( "fmt" + "strconv" "time" ) @@ -32,6 +33,12 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t } dt = dt.In(convertedLocation) return &dt, nil + } else { + i, err := strconv.ParseInt(s, 10, 64) + if err == nil { + tm := time.Unix(i, 0).In(convertedLocation) + return &tm, nil + } } return nil, fmt.Errorf("unsupported convertion from %s to time", s) } diff --git a/rows.go b/rows.go index a56ea1c9..5e0a1ffe 100644 --- a/rows.go +++ b/rows.go @@ -129,8 +129,12 @@ func (rows *Rows) Scan(bean interface{}) error { if err != nil { return err } + types, err := rows.rows.ColumnTypes() + if err != nil { + return err + } - scanResults, err := rows.session.row2Slice(rows.rows, fields, bean) + scanResults, err := rows.session.row2Slice(rows.rows, fields, types, bean) if err != nil { return err } diff --git a/scan.go b/scan.go index d8a1ac3d..4bf609db 100644 --- a/scan.go +++ b/scan.go @@ -20,6 +20,8 @@ import ( // genScanResultsByBeanNullabale generates scan result func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { switch t := bean.(type) { + case *interface{}: + return t, false, nil case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: return t, false, nil case *time.Time: @@ -71,6 +73,8 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { switch t := bean.(type) { + case *interface{}: + return t, false, nil case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes, *string, @@ -194,7 +198,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column var scanResults = make([]interface{}, 0, len(types)) var replaces = make([]bool, 0, len(types)) var err error - for _, v := range vv { + for i, v := range vv { var replaced bool var scanResult interface{} switch t := v.(type) { @@ -222,6 +226,8 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column } } + fmt.Printf("----- %v ----- %#v\n", fields[i], scanResult) + scanResults = append(scanResults, scanResult) replaces = append(replaces, replaced) } @@ -235,6 +241,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column for i, replaced := range replaces { if replaced { + fmt.Printf("===== %v %#v\n", fields[i], scanResults[i]) if err = convertAssign(vv[i], scanResults[i], engine.DatabaseTZ, engine.TZLocation); err != nil { return err } diff --git a/session.go b/session.go index 806feed3..5557d717 100644 --- a/session.go +++ b/session.go @@ -16,7 +16,6 @@ import ( "io" "reflect" "strings" - "time" "xorm.io/xorm/contexts" "xorm.io/xorm/convert" @@ -389,7 +388,7 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s // Cell cell is a result of one column field type Cell *interface{} -func (session *Session) rows2Beans(rows *core.Rows, fields []string, +func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sql.ColumnType, table *schemas.Table, newElemFunc func([]string) reflect.Value, sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { for rows.Next() { @@ -398,7 +397,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, dataStruct := newValue.Elem() // handle beforeClosures - scanResults, err := session.row2Slice(rows, fields, bean) + scanResults, err := session.row2Slice(rows, fields, types, bean) if err != nil { return err } @@ -417,7 +416,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, return nil } -func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) { +func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql.ColumnType, bean interface{}) ([]interface{}, error) { for _, closure := range session.beforeClosures { closure(bean) } @@ -427,7 +426,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa var cell interface{} scanResults[i] = &cell } - if err := rows.Scan(scanResults...); err != nil { + if err := session.engine.scan(rows, fields, types, scanResults...); err != nil { return nil, err } @@ -555,64 +554,28 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } return nil case reflect.Slice, reflect.Array: - switch rawValueType.Kind() { - case reflect.Slice, reflect.Array: - switch rawValueType.Elem().Kind() { - case reflect.Uint8: - if fieldType.Elem().Kind() == reflect.Uint8 { - if col.SQLType.IsText() { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } else { - if fieldValue.Len() > 0 { - for i := 0; i < fieldValue.Len(); i++ { - if i < vv.Len() { - fieldValue.Index(i).Set(vv.Index(i)) - } - } - } else { - for i := 0; i < vv.Len(); i++ { - fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) - } + bs, ok := asBytes(scanResult) + if ok && fieldType.Elem().Kind() == reflect.Uint8 { + if col.SQLType.IsText() { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } else { + if fieldValue.Len() > 0 { + for i := 0; i < fieldValue.Len(); i++ { + if i < vv.Len() { + fieldValue.Index(i).Set(vv.Index(i)) } } - return nil + } else { + for i := 0; i < vv.Len(); i++ { + fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) + } } } - } - case reflect.String: - if rawValueType.Kind() == reflect.String { - fieldValue.SetString(vv.String()) - return nil - } - case reflect.Bool: - if rawValueType.Kind() == reflect.Bool { - fieldValue.SetBool(vv.Bool()) - return nil - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch rawValueType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fieldValue.SetInt(vv.Int()) - return nil - } - case reflect.Float32, reflect.Float64: - switch rawValueType.Kind() { - case reflect.Float32, reflect.Float64: - fieldValue.SetFloat(vv.Float()) - return nil - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - switch rawValueType.Kind() { - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - fieldValue.SetUint(vv.Uint()) - return nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fieldValue.SetUint(uint64(vv.Int())) return nil } case reflect.Struct: @@ -631,47 +594,13 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec dbTZ = col.TimeZone } - if rawValueType == schemas.TimeType { - t := vv.Convert(schemas.TimeType).Interface().(time.Time) - - z, _ := t.Zone() - // set new location if database don't save timezone or give an incorrect timezone - if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location - session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", col.Name, t, z, *t.Location()) - t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), - t.Minute(), t.Second(), t.Nanosecond(), dbTZ) - } - - t = t.In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type || - rawValueType == schemas.Int32Type { - t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } else { - if d, ok := vv.Interface().([]uint8); ok { - t, err := session.byte2Time(col, d) - if err != nil { - session.engine.logger.Errorf("byte2Time error: %v", err) - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } - - } else if d, ok := vv.Interface().(string); ok { - t, err := session.str2Time(col, d) - if err != nil { - session.engine.logger.Errorf("byte2Time error: %v", err) - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } - } else { - return fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) - } + t, err := asTime(scanResult, dbTZ, session.engine.TZLocation) + if err != nil { + return err } + + fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType)) + return nil } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { err := nulVal.Scan(vv.Interface()) if err == nil { diff --git a/session_convert.go b/session_convert.go index ceeae44c..452801e2 100644 --- a/session_convert.go +++ b/session_convert.go @@ -68,7 +68,3 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time outTime = x.In(session.engine.TZLocation) return } - -func (session *Session) byte2Time(col *schemas.Column, data []byte) (outTime time.Time, outErr error) { - return session.str2Time(col, string(data)) -} diff --git a/session_find.go b/session_find.go index 261e6b7f..41d68479 100644 --- a/session_find.go +++ b/session_find.go @@ -172,6 +172,11 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return err } + types, err := rows.ColumnTypes() + if err != nil { + return err + } + var newElemFunc func(fields []string) reflect.Value elemType := containerValue.Type().Elem() var isPointer bool @@ -241,7 +246,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect if err != nil { return err } - err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc) + err = session.rows2Beans(rows, fields, types, tb, newElemFunc, containerValueSetFunc) rows.Close() if err != nil { return err diff --git a/session_get.go b/session_get.go index 96b1ee87..fa97e68e 100644 --- a/session_get.go +++ b/session_get.go @@ -268,7 +268,7 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields } func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { - scanResults, err := session.row2Slice(rows, fields, bean) + scanResults, err := session.row2Slice(rows, fields, types, bean) if err != nil { return false, err }