improve code

This commit is contained in:
Lunny Xiao 2021-06-25 17:24:32 +08:00
parent 2acd543562
commit 257fea90c6
5 changed files with 51 additions and 59 deletions

View File

@ -135,7 +135,7 @@ func (rows *Rows) Scan(bean interface{}) error {
return err
}
scanResults, err := rows.session.row2Slice(rows.rows, types, fields, bean)
scanResults, err := rows.session.row2Slice(rows.rows, types, fields, bean, rows.session.statement.RefTable)
if err != nil {
return err
}

45
scan.go
View File

@ -13,6 +13,7 @@ import (
"xorm.io/xorm/convert"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/schemas"
)
// genScanResultsByBeanNullabale generates scan result
@ -120,6 +121,19 @@ func genScanResultsByBean(bean interface{}) (interface{}, bool, error) {
}
}
// genRowsScanResults generating scan results according column types
func genRowsScanResults(driver dialects.Driver, rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) {
var scanResults = make([]interface{}, len(types))
var err error
for i, t := range types {
scanResults[i], err = driver.GenScanResult(t.DatabaseTypeName())
if err != nil {
return nil, err
}
}
return scanResults, nil
}
func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) {
var scanResults = make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
@ -142,6 +156,37 @@ func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[
return result, nil
}
func genColScanResult(driver dialects.Driver, fieldType reflect.Type, columnType *sql.ColumnType) (interface{}, error) {
if fieldType.Implements(scannerType) || fieldType.Implements(conversionType) {
return &sql.RawBytes{}, nil
}
switch fieldType.Kind() {
case reflect.Ptr:
return genColScanResult(driver, fieldType.Elem(), columnType)
case reflect.Array, reflect.Slice:
return &sql.RawBytes{}, nil
default:
return driver.GenScanResult(columnType.DatabaseTypeName())
}
}
func genScanResults(driver dialects.Driver, types []*sql.ColumnType, fields []string, table *schemas.Table) ([]interface{}, error) {
var scanResults = make([]interface{}, 0, len(types))
for i, tp := range types {
col := table.GetColumn(fields[i])
if col == nil {
scanResults = append(scanResults, &sql.RawBytes{})
continue
}
scanResult, err := genColScanResult(driver, col.Type, tp)
if err != nil {
return nil, err
}
scanResults = append(scanResults, scanResult)
}
return scanResults, nil
}
func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string][]byte, error) {
var scanResults = make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {

View File

@ -399,7 +399,7 @@ func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fie
dataStruct := newValue.Elem()
// handle beforeClosures
scanResults, err := session.row2Slice(rows, types, fields, bean)
scanResults, err := session.row2Slice(rows, types, fields, bean, table)
if err != nil {
return err
}
@ -418,24 +418,12 @@ func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fie
return nil
}
func (session *Session) genScanResultsByTypes(types []*sql.ColumnType) ([]interface{}, error) {
scanResults := make([]interface{}, len(types))
for i := 0; i < len(types); i++ {
result, err := session.engine.driver.GenScanResult(types[i].DatabaseTypeName())
if err != nil {
return nil, err
}
scanResults[i] = result
}
return scanResults, nil
}
func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) ([]interface{}, error) {
func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}, table *schemas.Table) ([]interface{}, error) {
for _, closure := range session.beforeClosures {
closure(bean)
}
scanResults, err := session.genScanResultsByTypes(types)
scanResults, err := genScanResults(session.engine.driver, types, fields, table)
if err != nil {
return nil, err
}
@ -608,16 +596,6 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
}
return nil
case reflect.Slice, reflect.Array:
switch t := src.(type) {
case *sql.NullString:
hasAssigned = true
fmt.Printf("====== %#v <-------- %#v \n", fieldValue.Interface(), t)
if t.Valid {
if fieldType.Elem().Kind() == reflect.Uint8 {
fieldValue.SetBytes([]byte(t.String))
}
}
}
switch rawValueType.Kind() {
case reflect.Slice, reflect.Array:
switch rawValueType.Elem().Kind() {

View File

@ -243,13 +243,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
}
if elemType.Kind() == reflect.Struct {
var newValue = newElemFunc(fields)
dataStruct := utils.ReflectValue(newValue.Interface())
tb, err := session.engine.tagParser.ParseWithCache(dataStruct)
if err != nil {
return err
}
err = session.rows2Beans(rows, types, fields, tb, newElemFunc, containerValueSetFunc)
err = session.rows2Beans(rows, types, fields, table, newElemFunc, containerValueSetFunc)
rows.Close()
if err != nil {
return err

View File

@ -301,35 +301,10 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields
}
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {
var scanResults = make([]interface{}, 0, len(types))
for i, tp := range types {
col := table.GetColumn(fields[i])
if col == nil {
return true, fmt.Errorf("cannot find column named %v from columns %v", fields[i], table.ColumnsSeq())
}
if col.Type.Implements(scannerType) {
scanResults = append(scanResults, &sql.RawBytes{})
} else if col.Type.Implements(conversionType) {
scanResults = append(scanResults, &sql.RawBytes{})
} else {
v, err := session.engine.driver.GenScanResult(tp.DatabaseTypeName())
scanResults, err := session.row2Slice(rows, types, fields, bean, table)
if err != nil {
return true, err
}
scanResults = append(scanResults, v)
}
}
for _, closure := range session.beforeClosures {
closure(bean)
}
err := session.engine.scan(rows, types, scanResults...)
if err != nil {
return true, err
}
executeBeforeSet(bean, fields, scanResults)
// close it before convert data
rows.Close()