From 754998d8fcf15a6480c8558e476aecab7796a07d Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 16 Jul 2021 11:00:45 +0800 Subject: [PATCH] refactor more --- convert.go | 4 ++++ rows.go | 7 ++++++- session.go | 23 +++++++++++++---------- session_find.go | 7 ++++++- session_get.go | 2 +- 5 files changed, 30 insertions(+), 13 deletions(-) diff --git a/convert.go b/convert.go index 6886ccc3..116dd783 100644 --- a/convert.go +++ b/convert.go @@ -651,6 +651,10 @@ func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, conver bs = []byte(t) case []byte: bs = t + case *sql.NullString: + if t.Valid { + bs = []byte(t.String) + } } if bs != nil { diff --git a/rows.go b/rows.go index a56ea1c9..fbdcf422 100644 --- a/rows.go +++ b/rows.go @@ -125,12 +125,17 @@ func (rows *Rows) Scan(bean interface{}) error { return err } + types, err := rows.rows.ColumnTypes() + if err != nil { + return err + } + fields, err := rows.rows.Columns() if err != nil { return err } - scanResults, err := rows.session.row2Slice(rows.rows, fields, bean) + scanResults, err := rows.session.row2Slice(rows.rows, types, fields, bean) if err != nil { return err } diff --git a/session.go b/session.go index 2b410e85..b3e7d0ab 100644 --- a/session.go +++ b/session.go @@ -389,7 +389,7 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s // Cell cell is a result of one column field type Cell *interface{} -func (session *Session) rows2Beans(rows *core.Rows, fields []string, +func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, newElemFunc func([]string) reflect.Value, sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { for rows.Next() { @@ -398,7 +398,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, dataStruct := newValue.Elem() // handle beforeClosures - scanResults, err := session.row2Slice(rows, fields, bean) + scanResults, err := session.row2Slice(rows, types, fields, bean) if err != nil { return err } @@ -417,15 +417,18 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, return nil } -func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) { +func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) ([]interface{}, error) { for _, closure := range session.beforeClosures { closure(bean) } - scanResults := make([]interface{}, len(fields)) + var scanResults = make([]interface{}, len(fields)) + var err error for i := 0; i < len(fields); i++ { - var cell interface{} - scanResults[i] = &cell + scanResults[i], err = session.engine.driver.GenScanResult(types[i].DatabaseTypeName()) + if err != nil { + return nil, err + } } if err := rows.Scan(scanResults...); err != nil { return nil, err @@ -468,8 +471,7 @@ func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Typ return nil } -func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, - scanResult interface{}, table *schemas.Table) error { +func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, scanResult interface{}) error { v, ok := scanResult.(*interface{}) if ok { scanResult = *v @@ -525,7 +527,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } else { e = fieldValue.Elem() } - if err := session.convertBeanField(col, &e, scanResult, table); err != nil { + if err := session.convertBeanField(col, &e, scanResult); err != nil { return err } if fieldValue.IsNil() { @@ -554,6 +556,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec fieldValue.Set(x.Elem()) return nil case reflect.Slice, reflect.Array: + fmt.Printf("======%T\n", scanResult) switch rawValueType.Elem().Kind() { case reflect.Uint8: if fieldType.Elem().Kind() == reflect.Uint8 { @@ -711,7 +714,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } var scanResult = scanResults[ii] - if err := session.convertBeanField(col, fieldValue, scanResult, table); err != nil { + if err := session.convertBeanField(col, fieldValue, scanResult); err != nil { return nil, err } if col.IsPrimaryKey { diff --git a/session_find.go b/session_find.go index 261e6b7f..4ce6fa71 100644 --- a/session_find.go +++ b/session_find.go @@ -167,6 +167,11 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect } defer rows.Close() + types, err := rows.ColumnTypes() + if err != nil { + return err + } + fields, err := rows.Columns() if err != nil { return err @@ -241,7 +246,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect if err != nil { return err } - err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc) + err = session.rows2Beans(rows, types, fields, tb, newElemFunc, containerValueSetFunc) rows.Close() if err != nil { return err diff --git a/session_get.go b/session_get.go index cc6427d7..d2d5057d 100644 --- a/session_get.go +++ b/session_get.go @@ -268,7 +268,7 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields } func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { - scanResults, err := session.row2Slice(rows, fields, bean) + scanResults, err := session.row2Slice(rows, types, fields, bean) if err != nil { return false, err }