refactor more
This commit is contained in:
parent
754998d8fc
commit
27ff0fd873
80
convert.go
80
convert.go
|
@ -15,6 +15,7 @@ import (
|
|||
"time"
|
||||
|
||||
"xorm.io/xorm/convert"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
|
||||
|
@ -192,6 +193,8 @@ func asFloat64(src interface{}) (float64, error) {
|
|||
return float64(v.Int32), nil
|
||||
case *sql.NullInt64:
|
||||
return float64(v.Int64), nil
|
||||
case *sql.NullFloat64:
|
||||
return v.Float64, nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(src)
|
||||
|
@ -208,6 +211,42 @@ func asFloat64(src interface{}) (float64, error) {
|
|||
return 0, fmt.Errorf("unsupported value %T as int64", src)
|
||||
}
|
||||
|
||||
func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.Time, error) {
|
||||
switch t := src.(type) {
|
||||
case string:
|
||||
return convert.String2Time(t, dbLoc, uiLoc)
|
||||
case *sql.NullString:
|
||||
if !t.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return convert.String2Time(t.String, dbLoc, uiLoc)
|
||||
case []uint8:
|
||||
if t == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return convert.String2Time(string(t), dbLoc, uiLoc)
|
||||
case *sql.NullTime:
|
||||
tm := t.Time
|
||||
return &tm, nil
|
||||
case *time.Time:
|
||||
z, _ := t.Zone()
|
||||
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() {
|
||||
tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
|
||||
t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc)
|
||||
return &tm, nil
|
||||
}
|
||||
tm := t.In(uiLoc)
|
||||
return &tm, nil
|
||||
case int:
|
||||
tm := time.Unix(int64(t), 0).In(uiLoc)
|
||||
return &tm, nil
|
||||
case int64:
|
||||
tm := time.Unix(t, 0).In(uiLoc)
|
||||
return &tm, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported value %#v as time", src)
|
||||
}
|
||||
|
||||
func asBigFloat(src interface{}) (*big.Float, error) {
|
||||
res := big.NewFloat(0)
|
||||
switch v := src.(type) {
|
||||
|
@ -285,23 +324,33 @@ func asBigFloat(src interface{}) (*big.Float, error) {
|
|||
return nil, fmt.Errorf("unsupported value %T as big.Float", src)
|
||||
}
|
||||
|
||||
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
|
||||
func asBytes(src interface{}) ([]byte, bool) {
|
||||
switch t := src.(type) {
|
||||
case []byte:
|
||||
return t, true
|
||||
case *sql.NullString:
|
||||
return []byte(t.String), true
|
||||
case *sql.RawBytes:
|
||||
return *t, true
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(src)
|
||||
switch rv.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return strconv.AppendInt(buf, rv.Int(), 10), true
|
||||
return strconv.AppendInt(nil, rv.Int(), 10), true
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return strconv.AppendUint(buf, rv.Uint(), 10), true
|
||||
return strconv.AppendUint(nil, rv.Uint(), 10), true
|
||||
case reflect.Float32:
|
||||
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
|
||||
return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 32), true
|
||||
case reflect.Float64:
|
||||
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
|
||||
return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 64), true
|
||||
case reflect.Bool:
|
||||
return strconv.AppendBool(buf, rv.Bool()), true
|
||||
return strconv.AppendBool(nil, rv.Bool()), true
|
||||
case reflect.String:
|
||||
s := rv.String()
|
||||
return append(buf, s...), true
|
||||
return []byte(s), true
|
||||
}
|
||||
return
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// convertAssign copies to dest the value in src, converting it if possible.
|
||||
|
@ -559,8 +608,7 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve
|
|||
return nil
|
||||
}
|
||||
case *[]byte:
|
||||
sv = reflect.ValueOf(src)
|
||||
if b, ok := asBytes(nil, sv); ok {
|
||||
if b, ok := asBytes(src); ok {
|
||||
*d = b
|
||||
return nil
|
||||
}
|
||||
|
@ -678,6 +726,8 @@ func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, conver
|
|||
|
||||
func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
|
||||
switch tp.Kind() {
|
||||
case reflect.Ptr:
|
||||
return asKind(vv.Elem(), tp.Elem())
|
||||
case reflect.Int64:
|
||||
return vv.Int(), nil
|
||||
case reflect.Int:
|
||||
|
@ -708,7 +758,11 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
|
|||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
if vv.Type().ConvertibleTo(schemas.NullInt64Type) {
|
||||
r := vv.Convert(schemas.NullInt64Type)
|
||||
return r.Interface().(sql.NullInt64).Int64, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv)
|
||||
}
|
||||
|
@ -743,6 +797,10 @@ func asBool(src interface{}) (bool, error) {
|
|||
return strconv.ParseBool(string(v))
|
||||
case string:
|
||||
return strconv.ParseBool(v)
|
||||
case *sql.NullInt64:
|
||||
return v.Int64 > 0, nil
|
||||
case *sql.NullInt32:
|
||||
return v.Int32 > 0, nil
|
||||
default:
|
||||
return false, fmt.Errorf("unknow type %T as bool", src)
|
||||
}
|
||||
|
|
|
@ -94,7 +94,7 @@ func executeBeforeClosures(session *Session, bean interface{}) {
|
|||
func executeBeforeSet(bean interface{}, fields []string, scanResults []interface{}) {
|
||||
if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet {
|
||||
for ii, key := range fields {
|
||||
b.BeforeSet(key, Cell(scanResults[ii].(*interface{})))
|
||||
b.BeforeSet(key, Cell(scanResults[ii]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ func executeBeforeSet(bean interface{}, fields []string, scanResults []interface
|
|||
func executeAfterSet(bean interface{}, fields []string, scanResults []interface{}) {
|
||||
if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet {
|
||||
for ii, key := range fields {
|
||||
b.AfterSet(key, Cell(scanResults[ii].(*interface{})))
|
||||
b.AfterSet(key, Cell(scanResults[ii]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package schemas
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"sort"
|
||||
|
@ -248,6 +249,7 @@ var (
|
|||
uintDefault uint
|
||||
timeDefault time.Time
|
||||
bigFloatDefault big.Float
|
||||
nullInt64Default sql.NullInt64
|
||||
)
|
||||
|
||||
// enumerates all types
|
||||
|
@ -277,6 +279,8 @@ var (
|
|||
|
||||
TimeType = reflect.TypeOf(timeDefault)
|
||||
BigFloatType = reflect.TypeOf(bigFloatDefault)
|
||||
|
||||
NullInt64Type = reflect.TypeOf(nullInt64Default)
|
||||
)
|
||||
|
||||
// enumerates all types
|
||||
|
|
111
session.go
111
session.go
|
@ -16,7 +16,6 @@ import (
|
|||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"xorm.io/xorm/contexts"
|
||||
"xorm.io/xorm/convert"
|
||||
|
@ -387,7 +386,7 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s
|
|||
}
|
||||
|
||||
// Cell cell is a result of one column field
|
||||
type Cell *interface{}
|
||||
type Cell interface{}
|
||||
|
||||
func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fields []string,
|
||||
table *schemas.Table, newElemFunc func([]string) reflect.Value,
|
||||
|
@ -439,14 +438,17 @@ func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fiel
|
|||
return scanResults, nil
|
||||
}
|
||||
|
||||
func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, vv reflect.Value, rawValueType reflect.Type) error {
|
||||
func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, scanResult interface{}) error {
|
||||
var bs []byte
|
||||
if rawValueType.Kind() == reflect.String {
|
||||
bs = []byte(vv.String())
|
||||
} else if rawValueType.ConvertibleTo(schemas.BytesType) {
|
||||
bs = vv.Bytes()
|
||||
} else {
|
||||
return fmt.Errorf("unsupported database data type: %v", rawValueType.Kind())
|
||||
switch t := scanResult.(type) {
|
||||
case string:
|
||||
bs = []byte(t)
|
||||
case []byte:
|
||||
bs = t
|
||||
case *sql.NullString:
|
||||
bs = []byte(t.String)
|
||||
default:
|
||||
return fmt.Errorf("unsupported database data type: %#v", scanResult)
|
||||
}
|
||||
|
||||
if len(bs) > 0 {
|
||||
|
@ -487,26 +489,33 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
|
|||
}
|
||||
|
||||
if fieldValue.CanAddr() {
|
||||
if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
||||
return scanner.Scan(scanResult)
|
||||
}
|
||||
if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
|
||||
data, err := value2Bytes(&rawValue)
|
||||
if err != nil {
|
||||
return err
|
||||
data, ok := asBytes(scanResult)
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot convert %#v as bytes", scanResult)
|
||||
}
|
||||
if err := structConvert.FromDB(data); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return structConvert.FromDB(data)
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
|
||||
if data, err := value2Bytes(&rawValue); err == nil {
|
||||
if scanner, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||
return scanner.Scan(scanResult)
|
||||
}
|
||||
|
||||
if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok {
|
||||
data, ok := asBytes(scanResult)
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot convert %#v as bytes", scanResult)
|
||||
}
|
||||
if data != nil {
|
||||
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
|
||||
return fieldValue.Interface().(convert.Conversion).FromDB(data)
|
||||
}
|
||||
fieldValue.Interface().(convert.Conversion).FromDB(data)
|
||||
} else {
|
||||
return err
|
||||
return structConvert.FromDB(data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -516,7 +525,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
|
|||
fieldType := fieldValue.Type()
|
||||
|
||||
if col.IsJSON {
|
||||
return session.setJSON(fieldValue, fieldType, vv, rawValueType)
|
||||
return session.setJSON(fieldValue, fieldType, scanResult)
|
||||
}
|
||||
|
||||
switch fieldType.Kind() {
|
||||
|
@ -535,13 +544,13 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
|
|||
}
|
||||
return nil
|
||||
case reflect.Complex64, reflect.Complex128:
|
||||
return session.setJSON(fieldValue, fieldType, vv, rawValueType)
|
||||
return session.setJSON(fieldValue, fieldType, scanResult)
|
||||
case reflect.Map:
|
||||
switch rawValueType.Kind() {
|
||||
case reflect.String:
|
||||
return session.setJSON(fieldValue, fieldType, vv, rawValueType)
|
||||
return session.setJSON(fieldValue, fieldType, scanResult)
|
||||
case reflect.Slice:
|
||||
return session.setJSON(fieldValue, fieldType, vv, rawValueType)
|
||||
return session.setJSON(fieldValue, fieldType, scanResult)
|
||||
default:
|
||||
return fmt.Errorf("unsupported %v -> %T", scanResult, fieldType)
|
||||
}
|
||||
|
@ -556,7 +565,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
|
|||
fieldValue.Set(x.Elem())
|
||||
return nil
|
||||
case reflect.Slice, reflect.Array:
|
||||
fmt.Printf("======%T\n", scanResult)
|
||||
switch rawValueType.Elem().Kind() {
|
||||
case reflect.Uint8:
|
||||
if fieldType.Elem().Kind() == reflect.Uint8 {
|
||||
|
@ -600,53 +608,12 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
|
|||
dbTZ = col.TimeZone
|
||||
}
|
||||
|
||||
if rawValueType == schemas.TimeType {
|
||||
t := vv.Convert(schemas.TimeType).Interface().(time.Time)
|
||||
|
||||
z, _ := t.Zone()
|
||||
// set new location if database don't save timezone or give an incorrect timezone
|
||||
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
|
||||
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", col.Name, t, z, *t.Location())
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
|
||||
t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
|
||||
}
|
||||
|
||||
t = t.In(session.engine.TZLocation)
|
||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
|
||||
return nil
|
||||
} else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type ||
|
||||
rawValueType == schemas.Int32Type {
|
||||
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
|
||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
|
||||
return nil
|
||||
} else {
|
||||
if d, ok := vv.Interface().([]uint8); ok {
|
||||
t, err := session.byte2Time(col, d)
|
||||
tm, err := asTime(scanResult, dbTZ, session.engine.TZLocation)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("byte2Time error: %v", err)
|
||||
} else {
|
||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
|
||||
return err
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(*tm).Convert(fieldType))
|
||||
return nil
|
||||
}
|
||||
|
||||
} else if d, ok := vv.Interface().(string); ok {
|
||||
t, err := session.str2Time(col, d)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("byte2Time error: %v", err)
|
||||
} else {
|
||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface())
|
||||
}
|
||||
}
|
||||
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
||||
err := nulVal.Scan(vv.Interface())
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
session.engine.logger.Errorf("sql.Sanner error: %v", err)
|
||||
} else if session.statement.UseCascade {
|
||||
table, err := session.engine.tagParser.ParseWithCache(*fieldValue)
|
||||
if err != nil {
|
||||
|
@ -679,7 +646,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
|
|||
}
|
||||
return nil
|
||||
}
|
||||
return session.setJSON(fieldValue, fieldType, vv, rawValueType)
|
||||
return session.setJSON(fieldValue, fieldType, scanResult)
|
||||
} // switch fieldType.Kind()
|
||||
|
||||
return convertAssignV(fieldValue.Addr(), scanResult, session.engine.DatabaseTZ, session.engine.TZLocation)
|
||||
|
|
|
@ -96,14 +96,14 @@ func value2String(rawValue *reflect.Value) (str string, err error) {
|
|||
str = "0"
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
|
||||
err = fmt.Errorf("Unsupported struct type %v as array", vv.Type().Name())
|
||||
}
|
||||
// time type
|
||||
case reflect.Struct:
|
||||
if aa.ConvertibleTo(schemas.TimeType) {
|
||||
str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
|
||||
} else {
|
||||
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
|
||||
err = fmt.Errorf("Unsupported struct type %v as struct", vv.Type().Name())
|
||||
}
|
||||
case reflect.Bool:
|
||||
str = strconv.FormatBool(vv.Bool())
|
||||
|
@ -117,7 +117,7 @@ func value2String(rawValue *reflect.Value) (str string, err error) {
|
|||
case reflect.Chan, reflect.Func, reflect.Interface:
|
||||
*/
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
|
||||
err = fmt.Errorf("Unsupported struct type %v as %v", vv.Type().Name(), aa.Kind())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue