refactor more

This commit is contained in:
Lunny Xiao 2021-07-16 11:00:45 +08:00
parent 693d25be0e
commit 754998d8fc
5 changed files with 30 additions and 13 deletions

View File

@ -651,6 +651,10 @@ func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, conver
bs = []byte(t) bs = []byte(t)
case []byte: case []byte:
bs = t bs = t
case *sql.NullString:
if t.Valid {
bs = []byte(t.String)
}
} }
if bs != nil { if bs != nil {

View File

@ -125,12 +125,17 @@ func (rows *Rows) Scan(bean interface{}) error {
return err return err
} }
types, err := rows.rows.ColumnTypes()
if err != nil {
return err
}
fields, err := rows.rows.Columns() fields, err := rows.rows.Columns()
if err != nil { if err != nil {
return err return err
} }
scanResults, err := rows.session.row2Slice(rows.rows, fields, bean) scanResults, err := rows.session.row2Slice(rows.rows, types, fields, bean)
if err != nil { if err != nil {
return err return err
} }

View File

@ -389,7 +389,7 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s
// Cell cell is a result of one column field // Cell cell is a result of one column field
type Cell *interface{} type Cell *interface{}
func (session *Session) rows2Beans(rows *core.Rows, fields []string, func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fields []string,
table *schemas.Table, newElemFunc func([]string) reflect.Value, table *schemas.Table, newElemFunc func([]string) reflect.Value,
sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error {
for rows.Next() { for rows.Next() {
@ -398,7 +398,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string,
dataStruct := newValue.Elem() dataStruct := newValue.Elem()
// handle beforeClosures // handle beforeClosures
scanResults, err := session.row2Slice(rows, fields, bean) scanResults, err := session.row2Slice(rows, types, fields, bean)
if err != nil { if err != nil {
return err return err
} }
@ -417,15 +417,18 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string,
return nil return nil
} }
func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) { func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) ([]interface{}, error) {
for _, closure := range session.beforeClosures { for _, closure := range session.beforeClosures {
closure(bean) closure(bean)
} }
scanResults := make([]interface{}, len(fields)) var scanResults = make([]interface{}, len(fields))
var err error
for i := 0; i < len(fields); i++ { for i := 0; i < len(fields); i++ {
var cell interface{} scanResults[i], err = session.engine.driver.GenScanResult(types[i].DatabaseTypeName())
scanResults[i] = &cell if err != nil {
return nil, err
}
} }
if err := rows.Scan(scanResults...); err != nil { if err := rows.Scan(scanResults...); err != nil {
return nil, err return nil, err
@ -468,8 +471,7 @@ func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Typ
return nil return nil
} }
func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, scanResult interface{}) error {
scanResult interface{}, table *schemas.Table) error {
v, ok := scanResult.(*interface{}) v, ok := scanResult.(*interface{})
if ok { if ok {
scanResult = *v scanResult = *v
@ -525,7 +527,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
} else { } else {
e = fieldValue.Elem() e = fieldValue.Elem()
} }
if err := session.convertBeanField(col, &e, scanResult, table); err != nil { if err := session.convertBeanField(col, &e, scanResult); err != nil {
return err return err
} }
if fieldValue.IsNil() { if fieldValue.IsNil() {
@ -554,6 +556,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
return nil return nil
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
fmt.Printf("======%T\n", scanResult)
switch rawValueType.Elem().Kind() { switch rawValueType.Elem().Kind() {
case reflect.Uint8: case reflect.Uint8:
if fieldType.Elem().Kind() == reflect.Uint8 { if fieldType.Elem().Kind() == reflect.Uint8 {
@ -711,7 +714,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
} }
var scanResult = scanResults[ii] var scanResult = scanResults[ii]
if err := session.convertBeanField(col, fieldValue, scanResult, table); err != nil { if err := session.convertBeanField(col, fieldValue, scanResult); err != nil {
return nil, err return nil, err
} }
if col.IsPrimaryKey { if col.IsPrimaryKey {

View File

@ -167,6 +167,11 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
} }
defer rows.Close() defer rows.Close()
types, err := rows.ColumnTypes()
if err != nil {
return err
}
fields, err := rows.Columns() fields, err := rows.Columns()
if err != nil { if err != nil {
return err return err
@ -241,7 +246,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
if err != nil { if err != nil {
return err return err
} }
err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc) err = session.rows2Beans(rows, types, fields, tb, newElemFunc, containerValueSetFunc)
rows.Close() rows.Close()
if err != nil { if err != nil {
return err return err

View File

@ -268,7 +268,7 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields
} }
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {
scanResults, err := session.row2Slice(rows, fields, bean) scanResults, err := session.row2Slice(rows, types, fields, bean)
if err != nil { if err != nil {
return false, err return false, err
} }