From a5030dc7a444c098b8c97d514f06df90ed6b57f8 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 6 Jul 2021 16:06:04 +0800 Subject: [PATCH] refactor get (#1967) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1967 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- convert.go | 416 +++++++++++++++++++++++++++++++++++++++++-- dialects/driver.go | 11 ++ dialects/mysql.go | 1 + dialects/postgres.go | 6 + dialects/sqlite3.go | 6 + scan.go | 197 +++++++++++++++++++- session_find.go | 2 +- session_get.go | 299 ++++++++++++++++--------------- session_insert.go | 6 +- session_query.go | 2 +- 10 files changed, 784 insertions(+), 162 deletions(-) diff --git a/convert.go b/convert.go index b7f30cad..f7d733ad 100644 --- a/convert.go +++ b/convert.go @@ -5,12 +5,15 @@ package xorm import ( + "database/sql" "database/sql/driver" "errors" "fmt" "reflect" "strconv" "time" + + "xorm.io/xorm/convert" ) var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error @@ -37,6 +40,12 @@ func asString(src interface{}) string { return v case []byte: return string(v) + case *sql.NullString: + return v.String + case *sql.NullInt32: + return fmt.Sprintf("%d", v.Int32) + case *sql.NullInt64: + return fmt.Sprintf("%d", v.Int64) } rv := reflect.ValueOf(src) switch rv.Kind() { @@ -54,6 +63,156 @@ func asString(src interface{}) string { return fmt.Sprintf("%v", src) } +func asInt64(src interface{}) (int64, error) { + switch v := src.(type) { + case int: + return int64(v), nil + case int16: + return int64(v), nil + case int32: + return int64(v), nil + case int8: + return int64(v), nil + case int64: + return v, nil + case uint: + return int64(v), nil + case uint8: + return int64(v), nil + case uint16: + return int64(v), nil + case uint32: + return int64(v), nil + case uint64: + return int64(v), nil + case []byte: + return strconv.ParseInt(string(v), 10, 64) + case string: + return strconv.ParseInt(v, 10, 64) + case *sql.NullString: + return strconv.ParseInt(v.String, 10, 64) + case *sql.NullInt32: + return int64(v.Int32), nil + case *sql.NullInt64: + return int64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int64(rv.Uint()), nil + case reflect.Float64: + return int64(rv.Float()), nil + case reflect.Float32: + return int64(rv.Float()), nil + case reflect.String: + return strconv.ParseInt(rv.String(), 10, 64) + } + return 0, fmt.Errorf("unsupported value %T as int64", src) +} + +func asUint64(src interface{}) (uint64, error) { + switch v := src.(type) { + case int: + return uint64(v), nil + case int16: + return uint64(v), nil + case int32: + return uint64(v), nil + case int8: + return uint64(v), nil + case int64: + return uint64(v), nil + case uint: + return uint64(v), nil + case uint8: + return uint64(v), nil + case uint16: + return uint64(v), nil + case uint32: + return uint64(v), nil + case uint64: + return v, nil + case []byte: + return strconv.ParseUint(string(v), 10, 64) + case string: + return strconv.ParseUint(v, 10, 64) + case *sql.NullString: + return strconv.ParseUint(v.String, 10, 64) + case *sql.NullInt32: + return uint64(v.Int32), nil + case *sql.NullInt64: + return uint64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return uint64(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uint64(rv.Uint()), nil + case reflect.Float64: + return uint64(rv.Float()), nil + case reflect.Float32: + return uint64(rv.Float()), nil + case reflect.String: + return strconv.ParseUint(rv.String(), 10, 64) + } + return 0, fmt.Errorf("unsupported value %T as uint64", src) +} + +func asFloat64(src interface{}) (float64, error) { + switch v := src.(type) { + case int: + return float64(v), nil + case int16: + return float64(v), nil + case int32: + return float64(v), nil + case int8: + return float64(v), nil + case int64: + return float64(v), nil + case uint: + return float64(v), nil + case uint8: + return float64(v), nil + case uint16: + return float64(v), nil + case uint32: + return float64(v), nil + case uint64: + return float64(v), nil + case []byte: + return strconv.ParseFloat(string(v), 64) + case string: + return strconv.ParseFloat(v, 64) + case *sql.NullString: + return strconv.ParseFloat(v.String, 64) + case *sql.NullInt32: + return float64(v.Int32), nil + case *sql.NullInt64: + return float64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return float64(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return float64(rv.Uint()), nil + case reflect.Float64: + return float64(rv.Float()), nil + case reflect.Float32: + return float64(rv.Float()), nil + case reflect.String: + return strconv.ParseFloat(rv.String(), 64) + } + return 0, fmt.Errorf("unsupported value %T as int64", src) +} + func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -76,7 +235,7 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { // convertAssign copies to dest the value in src, converting it if possible. // An error is returned if the copy would result in loss of information. // dest should be a pointer type. -func convertAssign(dest, src interface{}) error { +func convertAssign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error { // Common cases, without reflect. switch s := src.(type) { case string: @@ -143,6 +302,163 @@ func convertAssign(dest, src interface{}) error { *d = nil return nil } + case *sql.NullString: + switch d := dest.(type) { + case *int: + if s.Valid { + *d, _ = strconv.Atoi(s.String) + } + case *int64: + if s.Valid { + *d, _ = strconv.ParseInt(s.String, 10, 64) + } + case *string: + if s.Valid { + *d = s.String + } + return nil + case *time.Time: + if s.Valid { + var err error + dt, err := convert.String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + *d = *dt + } + return nil + case *sql.NullTime: + if s.Valid { + var err error + dt, err := convert.String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + d.Valid = true + d.Time = *dt + } + } + case *sql.NullInt32: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int32) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int32) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int32) + } + return nil + case *int32: + if s.Valid { + *d = s.Int32 + } + return nil + case *int64: + if s.Valid { + *d = int64(s.Int32) + } + return nil + } + case *sql.NullInt64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int64) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int64) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int64) + } + return nil + case *int32: + if s.Valid { + *d = int32(s.Int64) + } + return nil + case *int64: + if s.Valid { + *d = s.Int64 + } + return nil + } + case *sql.NullFloat64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Float64) + } + return nil + case *float64: + if s.Valid { + *d = s.Float64 + } + return nil + } + case *sql.NullBool: + switch d := dest.(type) { + case *bool: + if s.Valid { + *d = s.Bool + } + return nil + } + case *sql.NullTime: + switch d := dest.(type) { + case *time.Time: + if s.Valid { + *d = s.Time + } + return nil + case *string: + if s.Valid { + *d = s.Time.In(convertedLocation).Format("2006-01-02 15:04:05") + } + return nil + } + case *NullUint32: + switch d := dest.(type) { + case *uint8: + if s.Valid { + *d = uint8(s.Uint32) + } + return nil + case *uint16: + if s.Valid { + *d = uint16(s.Uint32) + } + return nil + case *uint: + if s.Valid { + *d = uint(s.Uint32) + } + return nil + } + case *NullUint64: + switch d := dest.(type) { + case *uint64: + if s.Valid { + *d = s.Uint64 + } + return nil + } + case *sql.RawBytes: + switch d := dest.(type) { + case convert.Conversion: + return d.FromDB(*s) + } } var sv reflect.Value @@ -175,10 +491,10 @@ func convertAssign(dest, src interface{}) error { return nil } - return convertAssignV(reflect.ValueOf(dest), src) + return convertAssignV(reflect.ValueOf(dest), src, originalLocation, convertedLocation) } -func convertAssignV(dpv reflect.Value, src interface{}) error { +func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, convertedLocation *time.Location) error { if dpv.Kind() != reflect.Ptr { return errors.New("destination not a pointer") } @@ -212,31 +528,28 @@ func convertAssignV(dpv reflect.Value, src interface{}) error { } dv.Set(reflect.New(dv.Type().Elem())) - return convertAssign(dv.Interface(), src) + return convertAssign(dv.Interface(), src, originalLocation, convertedLocation) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - s := asString(src) - i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + i64, err := asInt64(src) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) } dv.SetInt(i64) return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - s := asString(src) - u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + u64, err := asUint64(src) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) } dv.SetUint(u64) return nil case reflect.Float32, reflect.Float64: - s := asString(src) - f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + f64, err := asFloat64(src) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) } dv.SetFloat(f64) return nil @@ -376,3 +689,80 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) { } return v.Interface(), nil } + +var ( + _ sql.Scanner = &NullUint64{} +) + +// NullUint64 represents an uint64 that may be null. +// NullUint64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint64 struct { + Uint64 uint64 + Valid bool +} + +// Scan implements the Scanner interface. +func (n *NullUint64) Scan(value interface{}) error { + if value == nil { + n.Uint64, n.Valid = 0, false + return nil + } + n.Valid = true + var err error + n.Uint64, err = asUint64(value) + return err +} + +// Value implements the driver Valuer interface. +func (n NullUint64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Uint64, nil +} + +var ( + _ sql.Scanner = &NullUint32{} +) + +// NullUint32 represents an uint32 that may be null. +// NullUint32 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint32 struct { + Uint32 uint32 + Valid bool // Valid is true if Uint32 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullUint32) Scan(value interface{}) error { + if value == nil { + n.Uint32, n.Valid = 0, false + return nil + } + n.Valid = true + i64, err := asUint64(value) + if err != nil { + return err + } + n.Uint32 = uint32(i64) + return nil +} + +// Value implements the driver Valuer interface. +func (n NullUint32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Uint32), nil +} + +var ( + _ sql.Scanner = &EmptyScanner{} +) + +type EmptyScanner struct{} + +func (EmptyScanner) Scan(value interface{}) error { + return nil +} diff --git a/dialects/driver.go b/dialects/driver.go index c511b665..0b6187d3 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -18,9 +18,14 @@ type ScanContext struct { UserLocation *time.Location } +type DriverFeatures struct { + SupportNullable bool +} + // Driver represents a database driver type Driver interface { Parse(string, string) (*URI, error) + Features() DriverFeatures GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error } @@ -77,3 +82,9 @@ type baseDriver struct{} func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error { return rows.Scan(v...) } + +func (b *baseDriver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: true, + } +} diff --git a/dialects/mysql.go b/dialects/mysql.go index 03bc9a4b..a341ce05 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -633,6 +633,7 @@ func (db *mysql) Filters() []Filter { } type mysqlDriver struct { + baseDriver } func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { diff --git a/dialects/postgres.go b/dialects/postgres.go index e4641509..a2611c60 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1302,6 +1302,12 @@ type pqDriver struct { baseDriver } +func (b *pqDriver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: false, + } +} + type values map[string]string func (vs values) Set(k, v string) { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 306f377c..1bc0b218 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -576,3 +576,9 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { return &r, nil } } + +func (b *sqlite3Driver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: false, + } +} diff --git a/scan.go b/scan.go index e11d6e8d..c5cb77ff 100644 --- a/scan.go +++ b/scan.go @@ -6,12 +6,120 @@ package xorm import ( "database/sql" + "fmt" + "reflect" + "time" "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/dialects" ) +// genScanResultsByBeanNullabale generates scan result +func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { + switch t := bean.(type) { + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: + return t, false, nil + case *time.Time: + return &sql.NullTime{}, true, nil + case *string: + return &sql.NullString{}, true, nil + case *int, *int8, *int16, *int32: + return &sql.NullInt32{}, true, nil + case *int64: + return &sql.NullInt64{}, true, nil + case *uint, *uint8, *uint16, *uint32: + return &NullUint32{}, true, nil + case *uint64: + return &NullUint64{}, true, nil + case *float32, *float64: + return &sql.NullFloat64{}, true, nil + case *bool: + return &sql.NullBool{}, true, nil + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, + time.Time, + string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + float32, float64, + bool: + return nil, false, fmt.Errorf("unsupported scan type: %t", t) + case convert.Conversion: + return &sql.RawBytes{}, true, nil + } + + tp := reflect.TypeOf(bean).Elem() + switch tp.Kind() { + case reflect.String: + return &sql.NullString{}, true, nil + case reflect.Int64: + return &sql.NullInt64{}, true, nil + case reflect.Int32, reflect.Int, reflect.Int16, reflect.Int8: + return &sql.NullInt32{}, true, nil + case reflect.Uint64: + return &NullUint64{}, true, nil + case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8: + return &NullUint32{}, true, nil + default: + return nil, false, fmt.Errorf("unsupported type: %#v", bean) + } +} + +func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { + switch t := bean.(type) { + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, + *string, + *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, + *float32, *float64, + *bool: + return t, false, nil + case *time.Time: + return &sql.NullTime{}, true, nil + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, + time.Time, + string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + bool: + return nil, false, fmt.Errorf("unsupported scan type: %t", t) + case convert.Conversion: + return &sql.RawBytes{}, true, nil + } + + tp := reflect.TypeOf(bean).Elem() + switch tp.Kind() { + case reflect.String: + return new(string), true, nil + case reflect.Int64: + return new(int64), true, nil + case reflect.Int32: + return new(int32), true, nil + case reflect.Int: + return new(int32), true, nil + case reflect.Int16: + return new(int32), true, nil + case reflect.Int8: + return new(int32), true, nil + case reflect.Uint64: + return new(uint64), true, nil + case reflect.Uint32: + return new(uint32), true, nil + case reflect.Uint: + return new(uint), true, nil + case reflect.Uint16: + return new(uint16), true, nil + case reflect.Uint8: + return new(uint8), true, nil + case reflect.Float32: + return new(float32), true, nil + case reflect.Float64: + return new(float64), true, nil + default: + return nil, false, fmt.Errorf("unsupported type: %#v", bean) + } +} + func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { var scanResults = make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { @@ -50,18 +158,97 @@ func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (ma return result, nil } -func row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { - results := make([]string, 0, len(fields)) - var scanResults = make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { +func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { + var scanResults = make([]interface{}, len(types)) + for i := 0; i < len(types); i++ { var s sql.NullString scanResults[i] = &s } - if err := rows.Scan(scanResults...); err != nil { + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResults...); err != nil { + return nil, err + } + return scanResults, nil +} + +// scan is a wrap of driver.Scan but will automatically change the input values according requirements +func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error { + var scanResults = make([]interface{}, 0, len(types)) + var replaces = make([]bool, 0, len(types)) + var err error + for _, v := range vv { + var replaced bool + var scanResult interface{} + if _, ok := v.(sql.Scanner); !ok { + var useNullable = true + if engine.driver.Features().SupportNullable { + nullable, ok := types[0].Nullable() + useNullable = ok && nullable + } + + if useNullable { + scanResult, replaced, err = genScanResultsByBeanNullable(v) + } else { + scanResult, replaced, err = genScanResultsByBean(v) + } + if err != nil { + return err + } + } else { + scanResult = v + } + scanResults = append(scanResults, scanResult) + replaces = append(replaces, replaced) + } + + var scanCtx = dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + } + + if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil { + return err + } + + for i, replaced := range replaces { + if replaced { + if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil { + return err + } + } + } + + return nil +} + +func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { + var scanResultContainers = make([]interface{}, len(types)) + for i := 0; i < len(types); i++ { + scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) + if err != nil { + return nil, err + } + scanResultContainers[i] = scanResult + } + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResultContainers...); err != nil { + return nil, err + } + return scanResultContainers, nil +} + +func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { + scanResults, err := engine.scanStringInterface(rows, types) + if err != nil { return nil, err } + var results = make([]string, 0, len(fields)) for i := 0; i < len(fields); i++ { results = append(results, scanResults[i].(*sql.NullString).String) } diff --git a/session_find.go b/session_find.go index 0daea005..261e6b7f 100644 --- a/session_find.go +++ b/session_find.go @@ -276,7 +276,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { cols := table.PKColumns() if len(cols) == 1 { - return convertAssign(dst, pk[0]) + return convertAssign(dst, pk[0], nil, nil) } dst = pk diff --git a/session_get.go b/session_get.go index e303176d..cb2bda75 100644 --- a/session_get.go +++ b/session_get.go @@ -6,12 +6,16 @@ package xorm import ( "database/sql" + "database/sql/driver" "errors" "fmt" "reflect" "strconv" + "time" "xorm.io/xorm/caches" + "xorm.io/xorm/convert" + "xorm.io/xorm/core" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -108,6 +112,17 @@ func (session *Session) get(bean interface{}) (bool, error) { return true, nil } +var ( + valuerTypePlaceHolder driver.Valuer + valuerType = reflect.TypeOf(&valuerTypePlaceHolder).Elem() + + scannerTypePlaceHolder sql.Scanner + scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem() + + conversionTypePlaceHolder convert.Conversion + conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem() +) + func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { @@ -122,155 +137,161 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, return false, nil } - switch bean.(type) { - case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString: - return true, rows.Scan(&bean) - case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString: - return true, rows.Scan(bean) - case *string: - var res sql.NullString - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*string)) = res.String - } - return true, nil - case *int: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int)) = int(res.Int64) - } - return true, nil - case *int8: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int8)) = int8(res.Int64) - } - return true, nil - case *int16: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int16)) = int16(res.Int64) - } - return true, nil - case *int32: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int32)) = int32(res.Int64) - } - return true, nil - case *int64: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int64)) = int64(res.Int64) - } - return true, nil - case *uint: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint)) = uint(res.Int64) - } - return true, nil - case *uint8: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint8)) = uint8(res.Int64) - } - return true, nil - case *uint16: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint16)) = uint16(res.Int64) - } - return true, nil - case *uint32: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint32)) = uint32(res.Int64) - } - return true, nil - case *uint64: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint64)) = uint64(res.Int64) - } - return true, nil - case *bool: - var res sql.NullBool - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*bool)) = res.Bool - } - return true, nil + // WARN: Alougth rows return true, but we may also return error. + types, err := rows.ColumnTypes() + if err != nil { + return true, err + } + fields, err := rows.Columns() + if err != nil { + return true, err } - switch beanKind { case reflect.Struct: - fields, err := rows.Columns() - if err != nil { - // WARN: Alougth rows return true, but get fields failed - return true, err + if _, ok := bean.(*time.Time); ok { + break } - - scanResults, err := session.row2Slice(rows, fields, bean) - if err != nil { - return false, err + if _, ok := bean.(sql.Scanner); ok { + break } - // close it before convert data - rows.Close() - - dataStruct := utils.ReflectValue(bean) - _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) - if err != nil { - return true, err + if _, ok := bean.(convert.Conversion); len(types) == 1 && ok { + break } - - return true, session.executeProcessors() + return session.getStruct(rows, types, fields, table, bean) case reflect.Slice: - err = rows.ScanSlice(bean) + return session.getSlice(rows, types, fields, bean) case reflect.Map: - err = rows.ScanMap(bean) - case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - err = rows.Scan(bean) - default: - err = rows.Scan(bean) + return session.getMap(rows, types, fields, bean) } - return true, err + return session.getVars(rows, types, fields, bean) +} + +func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { + switch t := bean.(type) { + case *[]string: + res, err := session.engine.scanStringInterface(rows, types) + if err != nil { + return true, err + } + + var needAppend = len(*t) == 0 // both support slice is empty or has been initlized + for i, r := range res { + if needAppend { + *t = append(*t, r.(*sql.NullString).String) + } else { + (*t)[i] = r.(*sql.NullString).String + } + } + return true, nil + case *[]interface{}: + scanResults, err := session.engine.scanInterfaces(rows, types) + if err != nil { + return true, err + } + var needAppend = len(*t) == 0 + for ii := range fields { + s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) + if err != nil { + return true, err + } + if needAppend { + *t = append(*t, s) + } else { + (*t)[ii] = s + } + } + return true, nil + default: + return true, fmt.Errorf("unspoorted slice type: %t", t) + } +} + +func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { + switch t := bean.(type) { + case *map[string]string: + scanResults, err := session.engine.scanStringInterface(rows, types) + if err != nil { + return true, err + } + for ii, key := range fields { + (*t)[key] = scanResults[ii].(*sql.NullString).String + } + return true, nil + case *map[string]interface{}: + scanResults, err := session.engine.scanInterfaces(rows, types) + if err != nil { + return true, err + } + for ii, key := range fields { + s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) + if err != nil { + return true, err + } + (*t)[key] = s + } + return true, nil + default: + return true, fmt.Errorf("unspoorted map type: %t", t) + } +} + +func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields []string, beans ...interface{}) (bool, error) { + if len(beans) != len(types) { + return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans)) + } + var scanResults = make([]interface{}, 0, len(types)) + var replaceds = make([]bool, 0, len(types)) + for _, bean := range beans { + switch t := bean.(type) { + case sql.Scanner: + scanResults = append(scanResults, t) + replaceds = append(replaceds, false) + case convert.Conversion: + scanResults = append(scanResults, &sql.RawBytes{}) + replaceds = append(replaceds, true) + default: + scanResults = append(scanResults, bean) + replaceds = append(replaceds, false) + } + } + + err := session.engine.scan(rows, fields, types, scanResults...) + if err != nil { + return true, err + } + for i, replaced := range replaceds { + if replaced { + err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation) + if err != nil { + return true, err + } + } + } + return true, nil +} + +func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { + fields, err := rows.Columns() + if err != nil { + // WARN: Alougth rows return true, but get fields failed + return true, err + } + + scanResults, err := session.row2Slice(rows, fields, bean) + if err != nil { + return false, err + } + // close it before convert data + rows.Close() + + dataStruct := utils.ReflectValue(bean) + _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) + if err != nil { + return true, err + } + + return true, session.executeProcessors() } func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { diff --git a/session_insert.go b/session_insert.go index e733e06e..7f8f3008 100644 --- a/session_insert.go +++ b/session_insert.go @@ -375,7 +375,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(aiValue.Addr(), id) + return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) @@ -415,7 +415,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(aiValue.Addr(), id) + return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) } res, err := session.exec(sqlStr, args...) @@ -455,7 +455,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - if err := convertAssignV(aiValue.Addr(), id); err != nil { + if err := convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation); err != nil { return 0, err } diff --git a/session_query.go b/session_query.go index 01cd6f44..fa33496d 100644 --- a/session_query.go +++ b/session_query.go @@ -54,7 +54,7 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri } for rows.Next() { - record, err := row2sliceStr(rows, types, fields) + record, err := session.engine.row2sliceStr(rows, types, fields) if err != nil { return nil, err }