diff --git a/session.go b/session.go index 1e33541a..3fb92991 100644 --- a/session.go +++ b/session.go @@ -437,7 +437,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa } func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, - scanResult interface{}, dataStruct *reflect.Value, table *schemas.Table) error { + scanResult interface{}, table *schemas.Table) error { rawValue := reflect.Indirect(reflect.ValueOf(scanResult)) // if row is null then ignore @@ -474,7 +474,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec vv := reflect.ValueOf(rawValue.Interface()) fieldType := fieldValue.Type() - hasAssigned := false if col.IsJSON { var bs []byte @@ -486,8 +485,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return fmt.Errorf("unsupported database data type: %s %v", col.Name, rawValueType.Kind()) } - hasAssigned = true - if len(bs) > 0 { if fieldType.Kind() == reflect.String { fieldValue.SetString(string(bs)) @@ -507,7 +504,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec fieldValue.Set(x.Elem()) } } - return nil } @@ -521,7 +517,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec bs = vv.Bytes() } - hasAssigned = true if len(bs) > 0 { if fieldValue.CanAddr() { err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) @@ -537,13 +532,13 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec fieldValue.Set(x.Elem()) } } + return nil case reflect.Slice, reflect.Array: switch rawValueType.Kind() { case reflect.Slice, reflect.Array: switch rawValueType.Elem().Kind() { case reflect.Uint8: if fieldType.Elem().Kind() == reflect.Uint8 { - hasAssigned = true if col.SQLType.IsText() { x := reflect.New(fieldType) err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) @@ -564,39 +559,40 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } } } + return nil } } } case reflect.String: if rawValueType.Kind() == reflect.String { - hasAssigned = true fieldValue.SetString(vv.String()) + return nil } case reflect.Bool: if rawValueType.Kind() == reflect.Bool { - hasAssigned = true fieldValue.SetBool(vv.Bool()) + return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: switch rawValueType.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - hasAssigned = true fieldValue.SetInt(vv.Int()) + return nil } case reflect.Float32, reflect.Float64: switch rawValueType.Kind() { case reflect.Float32, reflect.Float64: - hasAssigned = true fieldValue.SetFloat(vv.Float()) + return nil } case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: switch rawValueType.Kind() { case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - hasAssigned = true fieldValue.SetUint(vv.Uint()) + return nil case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - hasAssigned = true fieldValue.SetUint(uint64(vv.Int())) + return nil } case reflect.Struct: if fieldType.ConvertibleTo(schemas.TimeType) { @@ -606,8 +602,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } if rawValueType == schemas.TimeType { - hasAssigned = true - t := vv.Convert(schemas.TimeType).Interface().(time.Time) z, _ := t.Zone() @@ -620,45 +614,42 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec t = t.In(session.engine.TZLocation) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + return nil } else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type || rawValueType == schemas.Int32Type { - hasAssigned = true - t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + return nil } else { if d, ok := vv.Interface().([]uint8); ok { - hasAssigned = true t, err := session.byte2Time(col, d) if err != nil { session.engine.logger.Errorf("byte2Time error: %v", err) - hasAssigned = false } else { fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + return nil } + } else if d, ok := vv.Interface().(string); ok { - hasAssigned = true t, err := session.str2Time(col, d) if err != nil { session.engine.logger.Errorf("byte2Time error: %v", err) - hasAssigned = false } else { fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + return nil } } else { return fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) } } } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - // !! 增加支持sql.Scanner接口的结构,如sql.NullString - hasAssigned = true - if err := nulVal.Scan(vv.Interface()); err != nil { - session.engine.logger.Errorf("sql.Sanner error: %v", err) - hasAssigned = false + err := nulVal.Scan(vv.Interface()) + if err == nil { + return nil } + session.engine.logger.Errorf("sql.Sanner error: %v", err) } else if col.IsJSON { if rawValueType.Kind() == reflect.String { - hasAssigned = true x := reflect.New(fieldType) if len([]byte(vv.String())) > 0 { err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) @@ -667,8 +658,8 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } fieldValue.Set(x.Elem()) } + return nil } else if rawValueType.Kind() == reflect.Slice { - hasAssigned = true x := reflect.New(fieldType) if len(vv.Bytes()) > 0 { err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) @@ -677,6 +668,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } fieldValue.Set(x.Elem()) } + return nil } } else if session.statement.UseCascade { table, err := session.engine.tagParser.ParseWithCache(*fieldValue) @@ -684,7 +676,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return err } - hasAssigned = true if len(table.PrimaryKeys) != 1 { return errors.New("unsupported non or composited primary key cascade") } @@ -709,6 +700,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return errors.New("cascade obj is not exist") } } + return nil } case reflect.Ptr: // !nashtsai! TODO merge duplicated codes above @@ -717,92 +709,92 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec case schemas.PtrStringType: if rawValueType.Kind() == reflect.String { x := vv.String() - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.PtrBoolType: if rawValueType.Kind() == reflect.Bool { x := vv.Bool() - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.PtrTimeType: if rawValueType == schemas.PtrTimeType { - hasAssigned = true var x = rawValue.Interface().(time.Time) fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.PtrFloat64Type: if rawValueType.Kind() == reflect.Float64 { x := vv.Float() - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.PtrUint64Type: if rawValueType.Kind() == reflect.Int64 { var x = uint64(vv.Int()) - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.PtrInt64Type: if rawValueType.Kind() == reflect.Int64 { x := vv.Int() - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.PtrFloat32Type: if rawValueType.Kind() == reflect.Float64 { var x = float32(vv.Float()) - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.PtrIntType: if rawValueType.Kind() == reflect.Int64 { var x = int(vv.Int()) - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.PtrInt32Type: if rawValueType.Kind() == reflect.Int64 { var x = int32(vv.Int()) - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.PtrInt8Type: if rawValueType.Kind() == reflect.Int64 { var x = int8(vv.Int()) - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.PtrInt16Type: if rawValueType.Kind() == reflect.Int64 { var x = int16(vv.Int()) - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.PtrUintType: if rawValueType.Kind() == reflect.Int64 { var x = uint(vv.Int()) - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.PtrUint32Type: if rawValueType.Kind() == reflect.Int64 { var x = uint32(vv.Int()) - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.Uint8Type: if rawValueType.Kind() == reflect.Int64 { var x = uint8(vv.Int()) - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.Uint16Type: if rawValueType.Kind() == reflect.Int64 { var x = uint16(vv.Int()) - hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) + return nil } case schemas.Complex64Type: var x complex64 @@ -813,7 +805,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } fieldValue.Set(reflect.ValueOf(&x)) } - hasAssigned = true + return nil case schemas.Complex128Type: var x complex128 if len([]byte(vv.String())) > 0 { @@ -823,23 +815,16 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } fieldValue.Set(reflect.ValueOf(&x)) } - hasAssigned = true + return nil } // switch fieldType } // switch fieldType.Kind() - if hasAssigned { - return nil - } - data, err := value2Bytes(&rawValue) if err != nil { return err } - if err = session.bytes2Value(col, fieldValue, data); err != nil { - return err - } - return nil + return session.bytes2Value(col, fieldValue, data) } func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { @@ -878,7 +863,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b continue } - if err := session.convertBeanField(col, fieldValue, scanResult, dataStruct, table); err != nil { + if err := session.convertBeanField(col, fieldValue, scanResult, table); err != nil { return nil, err } if col.IsPrimaryKey {