refactor insert condition generation

This commit is contained in:
Lunny Xiao 2021-07-19 14:35:55 +08:00
parent 86775af2ec
commit cf9383c808
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
2 changed files with 149 additions and 133 deletions

View File

@ -818,8 +818,9 @@ func TestGetBigFloat(t *testing.T) {
}
type GetBigFloat2 struct {
Id int64
Money *big.Float `xorm:"decimal(22,2)"`
Id int64
Money *big.Float `xorm:"decimal(22,2)"`
Money2 big.Float `xorm:"decimal(22,2)"`
}
assert.NoError(t, PrepareEngine())
@ -827,7 +828,8 @@ func TestGetBigFloat(t *testing.T) {
{
var gf2 = GetBigFloat2{
Money: big.NewFloat(9999999.99),
Money: big.NewFloat(9999999.99),
Money2: *big.NewFloat(99.99),
}
_, err := testEngine.Insert(&gf2)
assert.NoError(t, err)
@ -845,12 +847,14 @@ func TestGetBigFloat(t *testing.T) {
assert.NoError(t, err)
assert.True(t, has)
assert.True(t, gf3.Money.String() == gf2.Money.String(), "%v != %v", gf3.Money.String(), gf2.Money.String())
assert.True(t, gf3.Money2.String() == gf2.Money2.String(), "%v != %v", gf3.Money2.String(), gf2.Money2.String())
var gfs []GetBigFloat2
err = testEngine.Find(&gfs)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(gfs))
assert.True(t, gfs[0].Money.String() == gf2.Money.String(), "%v != %v", gfs[0].Money.String(), gf2.Money.String())
assert.True(t, gfs[0].Money2.String() == gf2.Money2.String(), "%v != %v", gfs[0].Money2.String(), gf2.Money2.String())
}
}

View File

@ -8,6 +8,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"math/big"
"reflect"
"strings"
"time"
@ -662,10 +663,6 @@ func (statement *Statement) GenIndexSQL() []string {
return sqls
}
func uniqueName(tableName, uqeName string) string {
return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
}
// GenUniqueSQL generates unique SQL
func (statement *Statement) GenUniqueSQL() []string {
var sqls []string
@ -693,6 +690,141 @@ func (statement *Statement) GenDelIndexSQL() []string {
return sqls
}
func (statement *Statement) asDBCond(fieldValue reflect.Value, fieldType reflect.Type, col *schemas.Column, allUseBool, requiredField bool) (interface{}, bool, error) {
switch fieldType.Kind() {
case reflect.Ptr:
if fieldValue.IsNil() {
return nil, true, nil
}
return statement.asDBCond(fieldValue.Elem(), fieldType.Elem(), col, allUseBool, requiredField)
case reflect.Bool:
if allUseBool || requiredField {
return fieldValue.Interface(), true, nil
}
// if a bool in a struct, it will not be as a condition because it default is false,
// please use Where() instead
return nil, false, nil
case reflect.String:
if !requiredField && fieldValue.String() == "" {
return nil, false, nil
}
// for MyString, should convert to string or panic
if fieldType.String() != reflect.String.String() {
return fieldValue.String(), true, nil
}
return fieldValue.Interface(), true, nil
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
if !requiredField && fieldValue.Int() == 0 {
return nil, false, nil
}
return fieldValue.Interface(), true, nil
case reflect.Float32, reflect.Float64:
if !requiredField && fieldValue.Float() == 0.0 {
return nil, false, nil
}
return fieldValue.Interface(), true, nil
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
if !requiredField && fieldValue.Uint() == 0 {
return nil, false, nil
}
return fieldValue.Interface(), true, nil
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
return nil, false, nil
}
return dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t), true, nil
} else if fieldType.ConvertibleTo(schemas.BigFloatType) {
t := fieldValue.Convert(schemas.BigFloatType).Interface().(big.Float)
v := t.String()
if v == "0" {
return nil, false, nil
}
return t.String(), true, nil
} else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok {
return nil, false, nil
} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ := valNul.Value()
if val == nil && !requiredField {
return nil, false, nil
}
return val, true, nil
} else {
if col.IsJSON {
if col.SQLType.IsText() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, false, err
}
return string(bytes), true, nil
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, false, err
}
return bytes, true, nil
}
} else {
table, err := statement.tagParser.ParseWithCache(fieldValue)
if err != nil {
return fieldValue.Interface(), true, nil
}
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues
//if pkField.Int() != 0 {
if pkField.IsValid() && !utils.IsZero(pkField.Interface()) {
return pkField.Interface(), true, nil
}
return nil, false, nil
}
return nil, false, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys)
}
}
case reflect.Array:
return nil, false, nil
case reflect.Slice, reflect.Map:
if fieldValue == reflect.Zero(fieldType) {
return nil, false, nil
}
if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
return nil, false, nil
}
if col.SQLType.IsText() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, false, err
}
return string(bytes), true, nil
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
fieldType.Elem().Kind() == reflect.Uint8 {
if fieldValue.Len() > 0 {
return fieldValue.Bytes(), true, nil
}
return nil, false, nil
} else {
bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, false, err
}
return bytes, true, nil
}
} else {
return nil, false, nil
}
}
return fieldValue.Interface(), true, nil
}
func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
@ -747,9 +879,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
continue
}
fieldType := reflect.TypeOf(fieldValue.Interface())
requiredField := useAllCols
if b, ok := getFlagForColumn(mustColumnMap, col); ok {
if b {
requiredField = true
@ -758,6 +888,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
}
}
fieldType := reflect.TypeOf(fieldValue.Interface())
if fieldType.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
if includeNil {
@ -774,131 +905,12 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
}
}
var val interface{}
switch fieldType.Kind() {
case reflect.Bool:
if allUseBool || requiredField {
val = fieldValue.Interface()
} else {
// if a bool in a struct, it will not be as a condition because it default is false,
// please use Where() instead
continue
}
case reflect.String:
if !requiredField && fieldValue.String() == "" {
continue
}
// for MyString, should convert to string or panic
if fieldType.String() != reflect.String.String() {
val = fieldValue.String()
} else {
val = fieldValue.Interface()
}
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
if !requiredField && fieldValue.Int() == 0 {
continue
}
val = fieldValue.Interface()
case reflect.Float32, reflect.Float64:
if !requiredField && fieldValue.Float() == 0.0 {
continue
}
val = fieldValue.Interface()
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
if !requiredField && fieldValue.Uint() == 0 {
continue
}
val = fieldValue.Interface()
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue
}
val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t)
} else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok {
continue
} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = valNul.Value()
if val == nil && !requiredField {
continue
}
} else {
if col.IsJSON {
if col.SQLType.IsText() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, err
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, err
}
val = bytes
}
} else {
table, err := statement.tagParser.ParseWithCache(fieldValue)
if err != nil {
val = fieldValue.Interface()
} else {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues
//if pkField.Int() != 0 {
if pkField.IsValid() && !utils.IsZero(pkField.Interface()) {
val = pkField.Interface()
} else {
continue
}
} else {
//TODO: how to handler?
return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys)
}
}
}
}
case reflect.Array:
val, ok, err := statement.asDBCond(fieldValue, fieldType, col, allUseBool, requiredField)
if err != nil {
return nil, err
}
if !ok {
continue
case reflect.Slice, reflect.Map:
if fieldValue == reflect.Zero(fieldType) {
continue
}
if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
continue
}
if col.SQLType.IsText() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, err
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
fieldType.Elem().Kind() == reflect.Uint8 {
if fieldValue.Len() > 0 {
val = fieldValue.Bytes()
} else {
continue
}
} else {
bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, err
}
val = bytes
}
} else {
continue
}
default:
val = fieldValue.Interface()
}
conds = append(conds, builder.Eq{colName: val})