diff --git a/README.md b/README.md index ad399a72..d67951f9 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,12 @@ Or # Discuss -Please visit [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm) +Please visit [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm) + +# Contributors + +* [Lunny](https://github.com/lunny) +* [Nashtsai](https://github.com/nashtsai) # LICENSE diff --git a/README_CN.md b/README_CN.md index b2fe1f9b..febc65f9 100644 --- a/README_CN.md +++ b/README_CN.md @@ -81,6 +81,10 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 请加入QQ群:280360085 进行讨论。 +# 贡献者 + +* [Lunny](https://github.com/lunny) +* [Nashtsai](https://github.com/nashtsai) ## LICENSE diff --git a/base_test.go b/base_test.go index cd75f86a..070eba6e 100644 --- a/base_test.go +++ b/base_test.go @@ -2469,6 +2469,689 @@ func testProcessorsTx(engine *Engine, t *testing.T) { // -- } +type NullData struct { + Id int64 + StringPtr *string + StringPtr2 *string `xorm:"text"` + BoolPtr *bool + BytePtr *byte + UintPtr *uint + Uint8Ptr *uint8 + Uint16Ptr *uint16 + Uint32Ptr *uint32 + Uint64Ptr *uint64 + IntPtr *int + Int8Ptr *int8 + Int16Ptr *int16 + Int32Ptr *int32 + Int64Ptr *int64 + RunePtr *rune + Float32Ptr *float32 + Float64Ptr *float64 + // Complex64Ptr *complex64 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' + // Complex128Ptr *complex128 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' + TimePtr *time.Time +} + +type NullData2 struct { + Id int64 + StringPtr string + StringPtr2 string `xorm:"text"` + BoolPtr bool + BytePtr byte + UintPtr uint + Uint8Ptr uint8 + Uint16Ptr uint16 + Uint32Ptr uint32 + Uint64Ptr uint64 + IntPtr int + Int8Ptr int8 + Int16Ptr int16 + Int32Ptr int32 + Int64Ptr int64 + RunePtr rune + Float32Ptr float32 + Float64Ptr float64 + // Complex64Ptr complex64 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' + // Complex128Ptr complex128 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' + TimePtr time.Time +} + +type NullData3 struct { + Id int64 + StringPtr *string +} + +func testPointerData(engine *Engine, t *testing.T) { + + err := engine.DropTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } + + nullData := NullData{ + StringPtr: new(string), + StringPtr2: new(string), + BoolPtr: new(bool), + BytePtr: new(byte), + UintPtr: new(uint), + Uint8Ptr: new(uint8), + Uint16Ptr: new(uint16), + Uint32Ptr: new(uint32), + Uint64Ptr: new(uint64), + IntPtr: new(int), + Int8Ptr: new(int8), + Int16Ptr: new(int16), + Int32Ptr: new(int32), + Int64Ptr: new(int64), + RunePtr: new(rune), + Float32Ptr: new(float32), + Float64Ptr: new(float64), + // Complex64Ptr: new(complex64), + // Complex128Ptr: new(complex128), + TimePtr: new(time.Time), + } + + *nullData.StringPtr = "abc" + *nullData.StringPtr2 = "123" + *nullData.BoolPtr = true + *nullData.BytePtr = 1 + *nullData.UintPtr = 1 + *nullData.Uint8Ptr = 1 + *nullData.Uint16Ptr = 1 + *nullData.Uint32Ptr = 1 + *nullData.Uint64Ptr = 1 + *nullData.IntPtr = -1 + *nullData.Int8Ptr = -1 + *nullData.Int16Ptr = -1 + *nullData.Int32Ptr = -1 + *nullData.Int64Ptr = -1 + *nullData.RunePtr = 1 + *nullData.Float32Ptr = -1.2 + *nullData.Float64Ptr = -1.1 + // *nullData.Complex64Ptr = 123456789012345678901234567890 + // *nullData.Complex128Ptr = 123456789012345678901234567890123456789012345678901234567890 + *nullData.TimePtr = time.Now() + + cnt, err := engine.Insert(&nullData) + fmt.Println(nullData.Id) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + if nullData.Id <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } + + // verify get values + nullDataGet := NullData{} + has, err := engine.Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + panic(err) + } else if !has { + t.Error(errors.New("ID not found")) + } + + if *nullDataGet.StringPtr != *nullData.StringPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr))) + } + + if *nullDataGet.StringPtr2 != *nullData.StringPtr2 { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr2))) + } + + if *nullDataGet.BoolPtr != *nullData.BoolPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%t]", *nullDataGet.BoolPtr))) + } + + if *nullDataGet.UintPtr != *nullData.UintPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.UintPtr))) + } + + if *nullDataGet.Uint8Ptr != *nullData.Uint8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint8Ptr))) + } + + if *nullDataGet.Uint16Ptr != *nullData.Uint16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint16Ptr))) + } + + if *nullDataGet.Uint32Ptr != *nullData.Uint32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint32Ptr))) + } + + if *nullDataGet.Uint64Ptr != *nullData.Uint64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint64Ptr))) + } + + if *nullDataGet.IntPtr != *nullData.IntPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.IntPtr))) + } + + if *nullDataGet.Int8Ptr != *nullData.Int8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int8Ptr))) + } + + if *nullDataGet.Int16Ptr != *nullData.Int16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int16Ptr))) + } + + if *nullDataGet.Int32Ptr != *nullData.Int32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int32Ptr))) + } + + if *nullDataGet.Int64Ptr != *nullData.Int64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int64Ptr))) + } + + if *nullDataGet.RunePtr != *nullData.RunePtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.RunePtr))) + } + + if *nullDataGet.Float32Ptr != *nullData.Float32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float32Ptr))) + } + + if *nullDataGet.Float64Ptr != *nullData.Float64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr))) + } + + // if *nullDataGet.Complex64Ptr != *nullData.Complex64Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex64Ptr))) + // } + + // if *nullDataGet.Complex128Ptr != *nullData.Complex128Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr))) + // } + + /*if (*nullDataGet.TimePtr).Unix() != (*nullData.TimePtr).Unix() { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr))) + } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver + fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr) + fmt.Println() + }*/ + // -- + + // using instance type should just work too + nullData2Get := NullData2{} + + has, err = engine.Table("null_data").Id(nullData.Id).Get(&nullData2Get) + if err != nil { + t.Error(err) + panic(err) + } else if !has { + t.Error(errors.New("ID not found")) + } + + if nullData2Get.StringPtr != *nullData.StringPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.StringPtr))) + } + + if nullData2Get.StringPtr2 != *nullData.StringPtr2 { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.StringPtr2))) + } + + if nullData2Get.BoolPtr != *nullData.BoolPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%t]", nullData2Get.BoolPtr))) + } + + if nullData2Get.UintPtr != *nullData.UintPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.UintPtr))) + } + + if nullData2Get.Uint8Ptr != *nullData.Uint8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint8Ptr))) + } + + if nullData2Get.Uint16Ptr != *nullData.Uint16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint16Ptr))) + } + + if nullData2Get.Uint32Ptr != *nullData.Uint32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint32Ptr))) + } + + if nullData2Get.Uint64Ptr != *nullData.Uint64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint64Ptr))) + } + + if nullData2Get.IntPtr != *nullData.IntPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.IntPtr))) + } + + if nullData2Get.Int8Ptr != *nullData.Int8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int8Ptr))) + } + + if nullData2Get.Int16Ptr != *nullData.Int16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int16Ptr))) + } + + if nullData2Get.Int32Ptr != *nullData.Int32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int32Ptr))) + } + + if nullData2Get.Int64Ptr != *nullData.Int64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int64Ptr))) + } + + if nullData2Get.RunePtr != *nullData.RunePtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.RunePtr))) + } + + if nullData2Get.Float32Ptr != *nullData.Float32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Float32Ptr))) + } + + if nullData2Get.Float64Ptr != *nullData.Float64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Float64Ptr))) + } + + // if nullData2Get.Complex64Ptr != *nullData.Complex64Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Complex64Ptr))) + // } + + // if nullData2Get.Complex128Ptr != *nullData.Complex128Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Complex128Ptr))) + // } + + /*if nullData2Get.TimePtr.Unix() != (*nullData.TimePtr).Unix() { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr))) + } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver + fmt.Printf("time value: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr) + fmt.Println() + }*/ + // -- +} + +func testNullValue(engine *Engine, t *testing.T) { + + err := engine.DropTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } + + nullData := NullData{} + + cnt, err := engine.Insert(&nullData) + fmt.Println(nullData.Id) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + if nullData.Id <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } + + nullDataGet := NullData{} + + has, err := engine.Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + panic(err) + } else if !has { + t.Error(errors.New("ID not found")) + } + + if nullDataGet.StringPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr))) + } + + if nullDataGet.StringPtr2 != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr2))) + } + + if nullDataGet.BoolPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%t]", *nullDataGet.BoolPtr))) + } + + if nullDataGet.UintPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.UintPtr))) + } + + if nullDataGet.Uint8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint8Ptr))) + } + + if nullDataGet.Uint16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint16Ptr))) + } + + if nullDataGet.Uint32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint32Ptr))) + } + + if nullDataGet.Uint64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint64Ptr))) + } + + if nullDataGet.IntPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.IntPtr))) + } + + if nullDataGet.Int8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int8Ptr))) + } + + if nullDataGet.Int16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int16Ptr))) + } + + if nullDataGet.Int32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int32Ptr))) + } + + if nullDataGet.Int64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int64Ptr))) + } + + if nullDataGet.RunePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.RunePtr))) + } + + if nullDataGet.Float32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float32Ptr))) + } + + if nullDataGet.Float64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + } + + // if nullDataGet.Complex64Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Complex64Ptr))) + // } + + // if nullDataGet.Complex128Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Complex128Ptr))) + // } + + if nullDataGet.TimePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr))) + } + + nullDataUpdate := NullData{ + StringPtr: new(string), + StringPtr2: new(string), + BoolPtr: new(bool), + BytePtr: new(byte), + UintPtr: new(uint), + Uint8Ptr: new(uint8), + Uint16Ptr: new(uint16), + Uint32Ptr: new(uint32), + Uint64Ptr: new(uint64), + IntPtr: new(int), + Int8Ptr: new(int8), + Int16Ptr: new(int16), + Int32Ptr: new(int32), + Int64Ptr: new(int64), + RunePtr: new(rune), + Float32Ptr: new(float32), + Float64Ptr: new(float64), + // Complex64Ptr: new(complex64), + // Complex128Ptr: new(complex128), + TimePtr: new(time.Time), + } + + *nullDataUpdate.StringPtr = "abc" + *nullDataUpdate.StringPtr2 = "123" + *nullDataUpdate.BoolPtr = true + *nullDataUpdate.BytePtr = 1 + *nullDataUpdate.UintPtr = 1 + *nullDataUpdate.Uint8Ptr = 1 + *nullDataUpdate.Uint16Ptr = 1 + *nullDataUpdate.Uint32Ptr = 1 + *nullDataUpdate.Uint64Ptr = 1 + *nullDataUpdate.IntPtr = -1 + *nullDataUpdate.Int8Ptr = -1 + *nullDataUpdate.Int16Ptr = -1 + *nullDataUpdate.Int32Ptr = -1 + *nullDataUpdate.Int64Ptr = -1 + *nullDataUpdate.RunePtr = 1 + *nullDataUpdate.Float32Ptr = -1.2 + *nullDataUpdate.Float64Ptr = -1.1 + // *nullDataUpdate.Complex64Ptr = 123456789012345678901234567890 + // *nullDataUpdate.Complex128Ptr = 123456789012345678901234567890123456789012345678901234567890 + *nullDataUpdate.TimePtr = time.Now() + + cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate) + if err != nil { + t.Error(err) + panic(err) + } else if cnt != 1 { + t.Error(errors.New("update count == 0, how can this happen!?")) + return + } + + // verify get values + nullDataGet = NullData{} + has, err = engine.Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + return + } else if !has { + t.Error(errors.New("ID not found")) + return + } + + if *nullDataGet.StringPtr != *nullDataUpdate.StringPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr))) + } + + if *nullDataGet.StringPtr2 != *nullDataUpdate.StringPtr2 { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr2))) + } + + if *nullDataGet.BoolPtr != *nullDataUpdate.BoolPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%t]", *nullDataGet.BoolPtr))) + } + + if *nullDataGet.UintPtr != *nullDataUpdate.UintPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.UintPtr))) + } + + if *nullDataGet.Uint8Ptr != *nullDataUpdate.Uint8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint8Ptr))) + } + + if *nullDataGet.Uint16Ptr != *nullDataUpdate.Uint16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint16Ptr))) + } + + if *nullDataGet.Uint32Ptr != *nullDataUpdate.Uint32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint32Ptr))) + } + + if *nullDataGet.Uint64Ptr != *nullDataUpdate.Uint64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint64Ptr))) + } + + if *nullDataGet.IntPtr != *nullDataUpdate.IntPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.IntPtr))) + } + + if *nullDataGet.Int8Ptr != *nullDataUpdate.Int8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int8Ptr))) + } + + if *nullDataGet.Int16Ptr != *nullDataUpdate.Int16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int16Ptr))) + } + + if *nullDataGet.Int32Ptr != *nullDataUpdate.Int32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int32Ptr))) + } + + if *nullDataGet.Int64Ptr != *nullDataUpdate.Int64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int64Ptr))) + } + + if *nullDataGet.RunePtr != *nullDataUpdate.RunePtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.RunePtr))) + } + + if *nullDataGet.Float32Ptr != *nullDataUpdate.Float32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float32Ptr))) + } + + if *nullDataGet.Float64Ptr != *nullDataUpdate.Float64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr))) + } + + // if *nullDataGet.Complex64Ptr != *nullDataUpdate.Complex64Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex64Ptr))) + // } + + // if *nullDataGet.Complex128Ptr != *nullDataUpdate.Complex128Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr))) + // } + + /*if (*nullDataGet.TimePtr).Unix() != (*nullDataUpdate.TimePtr).Unix() { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr))) + } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver + fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr) + fmt.Println() + }*/ + // -- + + // update to null values + /*nullDataUpdate = NullData{} + + cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate) + if err != nil { + t.Error(err) + panic(err) + } else if cnt != 1 { + t.Error(errors.New("update count == 0, how can this happen!?")) + return + }*/ + + // verify get values + /*nullDataGet = NullData{} + has, err = engine.Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + return + } else if !has { + t.Error(errors.New("ID not found")) + return + } + + fmt.Printf("%+v", nullDataGet) + fmt.Println() + + if nullDataGet.StringPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr))) + } + + if nullDataGet.StringPtr2 != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr2))) + } + + if nullDataGet.BoolPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%t]", *nullDataGet.BoolPtr))) + } + + if nullDataGet.UintPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.UintPtr))) + } + + if nullDataGet.Uint8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint8Ptr))) + } + + if nullDataGet.Uint16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint16Ptr))) + } + + if nullDataGet.Uint32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint32Ptr))) + } + + if nullDataGet.Uint64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint64Ptr))) + } + + if nullDataGet.IntPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.IntPtr))) + } + + if nullDataGet.Int8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int8Ptr))) + } + + if nullDataGet.Int16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int16Ptr))) + } + + if nullDataGet.Int32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int32Ptr))) + } + + if nullDataGet.Int64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int64Ptr))) + } + + if nullDataGet.RunePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.RunePtr))) + } + + if nullDataGet.Float32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float32Ptr))) + } + + if nullDataGet.Float64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + } + + // if nullDataGet.Complex64Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + // } + + // if nullDataGet.Complex128Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + // } + + if nullDataGet.TimePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr))) + }*/ + // -- + +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -2567,8 +3250,16 @@ func testAll2(engine *Engine, t *testing.T) { testCreatedUpdated(engine, t) fmt.Println("-------------- processors --------------") testProcessors(engine, t) - fmt.Println("-------------- processors TX --------------") - testProcessorsTx(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) } + +// !nash! the 3rd set of the test is intended for non-cache enabled engine +func testAll3(engine *Engine, t *testing.T) { + fmt.Println("-------------- processors TX --------------") + testProcessorsTx(engine, t) + fmt.Println("-------------- insert pointer data --------------") + testPointerData(engine, t) + fmt.Println("-------------- insert null data --------------") + testNullValue(engine, t) +} diff --git a/mymysql_test.go b/mymysql_test.go index 2e4526cf..333f37f3 100644 --- a/mymysql_test.go +++ b/mymysql_test.go @@ -13,7 +13,7 @@ utf8 COLLATE utf8_general_ci; var showTestSql bool = true func TestMyMysql(t *testing.T) { - err := mysqlDdlImport() + err := mymysqlDdlImport() if err != nil { t.Error(err) return @@ -34,10 +34,11 @@ func TestMyMysql(t *testing.T) { testAll(engine, t) testAll2(engine, t) + testAll3(engine, t) } func TestMyMysqlWithCache(t *testing.T) { - err := mysqlDdlImport() + err := mymysqlDdlImport() if err != nil { t.Error(err) return @@ -65,7 +66,7 @@ func newMyMysqlEngine() (*Engine, error) { return NewEngine("mymysql", "xorm_test2/root/") } -func mysqlDdlImport() error { +func mymysqlDdlImport() error { engine, err := NewEngine("mymysql", "/root/") if err != nil { return err diff --git a/mysql_test.go b/mysql_test.go index 1451c08f..87b166da 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -10,43 +10,74 @@ CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET utf8 COLLATE utf8_general_ci; */ -func newMysqlEngine() (*Engine, error) { - return NewEngine("mysql", "root:@/xorm_test?charset=utf8") -} +var mysqlShowTestSql bool = true func TestMysql(t *testing.T) { - engine, err := newMysqlEngine() + err := mysqlDdlImport() + if err != nil { + t.Error(err) + return + } + + engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8") defer engine.Close() if err != nil { t.Error(err) return } - engine.ShowSQL = showTestSql - engine.ShowErr = showTestSql - engine.ShowWarn = showTestSql - engine.ShowDebug = showTestSql + engine.ShowSQL = mysqlShowTestSql + engine.ShowErr = mysqlShowTestSql + engine.ShowWarn = mysqlShowTestSql + engine.ShowDebug = mysqlShowTestSql testAll(engine, t) testAll2(engine, t) + testAll3(engine, t) } func TestMysqlWithCache(t *testing.T) { - engine, err := newMysqlEngine() + err := mysqlDdlImport() + if err != nil { + t.Error(err) + return + } + + engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8") defer engine.Close() if err != nil { t.Error(err) return } engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) - engine.ShowSQL = showTestSql - engine.ShowErr = showTestSql - engine.ShowWarn = showTestSql - engine.ShowDebug = showTestSql + engine.ShowSQL = mysqlShowTestSql + engine.ShowErr = mysqlShowTestSql + engine.ShowWarn = mysqlShowTestSql + engine.ShowDebug = mysqlShowTestSql testAll(engine, t) testAll2(engine, t) } +func newMysqlEngine() (*Engine, error) { + return NewEngine("mysql", "root:@/xorm_test?charset=utf8") +} + +func mysqlDdlImport() error { + engine, err := NewEngine("mysql", "root:@/?charset=utf8") + if err != nil { + return err + } + engine.ShowSQL = mysqlShowTestSql + engine.ShowErr = mysqlShowTestSql + engine.ShowWarn = mysqlShowTestSql + engine.ShowDebug = mysqlShowTestSql + + sqlResults, _ := engine.Import("tests/mysql_ddl.sql") + engine.LogDebug("sql results: %v", sqlResults) + engine.Close() + return nil +} + func BenchmarkMysqlNoCacheInsert(t *testing.B) { engine, err := newMysqlEngine() defer engine.Close() diff --git a/postgres_test.go b/postgres_test.go index 8a9b2567..9a84adac 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -25,6 +25,7 @@ func TestPostgres(t *testing.T) { testAll(engine, t) testAll2(engine, t) + testAll3(engine, t) } func TestPostgresWithCache(t *testing.T) { diff --git a/session.go b/session.go index 1df486f9..34a0feb7 100644 --- a/session.go +++ b/session.go @@ -875,6 +875,7 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { // get retrieve one record from database, bean's non-empty fields // will be as conditions func (session *Session) Get(bean interface{}) (bool, error) { + err := session.newDb() if err != nil { return false, err @@ -889,6 +890,7 @@ func (session *Session) Get(bean interface{}) (bool, error) { var sql string var args []interface{} session.Statement.RefTable = session.Engine.autoMap(bean) + if session.Statement.RawSQL == "" { sql, args = session.Statement.genGetSql(bean) } else { @@ -1000,7 +1002,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if len(condiBean) > 0 { colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true, - session.Statement.allUseBool, session.Statement.boolColumnMap) + false, session.Statement.allUseBool, session.Statement.boolColumnMap) session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args } @@ -1660,7 +1662,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data fieldValue.SetUint(x) //Now only support Time type case reflect.Struct: - if fieldValue.Type().String() == "time.Time" { + if fieldType.String() == "time.Time" { sdata := strings.TrimSpace(string(data)) var x time.Time var err error @@ -1723,6 +1725,265 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return errors.New("unsupported struct type in Scan: " + fieldValue.Type().String()) } } + case reflect.Ptr: + // !nashtsai! TODO merge duplicated codes above + typeStr := fieldType.String() + switch typeStr { + case "*string": + x := string(data) + fieldValue.Set(reflect.ValueOf(&x)) + case "*bool": + d := string(data) + v, err := strconv.ParseBool(d) + if err != nil { + return errors.New("arg " + key + " as bool: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&v)) + case "*complex64": + var x complex64 + err := json.Unmarshal(data, &x) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(reflect.ValueOf(&x)) + case "*complex128": + var x complex128 + err := json.Unmarshal(data, &x) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(reflect.ValueOf(&x)) + case "*float64": + x, err := strconv.ParseFloat(string(data), 64) + if err != nil { + return errors.New("arg " + key + " as float64: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + case "*float32": + var x float32 + x1, err := strconv.ParseFloat(string(data), 32) + if err != nil { + return errors.New("arg " + key + " as float32: " + err.Error()) + } + x = float32(x1) + fieldValue.Set(reflect.ValueOf(&x)) + case "*time.Time": + sdata := strings.TrimSpace(string(data)) + var x time.Time + var err error + + if sdata == "0000-00-00 00:00:00" || + sdata == "0001-01-01 00:00:00" { + } else if !strings.ContainsAny(sdata, "- :") { + // time stamp + sd, err := strconv.ParseInt(sdata, 10, 64) + if err == nil { + x = time.Unix(0, sd) + } + } else if len(sdata) > 19 { + x, err = time.Parse(time.RFC3339Nano, sdata) + if err != nil { + x, err = time.Parse("2006-01-02 15:04:05.999999999", sdata) + } + } else if len(sdata) == 19 { + x, err = time.Parse("2006-01-02 15:04:05", sdata) + } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { + x, err = time.Parse("2006-01-02", sdata) + } else if col.SQLType.Name == Time { + if len(sdata) > 8 { + sdata = sdata[len(sdata)-8:] + } + st := fmt.Sprintf("2006-01-02 %v", sdata) + x, err = time.Parse("2006-01-02 15:04:05", st) + } else { + return errors.New(fmt.Sprintf("unsupported time format %v", string(data))) + } + if err != nil { + return errors.New(fmt.Sprintf("unsupported time format %v: %v", string(data), err)) + } + + v = x + fieldValue.Set(reflect.ValueOf(&x)) + case "*uint64": + var x uint64 + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + case "*uint": + var x uint + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint(x1) + fieldValue.Set(reflect.ValueOf(&x)) + case "*uint32": + var x uint32 + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint32(x1) + fieldValue.Set(reflect.ValueOf(&x)) + case "*uint8": + var x uint8 + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint8(x1) + fieldValue.Set(reflect.ValueOf(&x)) + case "*uint16": + var x uint16 + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint16(x1) + fieldValue.Set(reflect.ValueOf(&x)) + case "*int64": + sdata := string(data) + var x int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int64(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x, err = strconv.ParseInt(sdata, 16, 64) + } else if strings.HasPrefix(sdata, "0") { + x, err = strconv.ParseInt(sdata, 8, 64) + } else { + x, err = strconv.ParseInt(sdata, 10, 64) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + case "*int": + sdata := string(data) + var x int + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + case "*int32": + sdata := string(data) + var x int32 + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int32(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int32(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int32(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int32(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + case "*int8": + sdata := string(data) + var x int8 + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int8(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int8(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int8(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int8(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + case "*int16": + sdata := string(data) + var x int16 + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int16(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int16(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int16(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int16(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + default: + return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + } default: return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) } @@ -1742,8 +2003,22 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( } } } + fieldType := fieldValue.Type() + k := fieldType.Kind() + if k == reflect.Ptr { + if fieldValue.IsNil() { + return nil, nil + } else if !fieldValue.IsValid() { + session.Engine.LogWarn("the field[", col.FieldName, "] is invalid") + return nil, nil + } else { + // !nashtsai! deference pointer type to instance type + fieldValue = fieldValue.Elem() + fieldType = fieldValue.Type() + k = fieldType.Kind() + } + } - k := fieldValue.Type().Kind() switch k { case reflect.Bool: if fieldValue.Bool() { @@ -1754,7 +2029,7 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( case reflect.String: return fieldValue.String(), nil case reflect.Struct: - if fieldValue.Type().String() == "time.Time" { + if fieldType.String() == "time.Time" { if col.SQLType.Name == Time { //s := fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700") s := fieldValue.Interface().(time.Time).Format(time.RFC3339) @@ -2178,7 +2453,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.Statement.ColumnStr == "" { colNames, args = buildConditions(session.Engine, table, bean, false, false, - session.Statement.allUseBool, session.Statement.boolColumnMap) + false, session.Statement.allUseBool, session.Statement.boolColumnMap) } else { colNames, args, err = table.genCols(session, bean, true, true) if err != nil { @@ -2212,7 +2487,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condiBean) > 0 { condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, true, - session.Statement.allUseBool, session.Statement.boolColumnMap) + false, session.Statement.allUseBool, session.Statement.boolColumnMap) } var condition = "" @@ -2378,7 +2653,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { table := session.Engine.autoMap(bean) session.Statement.RefTable = table colNames, args := buildConditions(session.Engine, table, bean, true, true, - session.Statement.allUseBool, session.Statement.boolColumnMap) + false, session.Statement.allUseBool, session.Statement.boolColumnMap) var condition = "" if session.Statement.WhereStr != "" { diff --git a/sqlite3_test.go b/sqlite3_test.go index 82df0508..531f9ad5 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -26,6 +26,7 @@ func TestSqlite3(t *testing.T) { testAll(engine, t) testAll2(engine, t) + testAll3(engine, t) } func TestSqlite3WithCache(t *testing.T) { diff --git a/statement.go b/statement.go index a0cdfaa9..87c3f7b3 100644 --- a/statement.go +++ b/statement.go @@ -234,8 +234,9 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { // Auto generating conditions according a struct func buildConditions(engine *Engine, table *Table, bean interface{}, - includeVersion bool, includeUpdated bool, allUseBool bool, + includeVersion bool, includeUpdated bool, includeNil bool, allUseBool bool, boolColumnMap map[string]bool) ([]string, []interface{}) { + colNames := make([]string, 0) var args = make([]interface{}, 0) for _, col := range table.Columns { @@ -247,10 +248,29 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, } fieldValue := col.ValueOf(bean) fieldType := reflect.TypeOf(fieldValue.Interface()) + + requiredField := false + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + args = append(args, nil) + colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + var val interface{} switch fieldType.Kind() { case reflect.Bool: - if allUseBool { + if allUseBool || requiredField { val = fieldValue.Interface() } else if _, ok := boolColumnMap[col.Name]; ok { val = fieldValue.Interface() @@ -260,7 +280,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, continue } case reflect.String: - if fieldValue.String() == "" { + if !requiredField && fieldValue.String() == "" { continue } // for MyString, should convert to string or panic @@ -270,24 +290,24 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, val = fieldValue.Interface() } case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if fieldValue.Int() == 0 { + if !requiredField && fieldValue.Int() == 0 { continue } val = fieldValue.Interface() case reflect.Float32, reflect.Float64: - if fieldValue.Float() == 0.0 { + if !requiredField && fieldValue.Float() == 0.0 { continue } val = fieldValue.Interface() case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if fieldValue.Uint() == 0 { + if !requiredField && fieldValue.Uint() == 0 { continue } val = fieldValue.Interface() case reflect.Struct: if fieldType == reflect.TypeOf(time.Now()) { t := fieldValue.Interface().(time.Time) - if t.IsZero() || !fieldValue.IsValid() { + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { continue } var str string @@ -587,12 +607,14 @@ func (s *Statement) genDropSQL() string { return sql } +// !nashtsai! REVIEW, Statement is a huge struct why is this method not passing *Statement? func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) { table := statement.Engine.autoMap(bean) statement.RefTable = table colNames, args := buildConditions(statement.Engine, table, bean, true, true, - statement.allUseBool, statement.boolColumnMap) + false, statement.allUseBool, statement.boolColumnMap) + statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args @@ -629,7 +651,9 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{}) table := statement.Engine.autoMap(bean) statement.RefTable = table - colNames, args := buildConditions(statement.Engine, table, bean, true, true, statement.allUseBool, statement.boolColumnMap) + colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, + statement.allUseBool, statement.boolColumnMap) + statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args var id string = "*" diff --git a/table.go b/table.go index 012c4030..901c9ecc 100644 --- a/table.go +++ b/table.go @@ -2,6 +2,7 @@ package xorm import ( "reflect" + "sort" "strings" "time" ) @@ -24,6 +25,8 @@ func (s *SQLType) IsBlob() bool { s.Name == Binary || s.Name == VarBinary || s.Name == Bytea } +const () + var ( Bit = "BIT" TinyInt = "TINYINT" @@ -107,6 +110,9 @@ var ( Serial: true, BigSerial: true, } + + intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"} + uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} ) var b byte @@ -140,12 +146,41 @@ func Type2SQLType(t reflect.Type) (st SQLType) { } else { st = SQLType{Text, 0, 0} } + case reflect.Ptr: + st, _ = ptrType2SQLType(t) default: st = SQLType{Text, 0, 0} } return } +func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) { + typeStr := t.String() + has = true + + switch typeStr { + case "*string": + st = SQLType{Varchar, 255, 0} + case "*bool": + st = SQLType{Bool, 0, 0} + case "*complex64", "*complex128": + st = SQLType{Varchar, 64, 0} + case "*float32": + st = SQLType{Float, 0, 0} + case "*float64": + st = SQLType{Varchar, 64, 0} + case "*int64", "*uint64": + st = SQLType{BigInt, 0, 0} + case "*time.Time": + st = SQLType{DateTime, 0, 0} + case "*int", "*int16", "*int32", "*int8", "*uint", "*uint16", "*uint32", "*uint8": + st = SQLType{Int, 0, 0} + default: + has = false + } + return +} + // default sql type change to go types func SQLType2Type(st SQLType) reflect.Type { name := strings.ToUpper(st.Name) diff --git a/xorm/reverse.go b/xorm/reverse.go index f473ad91..03c07c7d 100644 --- a/xorm/reverse.go +++ b/xorm/reverse.go @@ -132,7 +132,7 @@ func runReverse(cmd *Command, args []string) { } if langTmpl, ok = langTmpls[lang]; !ok { - fmt.Println("Unsupported lang", lang) + fmt.Println("Unsupported programing language", lang) return } diff --git a/xorm/shell.go b/xorm/shell.go index 7f4cb200..cd6462ce 100644 --- a/xorm/shell.go +++ b/xorm/shell.go @@ -24,6 +24,19 @@ func init() { var engine *xorm.Engine +func help() { + fmt.Println(` + show tables show all tables + columns show table's column info + indexes show table's index info + exit exit shell + source exec sql file to current database + dump [-nodata] dump structs or records to sql file + help show this document + SQL statement + `) +} + func runShell(cmd *Command, args []string) { if len(args) != 2 { fmt.Println("params error, please see xorm help shell") @@ -37,6 +50,12 @@ func runShell(cmd *Command, args []string) { return } + err = engine.Ping() + if err != nil { + fmt.Println(err) + return + } + var scmd string fmt.Print("xorm$ ") for { @@ -107,6 +126,13 @@ func runShell(cmd *Command, args []string) { //fmt.Println(res) } } + } else if lcmd == "show tables;" { + tables, err := engine.DBMetas() + if err != nil { + fmt.Println(err) + } else { + + } } else { cnt, err := engine.Exec(scmd) if err != nil {