From d90767bcb797f9c12a0d99b7d287fa9ec387ea93 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 4 Jul 2021 23:49:59 +0800 Subject: [PATCH] refactor get --- convert.go | 244 +++++++++++++++++++++++++++++++++++- dialects/driver.go | 11 ++ dialects/mysql.go | 1 + dialects/sqlite3.go | 6 + scan.go | 198 ++++++++++++++++++++++++++++- session_find.go | 2 +- session_get.go | 299 ++++++++++++++++++++++++-------------------- session_insert.go | 6 +- session_query.go | 2 +- 9 files changed, 616 insertions(+), 153 deletions(-) diff --git a/convert.go b/convert.go index b7f30cad..c4774d97 100644 --- a/convert.go +++ b/convert.go @@ -5,12 +5,15 @@ package xorm import ( + "database/sql" "database/sql/driver" "errors" "fmt" "reflect" "strconv" "time" + + "xorm.io/xorm/convert" ) var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error @@ -76,7 +79,7 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { // convertAssign copies to dest the value in src, converting it if possible. // An error is returned if the copy would result in loss of information. // dest should be a pointer type. -func convertAssign(dest, src interface{}) error { +func convertAssign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error { // Common cases, without reflect. switch s := src.(type) { case string: @@ -143,6 +146,163 @@ func convertAssign(dest, src interface{}) error { *d = nil return nil } + case *sql.NullString: + switch d := dest.(type) { + case *int: + if s.Valid { + *d, _ = strconv.Atoi(s.String) + } + case *int64: + if s.Valid { + *d, _ = strconv.ParseInt(s.String, 10, 64) + } + case *string: + if s.Valid { + *d = s.String + } + return nil + case *time.Time: + if s.Valid { + var err error + dt, err := convert.String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + *d = *dt + } + return nil + case *sql.NullTime: + if s.Valid { + var err error + dt, err := convert.String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + d.Valid = true + d.Time = *dt + } + } + case *sql.NullInt32: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int32) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int32) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int32) + } + return nil + case *int32: + if s.Valid { + *d = s.Int32 + } + return nil + case *int64: + if s.Valid { + *d = int64(s.Int32) + } + return nil + } + case *sql.NullInt64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int64) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int64) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int64) + } + return nil + case *int32: + if s.Valid { + *d = int32(s.Int64) + } + return nil + case *int64: + if s.Valid { + *d = s.Int64 + } + return nil + } + case *sql.NullFloat64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Float64) + } + return nil + case *float64: + if s.Valid { + *d = s.Float64 + } + return nil + } + case *sql.NullBool: + switch d := dest.(type) { + case *bool: + if s.Valid { + *d = s.Bool + } + return nil + } + case *sql.NullTime: + switch d := dest.(type) { + case *time.Time: + if s.Valid { + *d = s.Time + } + return nil + case *string: + if s.Valid { + *d = s.Time.In(convertedLocation).Format("2006-01-02 15:04:05") + } + return nil + } + case *NullUint32: + switch d := dest.(type) { + case *uint8: + if s.Valid { + *d = uint8(s.Uint32) + } + return nil + case *uint16: + if s.Valid { + *d = uint16(s.Uint32) + } + return nil + case *uint: + if s.Valid { + *d = uint(s.Uint32) + } + return nil + } + case *NullUint64: + switch d := dest.(type) { + case *uint64: + if s.Valid { + *d = s.Uint64 + } + return nil + } + case *sql.RawBytes: + switch d := dest.(type) { + case convert.Conversion: + return d.FromDB(*s) + } } var sv reflect.Value @@ -175,10 +335,10 @@ func convertAssign(dest, src interface{}) error { return nil } - return convertAssignV(reflect.ValueOf(dest), src) + return convertAssignV(reflect.ValueOf(dest), src, originalLocation, convertedLocation) } -func convertAssignV(dpv reflect.Value, src interface{}) error { +func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, convertedLocation *time.Location) error { if dpv.Kind() != reflect.Ptr { return errors.New("destination not a pointer") } @@ -212,7 +372,7 @@ func convertAssignV(dpv reflect.Value, src interface{}) error { } dv.Set(reflect.New(dv.Type().Elem())) - return convertAssign(dv.Interface(), src) + return convertAssign(dv.Interface(), src, originalLocation, convertedLocation) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: s := asString(src) i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) @@ -376,3 +536,79 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) { } return v.Interface(), nil } + +var ( + _ sql.Scanner = &NullUint64{} +) + +// NullUint64 represents an uint64 that may be null. +// NullUint64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint64 struct { + Uint64 uint64 + Valid bool // Valid is true if Uint64 is not NULL + OriginalLocation *time.Location + ConvertedLocation *time.Location +} + +// Scan implements the Scanner interface. +func (n *NullUint64) Scan(value interface{}) error { + if value == nil { + n.Uint64, n.Valid = 0, false + return nil + } + n.Valid = true + fmt.Println("======44444") + return convertAssign(&n.Uint64, value, n.OriginalLocation, n.ConvertedLocation) +} + +// Value implements the driver Valuer interface. +func (n NullUint64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Uint64, nil +} + +var ( + _ sql.Scanner = &NullUint32{} +) + +// NullUint32 represents an uint32 that may be null. +// NullUint32 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint32 struct { + Uint32 uint32 + Valid bool // Valid is true if Uint32 is not NULL + OriginalLocation *time.Location + ConvertedLocation *time.Location +} + +// Scan implements the Scanner interface. +func (n *NullUint32) Scan(value interface{}) error { + if value == nil { + n.Uint32, n.Valid = 0, false + return nil + } + n.Valid = true + fmt.Println("555555") + return convertAssign(&n.Uint32, value, n.OriginalLocation, n.ConvertedLocation) +} + +// Value implements the driver Valuer interface. +func (n NullUint32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Uint32), nil +} + +var ( + _ sql.Scanner = &EmptyScanner{} +) + +type EmptyScanner struct{} + +func (EmptyScanner) Scan(value interface{}) error { + return nil +} diff --git a/dialects/driver.go b/dialects/driver.go index c511b665..0b6187d3 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -18,9 +18,14 @@ type ScanContext struct { UserLocation *time.Location } +type DriverFeatures struct { + SupportNullable bool +} + // Driver represents a database driver type Driver interface { Parse(string, string) (*URI, error) + Features() DriverFeatures GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error } @@ -77,3 +82,9 @@ type baseDriver struct{} func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error { return rows.Scan(v...) } + +func (b *baseDriver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: true, + } +} diff --git a/dialects/mysql.go b/dialects/mysql.go index 03bc9a4b..a341ce05 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -633,6 +633,7 @@ func (db *mysql) Filters() []Filter { } type mysqlDriver struct { + baseDriver } func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 306f377c..1bc0b218 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -576,3 +576,9 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { return &r, nil } } + +func (b *sqlite3Driver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: false, + } +} diff --git a/scan.go b/scan.go index e11d6e8d..b23785d8 100644 --- a/scan.go +++ b/scan.go @@ -6,12 +6,121 @@ package xorm import ( "database/sql" + "fmt" + "reflect" + "time" "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/dialects" ) +// genScanResultsByBeanNullabale generates scan result +func genScanResultsByBeanNullable(bean interface{}, originalLocation, convertedLocation *time.Location) (interface{}, bool, error) { + switch t := bean.(type) { + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: + return t, false, nil + case *time.Time: + return &sql.NullTime{}, true, nil + case *string: + return &sql.NullString{}, true, nil + case *int, *int8, *int16, *int32: + return &sql.NullInt32{}, true, nil + case *int64: + return &sql.NullInt64{}, true, nil + case *uint, *uint8, *uint16, *uint32: + return &NullUint32{ + OriginalLocation: originalLocation, + ConvertedLocation: convertedLocation, + }, true, nil + case *uint64: + return &NullUint64{ + OriginalLocation: originalLocation, + ConvertedLocation: convertedLocation, + }, true, nil + case *float32, *float64: + return &sql.NullFloat64{}, true, nil + case *bool: + return &sql.NullBool{}, true, nil + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, + time.Time, + string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + float32, float64, + bool: + return nil, false, fmt.Errorf("unsupported scan type: %t", t) + case convert.Conversion: + return &sql.RawBytes{}, true, nil + } + + tp := reflect.TypeOf(bean).Elem() + switch tp.Kind() { + case reflect.String: + return &sql.NullString{}, true, nil + case reflect.Int64: + return &sql.NullInt64{}, true, nil + case reflect.Int32, reflect.Int, reflect.Int16, reflect.Int8: + return &sql.NullInt32{}, true, nil + case reflect.Uint64: + return &NullUint64{}, true, nil + case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8: + return &NullUint32{}, true, nil + default: + return nil, false, fmt.Errorf("unsupported type: %#v", bean) + } +} + +func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { + switch t := bean.(type) { + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, + *string, + *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, + *bool: + return t, false, nil + case *time.Time: + return &sql.NullTime{}, true, nil + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, + time.Time, + string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + bool: + return nil, false, fmt.Errorf("unsupported scan type: %t", t) + case convert.Conversion: + return &sql.RawBytes{}, true, nil + } + + tp := reflect.TypeOf(bean).Elem() + switch tp.Kind() { + case reflect.String: + return new(string), true, nil + case reflect.Int64: + return new(int64), true, nil + case reflect.Int32: + return new(int32), true, nil + case reflect.Int: + return new(int32), true, nil + case reflect.Int16: + return new(int32), true, nil + case reflect.Int8: + return new(int32), true, nil + case reflect.Uint64: + return new(uint64), true, nil + case reflect.Uint32: + return new(uint32), true, nil + case reflect.Uint: + return new(uint), true, nil + case reflect.Uint16: + return new(uint16), true, nil + case reflect.Uint8: + return new(uint8), true, nil + default: + return nil, false, fmt.Errorf("unsupported type: %#v", bean) + } +} + func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { var scanResults = make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { @@ -50,18 +159,97 @@ func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (ma return result, nil } -func row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { - results := make([]string, 0, len(fields)) - var scanResults = make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { +func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { + var scanResults = make([]interface{}, len(types)) + for i := 0; i < len(types); i++ { var s sql.NullString scanResults[i] = &s } - if err := rows.Scan(scanResults...); err != nil { + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResults...); err != nil { + return nil, err + } + return scanResults, nil +} + +// scan is a wrap of driver.Scan but will automatically change the input values according requirements +func (engine *Engine) scan(rows *core.Rows, types []*sql.ColumnType, vv ...interface{}) error { + var scanResults = make([]interface{}, 0, len(types)) + var replaces = make([]bool, 0, len(types)) + var err error + for _, v := range vv { + var replaced bool + var scanResult interface{} + if _, ok := v.(sql.Scanner); !ok { + var useNullable = true + if engine.driver.Features().SupportNullable { + nullable, ok := types[0].Nullable() + useNullable = ok && !nullable + } + + if useNullable { + scanResult, replaced, err = genScanResultsByBeanNullable(v, engine.DatabaseTZ, engine.TZLocation) + } else { + scanResult, replaced, err = genScanResultsByBean(v) + } + if err != nil { + return err + } + } else { + scanResult = v + } + scanResults = append(scanResults, scanResult) + replaces = append(replaces, replaced) + } + + var scanCtx = dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + } + + if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil { + return err + } + + for i, replaced := range replaces { + if replaced { + if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil { + return err + } + } + } + + return nil +} + +func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { + var scanResultContainers = make([]interface{}, len(types)) + for i := 0; i < len(types); i++ { + scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) + if err != nil { + return nil, err + } + scanResultContainers[i] = scanResult + } + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResultContainers...); err != nil { + return nil, err + } + return scanResultContainers, nil +} + +func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { + scanResults, err := engine.scanStringInterface(rows, types) + if err != nil { return nil, err } + var results = make([]string, 0, len(fields)) for i := 0; i < len(fields); i++ { results = append(results, scanResults[i].(*sql.NullString).String) } diff --git a/session_find.go b/session_find.go index 0daea005..261e6b7f 100644 --- a/session_find.go +++ b/session_find.go @@ -276,7 +276,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { cols := table.PKColumns() if len(cols) == 1 { - return convertAssign(dst, pk[0]) + return convertAssign(dst, pk[0], nil, nil) } dst = pk diff --git a/session_get.go b/session_get.go index e303176d..a84d3745 100644 --- a/session_get.go +++ b/session_get.go @@ -6,12 +6,16 @@ package xorm import ( "database/sql" + "database/sql/driver" "errors" "fmt" "reflect" "strconv" + "time" "xorm.io/xorm/caches" + "xorm.io/xorm/convert" + "xorm.io/xorm/core" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -108,6 +112,17 @@ func (session *Session) get(bean interface{}) (bool, error) { return true, nil } +var ( + valuerTypePlaceHolder driver.Valuer + valuerType = reflect.TypeOf(&valuerTypePlaceHolder).Elem() + + scannerTypePlaceHolder sql.Scanner + scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem() + + conversionTypePlaceHolder convert.Conversion + conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem() +) + func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { @@ -122,155 +137,161 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, return false, nil } - switch bean.(type) { - case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString: - return true, rows.Scan(&bean) - case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString: - return true, rows.Scan(bean) - case *string: - var res sql.NullString - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*string)) = res.String - } - return true, nil - case *int: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int)) = int(res.Int64) - } - return true, nil - case *int8: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int8)) = int8(res.Int64) - } - return true, nil - case *int16: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int16)) = int16(res.Int64) - } - return true, nil - case *int32: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int32)) = int32(res.Int64) - } - return true, nil - case *int64: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int64)) = int64(res.Int64) - } - return true, nil - case *uint: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint)) = uint(res.Int64) - } - return true, nil - case *uint8: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint8)) = uint8(res.Int64) - } - return true, nil - case *uint16: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint16)) = uint16(res.Int64) - } - return true, nil - case *uint32: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint32)) = uint32(res.Int64) - } - return true, nil - case *uint64: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint64)) = uint64(res.Int64) - } - return true, nil - case *bool: - var res sql.NullBool - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*bool)) = res.Bool - } - return true, nil + // WARN: Alougth rows return true, but we may also return error. + types, err := rows.ColumnTypes() + if err != nil { + return true, err + } + fields, err := rows.Columns() + if err != nil { + return true, err } - switch beanKind { case reflect.Struct: - fields, err := rows.Columns() - if err != nil { - // WARN: Alougth rows return true, but get fields failed - return true, err + if _, ok := bean.(*time.Time); ok { + break } - - scanResults, err := session.row2Slice(rows, fields, bean) - if err != nil { - return false, err + if _, ok := bean.(sql.Scanner); ok { + break } - // close it before convert data - rows.Close() - - dataStruct := utils.ReflectValue(bean) - _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) - if err != nil { - return true, err + if _, ok := bean.(convert.Conversion); len(types) == 1 && ok { + break } - - return true, session.executeProcessors() + return session.getStruct(rows, types, fields, table, bean) case reflect.Slice: - err = rows.ScanSlice(bean) + return session.getSlice(rows, types, fields, bean) case reflect.Map: - err = rows.ScanMap(bean) - case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - err = rows.Scan(bean) - default: - err = rows.Scan(bean) + return session.getMap(rows, types, fields, bean) } - return true, err + return session.getVars(rows, types, fields, bean) +} + +func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { + switch t := bean.(type) { + case *[]string: + res, err := session.engine.scanStringInterface(rows, types) + if err != nil { + return true, err + } + + var needAppend = len(*t) == 0 // both support slice is empty or has been initlized + for i, r := range res { + if needAppend { + *t = append(*t, r.(*sql.NullString).String) + } else { + (*t)[i] = r.(*sql.NullString).String + } + } + return true, nil + case *[]interface{}: + scanResults, err := session.engine.scanInterfaces(rows, types) + if err != nil { + return true, err + } + var needAppend = len(*t) == 0 + for ii := range fields { + s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) + if err != nil { + return true, err + } + if needAppend { + *t = append(*t, s) + } else { + (*t)[ii] = s + } + } + return true, nil + default: + return true, fmt.Errorf("unspoorted slice type: %t", t) + } +} + +func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { + switch t := bean.(type) { + case *map[string]string: + scanResults, err := session.engine.scanStringInterface(rows, types) + if err != nil { + return true, err + } + for ii, key := range fields { + (*t)[key] = scanResults[ii].(*sql.NullString).String + } + return true, nil + case *map[string]interface{}: + scanResults, err := session.engine.scanInterfaces(rows, types) + if err != nil { + return true, err + } + for ii, key := range fields { + s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) + if err != nil { + return true, err + } + (*t)[key] = s + } + return true, nil + default: + return true, fmt.Errorf("unspoorted map type: %t", t) + } +} + +func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields []string, beans ...interface{}) (bool, error) { + if len(beans) != len(types) { + return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans)) + } + var scanResults = make([]interface{}, 0, len(types)) + var replaceds = make([]bool, 0, len(types)) + for _, bean := range beans { + switch t := bean.(type) { + case sql.Scanner: + scanResults = append(scanResults, t) + replaceds = append(replaceds, false) + case convert.Conversion: + scanResults = append(scanResults, &sql.RawBytes{}) + replaceds = append(replaceds, true) + default: + scanResults = append(scanResults, bean) + replaceds = append(replaceds, false) + } + } + + err := session.engine.scan(rows, types, scanResults...) + if err != nil { + return true, err + } + for i, replaced := range replaceds { + if replaced { + err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation) + if err != nil { + return true, err + } + } + } + return true, nil +} + +func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { + fields, err := rows.Columns() + if err != nil { + // WARN: Alougth rows return true, but get fields failed + return true, err + } + + scanResults, err := session.row2Slice(rows, fields, bean) + if err != nil { + return false, err + } + // close it before convert data + rows.Close() + + dataStruct := utils.ReflectValue(bean) + _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) + if err != nil { + return true, err + } + + return true, session.executeProcessors() } func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { diff --git a/session_insert.go b/session_insert.go index e733e06e..7f8f3008 100644 --- a/session_insert.go +++ b/session_insert.go @@ -375,7 +375,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(aiValue.Addr(), id) + return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) @@ -415,7 +415,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(aiValue.Addr(), id) + return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) } res, err := session.exec(sqlStr, args...) @@ -455,7 +455,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - if err := convertAssignV(aiValue.Addr(), id); err != nil { + if err := convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation); err != nil { return 0, err } diff --git a/session_query.go b/session_query.go index 01cd6f44..fa33496d 100644 --- a/session_query.go +++ b/session_query.go @@ -54,7 +54,7 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri } for rows.Next() { - record, err := row2sliceStr(rows, types, fields) + record, err := session.engine.row2sliceStr(rows, types, fields) if err != nil { return nil, err }