diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index b1dffe14..d3ce2a11 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -818,8 +818,9 @@ func TestGetBigFloat(t *testing.T) { } type GetBigFloat2 struct { - Id int64 - Money *big.Float `xorm:"decimal(22,2)"` + Id int64 + Money *big.Float `xorm:"decimal(22,2)"` + Money2 big.Float `xorm:"decimal(22,2)"` } assert.NoError(t, PrepareEngine()) @@ -827,7 +828,8 @@ func TestGetBigFloat(t *testing.T) { { var gf2 = GetBigFloat2{ - Money: big.NewFloat(9999999.99), + Money: big.NewFloat(9999999.99), + Money2: *big.NewFloat(99.99), } _, err := testEngine.Insert(&gf2) assert.NoError(t, err) @@ -845,12 +847,14 @@ func TestGetBigFloat(t *testing.T) { assert.NoError(t, err) assert.True(t, has) assert.True(t, gf3.Money.String() == gf2.Money.String(), "%v != %v", gf3.Money.String(), gf2.Money.String()) + assert.True(t, gf3.Money2.String() == gf2.Money2.String(), "%v != %v", gf3.Money2.String(), gf2.Money2.String()) var gfs []GetBigFloat2 err = testEngine.Find(&gfs) assert.NoError(t, err) assert.EqualValues(t, 1, len(gfs)) assert.True(t, gfs[0].Money.String() == gf2.Money.String(), "%v != %v", gfs[0].Money.String(), gf2.Money.String()) + assert.True(t, gfs[0].Money2.String() == gf2.Money2.String(), "%v != %v", gfs[0].Money2.String(), gf2.Money2.String()) } } diff --git a/internal/statements/statement.go b/internal/statements/statement.go index bfe9987f..a1cff7c5 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -8,6 +8,7 @@ import ( "database/sql/driver" "errors" "fmt" + "math/big" "reflect" "strings" "time" @@ -662,10 +663,6 @@ func (statement *Statement) GenIndexSQL() []string { return sqls } -func uniqueName(tableName, uqeName string) string { - return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) -} - // GenUniqueSQL generates unique SQL func (statement *Statement) GenUniqueSQL() []string { var sqls []string @@ -693,6 +690,141 @@ func (statement *Statement) GenDelIndexSQL() []string { return sqls } +func (statement *Statement) asDBCond(fieldValue reflect.Value, fieldType reflect.Type, col *schemas.Column, allUseBool, requiredField bool) (interface{}, bool, error) { + switch fieldType.Kind() { + case reflect.Ptr: + if fieldValue.IsNil() { + return nil, true, nil + } + return statement.asDBCond(fieldValue.Elem(), fieldType.Elem(), col, allUseBool, requiredField) + case reflect.Bool: + if allUseBool || requiredField { + return fieldValue.Interface(), true, nil + } + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + return nil, false, nil + case reflect.String: + if !requiredField && fieldValue.String() == "" { + return nil, false, nil + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + return fieldValue.String(), true, nil + } + return fieldValue.Interface(), true, nil + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + return nil, false, nil + } + return fieldValue.Interface(), true, nil + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + return nil, false, nil + } + return fieldValue.Interface(), true, nil + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + return nil, false, nil + } + return fieldValue.Interface(), true, nil + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + return nil, false, nil + } + return dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t), true, nil + } else if fieldType.ConvertibleTo(schemas.BigFloatType) { + t := fieldValue.Convert(schemas.BigFloatType).Interface().(big.Float) + v := t.String() + if v == "0" { + return nil, false, nil + } + return t.String(), true, nil + } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { + return nil, false, nil + } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ := valNul.Value() + if val == nil && !requiredField { + return nil, false, nil + } + return val, true, nil + } else { + if col.IsJSON { + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return string(bytes), true, nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return bytes, true, nil + } + } else { + table, err := statement.tagParser.ParseWithCache(fieldValue) + if err != nil { + return fieldValue.Interface(), true, nil + } + + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + //if pkField.Int() != 0 { + if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { + return pkField.Interface(), true, nil + } + return nil, false, nil + + } + return nil, false, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) + } + } + case reflect.Array: + return nil, false, nil + case reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + return nil, false, nil + } + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + return nil, false, nil + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return string(bytes), true, nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + return fieldValue.Bytes(), true, nil + } + return nil, false, nil + } else { + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return bytes, true, nil + } + } else { + return nil, false, nil + } + } + return fieldValue.Interface(), true, nil +} + func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, @@ -747,9 +879,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, continue } - fieldType := reflect.TypeOf(fieldValue.Interface()) requiredField := useAllCols - if b, ok := getFlagForColumn(mustColumnMap, col); ok { if b { requiredField = true @@ -758,6 +888,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, } } + fieldType := reflect.TypeOf(fieldValue.Interface()) if fieldType.Kind() == reflect.Ptr { if fieldValue.IsNil() { if includeNil { @@ -774,131 +905,12 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, } } - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(schemas.TimeType) { - t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) - } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { - continue - } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = valNul.Value() - if val == nil && !requiredField { - continue - } - } else { - if col.IsJSON { - if col.SQLType.IsText() { - bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = bytes - } - } else { - table, err := statement.tagParser.ParseWithCache(fieldValue) - if err != nil { - val = fieldValue.Interface() - } else { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { - val = pkField.Interface() - } else { - continue - } - } else { - //TODO: how to handler? - return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) - } - } - } - } - case reflect.Array: + val, ok, err := statement.asDBCond(fieldValue, fieldType, col, allUseBool, requiredField) + if err != nil { + return nil, err + } + if !ok { continue - case reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - - if col.SQLType.IsText() { - bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() } conds = append(conds, builder.Eq{colName: val})