refactor more

This commit is contained in:
Lunny Xiao 2021-07-18 08:59:05 +08:00
parent b1a98491f5
commit a7b6479309
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
8 changed files with 114 additions and 107 deletions

View File

@ -318,6 +318,65 @@ func asBytes(src interface{}) ([]byte, bool) {
return nil, false return nil, false
} }
func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.Time, error) {
switch t := src.(type) {
case string:
return convert.String2Time(t, dbLoc, uiLoc)
case *sql.NullString:
if !t.Valid {
return nil, nil
}
return convert.String2Time(t.String, dbLoc, uiLoc)
case []uint8:
if t == nil {
return nil, nil
}
fmt.Printf("====== %#v,,%v,,%v\n", string(t), dbLoc.String(), uiLoc.String())
return convert.String2Time(string(t), dbLoc, uiLoc)
case *sql.NullTime:
if !t.Valid {
return nil, nil
}
z, _ := t.Time.Zone()
if len(z) == 0 || t.Time.Year() == 0 || t.Time.Location().String() != dbLoc.String() {
tm := time.Date(t.Time.Year(), t.Time.Month(), t.Time.Day(), t.Time.Hour(),
t.Time.Minute(), t.Time.Second(), t.Time.Nanosecond(), dbLoc).In(uiLoc)
return &tm, nil
}
tm := t.Time.In(uiLoc)
return &tm, nil
case *time.Time:
z, _ := t.Zone()
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() {
tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc)
return &tm, nil
}
tm := t.In(uiLoc)
return &tm, nil
case time.Time:
z, _ := t.Zone()
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() {
tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc)
return &tm, nil
}
tm := t.In(uiLoc)
return &tm, nil
case int:
tm := time.Unix(int64(t), 0).In(uiLoc)
return &tm, nil
case int64:
tm := time.Unix(t, 0).In(uiLoc)
return &tm, nil
case *sql.NullInt64:
tm := time.Unix(t.Int64, 0).In(uiLoc)
return &tm, nil
}
return nil, fmt.Errorf("unsupported value %#v as time", src)
}
// convertAssign copies to dest the value in src, converting it if possible. // convertAssign copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information. // An error is returned if the copy would result in loss of information.
// dest should be a pointer type. // dest should be a pointer type.

View File

@ -6,6 +6,7 @@ package convert
import ( import (
"fmt" "fmt"
"strconv"
"time" "time"
) )
@ -32,6 +33,12 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t
} }
dt = dt.In(convertedLocation) dt = dt.In(convertedLocation)
return &dt, nil return &dt, nil
} else {
i, err := strconv.ParseInt(s, 10, 64)
if err == nil {
tm := time.Unix(i, 0).In(convertedLocation)
return &tm, nil
}
} }
return nil, fmt.Errorf("unsupported convertion from %s to time", s) return nil, fmt.Errorf("unsupported convertion from %s to time", s)
} }

View File

@ -129,8 +129,12 @@ func (rows *Rows) Scan(bean interface{}) error {
if err != nil { if err != nil {
return err return err
} }
types, err := rows.rows.ColumnTypes()
if err != nil {
return err
}
scanResults, err := rows.session.row2Slice(rows.rows, fields, bean) scanResults, err := rows.session.row2Slice(rows.rows, fields, types, bean)
if err != nil { if err != nil {
return err return err
} }

View File

@ -20,6 +20,8 @@ import (
// genScanResultsByBeanNullabale generates scan result // genScanResultsByBeanNullabale generates scan result
func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) {
switch t := bean.(type) { switch t := bean.(type) {
case *interface{}:
return t, false, nil
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes:
return t, false, nil return t, false, nil
case *time.Time: case *time.Time:
@ -71,6 +73,8 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) {
func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { func genScanResultsByBean(bean interface{}) (interface{}, bool, error) {
switch t := bean.(type) { switch t := bean.(type) {
case *interface{}:
return t, false, nil
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString,
*sql.RawBytes, *sql.RawBytes,
*string, *string,
@ -194,7 +198,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
var scanResults = make([]interface{}, 0, len(types)) var scanResults = make([]interface{}, 0, len(types))
var replaces = make([]bool, 0, len(types)) var replaces = make([]bool, 0, len(types))
var err error var err error
for _, v := range vv { for i, v := range vv {
var replaced bool var replaced bool
var scanResult interface{} var scanResult interface{}
switch t := v.(type) { switch t := v.(type) {
@ -222,6 +226,8 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
} }
} }
fmt.Printf("----- %v ----- %#v\n", fields[i], scanResult)
scanResults = append(scanResults, scanResult) scanResults = append(scanResults, scanResult)
replaces = append(replaces, replaced) replaces = append(replaces, replaced)
} }
@ -235,6 +241,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
for i, replaced := range replaces { for i, replaced := range replaces {
if replaced { if replaced {
fmt.Printf("===== %v %#v\n", fields[i], scanResults[i])
if err = convertAssign(vv[i], scanResults[i], engine.DatabaseTZ, engine.TZLocation); err != nil { if err = convertAssign(vv[i], scanResults[i], engine.DatabaseTZ, engine.TZLocation); err != nil {
return err return err
} }

View File

@ -16,7 +16,6 @@ import (
"io" "io"
"reflect" "reflect"
"strings" "strings"
"time"
"xorm.io/xorm/contexts" "xorm.io/xorm/contexts"
"xorm.io/xorm/convert" "xorm.io/xorm/convert"
@ -389,7 +388,7 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s
// Cell cell is a result of one column field // Cell cell is a result of one column field
type Cell *interface{} type Cell *interface{}
func (session *Session) rows2Beans(rows *core.Rows, fields []string, func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sql.ColumnType,
table *schemas.Table, newElemFunc func([]string) reflect.Value, table *schemas.Table, newElemFunc func([]string) reflect.Value,
sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error {
for rows.Next() { for rows.Next() {
@ -398,7 +397,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string,
dataStruct := newValue.Elem() dataStruct := newValue.Elem()
// handle beforeClosures // handle beforeClosures
scanResults, err := session.row2Slice(rows, fields, bean) scanResults, err := session.row2Slice(rows, fields, types, bean)
if err != nil { if err != nil {
return err return err
} }
@ -417,7 +416,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string,
return nil return nil
} }
func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) { func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql.ColumnType, bean interface{}) ([]interface{}, error) {
for _, closure := range session.beforeClosures { for _, closure := range session.beforeClosures {
closure(bean) closure(bean)
} }
@ -427,7 +426,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa
var cell interface{} var cell interface{}
scanResults[i] = &cell scanResults[i] = &cell
} }
if err := rows.Scan(scanResults...); err != nil { if err := session.engine.scan(rows, fields, types, scanResults...); err != nil {
return nil, err return nil, err
} }
@ -555,14 +554,11 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
} }
return nil return nil
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
switch rawValueType.Kind() { bs, ok := asBytes(scanResult)
case reflect.Slice, reflect.Array: if ok && fieldType.Elem().Kind() == reflect.Uint8 {
switch rawValueType.Elem().Kind() {
case reflect.Uint8:
if fieldType.Elem().Kind() == reflect.Uint8 {
if col.SQLType.IsText() { if col.SQLType.IsText() {
x := reflect.New(fieldType) x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil { if err != nil {
return err return err
} }
@ -582,39 +578,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
} }
return nil return nil
} }
}
}
case reflect.String:
if rawValueType.Kind() == reflect.String {
fieldValue.SetString(vv.String())
return nil
}
case reflect.Bool:
if rawValueType.Kind() == reflect.Bool {
fieldValue.SetBool(vv.Bool())
return nil
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch rawValueType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fieldValue.SetInt(vv.Int())
return nil
}
case reflect.Float32, reflect.Float64:
switch rawValueType.Kind() {
case reflect.Float32, reflect.Float64:
fieldValue.SetFloat(vv.Float())
return nil
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
switch rawValueType.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
fieldValue.SetUint(vv.Uint())
return nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fieldValue.SetUint(uint64(vv.Int()))
return nil
}
case reflect.Struct: case reflect.Struct:
if fieldType.ConvertibleTo(schemas.BigFloatType) { if fieldType.ConvertibleTo(schemas.BigFloatType) {
v, err := asBigFloat(scanResult) v, err := asBigFloat(scanResult)
@ -631,47 +594,13 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
dbTZ = col.TimeZone dbTZ = col.TimeZone
} }
if rawValueType == schemas.TimeType { t, err := asTime(scanResult, dbTZ, session.engine.TZLocation)
t := vv.Convert(schemas.TimeType).Interface().(time.Time)
z, _ := t.Zone()
// set new location if database don't save timezone or give an incorrect timezone
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", col.Name, t, z, *t.Location())
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
}
t = t.In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
return nil
} else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type ||
rawValueType == schemas.Int32Type {
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
return nil
} else {
if d, ok := vv.Interface().([]uint8); ok {
t, err := session.byte2Time(col, d)
if err != nil { if err != nil {
session.engine.logger.Errorf("byte2Time error: %v", err) return err
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
return nil
} }
} else if d, ok := vv.Interface().(string); ok { fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType))
t, err := session.str2Time(col, d)
if err != nil {
session.engine.logger.Errorf("byte2Time error: %v", err)
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
return nil return nil
}
} else {
return fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface())
}
}
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
err := nulVal.Scan(vv.Interface()) err := nulVal.Scan(vv.Interface())
if err == nil { if err == nil {

View File

@ -68,7 +68,3 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time
outTime = x.In(session.engine.TZLocation) outTime = x.In(session.engine.TZLocation)
return return
} }
func (session *Session) byte2Time(col *schemas.Column, data []byte) (outTime time.Time, outErr error) {
return session.str2Time(col, string(data))
}

View File

@ -172,6 +172,11 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
return err return err
} }
types, err := rows.ColumnTypes()
if err != nil {
return err
}
var newElemFunc func(fields []string) reflect.Value var newElemFunc func(fields []string) reflect.Value
elemType := containerValue.Type().Elem() elemType := containerValue.Type().Elem()
var isPointer bool var isPointer bool
@ -241,7 +246,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
if err != nil { if err != nil {
return err return err
} }
err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc) err = session.rows2Beans(rows, fields, types, tb, newElemFunc, containerValueSetFunc)
rows.Close() rows.Close()
if err != nil { if err != nil {
return err return err

View File

@ -268,7 +268,7 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields
} }
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {
scanResults, err := session.row2Slice(rows, fields, bean) scanResults, err := session.row2Slice(rows, fields, types, bean)
if err != nil { if err != nil {
return false, err return false, err
} }