From d508e86ddba2e7a8c405a761760c69d133a3d521 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 17 Jul 2021 18:07:33 +0800 Subject: [PATCH] Fix tests --- engine.go | 7 ++++++- scan.go | 22 +++++++--------------- session_get.go | 8 ++++---- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/engine.go b/engine.go index d3ee8a8c..b4ef9593 100644 --- a/engine.go +++ b/engine.go @@ -543,6 +543,11 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } + fields, err := rows.Columns() + if err != nil { + return err + } + sess := engine.NewSession() defer sess.Close() for rows.Next() { @@ -551,7 +556,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } - scanResults, err := sess.engine.scanStringInterface(rows, types) + scanResults, err := sess.engine.scanStringInterface(rows, fields, types) if err != nil { return err } diff --git a/scan.go b/scan.go index 3896d459..d8a1ac3d 100644 --- a/scan.go +++ b/scan.go @@ -72,6 +72,7 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { switch t := bean.(type) { case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, + *sql.RawBytes, *string, *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, @@ -175,17 +176,14 @@ func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (ma return result, nil } -func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { +func (engine *Engine) scanStringInterface(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) { var scanResults = make([]interface{}, len(types)) for i := 0; i < len(types); i++ { var s sql.NullString scanResults[i] = &s } - if err := engine.driver.Scan(&dialects.ScanContext{ - DBLocation: engine.DatabaseTZ, - UserLocation: engine.TZLocation, - }, rows, types, scanResults...); err != nil { + if err := engine.scan(rows, fields, types, scanResults...); err != nil { return nil, err } return scanResults, nil @@ -246,7 +244,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column return nil } -func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { +func (engine *Engine) scanInterfaces(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) { var scanResultContainers = make([]interface{}, len(types)) for i := 0; i < len(types); i++ { scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) @@ -255,17 +253,14 @@ func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ( } scanResultContainers[i] = scanResult } - if err := engine.driver.Scan(&dialects.ScanContext{ - DBLocation: engine.DatabaseTZ, - UserLocation: engine.TZLocation, - }, rows, types, scanResultContainers...); err != nil { + if err := engine.scan(rows, fields, types, scanResultContainers...); err != nil { return nil, err } return scanResultContainers, nil } func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { - scanResults, err := engine.scanStringInterface(rows, types) + scanResults, err := engine.scanStringInterface(rows, fields, types) if err != nil { return nil, err } @@ -307,10 +302,7 @@ func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, } scanResultContainers[i] = scanResult } - if err := engine.driver.Scan(&dialects.ScanContext{ - DBLocation: engine.DatabaseTZ, - UserLocation: engine.TZLocation, - }, rows, types, scanResultContainers...); err != nil { + if err := engine.scan(rows, fields, types, scanResultContainers...); err != nil { return nil, err } diff --git a/session_get.go b/session_get.go index cc6427d7..96b1ee87 100644 --- a/session_get.go +++ b/session_get.go @@ -192,7 +192,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { switch t := bean.(type) { case *[]string: - res, err := session.engine.scanStringInterface(rows, types) + res, err := session.engine.scanStringInterface(rows, fields, types) if err != nil { return true, err } @@ -207,7 +207,7 @@ func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, field } return true, nil case *[]interface{}: - scanResults, err := session.engine.scanInterfaces(rows, types) + scanResults, err := session.engine.scanInterfaces(rows, fields, types) if err != nil { return true, err } @@ -232,7 +232,7 @@ func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, field func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { switch t := bean.(type) { case *map[string]string: - scanResults, err := session.engine.scanStringInterface(rows, types) + scanResults, err := session.engine.scanStringInterface(rows, fields, types) if err != nil { return true, err } @@ -241,7 +241,7 @@ func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields } return true, nil case *map[string]interface{}: - scanResults, err := session.engine.scanInterfaces(rows, types) + scanResults, err := session.engine.scanInterfaces(rows, fields, types) if err != nil { return true, err }