improve code
This commit is contained in:
parent
2acd543562
commit
257fea90c6
2
rows.go
2
rows.go
|
@ -135,7 +135,7 @@ func (rows *Rows) Scan(bean interface{}) error {
|
|||
return err
|
||||
}
|
||||
|
||||
scanResults, err := rows.session.row2Slice(rows.rows, types, fields, bean)
|
||||
scanResults, err := rows.session.row2Slice(rows.rows, types, fields, bean, rows.session.statement.RefTable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
45
scan.go
45
scan.go
|
@ -13,6 +13,7 @@ import (
|
|||
"xorm.io/xorm/convert"
|
||||
"xorm.io/xorm/core"
|
||||
"xorm.io/xorm/dialects"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
// genScanResultsByBeanNullabale generates scan result
|
||||
|
@ -120,6 +121,19 @@ func genScanResultsByBean(bean interface{}) (interface{}, bool, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// genRowsScanResults generating scan results according column types
|
||||
func genRowsScanResults(driver dialects.Driver, rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) {
|
||||
var scanResults = make([]interface{}, len(types))
|
||||
var err error
|
||||
for i, t := range types {
|
||||
scanResults[i], err = driver.GenScanResult(t.DatabaseTypeName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return scanResults, nil
|
||||
}
|
||||
|
||||
func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) {
|
||||
var scanResults = make([]interface{}, len(fields))
|
||||
for i := 0; i < len(fields); i++ {
|
||||
|
@ -142,6 +156,37 @@ func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func genColScanResult(driver dialects.Driver, fieldType reflect.Type, columnType *sql.ColumnType) (interface{}, error) {
|
||||
if fieldType.Implements(scannerType) || fieldType.Implements(conversionType) {
|
||||
return &sql.RawBytes{}, nil
|
||||
}
|
||||
switch fieldType.Kind() {
|
||||
case reflect.Ptr:
|
||||
return genColScanResult(driver, fieldType.Elem(), columnType)
|
||||
case reflect.Array, reflect.Slice:
|
||||
return &sql.RawBytes{}, nil
|
||||
default:
|
||||
return driver.GenScanResult(columnType.DatabaseTypeName())
|
||||
}
|
||||
}
|
||||
|
||||
func genScanResults(driver dialects.Driver, types []*sql.ColumnType, fields []string, table *schemas.Table) ([]interface{}, error) {
|
||||
var scanResults = make([]interface{}, 0, len(types))
|
||||
for i, tp := range types {
|
||||
col := table.GetColumn(fields[i])
|
||||
if col == nil {
|
||||
scanResults = append(scanResults, &sql.RawBytes{})
|
||||
continue
|
||||
}
|
||||
scanResult, err := genColScanResult(driver, col.Type, tp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scanResults = append(scanResults, scanResult)
|
||||
}
|
||||
return scanResults, nil
|
||||
}
|
||||
|
||||
func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string][]byte, error) {
|
||||
var scanResults = make([]interface{}, len(fields))
|
||||
for i := 0; i < len(fields); i++ {
|
||||
|
|
28
session.go
28
session.go
|
@ -399,7 +399,7 @@ func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fie
|
|||
dataStruct := newValue.Elem()
|
||||
|
||||
// handle beforeClosures
|
||||
scanResults, err := session.row2Slice(rows, types, fields, bean)
|
||||
scanResults, err := session.row2Slice(rows, types, fields, bean, table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -418,24 +418,12 @@ func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fie
|
|||
return nil
|
||||
}
|
||||
|
||||
func (session *Session) genScanResultsByTypes(types []*sql.ColumnType) ([]interface{}, error) {
|
||||
scanResults := make([]interface{}, len(types))
|
||||
for i := 0; i < len(types); i++ {
|
||||
result, err := session.engine.driver.GenScanResult(types[i].DatabaseTypeName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scanResults[i] = result
|
||||
}
|
||||
return scanResults, nil
|
||||
}
|
||||
|
||||
func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) ([]interface{}, error) {
|
||||
func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}, table *schemas.Table) ([]interface{}, error) {
|
||||
for _, closure := range session.beforeClosures {
|
||||
closure(bean)
|
||||
}
|
||||
|
||||
scanResults, err := session.genScanResultsByTypes(types)
|
||||
scanResults, err := genScanResults(session.engine.driver, types, fields, table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -608,16 +596,6 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
|
|||
}
|
||||
return nil
|
||||
case reflect.Slice, reflect.Array:
|
||||
switch t := src.(type) {
|
||||
case *sql.NullString:
|
||||
hasAssigned = true
|
||||
fmt.Printf("====== %#v <-------- %#v \n", fieldValue.Interface(), t)
|
||||
if t.Valid {
|
||||
if fieldType.Elem().Kind() == reflect.Uint8 {
|
||||
fieldValue.SetBytes([]byte(t.String))
|
||||
}
|
||||
}
|
||||
}
|
||||
switch rawValueType.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
switch rawValueType.Elem().Kind() {
|
||||
|
|
|
@ -243,13 +243,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
|
|||
}
|
||||
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
var newValue = newElemFunc(fields)
|
||||
dataStruct := utils.ReflectValue(newValue.Interface())
|
||||
tb, err := session.engine.tagParser.ParseWithCache(dataStruct)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = session.rows2Beans(rows, types, fields, tb, newElemFunc, containerValueSetFunc)
|
||||
err = session.rows2Beans(rows, types, fields, table, newElemFunc, containerValueSetFunc)
|
||||
rows.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -301,35 +301,10 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields
|
|||
}
|
||||
|
||||
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {
|
||||
var scanResults = make([]interface{}, 0, len(types))
|
||||
for i, tp := range types {
|
||||
col := table.GetColumn(fields[i])
|
||||
if col == nil {
|
||||
return true, fmt.Errorf("cannot find column named %v from columns %v", fields[i], table.ColumnsSeq())
|
||||
}
|
||||
if col.Type.Implements(scannerType) {
|
||||
scanResults = append(scanResults, &sql.RawBytes{})
|
||||
} else if col.Type.Implements(conversionType) {
|
||||
scanResults = append(scanResults, &sql.RawBytes{})
|
||||
} else {
|
||||
v, err := session.engine.driver.GenScanResult(tp.DatabaseTypeName())
|
||||
scanResults, err := session.row2Slice(rows, types, fields, bean, table)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
scanResults = append(scanResults, v)
|
||||
}
|
||||
}
|
||||
|
||||
for _, closure := range session.beforeClosures {
|
||||
closure(bean)
|
||||
}
|
||||
|
||||
err := session.engine.scan(rows, types, scanResults...)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
executeBeforeSet(bean, fields, scanResults)
|
||||
|
||||
// close it before convert data
|
||||
rows.Close()
|
||||
|
|
Loading…
Reference in New Issue