refactor more

This commit is contained in:
Lunny Xiao 2021-07-16 11:00:45 +08:00
parent 693d25be0e
commit 754998d8fc
5 changed files with 30 additions and 13 deletions

View File

@ -651,6 +651,10 @@ func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, conver
bs = []byte(t)
case []byte:
bs = t
case *sql.NullString:
if t.Valid {
bs = []byte(t.String)
}
}
if bs != nil {

View File

@ -125,12 +125,17 @@ func (rows *Rows) Scan(bean interface{}) error {
return err
}
types, err := rows.rows.ColumnTypes()
if err != nil {
return err
}
fields, err := rows.rows.Columns()
if err != nil {
return err
}
scanResults, err := rows.session.row2Slice(rows.rows, fields, bean)
scanResults, err := rows.session.row2Slice(rows.rows, types, fields, bean)
if err != nil {
return err
}

View File

@ -389,7 +389,7 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s
// Cell cell is a result of one column field
type Cell *interface{}
func (session *Session) rows2Beans(rows *core.Rows, fields []string,
func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fields []string,
table *schemas.Table, newElemFunc func([]string) reflect.Value,
sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error {
for rows.Next() {
@ -398,7 +398,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string,
dataStruct := newValue.Elem()
// handle beforeClosures
scanResults, err := session.row2Slice(rows, fields, bean)
scanResults, err := session.row2Slice(rows, types, fields, bean)
if err != nil {
return err
}
@ -417,15 +417,18 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string,
return nil
}
func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) {
func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) ([]interface{}, error) {
for _, closure := range session.beforeClosures {
closure(bean)
}
scanResults := make([]interface{}, len(fields))
var scanResults = make([]interface{}, len(fields))
var err error
for i := 0; i < len(fields); i++ {
var cell interface{}
scanResults[i] = &cell
scanResults[i], err = session.engine.driver.GenScanResult(types[i].DatabaseTypeName())
if err != nil {
return nil, err
}
}
if err := rows.Scan(scanResults...); err != nil {
return nil, err
@ -468,8 +471,7 @@ func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Typ
return nil
}
func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value,
scanResult interface{}, table *schemas.Table) error {
func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, scanResult interface{}) error {
v, ok := scanResult.(*interface{})
if ok {
scanResult = *v
@ -525,7 +527,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
} else {
e = fieldValue.Elem()
}
if err := session.convertBeanField(col, &e, scanResult, table); err != nil {
if err := session.convertBeanField(col, &e, scanResult); err != nil {
return err
}
if fieldValue.IsNil() {
@ -554,6 +556,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
fieldValue.Set(x.Elem())
return nil
case reflect.Slice, reflect.Array:
fmt.Printf("======%T\n", scanResult)
switch rawValueType.Elem().Kind() {
case reflect.Uint8:
if fieldType.Elem().Kind() == reflect.Uint8 {
@ -711,7 +714,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
}
var scanResult = scanResults[ii]
if err := session.convertBeanField(col, fieldValue, scanResult, table); err != nil {
if err := session.convertBeanField(col, fieldValue, scanResult); err != nil {
return nil, err
}
if col.IsPrimaryKey {

View File

@ -167,6 +167,11 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
}
defer rows.Close()
types, err := rows.ColumnTypes()
if err != nil {
return err
}
fields, err := rows.Columns()
if err != nil {
return err
@ -241,7 +246,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
if err != nil {
return err
}
err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc)
err = session.rows2Beans(rows, types, fields, tb, newElemFunc, containerValueSetFunc)
rows.Close()
if err != nil {
return err

View File

@ -268,7 +268,7 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields
}
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {
scanResults, err := session.row2Slice(rows, fields, bean)
scanResults, err := session.row2Slice(rows, types, fields, bean)
if err != nil {
return false, err
}