Improve code

This commit is contained in:
Lunny Xiao 2021-06-25 22:15:49 +08:00
parent 004408dd44
commit f22f863fc7
3 changed files with 93 additions and 23 deletions

13
scan.go
View File

@ -170,6 +170,19 @@ func genScanResult(driver dialects.Driver, fieldType reflect.Type, columnType *s
} }
} }
// genScanResults generating scan results according column types
func genScanResults(driver dialects.Driver, types []*sql.ColumnType) ([]interface{}, error) {
var scanResults = make([]interface{}, len(types))
var err error
for i, t := range types {
scanResults[i], err = driver.GenScanResult(t.DatabaseTypeName())
if err != nil {
return nil, err
}
}
return scanResults, nil
}
func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fields []string, table *schemas.Table) ([]interface{}, error) { func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fields []string, table *schemas.Table) ([]interface{}, error) {
var scanResults = make([]interface{}, 0, len(types)) var scanResults = make([]interface{}, 0, len(types))
for i, tp := range types { for i, tp := range types {

View File

@ -452,9 +452,9 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
return nil return nil
} }
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { /*if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem())) fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
} }*/
fmt.Printf("----- %v <------ %v \n", fieldValue.Type(), rawValue.Type()) fmt.Printf("----- %v <------ %v \n", fieldValue.Type(), rawValue.Type())
if fieldValue.Type() == rawValue.Type() { if fieldValue.Type() == rawValue.Type() {
@ -464,8 +464,16 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
fmt.Printf("%s, ===========00000000000 %#v <----- %#v \n", columnName, fieldValue.Addr().Interface(), src) switch t := src.(type) {
return scanner.Scan(src) case *sql.NullInt64:
if t.Valid {
return scanner.Scan(t.Int64)
}
return nil
default:
fmt.Printf("%s, ===========00000000000 %#v <----- %#v \n", columnName, fieldValue.Addr().Interface(), src)
return scanner.Scan(src)
}
} }
if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
@ -521,9 +529,24 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
rawValueType := reflect.TypeOf(rawValue.Interface()) rawValueType := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface())
fieldType := fieldValue.Type() fieldType := fieldValue.Type()
hasAssigned := false
if col.IsJSON { var hasAssigned bool
var isJSON = col.IsJSON
var kind = fieldType.Kind()
if reflect.Ptr == kind {
kind = fieldType.Elem().Kind()
}
if !isJSON {
switch kind {
case reflect.Map:
switch src.(type) {
case *sql.NullString:
isJSON = true
}
}
}
if isJSON {
var bs []byte var bs []byte
switch t := src.(type) { switch t := src.(type) {
case *sql.NullString: case *sql.NullString:
@ -564,11 +587,6 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
return nil return nil
} }
var kind = fieldType.Kind()
if reflect.Ptr == kind {
kind = fieldType.Elem().Kind()
}
switch kind { switch kind {
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
// TODO: reimplement this // TODO: reimplement this
@ -596,12 +614,37 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
} }
return nil return nil
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
switch t := src.(type) {
case *sql.RawBytes:
if col.SQLType.IsText() {
x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(*t, x.Interface())
if err != nil {
return err
}
fieldValue.Set(x.Elem())
} else {
l := len(*t)
if fieldValue.Len() > 0 {
for i := 0; i < fieldValue.Len(); i++ {
if i < l {
fieldValue.Index(i).Set(reflect.ValueOf((*t)[i]))
}
}
} else {
for i := 0; i < vv.Len(); i++ {
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i)))
}
}
}
return nil
}
switch rawValueType.Kind() { switch rawValueType.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
switch rawValueType.Elem().Kind() { switch rawValueType.Elem().Kind() {
case reflect.Uint8: case reflect.Uint8:
if fieldType.Elem().Kind() == reflect.Uint8 { if fieldType.Elem().Kind() == reflect.Uint8 {
hasAssigned = true
if col.SQLType.IsText() { if col.SQLType.IsText() {
x := reflect.New(fieldType) x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface())
@ -622,18 +665,22 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
} }
} }
} }
return nil
} }
} }
} }
case reflect.String: case reflect.String:
fmt.Printf("==================111111, %#v,,,,,,%#v\n", fieldValue.Interface(), src)
switch t := src.(type) { switch t := src.(type) {
case *sql.NullString: case *sql.NullString:
if t.Valid { if t.Valid {
fmt.Printf("0000000000000,,, %#v\n", t)
fieldValue.SetString(t.String) fieldValue.SetString(t.String)
} }
return nil return nil
case sql.NullString: case sql.NullString:
if t.Valid { if t.Valid {
fmt.Printf("111111111,,, %#v\n", t)
fieldValue.SetString(t.String) fieldValue.SetString(t.String)
} }
return nil return nil
@ -839,10 +886,27 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
} }
} }
} else if session.statement.UseCascade { } else if session.statement.UseCascade {
fmt.Printf("5565666======= %#v \n", *fieldValue) t := fieldValue.Type()
table, err := session.engine.tagParser.ParseWithCache(*fieldValue) var isPtr = t.Kind() == reflect.Ptr
if err != nil { if isPtr {
return err t = t.Elem()
}
var table *schemas.Table
var err error
if !(isPtr && fieldValue.IsNil()) {
fmt.Printf("5565666======= %#v \n", *fieldValue)
table, err = session.engine.tagParser.ParseWithCache(*fieldValue)
if err != nil {
return err
}
} else {
structInter := reflect.New(t)
table, err = session.engine.tagParser.ParseWithCache(structInter)
if err != nil {
return err
}
} }
hasAssigned = true hasAssigned = true
@ -868,11 +932,6 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily // property to be fetched lazily
t := fieldValue.Type()
var isPtr = t.Kind() == reflect.Ptr
if isPtr {
t = t.Elem()
}
structInter := reflect.New(t) structInter := reflect.New(t)
has, err := session.ID(pk).NoCascade().get(structInter.Interface()) has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil { if err != nil {

View File

@ -41,8 +41,6 @@ func (session *Session) get(bean interface{}) (bool, error) {
return false, session.statement.LastError return false, session.statement.LastError
} }
fmt.Printf("========11111,,, %#v \n", bean)
beanValue := reflect.ValueOf(bean) beanValue := reflect.ValueOf(bean)
if beanValue.Kind() != reflect.Ptr { if beanValue.Kind() != reflect.Ptr {
return false, errors.New("needs a pointer to a value") return false, errors.New("needs a pointer to a value")