update statement.buildConditions method to support pointer values update
This commit is contained in:
parent
772a2c7baa
commit
a74f8db232
154
base_test.go
154
base_test.go
|
@ -2488,8 +2488,8 @@ type NullData struct {
|
||||||
RunePtr *rune
|
RunePtr *rune
|
||||||
Float32Ptr *float32
|
Float32Ptr *float32
|
||||||
Float64Ptr *float64
|
Float64Ptr *float64
|
||||||
// Complex64Ptr *complex64
|
// Complex64Ptr *complex64 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128'
|
||||||
// Complex128Ptr *complex128
|
// Complex128Ptr *complex128 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128'
|
||||||
TimePtr *time.Time
|
TimePtr *time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2512,8 +2512,8 @@ type NullData2 struct {
|
||||||
RunePtr rune
|
RunePtr rune
|
||||||
Float32Ptr float32
|
Float32Ptr float32
|
||||||
Float64Ptr float64
|
Float64Ptr float64
|
||||||
//Complex64Ptr complex64
|
// Complex64Ptr complex64 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128'
|
||||||
//Complex128Ptr complex128
|
// Complex128Ptr complex128 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128'
|
||||||
TimePtr time.Time
|
TimePtr time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2576,8 +2576,8 @@ func testPointerData(engine *Engine, t *testing.T) {
|
||||||
*nullData.RunePtr = 1
|
*nullData.RunePtr = 1
|
||||||
*nullData.Float32Ptr = -1.2
|
*nullData.Float32Ptr = -1.2
|
||||||
*nullData.Float64Ptr = -1.1
|
*nullData.Float64Ptr = -1.1
|
||||||
// *nullData.Complex64Ptr :new(complex64),
|
// *nullData.Complex64Ptr = 123456789012345678901234567890
|
||||||
// *nullData.Complex128Ptr :new(complex128),
|
// *nullData.Complex128Ptr = 123456789012345678901234567890123456789012345678901234567890
|
||||||
*nullData.TimePtr = time.Now()
|
*nullData.TimePtr = time.Now()
|
||||||
|
|
||||||
cnt, err := engine.Insert(&nullData)
|
cnt, err := engine.Insert(&nullData)
|
||||||
|
@ -2672,9 +2672,18 @@ func testPointerData(engine *Engine, t *testing.T) {
|
||||||
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr)))
|
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if *nullDataGet.Complex64Ptr != *nullData.Complex64Ptr {
|
||||||
|
// t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex64Ptr)))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if *nullDataGet.Complex128Ptr != *nullData.Complex128Ptr {
|
||||||
|
// t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr)))
|
||||||
|
// }
|
||||||
|
|
||||||
if (*nullDataGet.TimePtr).Unix() != (*nullData.TimePtr).Unix() {
|
if (*nullDataGet.TimePtr).Unix() != (*nullData.TimePtr).Unix() {
|
||||||
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr)))
|
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr)))
|
||||||
} else {
|
} else {
|
||||||
|
// !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver
|
||||||
fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr)
|
fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr)
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
}
|
}
|
||||||
|
@ -2755,9 +2764,18 @@ func testPointerData(engine *Engine, t *testing.T) {
|
||||||
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Float64Ptr)))
|
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Float64Ptr)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if nullData2Get.Complex64Ptr != *nullData.Complex64Ptr {
|
||||||
|
// t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Complex64Ptr)))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if nullData2Get.Complex128Ptr != *nullData.Complex128Ptr {
|
||||||
|
// t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Complex128Ptr)))
|
||||||
|
// }
|
||||||
|
|
||||||
if nullData2Get.TimePtr.Unix() != (*nullData.TimePtr).Unix() {
|
if nullData2Get.TimePtr.Unix() != (*nullData.TimePtr).Unix() {
|
||||||
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr)))
|
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr)))
|
||||||
} else {
|
} else {
|
||||||
|
// !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver
|
||||||
fmt.Printf("time value: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr)
|
fmt.Printf("time value: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr)
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
}
|
}
|
||||||
|
@ -2872,6 +2890,14 @@ func testNullValue(engine *Engine, t *testing.T) {
|
||||||
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr)))
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if nullDataGet.Complex64Ptr != nil {
|
||||||
|
// t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Complex64Ptr)))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if nullDataGet.Complex128Ptr != nil {
|
||||||
|
// t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Complex128Ptr)))
|
||||||
|
// }
|
||||||
|
|
||||||
if nullDataGet.TimePtr != nil {
|
if nullDataGet.TimePtr != nil {
|
||||||
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr)))
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr)))
|
||||||
}
|
}
|
||||||
|
@ -2916,8 +2942,8 @@ func testNullValue(engine *Engine, t *testing.T) {
|
||||||
*nullDataUpdate.RunePtr = 1
|
*nullDataUpdate.RunePtr = 1
|
||||||
*nullDataUpdate.Float32Ptr = -1.2
|
*nullDataUpdate.Float32Ptr = -1.2
|
||||||
*nullDataUpdate.Float64Ptr = -1.1
|
*nullDataUpdate.Float64Ptr = -1.1
|
||||||
// *nullDataUpdate.Complex64Ptr :new(complex64),
|
// *nullDataUpdate.Complex64Ptr = 123456789012345678901234567890
|
||||||
// *nullDataUpdate.Complex128Ptr :new(complex128),
|
// *nullDataUpdate.Complex128Ptr = 123456789012345678901234567890123456789012345678901234567890
|
||||||
*nullDataUpdate.TimePtr = time.Now()
|
*nullDataUpdate.TimePtr = time.Now()
|
||||||
|
|
||||||
cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate)
|
cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate)
|
||||||
|
@ -3004,14 +3030,126 @@ func testNullValue(engine *Engine, t *testing.T) {
|
||||||
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr)))
|
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if *nullDataGet.Complex64Ptr != *nullDataUpdate.Complex64Ptr {
|
||||||
|
// t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex64Ptr)))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if *nullDataGet.Complex128Ptr != *nullDataUpdate.Complex128Ptr {
|
||||||
|
// t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr)))
|
||||||
|
// }
|
||||||
|
|
||||||
if (*nullDataGet.TimePtr).Unix() != (*nullDataUpdate.TimePtr).Unix() {
|
if (*nullDataGet.TimePtr).Unix() != (*nullDataUpdate.TimePtr).Unix() {
|
||||||
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr)))
|
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr)))
|
||||||
} else {
|
} else {
|
||||||
|
// !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver
|
||||||
fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr)
|
fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr)
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
}
|
}
|
||||||
// --
|
// --
|
||||||
|
|
||||||
|
// update to null values
|
||||||
|
nullDataUpdate = NullData{}
|
||||||
|
|
||||||
|
cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
panic(err)
|
||||||
|
} else if cnt != 1 {
|
||||||
|
t.Error(errors.New("update count == 0, how can this happen!?"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify get values
|
||||||
|
nullDataGet = NullData{}
|
||||||
|
has, err = engine.Id(nullData.Id).Get(&nullDataGet)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
} else if !has {
|
||||||
|
t.Error(errors.New("ID not found"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%+v", nullDataGet)
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
if nullDataGet.StringPtr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.StringPtr2 != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr2)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.BoolPtr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%t]", *nullDataGet.BoolPtr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.UintPtr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.UintPtr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.Uint8Ptr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint8Ptr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.Uint16Ptr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint16Ptr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.Uint32Ptr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint32Ptr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.Uint64Ptr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint64Ptr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.IntPtr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.IntPtr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.Int8Ptr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int8Ptr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.Int16Ptr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int16Ptr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.Int32Ptr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int32Ptr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.Int64Ptr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int64Ptr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.RunePtr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.RunePtr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.Float32Ptr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float32Ptr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullDataGet.Float64Ptr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// if nullDataGet.Complex64Ptr != nil {
|
||||||
|
// t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr)))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if nullDataGet.Complex128Ptr != nil {
|
||||||
|
// t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr)))
|
||||||
|
// }
|
||||||
|
|
||||||
|
if nullDataGet.TimePtr != nil {
|
||||||
|
t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr)))
|
||||||
|
}
|
||||||
|
// --
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testAll(engine *Engine, t *testing.T) {
|
func testAll(engine *Engine, t *testing.T) {
|
||||||
|
|
12
session.go
12
session.go
|
@ -875,6 +875,7 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error {
|
||||||
// get retrieve one record from database, bean's non-empty fields
|
// get retrieve one record from database, bean's non-empty fields
|
||||||
// will be as conditions
|
// will be as conditions
|
||||||
func (session *Session) Get(bean interface{}) (bool, error) {
|
func (session *Session) Get(bean interface{}) (bool, error) {
|
||||||
|
|
||||||
err := session.newDb()
|
err := session.newDb()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
@ -889,6 +890,7 @@ func (session *Session) Get(bean interface{}) (bool, error) {
|
||||||
var sql string
|
var sql string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
session.Statement.RefTable = session.Engine.autoMap(bean)
|
session.Statement.RefTable = session.Engine.autoMap(bean)
|
||||||
|
|
||||||
if session.Statement.RawSQL == "" {
|
if session.Statement.RawSQL == "" {
|
||||||
sql, args = session.Statement.genGetSql(bean)
|
sql, args = session.Statement.genGetSql(bean)
|
||||||
} else {
|
} else {
|
||||||
|
@ -1000,7 +1002,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
|
||||||
|
|
||||||
if len(condiBean) > 0 {
|
if len(condiBean) > 0 {
|
||||||
colNames, args := buildConditions(session.Engine, table, condiBean[0], true,
|
colNames, args := buildConditions(session.Engine, table, condiBean[0], true,
|
||||||
session.Statement.allUseBool, session.Statement.boolColumnMap)
|
session.Statement.allUseBool, false, session.Statement.boolColumnMap)
|
||||||
session.Statement.ConditionStr = strings.Join(colNames, " AND ")
|
session.Statement.ConditionStr = strings.Join(colNames, " AND ")
|
||||||
session.Statement.BeanArgs = args
|
session.Statement.BeanArgs = args
|
||||||
}
|
}
|
||||||
|
@ -1724,7 +1726,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Ptr:
|
case reflect.Ptr:
|
||||||
// TODO merge duplicated codes above
|
// !nashtsai! TODO merge duplicated codes above
|
||||||
typeStr := fieldType.String()
|
typeStr := fieldType.String()
|
||||||
switch typeStr {
|
switch typeStr {
|
||||||
case "*string":
|
case "*string":
|
||||||
|
@ -2498,7 +2500,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
|
|
||||||
if session.Statement.ColumnStr == "" {
|
if session.Statement.ColumnStr == "" {
|
||||||
colNames, args = buildConditions(session.Engine, table, bean, false,
|
colNames, args = buildConditions(session.Engine, table, bean, false,
|
||||||
session.Statement.allUseBool, session.Statement.boolColumnMap)
|
session.Statement.allUseBool, true, session.Statement.boolColumnMap)
|
||||||
} else {
|
} else {
|
||||||
colNames, args, err = table.genCols(session, bean, true, true)
|
colNames, args, err = table.genCols(session, bean, true, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2532,7 +2534,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
|
|
||||||
if len(condiBean) > 0 {
|
if len(condiBean) > 0 {
|
||||||
condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true,
|
condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true,
|
||||||
session.Statement.allUseBool, session.Statement.boolColumnMap)
|
session.Statement.allUseBool, false, session.Statement.boolColumnMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
var condition = ""
|
var condition = ""
|
||||||
|
@ -2698,7 +2700,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
|
||||||
table := session.Engine.autoMap(bean)
|
table := session.Engine.autoMap(bean)
|
||||||
session.Statement.RefTable = table
|
session.Statement.RefTable = table
|
||||||
colNames, args := buildConditions(session.Engine, table, bean, true,
|
colNames, args := buildConditions(session.Engine, table, bean, true,
|
||||||
session.Statement.allUseBool, session.Statement.boolColumnMap)
|
session.Statement.allUseBool, false, session.Statement.boolColumnMap)
|
||||||
|
|
||||||
var condition = ""
|
var condition = ""
|
||||||
if session.Statement.WhereStr != "" {
|
if session.Statement.WhereStr != "" {
|
||||||
|
|
56
statement.go
56
statement.go
|
@ -233,7 +233,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
// Auto generating conditions according a struct
|
// Auto generating conditions according a struct
|
||||||
func buildConditions(engine *Engine, table *Table, bean interface{}, includeVersion bool, allUseBool bool, boolColumnMap map[string]bool) ([]string, []interface{}) {
|
func buildConditions(engine *Engine, table *Table, bean interface{}, includeVersion bool, allUseBool bool, includeNil bool, boolColumnMap map[string]bool) ([]string, []interface{}) {
|
||||||
colNames := make([]string, 0)
|
colNames := make([]string, 0)
|
||||||
var args = make([]interface{}, 0)
|
var args = make([]interface{}, 0)
|
||||||
for _, col := range table.Columns {
|
for _, col := range table.Columns {
|
||||||
|
@ -242,10 +242,29 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers
|
||||||
}
|
}
|
||||||
fieldValue := col.ValueOf(bean)
|
fieldValue := col.ValueOf(bean)
|
||||||
fieldType := reflect.TypeOf(fieldValue.Interface())
|
fieldType := reflect.TypeOf(fieldValue.Interface())
|
||||||
|
|
||||||
|
requiredField := false
|
||||||
|
if fieldType.Kind() == reflect.Ptr {
|
||||||
|
if fieldValue.IsNil() {
|
||||||
|
if includeNil {
|
||||||
|
args = append(args, nil)
|
||||||
|
colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else if !fieldValue.IsValid() {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
// dereference ptr type to instance type
|
||||||
|
fieldValue = fieldValue.Elem()
|
||||||
|
fieldType = reflect.TypeOf(fieldValue.Interface())
|
||||||
|
requiredField = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var val interface{}
|
var val interface{}
|
||||||
switch fieldType.Kind() {
|
switch fieldType.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
if allUseBool {
|
if allUseBool || requiredField {
|
||||||
val = fieldValue.Interface()
|
val = fieldValue.Interface()
|
||||||
} else if _, ok := boolColumnMap[col.Name]; ok {
|
} else if _, ok := boolColumnMap[col.Name]; ok {
|
||||||
val = fieldValue.Interface()
|
val = fieldValue.Interface()
|
||||||
|
@ -255,7 +274,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
if fieldValue.String() == "" {
|
if !requiredField && fieldValue.String() == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// for MyString, should convert to string or panic
|
// for MyString, should convert to string or panic
|
||||||
|
@ -265,24 +284,24 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers
|
||||||
val = fieldValue.Interface()
|
val = fieldValue.Interface()
|
||||||
}
|
}
|
||||||
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
|
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
|
||||||
if fieldValue.Int() == 0 {
|
if !requiredField && fieldValue.Int() == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
val = fieldValue.Interface()
|
val = fieldValue.Interface()
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
if fieldValue.Float() == 0.0 {
|
if !requiredField && fieldValue.Float() == 0.0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
val = fieldValue.Interface()
|
val = fieldValue.Interface()
|
||||||
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||||
if fieldValue.Uint() == 0 {
|
if !requiredField && fieldValue.Uint() == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
val = fieldValue.Interface()
|
val = fieldValue.Interface()
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if fieldType == reflect.TypeOf(time.Now()) {
|
if fieldType == reflect.TypeOf(time.Now()) {
|
||||||
t := fieldValue.Interface().(time.Time)
|
t := fieldValue.Interface().(time.Time)
|
||||||
if t.IsZero() || !fieldValue.IsValid() {
|
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var str string
|
var str string
|
||||||
|
@ -344,22 +363,6 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers
|
||||||
} else {
|
} else {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
case reflect.Ptr:
|
|
||||||
if fieldValue.IsNil() || !fieldValue.IsValid() {
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
typeStr := fieldType.String()
|
|
||||||
switch typeStr {
|
|
||||||
case "*string", "*bool", "*float32", "*float64", "*int64", "*uint64", "*int", "*int16", "*int32 ", "*int8 ", "*uint", "*uint16", "*uint32", "*uint8":
|
|
||||||
val = fieldValue.Elem()
|
|
||||||
case "*complex64", "*complex128":
|
|
||||||
continue // TODO
|
|
||||||
case "*time.Time":
|
|
||||||
continue // TODO
|
|
||||||
default:
|
|
||||||
continue // TODO
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
val = fieldValue.Interface()
|
val = fieldValue.Interface()
|
||||||
}
|
}
|
||||||
|
@ -598,12 +601,14 @@ func (s *Statement) genDropSQL() string {
|
||||||
return sql
|
return sql
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// !nashtsai! REVIEW, Statement is a huge struct why is this method not passing *Statement?
|
||||||
func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) {
|
func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) {
|
||||||
table := statement.Engine.autoMap(bean)
|
table := statement.Engine.autoMap(bean)
|
||||||
statement.RefTable = table
|
statement.RefTable = table
|
||||||
|
|
||||||
colNames, args := buildConditions(statement.Engine, table, bean, true,
|
colNames, args := buildConditions(statement.Engine, table, bean, true,
|
||||||
statement.allUseBool, statement.boolColumnMap)
|
statement.allUseBool, false, statement.boolColumnMap)
|
||||||
|
|
||||||
statement.ConditionStr = strings.Join(colNames, " AND ")
|
statement.ConditionStr = strings.Join(colNames, " AND ")
|
||||||
statement.BeanArgs = args
|
statement.BeanArgs = args
|
||||||
|
|
||||||
|
@ -640,7 +645,8 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{})
|
||||||
table := statement.Engine.autoMap(bean)
|
table := statement.Engine.autoMap(bean)
|
||||||
statement.RefTable = table
|
statement.RefTable = table
|
||||||
|
|
||||||
colNames, args := buildConditions(statement.Engine, table, bean, true, statement.allUseBool, statement.boolColumnMap)
|
colNames, args := buildConditions(statement.Engine, table, bean, true,
|
||||||
|
statement.allUseBool, false, statement.boolColumnMap)
|
||||||
statement.ConditionStr = strings.Join(colNames, " AND ")
|
statement.ConditionStr = strings.Join(colNames, " AND ")
|
||||||
statement.BeanArgs = args
|
statement.BeanArgs = args
|
||||||
var id string = "*"
|
var id string = "*"
|
||||||
|
|
Loading…
Reference in New Issue