diff --git a/convert.go b/convert.go index 69277734..533dbe99 100644 --- a/convert.go +++ b/convert.go @@ -7,6 +7,7 @@ package xorm import ( "database/sql" "database/sql/driver" + "encoding/json" "errors" "fmt" "math/big" @@ -285,23 +286,94 @@ func asBigFloat(src interface{}) (*big.Float, error) { return nil, fmt.Errorf("unsupported value %T as big.Float", src) } -func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { +func asBytes(src interface{}) ([]byte, bool) { + switch t := src.(type) { + case []byte: + return t, true + case *sql.NullString: + if !t.Valid { + return nil, true + } + return []byte(t.String), true + case *sql.RawBytes: + return *t, true + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return strconv.AppendInt(buf, rv.Int(), 10), true + return strconv.AppendInt(nil, rv.Int(), 10), true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return strconv.AppendUint(buf, rv.Uint(), 10), true + return strconv.AppendUint(nil, rv.Uint(), 10), true case reflect.Float32: - return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 32), true case reflect.Float64: - return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 64), true case reflect.Bool: - return strconv.AppendBool(buf, rv.Bool()), true + return strconv.AppendBool(nil, rv.Bool()), true case reflect.String: - s := rv.String() - return append(buf, s...), true + return []byte(rv.String()), true } - return + return nil, false +} + +func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.Time, error) { + switch t := src.(type) { + case string: + return convert.String2Time(t, dbLoc, uiLoc) + case *sql.NullString: + if !t.Valid { + return nil, nil + } + return convert.String2Time(t.String, dbLoc, uiLoc) + case []uint8: + if t == nil { + return nil, nil + } + return convert.String2Time(string(t), dbLoc, uiLoc) + case *sql.NullTime: + if !t.Valid { + return nil, nil + } + z, _ := t.Time.Zone() + if len(z) == 0 || t.Time.Year() == 0 || t.Time.Location().String() != dbLoc.String() { + tm := time.Date(t.Time.Year(), t.Time.Month(), t.Time.Day(), t.Time.Hour(), + t.Time.Minute(), t.Time.Second(), t.Time.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.Time.In(uiLoc) + return &tm, nil + case *time.Time: + z, _ := t.Zone() + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { + tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.In(uiLoc) + return &tm, nil + case time.Time: + z, _ := t.Zone() + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { + tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.In(uiLoc) + return &tm, nil + case int: + tm := time.Unix(int64(t), 0).In(uiLoc) + return &tm, nil + case int64: + tm := time.Unix(t, 0).In(uiLoc) + return &tm, nil + case *sql.NullInt64: + tm := time.Unix(t.Int64, 0).In(uiLoc) + return &tm, nil + + } + return nil, fmt.Errorf("unsupported value %#v as time", src) } // convertAssign copies to dest the value in src, converting it if possible. @@ -559,8 +631,7 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve return nil } case *[]byte: - sv = reflect.ValueOf(src) - if b, ok := asBytes(nil, sv); ok { + if b, ok := asBytes(src); ok { *d = b return nil } @@ -575,44 +646,24 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve return nil } - return convertAssignV(reflect.ValueOf(dest), src, originalLocation, convertedLocation) + return convertAssignV(reflect.ValueOf(dest), src) } -func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, convertedLocation *time.Location) error { - if dpv.Kind() != reflect.Ptr { - return errors.New("destination not a pointer") - } - if dpv.IsNil() { - return errNilPtr - } - - var sv = reflect.ValueOf(src) - - dv := reflect.Indirect(dpv) - if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { - switch b := src.(type) { - case []byte: - dv.Set(reflect.ValueOf(cloneBytes(b))) - default: - dv.Set(sv) - } +func convertAssignV(dv reflect.Value, src interface{}) error { + if src == nil { return nil } - if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { - dv.Set(sv.Convert(dv.Type())) - return nil + if dv.Type().Implements(scannerType) { + return dv.Interface().(sql.Scanner).Scan(src) } switch dv.Kind() { case reflect.Ptr: - if src == nil { - dv.Set(reflect.Zero(dv.Type())) - return nil + if dv.IsNil() { + dv.Set(reflect.New(dv.Type().Elem())) } - - dv.Set(reflect.New(dv.Type().Elem())) - return convertAssign(dv.Interface(), src, originalLocation, convertedLocation) + return convertAssignV(dv.Elem(), src) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i64, err := asInt64(src) if err != nil { @@ -640,9 +691,28 @@ func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, conver case reflect.String: dv.SetString(asString(src)) return nil + case reflect.Bool: + b, err := asBool(src) + if err != nil { + return err + } + dv.SetBool(b) + return nil + case reflect.Slice, reflect.Map, reflect.Struct, reflect.Array: + data, ok := asBytes(src) + if !ok { + return fmt.Errorf("onvertAssignV: src cannot be as bytes %#v", src) + } + if data == nil { + return nil + } + if dv.Kind() != reflect.Ptr { + dv = dv.Addr() + } + return json.Unmarshal(data, dv.Interface()) + default: + return fmt.Errorf("convertAssignV: unsupported Scan, storing driver.Value type %T into type %T", src, dv.Interface()) } - - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dpv.Interface()) } func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { @@ -682,16 +752,43 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) } -func asBool(bs []byte) (bool, error) { - if len(bs) == 0 { - return false, nil +func asBool(src interface{}) (bool, error) { + switch v := src.(type) { + case bool: + return v, nil + case *bool: + return *v, nil + case *sql.NullBool: + return v.Bool, nil + case int64: + return v > 0, nil + case int: + return v > 0, nil + case int8: + return v > 0, nil + case int16: + return v > 0, nil + case int32: + return v > 0, nil + case []byte: + if len(v) == 0 { + return false, nil + } + if v[0] == 0x00 { + return false, nil + } else if v[0] == 0x01 { + return true, nil + } + return strconv.ParseBool(string(v)) + case string: + return strconv.ParseBool(v) + case *sql.NullInt64: + return v.Int64 > 0, nil + case *sql.NullInt32: + return v.Int32 > 0, nil + default: + return false, fmt.Errorf("unknow type %T as bool", src) } - if bs[0] == 0x00 { - return false, nil - } else if bs[0] == 0x01 { - return true, nil - } - return strconv.ParseBool(string(bs)) } // str2PK convert string value to primary key value according to tp diff --git a/convert/time.go b/convert/time.go index 696b301c..5a3e5246 100644 --- a/convert/time.go +++ b/convert/time.go @@ -6,6 +6,7 @@ package convert import ( "fmt" + "strconv" "time" ) @@ -19,7 +20,7 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t dt = dt.In(convertedLocation) return &dt, nil } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { - dt, err := time.ParseInLocation(time.RFC3339, s, originalLocation) + dt, err := time.ParseInLocation("2006-01-02T15:04:05", s[:19], originalLocation) if err != nil { return nil, err } @@ -32,6 +33,12 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t } dt = dt.In(convertedLocation) return &dt, nil + } else { + i, err := strconv.ParseInt(s, 10, 64) + if err == nil { + tm := time.Unix(i, 0).In(convertedLocation) + return &tm, nil + } } return nil, fmt.Errorf("unsupported convertion from %s to time", s) } diff --git a/dialects/mysql.go b/dialects/mysql.go index db45cd62..9312c071 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -15,7 +15,6 @@ import ( "strings" "time" - "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/schemas" ) @@ -733,52 +732,6 @@ func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) { } } -func (p *mysqlDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, scanResults ...interface{}) error { - var v2 = make([]interface{}, 0, len(scanResults)) - var turnBackIdxes = make([]int, 0, 5) - for i, vv := range scanResults { - switch vv.(type) { - case *time.Time: - v2 = append(v2, &sql.NullString{}) - turnBackIdxes = append(turnBackIdxes, i) - case *sql.NullTime: - v2 = append(v2, &sql.NullString{}) - turnBackIdxes = append(turnBackIdxes, i) - default: - v2 = append(v2, scanResults[i]) - } - } - if err := rows.Scan(v2...); err != nil { - return err - } - for _, i := range turnBackIdxes { - switch t := scanResults[i].(type) { - case *time.Time: - var s = *(v2[i].(*sql.NullString)) - if !s.Valid { - break - } - dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation) - if err != nil { - return err - } - *t = *dt - case *sql.NullTime: - var s = *(v2[i].(*sql.NullString)) - if !s.Valid { - break - } - dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation) - if err != nil { - return err - } - t.Time = *dt - t.Valid = true - } - } - return nil -} - type mymysqlDriver struct { mysqlDriver } diff --git a/engine.go b/engine.go index d3ee8a8c..b4ef9593 100644 --- a/engine.go +++ b/engine.go @@ -543,6 +543,11 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } + fields, err := rows.Columns() + if err != nil { + return err + } + sess := engine.NewSession() defer sess.Close() for rows.Next() { @@ -551,7 +556,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } - scanResults, err := sess.engine.scanStringInterface(rows, types) + scanResults, err := sess.engine.scanStringInterface(rows, fields, types) if err != nil { return err } diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index ca894d59..b1dffe14 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -914,7 +914,7 @@ func TestGetTime(t *testing.T) { assertSync(t, new(GetTimeStruct)) var gts = GetTimeStruct{ - CreateTime: time.Now(), + CreateTime: time.Now().In(testEngine.GetTZLocation()), } _, err := testEngine.Insert(>s) assert.NoError(t, err) diff --git a/integrations/time_test.go b/integrations/time_test.go index 6d8d812c..50fd1847 100644 --- a/integrations/time_test.go +++ b/integrations/time_test.go @@ -53,9 +53,18 @@ func TestTimeUserTimeDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type TimeUser2 struct { @@ -118,9 +127,18 @@ func TestTimeUserCreatedDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type UserCreated2 struct { @@ -204,9 +222,18 @@ func TestTimeUserUpdatedDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type UserUpdated2 struct { @@ -311,9 +338,18 @@ func TestTimeUserDeletedDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type UserDeleted2 struct { @@ -435,9 +471,18 @@ func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type UserDeleted4 struct { diff --git a/integrations/types_null_test.go b/integrations/types_null_test.go index 98bd86b9..86ce1939 100644 --- a/integrations/types_null_test.go +++ b/integrations/types_null_test.go @@ -7,7 +7,6 @@ package integrations import ( "database/sql" "database/sql/driver" - "errors" "fmt" "strconv" "strings" @@ -42,15 +41,22 @@ func (m *CustomStruct) Scan(value interface{}) error { return nil } - if s, ok := value.([]byte); ok { - seps := strings.Split(string(s), "/") + var s string + switch t := value.(type) { + case string: + s = t + case []byte: + s = string(t) + } + if len(s) > 0 { + seps := strings.Split(s, "/") m.Year, _ = strconv.Atoi(seps[0]) m.Month, _ = strconv.Atoi(seps[1]) m.Day, _ = strconv.Atoi(seps[2]) return nil } - return errors.New("scan data not fit []byte") + return fmt.Errorf("scan data %#v not fit []byte", value) } func (m CustomStruct) Value() (driver.Value, error) { diff --git a/rows.go b/rows.go index a56ea1c9..5e0a1ffe 100644 --- a/rows.go +++ b/rows.go @@ -129,8 +129,12 @@ func (rows *Rows) Scan(bean interface{}) error { if err != nil { return err } + types, err := rows.rows.ColumnTypes() + if err != nil { + return err + } - scanResults, err := rows.session.row2Slice(rows.rows, fields, bean) + scanResults, err := rows.session.row2Slice(rows.rows, fields, types, bean) if err != nil { return err } diff --git a/scan.go b/scan.go index 2fedd415..e4c0e4a1 100644 --- a/scan.go +++ b/scan.go @@ -20,6 +20,8 @@ import ( // genScanResultsByBeanNullabale generates scan result func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { switch t := bean.(type) { + case *interface{}: + return t, false, nil case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: return t, false, nil case *time.Time: @@ -71,7 +73,10 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { switch t := bean.(type) { + case *interface{}: + return t, false, nil case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, + *sql.RawBytes, *string, *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, @@ -175,17 +180,14 @@ func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (ma return result, nil } -func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { +func (engine *Engine) scanStringInterface(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) { var scanResults = make([]interface{}, len(types)) for i := 0; i < len(types); i++ { var s sql.NullString scanResults[i] = &s } - if err := engine.driver.Scan(&dialects.ScanContext{ - DBLocation: engine.DatabaseTZ, - UserLocation: engine.TZLocation, - }, rows, types, scanResults...); err != nil { + if err := engine.scan(rows, fields, types, scanResults...); err != nil { return nil, err } return scanResults, nil @@ -200,14 +202,14 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column var replaced bool var scanResult interface{} switch t := v.(type) { + case *big.Float, *time.Time, *sql.NullTime: + scanResult = &sql.NullString{} + replaced = true case sql.Scanner: scanResult = t case convert.Conversion: scanResult = &sql.RawBytes{} replaced = true - case *big.Float: - scanResult = &sql.NullString{} - replaced = true default: var useNullable = true if engine.driver.Features().SupportNullable { @@ -246,7 +248,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column return nil } -func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { +func (engine *Engine) scanInterfaces(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) { var scanResultContainers = make([]interface{}, len(types)) for i := 0; i < len(types); i++ { scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) @@ -255,17 +257,14 @@ func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ( } scanResultContainers[i] = scanResult } - if err := engine.driver.Scan(&dialects.ScanContext{ - DBLocation: engine.DatabaseTZ, - UserLocation: engine.TZLocation, - }, rows, types, scanResultContainers...); err != nil { + if err := engine.scan(rows, fields, types, scanResultContainers...); err != nil { return nil, err } return scanResultContainers, nil } func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { - scanResults, err := engine.scanStringInterface(rows, types) + scanResults, err := engine.scanStringInterface(rows, fields, types) if err != nil { return nil, err } @@ -307,10 +306,7 @@ func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, } scanResultContainers[i] = scanResult } - if err := engine.driver.Scan(&dialects.ScanContext{ - DBLocation: engine.DatabaseTZ, - UserLocation: engine.TZLocation, - }, rows, types, scanResultContainers...); err != nil { + if err := engine.scan(rows, fields, types, scanResultContainers...); err != nil { return nil, err } diff --git a/session.go b/session.go index 486911a5..5557d717 100644 --- a/session.go +++ b/session.go @@ -16,7 +16,6 @@ import ( "io" "reflect" "strings" - "time" "xorm.io/xorm/contexts" "xorm.io/xorm/convert" @@ -389,7 +388,7 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s // Cell cell is a result of one column field type Cell *interface{} -func (session *Session) rows2Beans(rows *core.Rows, fields []string, +func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sql.ColumnType, table *schemas.Table, newElemFunc func([]string) reflect.Value, sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { for rows.Next() { @@ -398,7 +397,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, dataStruct := newValue.Elem() // handle beforeClosures - scanResults, err := session.row2Slice(rows, fields, bean) + scanResults, err := session.row2Slice(rows, fields, types, bean) if err != nil { return err } @@ -417,7 +416,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, return nil } -func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) { +func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql.ColumnType, bean interface{}) ([]interface{}, error) { for _, closure := range session.beforeClosures { closure(bean) } @@ -427,7 +426,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa var cell interface{} scanResults[i] = &cell } - if err := rows.Scan(scanResults...); err != nil { + if err := session.engine.scan(rows, fields, types, scanResults...); err != nil { return nil, err } @@ -454,27 +453,28 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec if fieldValue.CanAddr() { if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - data, err := value2Bytes(&rawValue) - if err != nil { - return err + data, ok := asBytes(scanResult) + if !ok { + return fmt.Errorf("cannot convert %#v as bytes", scanResult) } - if err := structConvert.FromDB(data); err != nil { - return err - } - return nil + return structConvert.FromDB(data) } } - if _, ok := fieldValue.Interface().(convert.Conversion); ok { - if data, err := value2Bytes(&rawValue); err == nil { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - } - fieldValue.Interface().(convert.Conversion).FromDB(data) - } else { - return err + if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { + data, ok := asBytes(scanResult) + if !ok { + return fmt.Errorf("cannot convert %#v as bytes", scanResult) } - return nil + if data == nil { + return nil + } + + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + return fieldValue.Interface().(convert.Conversion).FromDB(data) + } + return structConvert.FromDB(data) } rawValueType := reflect.TypeOf(rawValue.Interface()) @@ -554,64 +554,28 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } return nil case reflect.Slice, reflect.Array: - switch rawValueType.Kind() { - case reflect.Slice, reflect.Array: - switch rawValueType.Elem().Kind() { - case reflect.Uint8: - if fieldType.Elem().Kind() == reflect.Uint8 { - if col.SQLType.IsText() { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } else { - if fieldValue.Len() > 0 { - for i := 0; i < fieldValue.Len(); i++ { - if i < vv.Len() { - fieldValue.Index(i).Set(vv.Index(i)) - } - } - } else { - for i := 0; i < vv.Len(); i++ { - fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) - } + bs, ok := asBytes(scanResult) + if ok && fieldType.Elem().Kind() == reflect.Uint8 { + if col.SQLType.IsText() { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } else { + if fieldValue.Len() > 0 { + for i := 0; i < fieldValue.Len(); i++ { + if i < vv.Len() { + fieldValue.Index(i).Set(vv.Index(i)) } } - return nil + } else { + for i := 0; i < vv.Len(); i++ { + fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) + } } } - } - case reflect.String: - if rawValueType.Kind() == reflect.String { - fieldValue.SetString(vv.String()) - return nil - } - case reflect.Bool: - if rawValueType.Kind() == reflect.Bool { - fieldValue.SetBool(vv.Bool()) - return nil - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch rawValueType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fieldValue.SetInt(vv.Int()) - return nil - } - case reflect.Float32, reflect.Float64: - switch rawValueType.Kind() { - case reflect.Float32, reflect.Float64: - fieldValue.SetFloat(vv.Float()) - return nil - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - switch rawValueType.Kind() { - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - fieldValue.SetUint(vv.Uint()) - return nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fieldValue.SetUint(uint64(vv.Int())) return nil } case reflect.Struct: @@ -630,47 +594,13 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec dbTZ = col.TimeZone } - if rawValueType == schemas.TimeType { - t := vv.Convert(schemas.TimeType).Interface().(time.Time) - - z, _ := t.Zone() - // set new location if database don't save timezone or give an incorrect timezone - if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location - session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", col.Name, t, z, *t.Location()) - t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), - t.Minute(), t.Second(), t.Nanosecond(), dbTZ) - } - - t = t.In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type || - rawValueType == schemas.Int32Type { - t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } else { - if d, ok := vv.Interface().([]uint8); ok { - t, err := session.byte2Time(col, d) - if err != nil { - session.engine.logger.Errorf("byte2Time error: %v", err) - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } - - } else if d, ok := vv.Interface().(string); ok { - t, err := session.str2Time(col, d) - if err != nil { - session.engine.logger.Errorf("byte2Time error: %v", err) - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } - } else { - return fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) - } + t, err := asTime(scanResult, dbTZ, session.engine.TZLocation) + if err != nil { + return err } + + fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType)) + return nil } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { err := nulVal.Scan(vv.Interface()) if err == nil { @@ -733,12 +663,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } } // switch fieldType.Kind() - data, err := value2Bytes(&rawValue) - if err != nil { - return err - } - - return session.bytes2Value(col, fieldValue, data) + return convertAssignV(fieldValue.Addr(), scanResult) } func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { diff --git a/session_convert.go b/session_convert.go index b8218a77..452801e2 100644 --- a/session_convert.go +++ b/session_convert.go @@ -5,16 +5,11 @@ package xorm import ( - "database/sql" - "errors" "fmt" - "reflect" "strconv" "strings" "time" - "xorm.io/xorm/convert" - "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -73,449 +68,3 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time outTime = x.In(session.engine.TZLocation) return } - -func (session *Session) byte2Time(col *schemas.Column, data []byte) (outTime time.Time, outErr error) { - return session.str2Time(col, string(data)) -} - -// convert a db data([]byte) to a field value -func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Value, data []byte) error { - if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - return structConvert.FromDB(data) - } - - if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { - return structConvert.FromDB(data) - } - - var v interface{} - key := col.Name - fieldType := fieldValue.Type() - - switch fieldType.Kind() { - case reflect.Complex64, reflect.Complex128: - x := reflect.New(fieldType) - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - case reflect.Slice, reflect.Array, reflect.Map: - v = data - t := fieldType.Elem() - k := t.Kind() - if col.SQLType.IsText() { - x := reflect.New(fieldType) - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - } else if col.SQLType.IsBlob() { - if k == reflect.Uint8 { - fieldValue.Set(reflect.ValueOf(v)) - } else { - x := reflect.New(fieldType) - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - } - } else { - return ErrUnSupportedType - } - case reflect.String: - fieldValue.SetString(string(data)) - case reflect.Bool: - v, err := asBool(data) - if err != nil { - return fmt.Errorf("arg %v as bool: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(v)) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - sdata := string(data) - var x int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - session.engine.dialect.URI().DBType == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API - if len(data) == 1 { - x = int64(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x, err = strconv.ParseInt(sdata, 16, 64) - } else if strings.HasPrefix(sdata, "0") { - x, err = strconv.ParseInt(sdata, 8, 64) - } else if strings.EqualFold(sdata, "true") { - x = 1 - } else if strings.EqualFold(sdata, "false") { - x = 0 - } else { - x, err = strconv.ParseInt(sdata, 10, 64) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.SetInt(x) - case reflect.Float32, reflect.Float64: - x, err := strconv.ParseFloat(string(data), 64) - if err != nil { - return fmt.Errorf("arg %v as float64: %s", key, err.Error()) - } - fieldValue.SetFloat(x) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.SetUint(x) - //Currently only support Time type - case reflect.Struct: - // !! 增加支持sql.Scanner接口的结构,如sql.NullString - if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - if err := nulVal.Scan(data); err != nil { - return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error()) - } - } else { - if fieldType.ConvertibleTo(schemas.TimeType) { - x, err := session.byte2Time(col, data) - if err != nil { - return err - } - v = x - fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) - } else if session.statement.UseCascade { - table, err := session.engine.tagParser.ParseWithCache(*fieldValue) - if err != nil { - return err - } - - // TODO: current only support 1 primary key - if len(table.PrimaryKeys) > 1 { - return errors.New("unsupported composited primary key cascade") - } - - var pk = make(schemas.PK, len(table.PrimaryKeys)) - rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) - pk[0], err = str2PK(string(data), rawValueType) - if err != nil { - return err - } - - if !pk.IsZero() { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Elem().Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist") - } - } - } - } - case reflect.Ptr: - // !nashtsai! TODO merge duplicated codes above - //typeStr := fieldType.String() - switch fieldType.Elem().Kind() { - // case "*string": - case schemas.StringType.Kind(): - x := string(data) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*bool": - case schemas.BoolType.Kind(): - d := string(data) - v, err := strconv.ParseBool(d) - if err != nil { - return fmt.Errorf("arg %v as bool: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&v).Convert(fieldType)) - // case "*complex64": - case schemas.Complex64Type.Kind(): - var x complex64 - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, &x) - if err != nil { - return err - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - } - // case "*complex128": - case schemas.Complex128Type.Kind(): - var x complex128 - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, &x) - if err != nil { - return err - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - } - // case "*float64": - case schemas.Float64Type.Kind(): - x, err := strconv.ParseFloat(string(data), 64) - if err != nil { - return fmt.Errorf("arg %v as float64: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*float32": - case schemas.Float32Type.Kind(): - var x float32 - x1, err := strconv.ParseFloat(string(data), 32) - if err != nil { - return fmt.Errorf("arg %v as float32: %s", key, err.Error()) - } - x = float32(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint64": - case schemas.Uint64Type.Kind(): - var x uint64 - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint": - case schemas.UintType.Kind(): - var x uint - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - x = uint(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint32": - case schemas.Uint32Type.Kind(): - var x uint32 - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - x = uint32(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint8": - case schemas.Uint8Type.Kind(): - var x uint8 - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - x = uint8(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint16": - case schemas.Uint16Type.Kind(): - var x uint16 - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - x = uint16(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int64": - case schemas.Int64Type.Kind(): - sdata := string(data) - var x int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - strings.Contains(session.engine.DriverName(), "mysql") { - if len(data) == 1 { - x = int64(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x, err = strconv.ParseInt(sdata, 16, 64) - } else if strings.HasPrefix(sdata, "0") { - x, err = strconv.ParseInt(sdata, 8, 64) - } else { - x, err = strconv.ParseInt(sdata, 10, 64) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int": - case schemas.IntType.Kind(): - sdata := string(data) - var x int - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - strings.Contains(session.engine.DriverName(), "mysql") { - if len(data) == 1 { - x = int(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int(x1) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int32": - case schemas.Int32Type.Kind(): - sdata := string(data) - var x int32 - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - session.engine.dialect.URI().DBType == schemas.MYSQL { - if len(data) == 1 { - x = int32(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int32(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int32(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int32(x1) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int8": - case schemas.Int8Type.Kind(): - sdata := string(data) - var x int8 - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - strings.Contains(session.engine.DriverName(), "mysql") { - if len(data) == 1 { - x = int8(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int8(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int8(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int8(x1) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int16": - case schemas.Int16Type.Kind(): - sdata := string(data) - var x int16 - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - strings.Contains(session.engine.DriverName(), "mysql") { - if len(data) == 1 { - x = int16(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int16(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int16(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int16(x1) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*SomeStruct": - case reflect.Struct: - switch fieldType { - // case "*.time.Time": - case schemas.PtrTimeType: - x, err := session.byte2Time(col, data) - if err != nil { - return err - } - v = x - fieldValue.Set(reflect.ValueOf(&x)) - default: - if session.statement.UseCascade { - structInter := reflect.New(fieldType.Elem()) - table, err := session.engine.tagParser.ParseWithCache(structInter.Elem()) - if err != nil { - return err - } - - if len(table.PrimaryKeys) > 1 { - return errors.New("unsupported composited primary key cascade") - } - - var pk = make(schemas.PK, len(table.PrimaryKeys)) - rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) - pk[0], err = str2PK(string(data), rawValueType) - if err != nil { - return err - } - - if !pk.IsZero() { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist") - } - } - } else { - return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) - } - } - default: - return fmt.Errorf("unsupported type in Scan: %s", fieldValue.Type().String()) - } - default: - return fmt.Errorf("unsupported type in Scan: %s", fieldValue.Type().String()) - } - - return nil -} diff --git a/session_find.go b/session_find.go index 261e6b7f..41d68479 100644 --- a/session_find.go +++ b/session_find.go @@ -172,6 +172,11 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return err } + types, err := rows.ColumnTypes() + if err != nil { + return err + } + var newElemFunc func(fields []string) reflect.Value elemType := containerValue.Type().Elem() var isPointer bool @@ -241,7 +246,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect if err != nil { return err } - err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc) + err = session.rows2Beans(rows, fields, types, tb, newElemFunc, containerValueSetFunc) rows.Close() if err != nil { return err diff --git a/session_get.go b/session_get.go index cc6427d7..fa97e68e 100644 --- a/session_get.go +++ b/session_get.go @@ -192,7 +192,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { switch t := bean.(type) { case *[]string: - res, err := session.engine.scanStringInterface(rows, types) + res, err := session.engine.scanStringInterface(rows, fields, types) if err != nil { return true, err } @@ -207,7 +207,7 @@ func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, field } return true, nil case *[]interface{}: - scanResults, err := session.engine.scanInterfaces(rows, types) + scanResults, err := session.engine.scanInterfaces(rows, fields, types) if err != nil { return true, err } @@ -232,7 +232,7 @@ func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, field func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { switch t := bean.(type) { case *map[string]string: - scanResults, err := session.engine.scanStringInterface(rows, types) + scanResults, err := session.engine.scanStringInterface(rows, fields, types) if err != nil { return true, err } @@ -241,7 +241,7 @@ func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields } return true, nil case *map[string]interface{}: - scanResults, err := session.engine.scanInterfaces(rows, types) + scanResults, err := session.engine.scanInterfaces(rows, fields, types) if err != nil { return true, err } @@ -268,7 +268,7 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields } func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { - scanResults, err := session.row2Slice(rows, fields, bean) + scanResults, err := session.row2Slice(rows, fields, types, bean) if err != nil { return false, err } diff --git a/session_insert.go b/session_insert.go index 7f8f3008..b41dbbac 100644 --- a/session_insert.go +++ b/session_insert.go @@ -375,7 +375,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) + return 1, convertAssignV(aiValue.Addr(), id) } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) @@ -415,7 +415,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) + return 1, convertAssignV(*aiValue, id) } res, err := session.exec(sqlStr, args...) @@ -455,7 +455,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - if err := convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation); err != nil { + if err := convertAssignV(*aiValue, id); err != nil { return 0, err } diff --git a/session_raw.go b/session_raw.go index bf32c6ed..7eb8585d 100644 --- a/session_raw.go +++ b/session_raw.go @@ -6,13 +6,8 @@ package xorm import ( "database/sql" - "fmt" - "reflect" - "strconv" - "time" "xorm.io/xorm/core" - "xorm.io/xorm/schemas" ) func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { @@ -75,61 +70,6 @@ func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row { return core.NewRow(session.queryRows(sqlStr, args...)) } -func value2String(rawValue *reflect.Value) (str string, err error) { - aa := reflect.TypeOf((*rawValue).Interface()) - vv := reflect.ValueOf((*rawValue).Interface()) - switch aa.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str = strconv.FormatInt(vv.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str = strconv.FormatUint(vv.Uint(), 10) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) - case reflect.String: - str = vv.String() - case reflect.Array, reflect.Slice: - switch aa.Elem().Kind() { - case reflect.Uint8: - data := rawValue.Interface().([]byte) - str = string(data) - if str == "\x00" { - str = "0" - } - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - // time type - case reflect.Struct: - if aa.ConvertibleTo(schemas.TimeType) { - str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) - } else { - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - case reflect.Bool: - str = strconv.FormatBool(vv.Bool()) - case reflect.Complex128, reflect.Complex64: - str = fmt.Sprintf("%v", vv.Complex()) - /* TODO: unsupported types below - case reflect.Map: - case reflect.Ptr: - case reflect.Uintptr: - case reflect.UnsafePointer: - case reflect.Chan, reflect.Func, reflect.Interface: - */ - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - return -} - -func value2Bytes(rawValue *reflect.Value) ([]byte, error) { - str, err := value2String(rawValue) - if err != nil { - return nil, err - } - return []byte(str), nil -} - func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil {