From f22f863fc743105e7f5a438adaa7bc7f8b6ef037 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 25 Jun 2021 22:15:49 +0800 Subject: [PATCH] Improve code --- scan.go | 13 +++++++ session.go | 101 +++++++++++++++++++++++++++++++++++++++---------- session_get.go | 2 - 3 files changed, 93 insertions(+), 23 deletions(-) diff --git a/scan.go b/scan.go index 7f6971e3..c50712db 100644 --- a/scan.go +++ b/scan.go @@ -170,6 +170,19 @@ func genScanResult(driver dialects.Driver, fieldType reflect.Type, columnType *s } } +// genScanResults generating scan results according column types +func genScanResults(driver dialects.Driver, types []*sql.ColumnType) ([]interface{}, error) { + var scanResults = make([]interface{}, len(types)) + var err error + for i, t := range types { + scanResults[i], err = driver.GenScanResult(t.DatabaseTypeName()) + if err != nil { + return nil, err + } + } + return scanResults, nil +} + func genScanResultsWithTable(driver dialects.Driver, types []*sql.ColumnType, fields []string, table *schemas.Table) ([]interface{}, error) { var scanResults = make([]interface{}, 0, len(types)) for i, tp := range types { diff --git a/session.go b/session.go index e8ee8015..791eb0bb 100644 --- a/session.go +++ b/session.go @@ -452,9 +452,9 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri return nil } - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + /*if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - } + }*/ fmt.Printf("----- %v <------ %v \n", fieldValue.Type(), rawValue.Type()) if fieldValue.Type() == rawValue.Type() { @@ -464,8 +464,16 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri if fieldValue.CanAddr() { if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - fmt.Printf("%s, ===========00000000000 %#v <----- %#v \n", columnName, fieldValue.Addr().Interface(), src) - return scanner.Scan(src) + switch t := src.(type) { + case *sql.NullInt64: + if t.Valid { + return scanner.Scan(t.Int64) + } + return nil + default: + fmt.Printf("%s, ===========00000000000 %#v <----- %#v \n", columnName, fieldValue.Addr().Interface(), src) + return scanner.Scan(src) + } } if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { @@ -521,9 +529,24 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri rawValueType := reflect.TypeOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface()) fieldType := fieldValue.Type() - hasAssigned := false - if col.IsJSON { + var hasAssigned bool + var isJSON = col.IsJSON + var kind = fieldType.Kind() + if reflect.Ptr == kind { + kind = fieldType.Elem().Kind() + } + if !isJSON { + switch kind { + case reflect.Map: + switch src.(type) { + case *sql.NullString: + isJSON = true + } + } + } + + if isJSON { var bs []byte switch t := src.(type) { case *sql.NullString: @@ -564,11 +587,6 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri return nil } - var kind = fieldType.Kind() - if reflect.Ptr == kind { - kind = fieldType.Elem().Kind() - } - switch kind { case reflect.Complex64, reflect.Complex128: // TODO: reimplement this @@ -596,12 +614,37 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri } return nil case reflect.Slice, reflect.Array: + switch t := src.(type) { + case *sql.RawBytes: + if col.SQLType.IsText() { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(*t, x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } else { + l := len(*t) + if fieldValue.Len() > 0 { + for i := 0; i < fieldValue.Len(); i++ { + if i < l { + fieldValue.Index(i).Set(reflect.ValueOf((*t)[i])) + } + } + } else { + for i := 0; i < vv.Len(); i++ { + fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) + } + } + } + return nil + } + switch rawValueType.Kind() { case reflect.Slice, reflect.Array: switch rawValueType.Elem().Kind() { case reflect.Uint8: if fieldType.Elem().Kind() == reflect.Uint8 { - hasAssigned = true if col.SQLType.IsText() { x := reflect.New(fieldType) err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) @@ -622,18 +665,22 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri } } } + return nil } } } case reflect.String: + fmt.Printf("==================111111, %#v,,,,,,%#v\n", fieldValue.Interface(), src) switch t := src.(type) { case *sql.NullString: if t.Valid { + fmt.Printf("0000000000000,,, %#v\n", t) fieldValue.SetString(t.String) } return nil case sql.NullString: if t.Valid { + fmt.Printf("111111111,,, %#v\n", t) fieldValue.SetString(t.String) } return nil @@ -839,10 +886,27 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri } } } else if session.statement.UseCascade { - fmt.Printf("5565666======= %#v \n", *fieldValue) - table, err := session.engine.tagParser.ParseWithCache(*fieldValue) - if err != nil { - return err + t := fieldValue.Type() + var isPtr = t.Kind() == reflect.Ptr + if isPtr { + t = t.Elem() + } + + var table *schemas.Table + var err error + if !(isPtr && fieldValue.IsNil()) { + fmt.Printf("5565666======= %#v \n", *fieldValue) + + table, err = session.engine.tagParser.ParseWithCache(*fieldValue) + if err != nil { + return err + } + } else { + structInter := reflect.New(t) + table, err = session.engine.tagParser.ParseWithCache(structInter) + if err != nil { + return err + } } hasAssigned = true @@ -868,11 +932,6 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily - t := fieldValue.Type() - var isPtr = t.Kind() == reflect.Ptr - if isPtr { - t = t.Elem() - } structInter := reflect.New(t) has, err := session.ID(pk).NoCascade().get(structInter.Interface()) if err != nil { diff --git a/session_get.go b/session_get.go index db586c04..79112eb0 100644 --- a/session_get.go +++ b/session_get.go @@ -41,8 +41,6 @@ func (session *Session) get(bean interface{}) (bool, error) { return false, session.statement.LastError } - fmt.Printf("========11111,,, %#v \n", bean) - beanValue := reflect.ValueOf(bean) if beanValue.Kind() != reflect.Ptr { return false, errors.New("needs a pointer to a value")