From 46fd8f58b3b925ac7305b879e0c1f4a2fc8ad140 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 7 Jul 2021 15:46:21 +0800 Subject: [PATCH] Get struct and Find support big.Float (#1976) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1976 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- convert.go | 89 ++++++++++++++++++++++++++++---- integrations/session_get_test.go | 12 +++++ schemas/type.go | 5 +- session.go | 34 ++++++++---- 4 files changed, 120 insertions(+), 20 deletions(-) diff --git a/convert.go b/convert.go index 491626a8..20a6e373 100644 --- a/convert.go +++ b/convert.go @@ -104,9 +104,7 @@ func asInt64(src interface{}) (int64, error) { return rv.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return int64(rv.Uint()), nil - case reflect.Float64: - return int64(rv.Float()), nil - case reflect.Float32: + case reflect.Float64, reflect.Float32: return int64(rv.Float()), nil case reflect.String: return strconv.ParseInt(rv.String(), 10, 64) @@ -154,9 +152,7 @@ func asUint64(src interface{}) (uint64, error) { return uint64(rv.Int()), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return uint64(rv.Uint()), nil - case reflect.Float64: - return uint64(rv.Float()), nil - case reflect.Float32: + case reflect.Float64, reflect.Float32: return uint64(rv.Float()), nil case reflect.String: return strconv.ParseUint(rv.String(), 10, 64) @@ -204,9 +200,7 @@ func asFloat64(src interface{}) (float64, error) { return float64(rv.Int()), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return float64(rv.Uint()), nil - case reflect.Float64: - return float64(rv.Float()), nil - case reflect.Float32: + case reflect.Float64, reflect.Float32: return float64(rv.Float()), nil case reflect.String: return strconv.ParseFloat(rv.String(), 64) @@ -214,6 +208,83 @@ func asFloat64(src interface{}) (float64, error) { return 0, fmt.Errorf("unsupported value %T as int64", src) } +func asBigFloat(src interface{}) (*big.Float, error) { + res := big.NewFloat(0) + switch v := src.(type) { + case int: + res.SetInt64(int64(v)) + return res, nil + case int16: + res.SetInt64(int64(v)) + return res, nil + case int32: + res.SetInt64(int64(v)) + return res, nil + case int8: + res.SetInt64(int64(v)) + return res, nil + case int64: + res.SetInt64(int64(v)) + return res, nil + case uint: + res.SetUint64(uint64(v)) + return res, nil + case uint8: + res.SetUint64(uint64(v)) + return res, nil + case uint16: + res.SetUint64(uint64(v)) + return res, nil + case uint32: + res.SetUint64(uint64(v)) + return res, nil + case uint64: + res.SetUint64(uint64(v)) + return res, nil + case []byte: + res.SetString(string(v)) + return res, nil + case string: + res.SetString(v) + return res, nil + case *sql.NullString: + if v.Valid { + res.SetString(v.String) + return res, nil + } + return nil, nil + case *sql.NullInt32: + if v.Valid { + res.SetInt64(int64(v.Int32)) + return res, nil + } + return nil, nil + case *sql.NullInt64: + if v.Valid { + res.SetInt64(int64(v.Int64)) + return res, nil + } + return nil, nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + res.SetInt64(rv.Int()) + return res, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + res.SetUint64(rv.Uint()) + return res, nil + case reflect.Float64, reflect.Float32: + res.SetFloat64(rv.Float()) + return res, nil + case reflect.String: + res.SetString(rv.String()) + return res, nil + } + return nil, fmt.Errorf("unsupported value %T as big.Float", src) +} + func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 6fc202bc..02b060b1 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -815,5 +815,17 @@ func TestGetBigFloat(t *testing.T) { assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String()) //fmt.Println(m.Cmp(gf.Money)) //assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + + var gf3 GetBigFloat2 + has, err = testEngine.ID(gf2.Id).Get(&gf3) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, gf3.Money.String() == gf2.Money.String(), "%v != %v", gf3.Money.String(), gf2.Money.String()) + + var gfs []GetBigFloat2 + err = testEngine.Find(&gfs) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(gfs)) + assert.True(t, gfs[0].Money.String() == gf2.Money.String(), "%v != %v", gfs[0].Money.String(), gf2.Money.String()) } } diff --git a/schemas/type.go b/schemas/type.go index fc02f015..3846b5ee 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -5,6 +5,7 @@ package schemas import ( + "math/big" "reflect" "sort" "strings" @@ -240,6 +241,7 @@ var ( intDefault int uintDefault uint timeDefault time.Time + bigFloatDefault big.Float ) // enumerates all types @@ -267,7 +269,8 @@ var ( ByteType = reflect.TypeOf(byteDefault) BytesType = reflect.SliceOf(ByteType) - TimeType = reflect.TypeOf(timeDefault) + TimeType = reflect.TypeOf(timeDefault) + BigFloatType = reflect.TypeOf(bigFloatDefault) ) // enumerates all types diff --git a/session.go b/session.go index a3b11889..64b1758a 100644 --- a/session.go +++ b/session.go @@ -438,8 +438,15 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, scanResult interface{}, table *schemas.Table) error { - rawValue := reflect.Indirect(reflect.ValueOf(scanResult)) + v, ok := scanResult.(*interface{}) + if ok { + scanResult = *v + } + if scanResult == nil { + return nil + } + rawValue := reflect.Indirect(reflect.ValueOf(scanResult)) // if row is null then ignore if rawValue.Interface() == nil { return nil @@ -508,21 +515,19 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec switch fieldType.Kind() { case reflect.Ptr: - if scanResult == nil { - return nil - } - if v, ok := scanResult.(*interface{}); ok && v == nil { - return nil - } - var e reflect.Value if fieldValue.IsNil() { e = reflect.New(fieldType.Elem()).Elem() } else { e = fieldValue.Elem() } - - return session.convertBeanField(col, &e, scanResult, table) + if err := session.convertBeanField(col, &e, scanResult, table); err != nil { + return err + } + if fieldValue.IsNil() { + fieldValue.Set(e.Addr()) + } + return nil case reflect.Complex64, reflect.Complex128: // TODO: reimplement this var bs []byte @@ -610,6 +615,15 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return nil } case reflect.Struct: + if fieldType.ConvertibleTo(schemas.BigFloatType) { + v, err := asBigFloat(scanResult) + if err != nil { + return err + } + fieldValue.Set(reflect.ValueOf(v).Elem().Convert(fieldType)) + return nil + } + if fieldType.ConvertibleTo(schemas.TimeType) { dbTZ := session.engine.DatabaseTZ if col.TimeZone != nil {