Support big.Float (#1973)

Now you can use big.Float for numeric type.

```go
type MyMoney struct {
	Id int64
    Money big.Float `xorm:"numeric(22,2)"`
}
```

Reviewed-on: https://gitea.com/xorm/xorm/pulls/1973
Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
Co-committed-by: Lunny Xiao <xiaolunwen@gmail.com>
This commit is contained in:
Lunny Xiao 2021-07-07 14:00:16 +08:00
parent 54bbead2be
commit b754e78269
6 changed files with 102 additions and 45 deletions

View File

@ -9,6 +9,7 @@ import (
"database/sql/driver" "database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"math/big"
"reflect" "reflect"
"strconv" "strconv"
"time" "time"
@ -310,10 +311,12 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve
if s.Valid { if s.Valid {
*d, _ = strconv.Atoi(s.String) *d, _ = strconv.Atoi(s.String)
} }
return nil
case *int64: case *int64:
if s.Valid { if s.Valid {
*d, _ = strconv.ParseInt(s.String, 10, 64) *d, _ = strconv.ParseInt(s.String, 10, 64)
} }
return nil
case *string: case *string:
if s.Valid { if s.Valid {
*d = s.String *d = s.String
@ -339,6 +342,15 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve
d.Valid = true d.Valid = true
d.Time = *dt d.Time = *dt
} }
return nil
case *big.Float:
if s.Valid {
if d == nil {
d = big.NewFloat(0)
}
d.SetString(s.String)
}
return nil
} }
case *sql.NullInt32: case *sql.NullInt32:
switch d := dest.(type) { switch d := dest.(type) {

View File

@ -565,7 +565,7 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) {
case "REAL": case "REAL":
var s sql.NullFloat64 var s sql.NullFloat64
return &s, nil return &s, nil
case "NUMERIC": case "NUMERIC", "DECIMAL":
var s sql.NullString var s sql.NullString
return &s, nil return &s, nil
case "BLOB": case "BLOB":

View File

@ -8,6 +8,7 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"math/big"
"strconv" "strconv"
"testing" "testing"
"time" "time"
@ -766,3 +767,53 @@ func TestGetNil(t *testing.T) {
assert.True(t, errors.Is(err, xorm.ErrObjectIsNil)) assert.True(t, errors.Is(err, xorm.ErrObjectIsNil))
assert.False(t, has) assert.False(t, has)
} }
func TestGetBigFloat(t *testing.T) {
type GetBigFloat struct {
Id int64
Money *big.Float `xorm:"numeric(22,2)"`
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(GetBigFloat))
{
var gf = GetBigFloat{
Money: big.NewFloat(999999.99),
}
_, err := testEngine.Insert(&gf)
assert.NoError(t, err)
var m big.Float
has, err := testEngine.Table("get_big_float").Cols("money").Where("id=?", gf.Id).Get(&m)
assert.NoError(t, err)
assert.True(t, has)
assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String())
//fmt.Println(m.Cmp(gf.Money))
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
}
type GetBigFloat2 struct {
Id int64
Money *big.Float `xorm:"decimal(22,2)"`
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(GetBigFloat2))
{
var gf2 = GetBigFloat2{
Money: big.NewFloat(9999999.99),
}
_, err := testEngine.Insert(&gf2)
assert.NoError(t, err)
var m2 big.Float
has, err := testEngine.Table("get_big_float2").Cols("money").Where("id=?", gf2.Id).Get(&m2)
assert.NoError(t, err)
assert.True(t, has)
assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String())
//fmt.Println(m.Cmp(gf.Money))
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
}
}

View File

@ -8,6 +8,7 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"math/big"
"reflect" "reflect"
"time" "time"
@ -19,6 +20,7 @@ import (
var ( var (
nullFloatType = reflect.TypeOf(sql.NullFloat64{}) nullFloatType = reflect.TypeOf(sql.NullFloat64{})
bigFloatType = reflect.TypeOf(big.Float{})
) )
// Value2Interface convert a field value of a struct to interface for puting into database // Value2Interface convert a field value of a struct to interface for puting into database
@ -84,6 +86,9 @@ func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue refl
return nil, nil return nil, nil
} }
return t.Float64, nil return t.Float64, nil
} else if fieldType.ConvertibleTo(bigFloatType) {
t := fieldValue.Convert(bigFloatType).Interface().(big.Float)
return t.String(), nil
} }
if !col.IsJSON { if !col.IsJSON {

24
scan.go
View File

@ -7,6 +7,7 @@ package xorm
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"math/big"
"reflect" "reflect"
"time" "time"
@ -182,13 +183,21 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
for _, v := range vv { for _, v := range vv {
var replaced bool var replaced bool
var scanResult interface{} var scanResult interface{}
if _, ok := v.(sql.Scanner); !ok { switch t := v.(type) {
case sql.Scanner:
scanResult = t
case convert.Conversion:
scanResult = &sql.RawBytes{}
replaced = true
case *big.Float:
scanResult = &sql.NullString{}
replaced = true
default:
var useNullable = true var useNullable = true
if engine.driver.Features().SupportNullable { if engine.driver.Features().SupportNullable {
nullable, ok := types[0].Nullable() nullable, ok := types[0].Nullable()
useNullable = ok && nullable useNullable = ok && nullable
} }
if useNullable { if useNullable {
scanResult, replaced, err = genScanResultsByBeanNullable(v) scanResult, replaced, err = genScanResultsByBeanNullable(v)
} else { } else {
@ -197,25 +206,22 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
if err != nil { if err != nil {
return err return err
} }
} else {
scanResult = v
} }
scanResults = append(scanResults, scanResult) scanResults = append(scanResults, scanResult)
replaces = append(replaces, replaced) replaces = append(replaces, replaced)
} }
var scanCtx = dialects.ScanContext{ if err = engine.driver.Scan(&dialects.ScanContext{
DBLocation: engine.DatabaseTZ, DBLocation: engine.DatabaseTZ,
UserLocation: engine.TZLocation, UserLocation: engine.TZLocation,
} }, rows, types, scanResults...); err != nil {
if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil {
return err return err
} }
for i, replaced := range replaces { for i, replaced := range replaces {
if replaced { if replaced {
if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil { if err = convertAssign(vv[i], scanResults[i], engine.DatabaseTZ, engine.TZLocation); err != nil {
return err return err
} }
} }

View File

@ -9,6 +9,7 @@ import (
"database/sql/driver" "database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"math/big"
"reflect" "reflect"
"strconv" "strconv"
"time" "time"
@ -123,6 +124,20 @@ var (
conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem() conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem()
) )
func isScannableStruct(bean interface{}, typeLen int) bool {
switch bean.(type) {
case *time.Time:
return false
case sql.Scanner:
return false
case convert.Conversion:
return typeLen > 1
case *big.Float:
return false
}
return true
}
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
rows, err := session.queryRows(sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)
if err != nil { if err != nil {
@ -148,13 +163,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table,
} }
switch beanKind { switch beanKind {
case reflect.Struct: case reflect.Struct:
if _, ok := bean.(*time.Time); ok { if !isScannableStruct(bean, len(types)) {
break
}
if _, ok := bean.(sql.Scanner); ok {
break
}
if _, ok := bean.(convert.Conversion); len(types) == 1 && ok {
break break
} }
return session.getStruct(rows, types, fields, table, bean) return session.getStruct(rows, types, fields, table, bean)
@ -240,35 +249,9 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields
if len(beans) != len(types) { if len(beans) != len(types) {
return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans)) return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans))
} }
var scanResults = make([]interface{}, 0, len(types))
var replaceds = make([]bool, 0, len(types))
for _, bean := range beans {
switch t := bean.(type) {
case sql.Scanner:
scanResults = append(scanResults, t)
replaceds = append(replaceds, false)
case convert.Conversion:
scanResults = append(scanResults, &sql.RawBytes{})
replaceds = append(replaceds, true)
default:
scanResults = append(scanResults, bean)
replaceds = append(replaceds, false)
}
}
err := session.engine.scan(rows, fields, types, scanResults...) err := session.engine.scan(rows, fields, types, beans...)
if err != nil { return true, err
return true, err
}
for i, replaced := range replaceds {
if replaced {
err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation)
if err != nil {
return true, err
}
}
}
return true, nil
} }
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {