From 257fea90c6ca3a7f53350b375c2d4a36b8d48c08 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 25 Jun 2021 17:24:32 +0800 Subject: [PATCH] improve code --- rows.go | 2 +- scan.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ session.go | 28 +++------------------------- session_find.go | 8 +------- session_get.go | 27 +-------------------------- 5 files changed, 51 insertions(+), 59 deletions(-) diff --git a/rows.go b/rows.go index dbae9862..a041f74d 100644 --- a/rows.go +++ b/rows.go @@ -135,7 +135,7 @@ func (rows *Rows) Scan(bean interface{}) error { return err } - scanResults, err := rows.session.row2Slice(rows.rows, types, fields, bean) + scanResults, err := rows.session.row2Slice(rows.rows, types, fields, bean, rows.session.statement.RefTable) if err != nil { return err } diff --git a/scan.go b/scan.go index 038c0488..14b17cee 100644 --- a/scan.go +++ b/scan.go @@ -13,6 +13,7 @@ import ( "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/dialects" + "xorm.io/xorm/schemas" ) // genScanResultsByBeanNullabale generates scan result @@ -120,6 +121,19 @@ func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { } } +// genRowsScanResults generating scan results according column types +func genRowsScanResults(driver dialects.Driver, rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { + var scanResults = make([]interface{}, len(types)) + var err error + for i, t := range types { + scanResults[i], err = driver.GenScanResult(t.DatabaseTypeName()) + if err != nil { + return nil, err + } + } + return scanResults, nil +} + func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { var scanResults = make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { @@ -142,6 +156,37 @@ func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[ return result, nil } +func genColScanResult(driver dialects.Driver, fieldType reflect.Type, columnType *sql.ColumnType) (interface{}, error) { + if fieldType.Implements(scannerType) || fieldType.Implements(conversionType) { + return &sql.RawBytes{}, nil + } + switch fieldType.Kind() { + case reflect.Ptr: + return genColScanResult(driver, fieldType.Elem(), columnType) + case reflect.Array, reflect.Slice: + return &sql.RawBytes{}, nil + default: + return driver.GenScanResult(columnType.DatabaseTypeName()) + } +} + +func genScanResults(driver dialects.Driver, types []*sql.ColumnType, fields []string, table *schemas.Table) ([]interface{}, error) { + var scanResults = make([]interface{}, 0, len(types)) + for i, tp := range types { + col := table.GetColumn(fields[i]) + if col == nil { + scanResults = append(scanResults, &sql.RawBytes{}) + continue + } + scanResult, err := genColScanResult(driver, col.Type, tp) + if err != nil { + return nil, err + } + scanResults = append(scanResults, scanResult) + } + return scanResults, nil +} + func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string][]byte, error) { var scanResults = make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { diff --git a/session.go b/session.go index cb2d4694..ef4ab6b2 100644 --- a/session.go +++ b/session.go @@ -399,7 +399,7 @@ func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fie dataStruct := newValue.Elem() // handle beforeClosures - scanResults, err := session.row2Slice(rows, types, fields, bean) + scanResults, err := session.row2Slice(rows, types, fields, bean, table) if err != nil { return err } @@ -418,24 +418,12 @@ func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fie return nil } -func (session *Session) genScanResultsByTypes(types []*sql.ColumnType) ([]interface{}, error) { - scanResults := make([]interface{}, len(types)) - for i := 0; i < len(types); i++ { - result, err := session.engine.driver.GenScanResult(types[i].DatabaseTypeName()) - if err != nil { - return nil, err - } - scanResults[i] = result - } - return scanResults, nil -} - -func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) ([]interface{}, error) { +func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}, table *schemas.Table) ([]interface{}, error) { for _, closure := range session.beforeClosures { closure(bean) } - scanResults, err := session.genScanResultsByTypes(types) + scanResults, err := genScanResults(session.engine.driver, types, fields, table) if err != nil { return nil, err } @@ -608,16 +596,6 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri } return nil case reflect.Slice, reflect.Array: - switch t := src.(type) { - case *sql.NullString: - hasAssigned = true - fmt.Printf("====== %#v <-------- %#v \n", fieldValue.Interface(), t) - if t.Valid { - if fieldType.Elem().Kind() == reflect.Uint8 { - fieldValue.SetBytes([]byte(t.String)) - } - } - } switch rawValueType.Kind() { case reflect.Slice, reflect.Array: switch rawValueType.Elem().Kind() { diff --git a/session_find.go b/session_find.go index cead02ed..3994da97 100644 --- a/session_find.go +++ b/session_find.go @@ -243,13 +243,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect } if elemType.Kind() == reflect.Struct { - var newValue = newElemFunc(fields) - dataStruct := utils.ReflectValue(newValue.Interface()) - tb, err := session.engine.tagParser.ParseWithCache(dataStruct) - if err != nil { - return err - } - err = session.rows2Beans(rows, types, fields, tb, newElemFunc, containerValueSetFunc) + err = session.rows2Beans(rows, types, fields, table, newElemFunc, containerValueSetFunc) rows.Close() if err != nil { return err diff --git a/session_get.go b/session_get.go index 556a99f6..db586c04 100644 --- a/session_get.go +++ b/session_get.go @@ -301,36 +301,11 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields } func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { - var scanResults = make([]interface{}, 0, len(types)) - for i, tp := range types { - col := table.GetColumn(fields[i]) - if col == nil { - return true, fmt.Errorf("cannot find column named %v from columns %v", fields[i], table.ColumnsSeq()) - } - if col.Type.Implements(scannerType) { - scanResults = append(scanResults, &sql.RawBytes{}) - } else if col.Type.Implements(conversionType) { - scanResults = append(scanResults, &sql.RawBytes{}) - } else { - v, err := session.engine.driver.GenScanResult(tp.DatabaseTypeName()) - if err != nil { - return true, err - } - scanResults = append(scanResults, v) - } - } - - for _, closure := range session.beforeClosures { - closure(bean) - } - - err := session.engine.scan(rows, types, scanResults...) + scanResults, err := session.row2Slice(rows, types, fields, bean, table) if err != nil { return true, err } - executeBeforeSet(bean, fields, scanResults) - // close it before convert data rows.Close()