improve code

This commit is contained in:
Lunny Xiao 2021-07-07 12:24:57 +08:00
parent 70756b0eef
commit 10dee93362
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
1 changed files with 40 additions and 55 deletions

View File

@ -437,7 +437,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa
}
func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value,
scanResult interface{}, dataStruct *reflect.Value, table *schemas.Table) error {
scanResult interface{}, table *schemas.Table) error {
rawValue := reflect.Indirect(reflect.ValueOf(scanResult))
// if row is null then ignore
@ -474,7 +474,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
vv := reflect.ValueOf(rawValue.Interface())
fieldType := fieldValue.Type()
hasAssigned := false
if col.IsJSON {
var bs []byte
@ -486,8 +485,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
return fmt.Errorf("unsupported database data type: %s %v", col.Name, rawValueType.Kind())
}
hasAssigned = true
if len(bs) > 0 {
if fieldType.Kind() == reflect.String {
fieldValue.SetString(string(bs))
@ -507,7 +504,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
fieldValue.Set(x.Elem())
}
}
return nil
}
@ -521,7 +517,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
bs = vv.Bytes()
}
hasAssigned = true
if len(bs) > 0 {
if fieldValue.CanAddr() {
err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface())
@ -537,13 +532,13 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
fieldValue.Set(x.Elem())
}
}
return nil
case reflect.Slice, reflect.Array:
switch rawValueType.Kind() {
case reflect.Slice, reflect.Array:
switch rawValueType.Elem().Kind() {
case reflect.Uint8:
if fieldType.Elem().Kind() == reflect.Uint8 {
hasAssigned = true
if col.SQLType.IsText() {
x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface())
@ -564,39 +559,40 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
}
}
}
return nil
}
}
}
case reflect.String:
if rawValueType.Kind() == reflect.String {
hasAssigned = true
fieldValue.SetString(vv.String())
return nil
}
case reflect.Bool:
if rawValueType.Kind() == reflect.Bool {
hasAssigned = true
fieldValue.SetBool(vv.Bool())
return nil
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch rawValueType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
hasAssigned = true
fieldValue.SetInt(vv.Int())
return nil
}
case reflect.Float32, reflect.Float64:
switch rawValueType.Kind() {
case reflect.Float32, reflect.Float64:
hasAssigned = true
fieldValue.SetFloat(vv.Float())
return nil
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
switch rawValueType.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
hasAssigned = true
fieldValue.SetUint(vv.Uint())
return nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
hasAssigned = true
fieldValue.SetUint(uint64(vv.Int()))
return nil
}
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.TimeType) {
@ -606,8 +602,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
}
if rawValueType == schemas.TimeType {
hasAssigned = true
t := vv.Convert(schemas.TimeType).Interface().(time.Time)
z, _ := t.Zone()
@ -620,45 +614,42 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
t = t.In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
return nil
} else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type ||
rawValueType == schemas.Int32Type {
hasAssigned = true
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
return nil
} else {
if d, ok := vv.Interface().([]uint8); ok {
hasAssigned = true
t, err := session.byte2Time(col, d)
if err != nil {
session.engine.logger.Errorf("byte2Time error: %v", err)
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
return nil
}
} else if d, ok := vv.Interface().(string); ok {
hasAssigned = true
t, err := session.str2Time(col, d)
if err != nil {
session.engine.logger.Errorf("byte2Time error: %v", err)
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
return nil
}
} else {
return fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface())
}
}
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
// !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString
hasAssigned = true
if err := nulVal.Scan(vv.Interface()); err != nil {
session.engine.logger.Errorf("sql.Sanner error: %v", err)
hasAssigned = false
err := nulVal.Scan(vv.Interface())
if err == nil {
return nil
}
session.engine.logger.Errorf("sql.Sanner error: %v", err)
} else if col.IsJSON {
if rawValueType.Kind() == reflect.String {
hasAssigned = true
x := reflect.New(fieldType)
if len([]byte(vv.String())) > 0 {
err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface())
@ -667,8 +658,8 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
}
fieldValue.Set(x.Elem())
}
return nil
} else if rawValueType.Kind() == reflect.Slice {
hasAssigned = true
x := reflect.New(fieldType)
if len(vv.Bytes()) > 0 {
err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface())
@ -677,6 +668,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
}
fieldValue.Set(x.Elem())
}
return nil
}
} else if session.statement.UseCascade {
table, err := session.engine.tagParser.ParseWithCache(*fieldValue)
@ -684,7 +676,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
return err
}
hasAssigned = true
if len(table.PrimaryKeys) != 1 {
return errors.New("unsupported non or composited primary key cascade")
}
@ -709,6 +700,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
return errors.New("cascade obj is not exist")
}
}
return nil
}
case reflect.Ptr:
// !nashtsai! TODO merge duplicated codes above
@ -717,92 +709,92 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
case schemas.PtrStringType:
if rawValueType.Kind() == reflect.String {
x := vv.String()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.PtrBoolType:
if rawValueType.Kind() == reflect.Bool {
x := vv.Bool()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.PtrTimeType:
if rawValueType == schemas.PtrTimeType {
hasAssigned = true
var x = rawValue.Interface().(time.Time)
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.PtrFloat64Type:
if rawValueType.Kind() == reflect.Float64 {
x := vv.Float()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.PtrUint64Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint64(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.PtrInt64Type:
if rawValueType.Kind() == reflect.Int64 {
x := vv.Int()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.PtrFloat32Type:
if rawValueType.Kind() == reflect.Float64 {
var x = float32(vv.Float())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.PtrIntType:
if rawValueType.Kind() == reflect.Int64 {
var x = int(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.PtrInt32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.PtrInt8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.PtrInt16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.PtrUintType:
if rawValueType.Kind() == reflect.Int64 {
var x = uint(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.PtrUint32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.Uint8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.Uint16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
return nil
}
case schemas.Complex64Type:
var x complex64
@ -813,7 +805,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
return nil
case schemas.Complex128Type:
var x complex128
if len([]byte(vv.String())) > 0 {
@ -823,23 +815,16 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
return nil
} // switch fieldType
} // switch fieldType.Kind()
if hasAssigned {
return nil
}
data, err := value2Bytes(&rawValue)
if err != nil {
return err
}
if err = session.bytes2Value(col, fieldValue, data); err != nil {
return err
}
return nil
return session.bytes2Value(col, fieldValue, data)
}
func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) {
@ -878,7 +863,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
continue
}
if err := session.convertBeanField(col, fieldValue, scanResult, dataStruct, table); err != nil {
if err := session.convertBeanField(col, fieldValue, scanResult, table); err != nil {
return nil, err
}
if col.IsPrimaryKey {