From a74f8db2326dc0c6ab55cfaa3f58f9068c70965d Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Sun, 8 Dec 2013 04:27:48 +0800 Subject: [PATCH] update statement.buildConditions method to support pointer values update --- base_test.go | 162 +++++++++++++++++++++++++++++++++++++++++++++++---- session.go | 12 ++-- statement.go | 56 ++++++++++-------- 3 files changed, 188 insertions(+), 42 deletions(-) diff --git a/base_test.go b/base_test.go index e95983aa..8ad6b7c3 100644 --- a/base_test.go +++ b/base_test.go @@ -2488,8 +2488,8 @@ type NullData struct { RunePtr *rune Float32Ptr *float32 Float64Ptr *float64 - // Complex64Ptr *complex64 - // Complex128Ptr *complex128 + // Complex64Ptr *complex64 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' + // Complex128Ptr *complex128 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' TimePtr *time.Time } @@ -2512,8 +2512,8 @@ type NullData2 struct { RunePtr rune Float32Ptr float32 Float64Ptr float64 - //Complex64Ptr complex64 - //Complex128Ptr complex128 + // Complex64Ptr complex64 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' + // Complex128Ptr complex128 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' TimePtr time.Time } @@ -2554,8 +2554,8 @@ func testPointerData(engine *Engine, t *testing.T) { RunePtr: new(rune), Float32Ptr: new(float32), Float64Ptr: new(float64), - // Complex64Ptr :new(complex64), - // Complex128Ptr :new(complex128), + // Complex64Ptr: new(complex64), + // Complex128Ptr: new(complex128), TimePtr: new(time.Time), } @@ -2576,8 +2576,8 @@ func testPointerData(engine *Engine, t *testing.T) { *nullData.RunePtr = 1 *nullData.Float32Ptr = -1.2 *nullData.Float64Ptr = -1.1 - // *nullData.Complex64Ptr :new(complex64), - // *nullData.Complex128Ptr :new(complex128), + // *nullData.Complex64Ptr = 123456789012345678901234567890 + // *nullData.Complex128Ptr = 123456789012345678901234567890123456789012345678901234567890 *nullData.TimePtr = time.Now() cnt, err := engine.Insert(&nullData) @@ -2672,9 +2672,18 @@ func testPointerData(engine *Engine, t *testing.T) { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr))) } + // if *nullDataGet.Complex64Ptr != *nullData.Complex64Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex64Ptr))) + // } + + // if *nullDataGet.Complex128Ptr != *nullData.Complex128Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr))) + // } + if (*nullDataGet.TimePtr).Unix() != (*nullData.TimePtr).Unix() { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr))) } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr) fmt.Println() } @@ -2755,9 +2764,18 @@ func testPointerData(engine *Engine, t *testing.T) { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Float64Ptr))) } + // if nullData2Get.Complex64Ptr != *nullData.Complex64Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Complex64Ptr))) + // } + + // if nullData2Get.Complex128Ptr != *nullData.Complex128Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Complex128Ptr))) + // } + if nullData2Get.TimePtr.Unix() != (*nullData.TimePtr).Unix() { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr))) } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver fmt.Printf("time value: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr) fmt.Println() } @@ -2872,6 +2890,14 @@ func testNullValue(engine *Engine, t *testing.T) { t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) } + // if nullDataGet.Complex64Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Complex64Ptr))) + // } + + // if nullDataGet.Complex128Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Complex128Ptr))) + // } + if nullDataGet.TimePtr != nil { t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr))) } @@ -2894,8 +2920,8 @@ func testNullValue(engine *Engine, t *testing.T) { RunePtr: new(rune), Float32Ptr: new(float32), Float64Ptr: new(float64), - // Complex64Ptr :new(complex64), - // Complex128Ptr :new(complex128), + // Complex64Ptr: new(complex64), + // Complex128Ptr: new(complex128), TimePtr: new(time.Time), } @@ -2916,8 +2942,8 @@ func testNullValue(engine *Engine, t *testing.T) { *nullDataUpdate.RunePtr = 1 *nullDataUpdate.Float32Ptr = -1.2 *nullDataUpdate.Float64Ptr = -1.1 - // *nullDataUpdate.Complex64Ptr :new(complex64), - // *nullDataUpdate.Complex128Ptr :new(complex128), + // *nullDataUpdate.Complex64Ptr = 123456789012345678901234567890 + // *nullDataUpdate.Complex128Ptr = 123456789012345678901234567890123456789012345678901234567890 *nullDataUpdate.TimePtr = time.Now() cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate) @@ -3004,14 +3030,126 @@ func testNullValue(engine *Engine, t *testing.T) { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr))) } + // if *nullDataGet.Complex64Ptr != *nullDataUpdate.Complex64Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex64Ptr))) + // } + + // if *nullDataGet.Complex128Ptr != *nullDataUpdate.Complex128Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr))) + // } + if (*nullDataGet.TimePtr).Unix() != (*nullDataUpdate.TimePtr).Unix() { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr))) } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr) fmt.Println() } // -- + // update to null values + nullDataUpdate = NullData{} + + cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate) + if err != nil { + t.Error(err) + panic(err) + } else if cnt != 1 { + t.Error(errors.New("update count == 0, how can this happen!?")) + return + } + + // verify get values + nullDataGet = NullData{} + has, err = engine.Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + return + } else if !has { + t.Error(errors.New("ID not found")) + return + } + + fmt.Printf("%+v", nullDataGet) + fmt.Println() + + if nullDataGet.StringPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr))) + } + + if nullDataGet.StringPtr2 != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr2))) + } + + if nullDataGet.BoolPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%t]", *nullDataGet.BoolPtr))) + } + + if nullDataGet.UintPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.UintPtr))) + } + + if nullDataGet.Uint8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint8Ptr))) + } + + if nullDataGet.Uint16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint16Ptr))) + } + + if nullDataGet.Uint32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint32Ptr))) + } + + if nullDataGet.Uint64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint64Ptr))) + } + + if nullDataGet.IntPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.IntPtr))) + } + + if nullDataGet.Int8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int8Ptr))) + } + + if nullDataGet.Int16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int16Ptr))) + } + + if nullDataGet.Int32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int32Ptr))) + } + + if nullDataGet.Int64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int64Ptr))) + } + + if nullDataGet.RunePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.RunePtr))) + } + + if nullDataGet.Float32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float32Ptr))) + } + + if nullDataGet.Float64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + } + + // if nullDataGet.Complex64Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + // } + + // if nullDataGet.Complex128Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + // } + + if nullDataGet.TimePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr))) + } + // -- + } func testAll(engine *Engine, t *testing.T) { diff --git a/session.go b/session.go index eef9f5f6..be024a44 100644 --- a/session.go +++ b/session.go @@ -875,6 +875,7 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { // get retrieve one record from database, bean's non-empty fields // will be as conditions func (session *Session) Get(bean interface{}) (bool, error) { + err := session.newDb() if err != nil { return false, err @@ -889,6 +890,7 @@ func (session *Session) Get(bean interface{}) (bool, error) { var sql string var args []interface{} session.Statement.RefTable = session.Engine.autoMap(bean) + if session.Statement.RawSQL == "" { sql, args = session.Statement.genGetSql(bean) } else { @@ -1000,7 +1002,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if len(condiBean) > 0 { colNames, args := buildConditions(session.Engine, table, condiBean[0], true, - session.Statement.allUseBool, session.Statement.boolColumnMap) + session.Statement.allUseBool, false, session.Statement.boolColumnMap) session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args } @@ -1724,7 +1726,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data } } case reflect.Ptr: - // TODO merge duplicated codes above + // !nashtsai! TODO merge duplicated codes above typeStr := fieldType.String() switch typeStr { case "*string": @@ -2498,7 +2500,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.Statement.ColumnStr == "" { colNames, args = buildConditions(session.Engine, table, bean, false, - session.Statement.allUseBool, session.Statement.boolColumnMap) + session.Statement.allUseBool, true, session.Statement.boolColumnMap) } else { colNames, args, err = table.genCols(session, bean, true, true) if err != nil { @@ -2532,7 +2534,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condiBean) > 0 { condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, - session.Statement.allUseBool, session.Statement.boolColumnMap) + session.Statement.allUseBool, false, session.Statement.boolColumnMap) } var condition = "" @@ -2698,7 +2700,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { table := session.Engine.autoMap(bean) session.Statement.RefTable = table colNames, args := buildConditions(session.Engine, table, bean, true, - session.Statement.allUseBool, session.Statement.boolColumnMap) + session.Statement.allUseBool, false, session.Statement.boolColumnMap) var condition = "" if session.Statement.WhereStr != "" { diff --git a/statement.go b/statement.go index 0e51c6ee..b48c0f1e 100644 --- a/statement.go +++ b/statement.go @@ -233,7 +233,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { }*/ // Auto generating conditions according a struct -func buildConditions(engine *Engine, table *Table, bean interface{}, includeVersion bool, allUseBool bool, boolColumnMap map[string]bool) ([]string, []interface{}) { +func buildConditions(engine *Engine, table *Table, bean interface{}, includeVersion bool, allUseBool bool, includeNil bool, boolColumnMap map[string]bool) ([]string, []interface{}) { colNames := make([]string, 0) var args = make([]interface{}, 0) for _, col := range table.Columns { @@ -242,10 +242,29 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers } fieldValue := col.ValueOf(bean) fieldType := reflect.TypeOf(fieldValue.Interface()) + + requiredField := false + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + args = append(args, nil) + colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + var val interface{} switch fieldType.Kind() { case reflect.Bool: - if allUseBool { + if allUseBool || requiredField { val = fieldValue.Interface() } else if _, ok := boolColumnMap[col.Name]; ok { val = fieldValue.Interface() @@ -255,7 +274,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers continue } case reflect.String: - if fieldValue.String() == "" { + if !requiredField && fieldValue.String() == "" { continue } // for MyString, should convert to string or panic @@ -265,24 +284,24 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers val = fieldValue.Interface() } case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if fieldValue.Int() == 0 { + if !requiredField && fieldValue.Int() == 0 { continue } val = fieldValue.Interface() case reflect.Float32, reflect.Float64: - if fieldValue.Float() == 0.0 { + if !requiredField && fieldValue.Float() == 0.0 { continue } val = fieldValue.Interface() case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if fieldValue.Uint() == 0 { + if !requiredField && fieldValue.Uint() == 0 { continue } val = fieldValue.Interface() case reflect.Struct: if fieldType == reflect.TypeOf(time.Now()) { t := fieldValue.Interface().(time.Time) - if t.IsZero() || !fieldValue.IsValid() { + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { continue } var str string @@ -344,22 +363,6 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers } else { continue } - case reflect.Ptr: - if fieldValue.IsNil() || !fieldValue.IsValid() { - continue - } else { - typeStr := fieldType.String() - switch typeStr { - case "*string", "*bool", "*float32", "*float64", "*int64", "*uint64", "*int", "*int16", "*int32 ", "*int8 ", "*uint", "*uint16", "*uint32", "*uint8": - val = fieldValue.Elem() - case "*complex64", "*complex128": - continue // TODO - case "*time.Time": - continue // TODO - default: - continue // TODO - } - } default: val = fieldValue.Interface() } @@ -598,12 +601,14 @@ func (s *Statement) genDropSQL() string { return sql } +// !nashtsai! REVIEW, Statement is a huge struct why is this method not passing *Statement? func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) { table := statement.Engine.autoMap(bean) statement.RefTable = table colNames, args := buildConditions(statement.Engine, table, bean, true, - statement.allUseBool, statement.boolColumnMap) + statement.allUseBool, false, statement.boolColumnMap) + statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args @@ -640,7 +645,8 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{}) table := statement.Engine.autoMap(bean) statement.RefTable = table - colNames, args := buildConditions(statement.Engine, table, bean, true, statement.allUseBool, statement.boolColumnMap) + colNames, args := buildConditions(statement.Engine, table, bean, true, + statement.allUseBool, false, statement.boolColumnMap) statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args var id string = "*"