refactor more

This commit is contained in:
Lunny Xiao 2021-07-16 13:21:55 +08:00
parent 754998d8fc
commit 27ff0fd873
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
5 changed files with 118 additions and 89 deletions

View File

@ -15,6 +15,7 @@ import (
"time"
"xorm.io/xorm/convert"
"xorm.io/xorm/schemas"
)
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
@ -192,6 +193,8 @@ func asFloat64(src interface{}) (float64, error) {
return float64(v.Int32), nil
case *sql.NullInt64:
return float64(v.Int64), nil
case *sql.NullFloat64:
return v.Float64, nil
}
rv := reflect.ValueOf(src)
@ -208,6 +211,42 @@ func asFloat64(src interface{}) (float64, error) {
return 0, fmt.Errorf("unsupported value %T as int64", src)
}
func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.Time, error) {
switch t := src.(type) {
case string:
return convert.String2Time(t, dbLoc, uiLoc)
case *sql.NullString:
if !t.Valid {
return nil, nil
}
return convert.String2Time(t.String, dbLoc, uiLoc)
case []uint8:
if t == nil {
return nil, nil
}
return convert.String2Time(string(t), dbLoc, uiLoc)
case *sql.NullTime:
tm := t.Time
return &tm, nil
case *time.Time:
z, _ := t.Zone()
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() {
tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc)
return &tm, nil
}
tm := t.In(uiLoc)
return &tm, nil
case int:
tm := time.Unix(int64(t), 0).In(uiLoc)
return &tm, nil
case int64:
tm := time.Unix(t, 0).In(uiLoc)
return &tm, nil
}
return nil, fmt.Errorf("unsupported value %#v as time", src)
}
func asBigFloat(src interface{}) (*big.Float, error) {
res := big.NewFloat(0)
switch v := src.(type) {
@ -285,23 +324,33 @@ func asBigFloat(src interface{}) (*big.Float, error) {
return nil, fmt.Errorf("unsupported value %T as big.Float", src)
}
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
func asBytes(src interface{}) ([]byte, bool) {
switch t := src.(type) {
case []byte:
return t, true
case *sql.NullString:
return []byte(t.String), true
case *sql.RawBytes:
return *t, true
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.AppendInt(buf, rv.Int(), 10), true
return strconv.AppendInt(nil, rv.Int(), 10), true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.AppendUint(buf, rv.Uint(), 10), true
return strconv.AppendUint(nil, rv.Uint(), 10), true
case reflect.Float32:
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 32), true
case reflect.Float64:
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 64), true
case reflect.Bool:
return strconv.AppendBool(buf, rv.Bool()), true
return strconv.AppendBool(nil, rv.Bool()), true
case reflect.String:
s := rv.String()
return append(buf, s...), true
return []byte(s), true
}
return
return nil, false
}
// convertAssign copies to dest the value in src, converting it if possible.
@ -559,8 +608,7 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve
return nil
}
case *[]byte:
sv = reflect.ValueOf(src)
if b, ok := asBytes(nil, sv); ok {
if b, ok := asBytes(src); ok {
*d = b
return nil
}
@ -678,6 +726,8 @@ func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, conver
func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
switch tp.Kind() {
case reflect.Ptr:
return asKind(vv.Elem(), tp.Elem())
case reflect.Int64:
return vv.Int(), nil
case reflect.Int:
@ -708,7 +758,11 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
}
return v, nil
}
case reflect.Struct:
if vv.Type().ConvertibleTo(schemas.NullInt64Type) {
r := vv.Convert(schemas.NullInt64Type)
return r.Interface().(sql.NullInt64).Int64, nil
}
}
return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv)
}
@ -743,6 +797,10 @@ func asBool(src interface{}) (bool, error) {
return strconv.ParseBool(string(v))
case string:
return strconv.ParseBool(v)
case *sql.NullInt64:
return v.Int64 > 0, nil
case *sql.NullInt32:
return v.Int32 > 0, nil
default:
return false, fmt.Errorf("unknow type %T as bool", src)
}

View File

@ -94,7 +94,7 @@ func executeBeforeClosures(session *Session, bean interface{}) {
func executeBeforeSet(bean interface{}, fields []string, scanResults []interface{}) {
if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet {
for ii, key := range fields {
b.BeforeSet(key, Cell(scanResults[ii].(*interface{})))
b.BeforeSet(key, Cell(scanResults[ii]))
}
}
}
@ -102,7 +102,7 @@ func executeBeforeSet(bean interface{}, fields []string, scanResults []interface
func executeAfterSet(bean interface{}, fields []string, scanResults []interface{}) {
if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet {
for ii, key := range fields {
b.AfterSet(key, Cell(scanResults[ii].(*interface{})))
b.AfterSet(key, Cell(scanResults[ii]))
}
}
}

View File

@ -5,6 +5,7 @@
package schemas
import (
"database/sql"
"math/big"
"reflect"
"sort"
@ -248,6 +249,7 @@ var (
uintDefault uint
timeDefault time.Time
bigFloatDefault big.Float
nullInt64Default sql.NullInt64
)
// enumerates all types
@ -277,6 +279,8 @@ var (
TimeType = reflect.TypeOf(timeDefault)
BigFloatType = reflect.TypeOf(bigFloatDefault)
NullInt64Type = reflect.TypeOf(nullInt64Default)
)
// enumerates all types

View File

@ -16,7 +16,6 @@ import (
"io"
"reflect"
"strings"
"time"
"xorm.io/xorm/contexts"
"xorm.io/xorm/convert"
@ -387,7 +386,7 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s
}
// Cell cell is a result of one column field
type Cell *interface{}
type Cell interface{}
func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fields []string,
table *schemas.Table, newElemFunc func([]string) reflect.Value,
@ -439,14 +438,17 @@ func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fiel
return scanResults, nil
}
func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, vv reflect.Value, rawValueType reflect.Type) error {
func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, scanResult interface{}) error {
var bs []byte
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(schemas.BytesType) {
bs = vv.Bytes()
} else {
return fmt.Errorf("unsupported database data type: %v", rawValueType.Kind())
switch t := scanResult.(type) {
case string:
bs = []byte(t)
case []byte:
bs = t
case *sql.NullString:
bs = []byte(t.String)
default:
return fmt.Errorf("unsupported database data type: %#v", scanResult)
}
if len(bs) > 0 {
@ -487,26 +489,33 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
}
if fieldValue.CanAddr() {
if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
return scanner.Scan(scanResult)
}
if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
data, err := value2Bytes(&rawValue)
if err != nil {
return err
data, ok := asBytes(scanResult)
if !ok {
return fmt.Errorf("cannot convert %#v as bytes", scanResult)
}
if err := structConvert.FromDB(data); err != nil {
return err
}
return nil
return structConvert.FromDB(data)
}
}
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil {
if scanner, ok := fieldValue.Interface().(sql.Scanner); ok {
return scanner.Scan(scanResult)
}
if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok {
data, ok := asBytes(scanResult)
if !ok {
return fmt.Errorf("cannot convert %#v as bytes", scanResult)
}
if data != nil {
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
return fieldValue.Interface().(convert.Conversion).FromDB(data)
}
fieldValue.Interface().(convert.Conversion).FromDB(data)
} else {
return err
return structConvert.FromDB(data)
}
return nil
}
@ -516,7 +525,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
fieldType := fieldValue.Type()
if col.IsJSON {
return session.setJSON(fieldValue, fieldType, vv, rawValueType)
return session.setJSON(fieldValue, fieldType, scanResult)
}
switch fieldType.Kind() {
@ -535,13 +544,13 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
}
return nil
case reflect.Complex64, reflect.Complex128:
return session.setJSON(fieldValue, fieldType, vv, rawValueType)
return session.setJSON(fieldValue, fieldType, scanResult)
case reflect.Map:
switch rawValueType.Kind() {
case reflect.String:
return session.setJSON(fieldValue, fieldType, vv, rawValueType)
return session.setJSON(fieldValue, fieldType, scanResult)
case reflect.Slice:
return session.setJSON(fieldValue, fieldType, vv, rawValueType)
return session.setJSON(fieldValue, fieldType, scanResult)
default:
return fmt.Errorf("unsupported %v -> %T", scanResult, fieldType)
}
@ -556,7 +565,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
fieldValue.Set(x.Elem())
return nil
case reflect.Slice, reflect.Array:
fmt.Printf("======%T\n", scanResult)
switch rawValueType.Elem().Kind() {
case reflect.Uint8:
if fieldType.Elem().Kind() == reflect.Uint8 {
@ -600,53 +608,12 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
dbTZ = col.TimeZone
}
if rawValueType == schemas.TimeType {
t := vv.Convert(schemas.TimeType).Interface().(time.Time)
z, _ := t.Zone()
// set new location if database don't save timezone or give an incorrect timezone
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", col.Name, t, z, *t.Location())
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
}
t = t.In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
return nil
} else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type ||
rawValueType == schemas.Int32Type {
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
return nil
} else {
if d, ok := vv.Interface().([]uint8); ok {
t, err := session.byte2Time(col, d)
tm, err := asTime(scanResult, dbTZ, session.engine.TZLocation)
if err != nil {
session.engine.logger.Errorf("byte2Time error: %v", err)
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
return err
}
fieldValue.Set(reflect.ValueOf(*tm).Convert(fieldType))
return nil
}
} else if d, ok := vv.Interface().(string); ok {
t, err := session.str2Time(col, d)
if err != nil {
session.engine.logger.Errorf("byte2Time error: %v", err)
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
return nil
}
} else {
return fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface())
}
}
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
err := nulVal.Scan(vv.Interface())
if err == nil {
return nil
}
session.engine.logger.Errorf("sql.Sanner error: %v", err)
} else if session.statement.UseCascade {
table, err := session.engine.tagParser.ParseWithCache(*fieldValue)
if err != nil {
@ -679,7 +646,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
}
return nil
}
return session.setJSON(fieldValue, fieldType, vv, rawValueType)
return session.setJSON(fieldValue, fieldType, scanResult)
} // switch fieldType.Kind()
return convertAssignV(fieldValue.Addr(), scanResult, session.engine.DatabaseTZ, session.engine.TZLocation)

View File

@ -96,14 +96,14 @@ func value2String(rawValue *reflect.Value) (str string, err error) {
str = "0"
}
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
err = fmt.Errorf("Unsupported struct type %v as array", vv.Type().Name())
}
// time type
case reflect.Struct:
if aa.ConvertibleTo(schemas.TimeType) {
str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
} else {
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
err = fmt.Errorf("Unsupported struct type %v as struct", vv.Type().Name())
}
case reflect.Bool:
str = strconv.FormatBool(vv.Bool())
@ -117,7 +117,7 @@ func value2String(rawValue *reflect.Value) (str string, err error) {
case reflect.Chan, reflect.Func, reflect.Interface:
*/
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
err = fmt.Errorf("Unsupported struct type %v as %v", vv.Type().Name(), aa.Kind())
}
return
}