refactor more
This commit is contained in:
parent
b1a98491f5
commit
a7b6479309
59
convert.go
59
convert.go
|
@ -318,6 +318,65 @@ func asBytes(src interface{}) ([]byte, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.Time, error) {
|
||||
switch t := src.(type) {
|
||||
case string:
|
||||
return convert.String2Time(t, dbLoc, uiLoc)
|
||||
case *sql.NullString:
|
||||
if !t.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return convert.String2Time(t.String, dbLoc, uiLoc)
|
||||
case []uint8:
|
||||
if t == nil {
|
||||
return nil, nil
|
||||
}
|
||||
fmt.Printf("====== %#v,,%v,,%v\n", string(t), dbLoc.String(), uiLoc.String())
|
||||
return convert.String2Time(string(t), dbLoc, uiLoc)
|
||||
case *sql.NullTime:
|
||||
if !t.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
z, _ := t.Time.Zone()
|
||||
if len(z) == 0 || t.Time.Year() == 0 || t.Time.Location().String() != dbLoc.String() {
|
||||
tm := time.Date(t.Time.Year(), t.Time.Month(), t.Time.Day(), t.Time.Hour(),
|
||||
t.Time.Minute(), t.Time.Second(), t.Time.Nanosecond(), dbLoc).In(uiLoc)
|
||||
return &tm, nil
|
||||
}
|
||||
tm := t.Time.In(uiLoc)
|
||||
return &tm, nil
|
||||
case *time.Time:
|
||||
z, _ := t.Zone()
|
||||
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() {
|
||||
tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
|
||||
t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc)
|
||||
return &tm, nil
|
||||
}
|
||||
tm := t.In(uiLoc)
|
||||
return &tm, nil
|
||||
case time.Time:
|
||||
z, _ := t.Zone()
|
||||
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() {
|
||||
tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
|
||||
t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc)
|
||||
return &tm, nil
|
||||
}
|
||||
tm := t.In(uiLoc)
|
||||
return &tm, nil
|
||||
case int:
|
||||
tm := time.Unix(int64(t), 0).In(uiLoc)
|
||||
return &tm, nil
|
||||
case int64:
|
||||
tm := time.Unix(t, 0).In(uiLoc)
|
||||
return &tm, nil
|
||||
case *sql.NullInt64:
|
||||
tm := time.Unix(t.Int64, 0).In(uiLoc)
|
||||
return &tm, nil
|
||||
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported value %#v as time", src)
|
||||
}
|
||||
|
||||
// convertAssign copies to dest the value in src, converting it if possible.
|
||||
// An error is returned if the copy would result in loss of information.
|
||||
// dest should be a pointer type.
|
||||
|
|
|
@ -6,6 +6,7 @@ package convert
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -32,6 +33,12 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t
|
|||
}
|
||||
dt = dt.In(convertedLocation)
|
||||
return &dt, nil
|
||||
} else {
|
||||
i, err := strconv.ParseInt(s, 10, 64)
|
||||
if err == nil {
|
||||
tm := time.Unix(i, 0).In(convertedLocation)
|
||||
return &tm, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported convertion from %s to time", s)
|
||||
}
|
||||
|
|
6
rows.go
6
rows.go
|
@ -129,8 +129,12 @@ func (rows *Rows) Scan(bean interface{}) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
types, err := rows.rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scanResults, err := rows.session.row2Slice(rows.rows, fields, bean)
|
||||
scanResults, err := rows.session.row2Slice(rows.rows, fields, types, bean)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
9
scan.go
9
scan.go
|
@ -20,6 +20,8 @@ import (
|
|||
// genScanResultsByBeanNullabale generates scan result
|
||||
func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) {
|
||||
switch t := bean.(type) {
|
||||
case *interface{}:
|
||||
return t, false, nil
|
||||
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes:
|
||||
return t, false, nil
|
||||
case *time.Time:
|
||||
|
@ -71,6 +73,8 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) {
|
|||
|
||||
func genScanResultsByBean(bean interface{}) (interface{}, bool, error) {
|
||||
switch t := bean.(type) {
|
||||
case *interface{}:
|
||||
return t, false, nil
|
||||
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString,
|
||||
*sql.RawBytes,
|
||||
*string,
|
||||
|
@ -194,7 +198,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
|
|||
var scanResults = make([]interface{}, 0, len(types))
|
||||
var replaces = make([]bool, 0, len(types))
|
||||
var err error
|
||||
for _, v := range vv {
|
||||
for i, v := range vv {
|
||||
var replaced bool
|
||||
var scanResult interface{}
|
||||
switch t := v.(type) {
|
||||
|
@ -222,6 +226,8 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
|
|||
}
|
||||
}
|
||||
|
||||
fmt.Printf("----- %v ----- %#v\n", fields[i], scanResult)
|
||||
|
||||
scanResults = append(scanResults, scanResult)
|
||||
replaces = append(replaces, replaced)
|
||||
}
|
||||
|
@ -235,6 +241,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
|
|||
|
||||
for i, replaced := range replaces {
|
||||
if replaced {
|
||||
fmt.Printf("===== %v %#v\n", fields[i], scanResults[i])
|
||||
if err = convertAssign(vv[i], scanResults[i], engine.DatabaseTZ, engine.TZLocation); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
91
session.go
91
session.go
|
@ -16,7 +16,6 @@ import (
|
|||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"xorm.io/xorm/contexts"
|
||||
"xorm.io/xorm/convert"
|
||||
|
@ -389,7 +388,7 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s
|
|||
// Cell cell is a result of one column field
|
||||
type Cell *interface{}
|
||||
|
||||
func (session *Session) rows2Beans(rows *core.Rows, fields []string,
|
||||
func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sql.ColumnType,
|
||||
table *schemas.Table, newElemFunc func([]string) reflect.Value,
|
||||
sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error {
|
||||
for rows.Next() {
|
||||
|
@ -398,7 +397,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string,
|
|||
dataStruct := newValue.Elem()
|
||||
|
||||
// handle beforeClosures
|
||||
scanResults, err := session.row2Slice(rows, fields, bean)
|
||||
scanResults, err := session.row2Slice(rows, fields, types, bean)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -417,7 +416,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) {
|
||||
func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql.ColumnType, bean interface{}) ([]interface{}, error) {
|
||||
for _, closure := range session.beforeClosures {
|
||||
closure(bean)
|
||||
}
|
||||
|
@ -427,7 +426,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa
|
|||
var cell interface{}
|
||||
scanResults[i] = &cell
|
||||
}
|
||||
if err := rows.Scan(scanResults...); err != nil {
|
||||
if err := session.engine.scan(rows, fields, types, scanResults...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -555,14 +554,11 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
|
|||
}
|
||||
return nil
|
||||
case reflect.Slice, reflect.Array:
|
||||
switch rawValueType.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
switch rawValueType.Elem().Kind() {
|
||||
case reflect.Uint8:
|
||||
if fieldType.Elem().Kind() == reflect.Uint8 {
|
||||
bs, ok := asBytes(scanResult)
|
||||
if ok && fieldType.Elem().Kind() == reflect.Uint8 {
|
||||
if col.SQLType.IsText() {
|
||||
x := reflect.New(fieldType)
|
||||
err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface())
|
||||
err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -582,39 +578,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
|
|||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.String:
|
||||
if rawValueType.Kind() == reflect.String {
|
||||
fieldValue.SetString(vv.String())
|
||||
return nil
|
||||
}
|
||||
case reflect.Bool:
|
||||
if rawValueType.Kind() == reflect.Bool {
|
||||
fieldValue.SetBool(vv.Bool())
|
||||
return nil
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
switch rawValueType.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
fieldValue.SetInt(vv.Int())
|
||||
return nil
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
switch rawValueType.Kind() {
|
||||
case reflect.Float32, reflect.Float64:
|
||||
fieldValue.SetFloat(vv.Float())
|
||||
return nil
|
||||
}
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
switch rawValueType.Kind() {
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
fieldValue.SetUint(vv.Uint())
|
||||
return nil
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
fieldValue.SetUint(uint64(vv.Int()))
|
||||
return nil
|
||||
}
|
||||
case reflect.Struct:
|
||||
if fieldType.ConvertibleTo(schemas.BigFloatType) {
|
||||
v, err := asBigFloat(scanResult)
|
||||
|
@ -631,47 +594,13 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
|
|||
dbTZ = col.TimeZone
|
||||
}
|
||||
|
||||
if rawValueType == schemas.TimeType {
|
||||
t := vv.Convert(schemas.TimeType).Interface().(time.Time)
|
||||
|
||||
z, _ := t.Zone()
|
||||
// set new location if database don't save timezone or give an incorrect timezone
|
||||
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
|
||||
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", col.Name, t, z, *t.Location())
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
|
||||
t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
|
||||
}
|
||||
|
||||
t = t.In(session.engine.TZLocation)
|
||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
|
||||
return nil
|
||||
} else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type ||
|
||||
rawValueType == schemas.Int32Type {
|
||||
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
|
||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
|
||||
return nil
|
||||
} else {
|
||||
if d, ok := vv.Interface().([]uint8); ok {
|
||||
t, err := session.byte2Time(col, d)
|
||||
t, err := asTime(scanResult, dbTZ, session.engine.TZLocation)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("byte2Time error: %v", err)
|
||||
} else {
|
||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
} else if d, ok := vv.Interface().(string); ok {
|
||||
t, err := session.str2Time(col, d)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("byte2Time error: %v", err)
|
||||
} else {
|
||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
|
||||
fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType))
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface())
|
||||
}
|
||||
}
|
||||
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
||||
err := nulVal.Scan(vv.Interface())
|
||||
if err == nil {
|
||||
|
|
|
@ -68,7 +68,3 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time
|
|||
outTime = x.In(session.engine.TZLocation)
|
||||
return
|
||||
}
|
||||
|
||||
func (session *Session) byte2Time(col *schemas.Column, data []byte) (outTime time.Time, outErr error) {
|
||||
return session.str2Time(col, string(data))
|
||||
}
|
||||
|
|
|
@ -172,6 +172,11 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
|
|||
return err
|
||||
}
|
||||
|
||||
types, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var newElemFunc func(fields []string) reflect.Value
|
||||
elemType := containerValue.Type().Elem()
|
||||
var isPointer bool
|
||||
|
@ -241,7 +246,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc)
|
||||
err = session.rows2Beans(rows, fields, types, tb, newElemFunc, containerValueSetFunc)
|
||||
rows.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -268,7 +268,7 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields
|
|||
}
|
||||
|
||||
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {
|
||||
scanResults, err := session.row2Slice(rows, fields, bean)
|
||||
scanResults, err := session.row2Slice(rows, fields, types, bean)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue