From e6a2c62ab1e64bc39688c47343f8fb4835bfc1f6 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 6 Dec 2013 10:31:38 +0800 Subject: [PATCH 01/13] committed partially done on null value support --- base_test.go | 105 +++++++++++++++++ session.go | 320 ++++++++++++++++++++++++++++++++++++++++++++++++++- table.go | 32 ++++++ 3 files changed, 453 insertions(+), 4 deletions(-) diff --git a/base_test.go b/base_test.go index 6b3040ea..995b5773 100644 --- a/base_test.go +++ b/base_test.go @@ -2469,6 +2469,109 @@ func testProcessorsTx(engine *Engine, t *testing.T) { // -- } +type NullData struct { + Id int64 + StringPtr *string + StringPtr2 *string `xorm:"text"` + BoolPtr *bool + BytePtr *byte + UintPtr *uint + Uint8Ptr *uint8 + Uint16Ptr *uint16 + Uint32Ptr *uint32 + UInt64Ptr *uint64 + IntPtr *int + Int8Ptr *int8 + Int16Ptr *int16 + Int32Ptr *int32 + Int64Ptr *int64 + RunePtr *rune + Float32Ptr *float32 + Float64Ptr *float64 + // Complex64Ptr *complex64 + // Complex128Ptr *complex128 + TimePtr *time.Time +} + +type NullData2 struct { + Id int64 + StringPtr string + StringPtr2 string `xorm:"text"` + BoolPtr bool + BytePtr byte + UintPtr uint + Uint8Ptr uint8 + Uint16Ptr uint16 + Uint32Ptr uint32 + UInt64Ptr uint64 + IntPtr int + Int8Ptr int8 + Int16Ptr int16 + Int32Ptr int32 + Int64Ptr int64 + RunePtr rune + Float32Ptr float32 + Float64Ptr float64 + //Complex64Ptr complex64 + //Complex128Ptr complex128 + TimePtr time.Time +} + +type NullData3 struct { + Id int64 + StringPtr *string +} + +func insertNullData(engine *Engine, t *testing.T) { + + err := engine.DropTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } + + nullData := NullData{BoolPtr: new(bool)} + *nullData.BoolPtr = true + cnt, err := engine.Insert(&nullData) + fmt.Println(nullData.Id) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + if nullData.Id <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } + + nullDataGet := NullData{} + + has, err := engine.Table("null_data").Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + panic(err) + } else if !has { + t.Error(errors.New("ID not found")) + } + + // if nullData2.BoolPtr == nil || !*(nullData2.BoolPtr) { + // t.Error(errors.New("BoolPtr wrong value")) + // } + +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -2571,4 +2674,6 @@ func testAll2(engine *Engine, t *testing.T) { testProcessorsTx(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) + fmt.Println("-------------- insert null data --------------") + insertNullData(engine, t) } diff --git a/session.go b/session.go index 197773db..d56843e3 100644 --- a/session.go +++ b/session.go @@ -1660,7 +1660,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data fieldValue.SetUint(x) //Now only support Time type case reflect.Struct: - if fieldValue.Type().String() == "time.Time" { + if fieldType.String() == "time.Time" { sdata := strings.TrimSpace(string(data)) var x time.Time var err error @@ -1723,6 +1723,264 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return errors.New("unsupported struct type in Scan: " + fieldValue.Type().String()) } } + case reflect.Ptr: + // TODO merge duplicated codes above + typeStr := fieldType.String() + switch typeStr { + case "*string": + x := string(data) + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*bool": + d := string(data) + v, err := strconv.ParseBool(d) + if err != nil { + return errors.New("arg " + key + " as bool: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(v).Addr()) + case "*complex64": + var x complex64 + err := json.Unmarshal(data, &x) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*complex128": + var x complex128 + err := json.Unmarshal(data, &x) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*float64": + x, err := strconv.ParseFloat(string(data), 64) + if err != nil { + return errors.New("arg " + key + " as float64: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*float32": + var x float32 + x1, err := strconv.ParseFloat(string(data), 32) + if err != nil { + return errors.New("arg " + key + " as float32: " + err.Error()) + } + x = float32(x1) + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*time.Time": + sdata := strings.TrimSpace(string(data)) + var x time.Time + var err error + + if sdata == "0000-00-00 00:00:00" || + sdata == "0001-01-01 00:00:00" { + } else if !strings.ContainsAny(sdata, "- :") { + // time stamp + sd, err := strconv.ParseInt(sdata, 10, 64) + if err == nil { + x = time.Unix(0, sd) + } + } else if len(sdata) > 19 { + x, err = time.Parse(time.RFC3339Nano, sdata) + if err != nil { + x, err = time.Parse("2006-01-02 15:04:05.999999999", sdata) + } + } else if len(sdata) == 19 { + x, err = time.Parse("2006-01-02 15:04:05", sdata) + } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { + x, err = time.Parse("2006-01-02", sdata) + } else if col.SQLType.Name == Time { + if len(sdata) > 8 { + sdata = sdata[len(sdata)-8:] + } + st := fmt.Sprintf("2006-01-02 %v", sdata) + x, err = time.Parse("2006-01-02 15:04:05", st) + } else { + return errors.New(fmt.Sprintf("unsupported time format %v", string(data))) + } + if err != nil { + return errors.New(fmt.Sprintf("unsupported time format %v: %v", string(data), err)) + } + + v = x + fieldValue.Set(reflect.ValueOf(v).Addr()) + case "*uint64": + var x uint64 + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*uint": + var x uint + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint(x1) + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*uint32": + var x uint32 + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint32(x1) + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*uint8": + var x uint8 + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint8(x1) + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*uint16": + var x uint16 + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint16(x1) + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*int64": + sdata := string(data) + var x int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int64(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x, err = strconv.ParseInt(sdata, 16, 64) + } else if strings.HasPrefix(sdata, "0") { + x, err = strconv.ParseInt(sdata, 8, 64) + } else { + x, err = strconv.ParseInt(sdata, 10, 64) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*int": + sdata := string(data) + var x int + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*int32": + sdata := string(data) + var x int32 + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int32(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int32(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int32(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int32(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*int8": + sdata := string(data) + var x int8 + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int8(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int8(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int8(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int8(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*int16": + sdata := string(data) + var x int16 + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int16(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int16(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int16(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int16(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + } + fallthrough default: return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) } @@ -1742,8 +2000,8 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( } } } - - k := fieldValue.Type().Kind() + fieldType := fieldValue.Type() + k := fieldType.Kind() switch k { case reflect.Bool: if fieldValue.Bool() { @@ -1754,7 +2012,7 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( case reflect.String: return fieldValue.String(), nil case reflect.Struct: - if fieldValue.Type().String() == "time.Time" { + if fieldType.String() == "time.Time" { if col.SQLType.Name == Time { //s := fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700") s := fieldValue.Interface().(time.Time).Format(time.RFC3339) @@ -1812,6 +2070,60 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( } else { return nil, ErrUnSupportedType } + case reflect.Ptr: + typeStr := fieldType.String() + if typeStr == "*string" { + if fieldValue.IsNil() { + return nil, nil + } else { + return fieldValue.Elem().String(), nil + } + } else if typeStr == "*bool" { + if fieldValue.IsNil() { + return nil, nil + } else { + return fieldValue.Elem().Bool(), nil + } + } else if typeStr == "*complex64" || typeStr == "*complex128" { + if fieldValue.IsNil() { + return nil, nil + } else { + bytes, err := json.Marshal(fieldValue.Elem().Complex()) + if err != nil { + session.Engine.LogSQL(err) + return 0, err + } + return string(bytes), nil + } + } else if typeStr == "*float32" || typeStr == "*float64" { + if fieldValue.IsNil() { + return nil, nil + } else { + return fieldValue.Elem().Float(), nil + } + } else if typeStr == "*time.Time" { + if fieldValue.IsNil() { + return nil, nil + } else { + if col.SQLType.Name == Time { + //s := fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700") + s := fieldValue.Elem().Interface().(time.Time).Format(time.RFC3339) + return s[11:19], nil + } else if col.SQLType.Name == Date { + return fieldValue.Elem().Interface().(time.Time).Format("2006-01-02"), nil + } else if col.SQLType.Name == TimeStampz { + return fieldValue.Elem().Interface().(time.Time).Format(time.RFC3339Nano), nil + } + return fieldValue.Elem().Interface(), nil + } + } else if typeStr == "*int64" || typeStr == "*uint64" || intTypes.Search(typeStr) < len(intTypes) { + if fieldValue.IsNil() { + return nil, nil + } else { + return fieldValue.Elem().Int(), nil + } + } + fallthrough default: return fieldValue.Interface(), nil } diff --git a/table.go b/table.go index 012c4030..479e060d 100644 --- a/table.go +++ b/table.go @@ -2,6 +2,7 @@ package xorm import ( "reflect" + "sort" "strings" "time" ) @@ -24,6 +25,8 @@ func (s *SQLType) IsBlob() bool { s.Name == Binary || s.Name == VarBinary || s.Name == Bytea } +const () + var ( Bit = "BIT" TinyInt = "TINYINT" @@ -107,6 +110,8 @@ var ( Serial: true, BigSerial: true, } + + intTypes = sort.StringSlice{"*int", "*int16", "*int32 ", "*int8 ", "*uint", "*uint16", "*uint32", "*uint8"} ) var b byte @@ -140,12 +145,39 @@ func Type2SQLType(t reflect.Type) (st SQLType) { } else { st = SQLType{Text, 0, 0} } + case reflect.Ptr: + st, _ = ptrType2SQLType(t) default: st = SQLType{Text, 0, 0} } return } +func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) { + typeStr := t.String() + has = true + if typeStr == "*string" { + st = SQLType{Varchar, 255, 0} + } else if typeStr == "*bool" { + st = SQLType{Bool, 0, 0} + } else if typeStr == "*complex64" || typeStr == "*complex128" { + st = SQLType{Varchar, 64, 0} + } else if typeStr == "*float32" { + st = SQLType{Float, 0, 0} + } else if typeStr == "*float64" { + st = SQLType{Varchar, 64, 0} + } else if typeStr == "*int64" || typeStr == "*uint64" { + st = SQLType{BigInt, 0, 0} + } else if typeStr == "*time.Time" { + st = SQLType{DateTime, 0, 0} + } else if intTypes.Search(typeStr) < len(intTypes) { + st = SQLType{Int, 0, 0} + } else { + has = false + } + return +} + // default sql type change to go types func SQLType2Type(st SQLType) reflect.Type { name := strings.ToUpper(st.Name) From 8ad1baa4cc7b696aa7f508d955d23e7ff3dde9c8 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 6 Dec 2013 11:06:02 +0800 Subject: [PATCH 02/13] tidy naming up for mymysql_test --- mymysql_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mymysql_test.go b/mymysql_test.go index 49366c14..872fc16c 100644 --- a/mymysql_test.go +++ b/mymysql_test.go @@ -13,7 +13,7 @@ utf8 COLLATE utf8_general_ci; var showTestSql bool = true func TestMyMysql(t *testing.T) { - err := mysqlDdlImport() + err := mymysqlDdlImport() if err != nil { t.Error(err) return @@ -37,7 +37,7 @@ func TestMyMysql(t *testing.T) { } func TestMyMysqlWithCache(t *testing.T) { - err := mysqlDdlImport() + err := mymysqlDdlImport() if err != nil { t.Error(err) return @@ -61,7 +61,7 @@ func TestMyMysqlWithCache(t *testing.T) { testAll2(engine, t) } -func mysqlDdlImport() error { +func mymysqlDdlImport() error { engine, err := NewEngine("mymysql", "/root/") if err != nil { return err From b71b3f0ad3e676c4d15064cf3a751790c375c439 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 6 Dec 2013 11:09:34 +0800 Subject: [PATCH 03/13] add mysqlDdlImport to mysql_test.go --- mysql_test.go | 46 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/mysql_test.go b/mysql_test.go index 106e898e..a8d0c803 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -10,23 +10,37 @@ CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET utf8 COLLATE utf8_general_ci; */ +var mysqlShowTestSql bool = true + func TestMysql(t *testing.T) { + err := mysqlDdlImport() + if err != nil { + t.Error(err) + return + } + engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8") defer engine.Close() if err != nil { t.Error(err) return } - engine.ShowSQL = showTestSql - engine.ShowErr = showTestSql - engine.ShowWarn = showTestSql - engine.ShowDebug = showTestSql + engine.ShowSQL = mysqlShowTestSql + engine.ShowErr = mysqlShowTestSql + engine.ShowWarn = mysqlShowTestSql + engine.ShowDebug = mysqlShowTestSql testAll(engine, t) testAll2(engine, t) } func TestMysqlWithCache(t *testing.T) { + err := mysqlDdlImport() + if err != nil { + t.Error(err) + return + } + engine, err := NewEngine("mysql", "root:@/xorm_test2?charset=utf8") defer engine.Close() if err != nil { @@ -34,15 +48,31 @@ func TestMysqlWithCache(t *testing.T) { return } engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) - engine.ShowSQL = showTestSql - engine.ShowErr = showTestSql - engine.ShowWarn = showTestSql - engine.ShowDebug = showTestSql + engine.ShowSQL = mysqlShowTestSql + engine.ShowErr = mysqlShowTestSql + engine.ShowWarn = mysqlShowTestSql + engine.ShowDebug = mysqlShowTestSql testAll(engine, t) testAll2(engine, t) } +func mysqlDdlImport() error { + engine, err := NewEngine("mysql", "root:@/?charset=utf8") + if err != nil { + return err + } + engine.ShowSQL = mysqlShowTestSql + engine.ShowErr = mysqlShowTestSql + engine.ShowWarn = mysqlShowTestSql + engine.ShowDebug = mysqlShowTestSql + + sqlResults, _ := engine.Import("tests/mysql_ddl.sql") + engine.LogDebug("sql results: %v", sqlResults) + engine.Close() + return nil +} + func BenchmarkMysqlNoCache(t *testing.B) { engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8") defer engine.Close() From 667dcd039f76699e6004c216fb2d0c9a76fa298d Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 6 Dec 2013 15:17:50 +0800 Subject: [PATCH 04/13] completed db null value and pointer type testing --- base_test.go | 238 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 230 insertions(+), 8 deletions(-) diff --git a/base_test.go b/base_test.go index 995b5773..b70d1a90 100644 --- a/base_test.go +++ b/base_test.go @@ -2479,7 +2479,7 @@ type NullData struct { Uint8Ptr *uint8 Uint16Ptr *uint16 Uint32Ptr *uint32 - UInt64Ptr *uint64 + Uint64Ptr *uint64 IntPtr *int Int8Ptr *int8 Int16Ptr *int16 @@ -2503,7 +2503,7 @@ type NullData2 struct { Uint8Ptr uint8 Uint16Ptr uint16 Uint32Ptr uint32 - UInt64Ptr uint64 + Uint64Ptr uint64 IntPtr int Int8Ptr int8 Int16Ptr int16 @@ -2522,7 +2522,7 @@ type NullData3 struct { StringPtr *string } -func insertNullData(engine *Engine, t *testing.T) { +func testPointerData(engine *Engine, t *testing.T) { err := engine.DropTables(&NullData{}) if err != nil { @@ -2536,8 +2536,50 @@ func insertNullData(engine *Engine, t *testing.T) { panic(err) } - nullData := NullData{BoolPtr: new(bool)} + nullData := NullData{ + StringPtr: new(string), + StringPtr2: new(string), + BoolPtr: new(bool), + BytePtr: new(byte), + UintPtr: new(uint), + Uint8Ptr: new(uint8), + Uint16Ptr: new(uint16), + Uint32Ptr: new(uint32), + Uint64Ptr: new(uint64), + IntPtr: new(int), + Int8Ptr: new(int8), + Int16Ptr: new(int16), + Int32Ptr: new(int32), + Int64Ptr: new(int64), + RunePtr: new(rune), + Float32Ptr: new(float32), + Float64Ptr: new(float64), + // Complex64Ptr :new(complex64), + // Complex128Ptr :new(complex128), + TimePtr: new(time.Time), + } + + *nullData.StringPtr = "abc" + *nullData.StringPtr2 = "123" *nullData.BoolPtr = true + *nullData.BytePtr = 1 + *nullData.UintPtr = 1 + *nullData.Uint8Ptr = 1 + *nullData.Uint16Ptr = 1 + *nullData.Uint32Ptr = 1 + *nullData.Uint64Ptr = 1 + *nullData.IntPtr = -1 + *nullData.Int8Ptr = -1 + *nullData.Int16Ptr = -1 + *nullData.Int32Ptr = -1 + *nullData.Int64Ptr = -1 + *nullData.RunePtr = 1 + *nullData.Float32Ptr = -1.2 + *nullData.Float64Ptr = -1.1 + // *nullData.Complex64Ptr :new(complex64), + // *nullData.Complex128Ptr :new(complex128), + *nullData.TimePtr = time.Now() + cnt, err := engine.Insert(&nullData) fmt.Println(nullData.Id) if err != nil { @@ -2566,10 +2608,188 @@ func insertNullData(engine *Engine, t *testing.T) { t.Error(errors.New("ID not found")) } - // if nullData2.BoolPtr == nil || !*(nullData2.BoolPtr) { - // t.Error(errors.New("BoolPtr wrong value")) - // } + if *nullDataGet.StringPtr != *nullData.StringPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr))) + } + if *nullDataGet.StringPtr2 != *nullData.StringPtr2 { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr2))) + } + + if *nullDataGet.BoolPtr != *nullData.BoolPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%t]", *nullDataGet.BoolPtr))) + } + + if *nullDataGet.UintPtr != *nullData.UintPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.UintPtr))) + } + + if *nullDataGet.Uint8Ptr != *nullData.Uint8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint8Ptr))) + } + + if *nullDataGet.Uint16Ptr != *nullData.Uint16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint16Ptr))) + } + + if *nullDataGet.Uint32Ptr != *nullData.Uint32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint32Ptr))) + } + + if *nullDataGet.Uint64Ptr != *nullData.Uint64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint64Ptr))) + } + + if *nullDataGet.IntPtr != *nullData.IntPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.IntPtr))) + } + + if *nullDataGet.Int8Ptr != *nullData.Int8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int8Ptr))) + } + + if *nullDataGet.Int16Ptr != *nullData.Int16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int16Ptr))) + } + + if *nullDataGet.Int32Ptr != *nullData.Int32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int32Ptr))) + } + + if *nullDataGet.Int64Ptr != *nullData.Int64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int64Ptr))) + } + + if *nullDataGet.RunePtr != *nullData.RunePtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.RunePtr))) + } + + if *nullDataGet.Float32Ptr != *nullData.Float32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float32Ptr))) + } + + if *nullDataGet.Float64Ptr != *nullData.Float64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr))) + } + + if (*nullDataGet.TimePtr).Unix() != (*nullData.TimePtr).Unix() { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr))) + } +} + +func testNullValue(engine *Engine, t *testing.T) { + + err := engine.DropTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } + + nullData := NullData{} + + cnt, err := engine.Insert(&nullData) + fmt.Println(nullData.Id) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + if nullData.Id <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } + + nullDataGet := NullData{} + + has, err := engine.Table("null_data").Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + panic(err) + } else if !has { + t.Error(errors.New("ID not found")) + } + + fmt.Printf("val: %+v\n", nullDataGet) + + if nullDataGet.StringPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr))) + } + + if nullDataGet.StringPtr2 != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr2))) + } + + if nullDataGet.BoolPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%t]", *nullDataGet.BoolPtr))) + } + + if nullDataGet.UintPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.UintPtr))) + } + + if nullDataGet.Uint8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint8Ptr))) + } + + if nullDataGet.Uint16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint16Ptr))) + } + + if nullDataGet.Uint32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint32Ptr))) + } + + if nullDataGet.Uint64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint64Ptr))) + } + + if nullDataGet.IntPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.IntPtr))) + } + + if nullDataGet.Int8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int8Ptr))) + } + + if nullDataGet.Int16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int16Ptr))) + } + + if nullDataGet.Int32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int32Ptr))) + } + + if nullDataGet.Int64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int64Ptr))) + } + + if nullDataGet.RunePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.RunePtr))) + } + + if nullDataGet.Float32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float32Ptr))) + } + + if nullDataGet.Float64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + } + + if nullDataGet.TimePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr))) + } } func testAll(engine *Engine, t *testing.T) { @@ -2674,6 +2894,8 @@ func testAll2(engine *Engine, t *testing.T) { testProcessorsTx(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) + fmt.Println("-------------- insert pointer data --------------") + testPointerData(engine, t) fmt.Println("-------------- insert null data --------------") - insertNullData(engine, t) + testNullValue(engine, t) } From f44f70acd82fa665c318db433602cf812678c3ec Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 6 Dec 2013 15:19:48 +0800 Subject: [PATCH 05/13] add nil value filter for buildConditions --- statement.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/statement.go b/statement.go index 90837a3b..060c64ae 100644 --- a/statement.go +++ b/statement.go @@ -344,6 +344,10 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers } else { continue } + case reflect.Ptr: + if fieldValue.IsNil() || !fieldValue.IsValid() { + continue + } default: val = fieldValue.Interface() } From dfcf1d749884634ec577a45938134db4da757944 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 6 Dec 2013 15:20:25 +0800 Subject: [PATCH 06/13] completed value2Interface implementation for Ptr type --- session.go | 46 +++++++++++++++++++++++++++------------------- table.go | 5 +++-- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/session.go b/session.go index d56843e3..eef9f5f6 100644 --- a/session.go +++ b/session.go @@ -1729,14 +1729,14 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data switch typeStr { case "*string": x := string(data) - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*bool": d := string(data) v, err := strconv.ParseBool(d) if err != nil { return errors.New("arg " + key + " as bool: " + err.Error()) } - fieldValue.Set(reflect.ValueOf(v).Addr()) + fieldValue.Set(reflect.ValueOf(&v)) case "*complex64": var x complex64 err := json.Unmarshal(data, &x) @@ -1744,7 +1744,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data session.Engine.LogSQL(err) return err } - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*complex128": var x complex128 err := json.Unmarshal(data, &x) @@ -1752,13 +1752,13 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data session.Engine.LogSQL(err) return err } - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*float64": x, err := strconv.ParseFloat(string(data), 64) if err != nil { return errors.New("arg " + key + " as float64: " + err.Error()) } - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*float32": var x float32 x1, err := strconv.ParseFloat(string(data), 32) @@ -1766,7 +1766,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return errors.New("arg " + key + " as float32: " + err.Error()) } x = float32(x1) - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*time.Time": sdata := strings.TrimSpace(string(data)) var x time.Time @@ -1803,14 +1803,14 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data } v = x - fieldValue.Set(reflect.ValueOf(v).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*uint64": var x uint64 x, err := strconv.ParseUint(string(data), 10, 64) if err != nil { return errors.New("arg " + key + " as int: " + err.Error()) } - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*uint": var x uint x1, err := strconv.ParseUint(string(data), 10, 64) @@ -1818,7 +1818,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return errors.New("arg " + key + " as int: " + err.Error()) } x = uint(x1) - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*uint32": var x uint32 x1, err := strconv.ParseUint(string(data), 10, 64) @@ -1826,7 +1826,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return errors.New("arg " + key + " as int: " + err.Error()) } x = uint32(x1) - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*uint8": var x uint8 x1, err := strconv.ParseUint(string(data), 10, 64) @@ -1834,7 +1834,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return errors.New("arg " + key + " as int: " + err.Error()) } x = uint8(x1) - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*uint16": var x uint16 x1, err := strconv.ParseUint(string(data), 10, 64) @@ -1842,7 +1842,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return errors.New("arg " + key + " as int: " + err.Error()) } x = uint16(x1) - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*int64": sdata := string(data) var x int64 @@ -1866,7 +1866,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data if err != nil { return errors.New("arg " + key + " as int: " + err.Error()) } - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*int": sdata := string(data) var x int @@ -1894,7 +1894,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data if err != nil { return errors.New("arg " + key + " as int: " + err.Error()) } - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*int32": sdata := string(data) var x int32 @@ -1922,7 +1922,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data if err != nil { return errors.New("arg " + key + " as int: " + err.Error()) } - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*int8": sdata := string(data) var x int8 @@ -1950,7 +1950,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data if err != nil { return errors.New("arg " + key + " as int: " + err.Error()) } - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) case "*int16": sdata := string(data) var x int16 @@ -1978,9 +1978,10 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data if err != nil { return errors.New("arg " + key + " as int: " + err.Error()) } - fieldValue.Set(reflect.ValueOf(x).Addr()) + fieldValue.Set(reflect.ValueOf(&x)) + default: + return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) } - fallthrough default: return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) } @@ -2116,13 +2117,20 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( } return fieldValue.Elem().Interface(), nil } - } else if typeStr == "*int64" || typeStr == "*uint64" || intTypes.Search(typeStr) < len(intTypes) { + } else if typeStr == "*int64" || intTypes.Search(typeStr) < len(intTypes) { if fieldValue.IsNil() { return nil, nil } else { return fieldValue.Elem().Int(), nil } + } else if typeStr == "*uint64" || uintTypes.Search(typeStr) < len(uintTypes) { + if fieldValue.IsNil() { + return nil, nil + } else { + return fieldValue.Elem().Uint(), nil + } } + fallthrough default: return fieldValue.Interface(), nil diff --git a/table.go b/table.go index 479e060d..5091d2b4 100644 --- a/table.go +++ b/table.go @@ -111,7 +111,8 @@ var ( BigSerial: true, } - intTypes = sort.StringSlice{"*int", "*int16", "*int32 ", "*int8 ", "*uint", "*uint16", "*uint32", "*uint8"} + intTypes = sort.StringSlice{"*int", "*int16", "*int32 ", "*int8 "} + uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} ) var b byte @@ -170,7 +171,7 @@ func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) { st = SQLType{BigInt, 0, 0} } else if typeStr == "*time.Time" { st = SQLType{DateTime, 0, 0} - } else if intTypes.Search(typeStr) < len(intTypes) { + } else if intTypes.Search(typeStr) < len(intTypes) || uintTypes.Search(typeStr) < len(uintTypes) { st = SQLType{Int, 0, 0} } else { has = false From aae8a6c90e0681a1b5654f8cf093ce9366b6c42f Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 6 Dec 2013 19:55:47 +0800 Subject: [PATCH 07/13] create 3rd base test to void use of cache --- base_test.go | 238 +++++++++++++++++++++++++++++++++++++++++++++-- mymysql_test.go | 1 + mysql_test.go | 1 + postgres_test.go | 1 + sqlite3_test.go | 1 + 5 files changed, 235 insertions(+), 7 deletions(-) diff --git a/base_test.go b/base_test.go index b70d1a90..dadbe304 100644 --- a/base_test.go +++ b/base_test.go @@ -2598,9 +2598,9 @@ func testPointerData(engine *Engine, t *testing.T) { panic(err) } + // verify get values nullDataGet := NullData{} - - has, err := engine.Table("null_data").Id(nullData.Id).Get(&nullDataGet) + has, err := engine.Id(nullData.Id).Get(&nullDataGet) if err != nil { t.Error(err) panic(err) @@ -2674,7 +2674,94 @@ func testPointerData(engine *Engine, t *testing.T) { if (*nullDataGet.TimePtr).Unix() != (*nullData.TimePtr).Unix() { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr))) + } else { + fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr) + fmt.Println() } + // -- + + // using instance type should just work too + nullData2Get := NullData2{} + + has, err = engine.Table("null_data").Id(nullData.Id).Get(&nullData2Get) + if err != nil { + t.Error(err) + panic(err) + } else if !has { + t.Error(errors.New("ID not found")) + } + + if nullData2Get.StringPtr != *nullData.StringPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.StringPtr))) + } + + if nullData2Get.StringPtr2 != *nullData.StringPtr2 { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.StringPtr2))) + } + + if nullData2Get.BoolPtr != *nullData.BoolPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%t]", nullData2Get.BoolPtr))) + } + + if nullData2Get.UintPtr != *nullData.UintPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.UintPtr))) + } + + if nullData2Get.Uint8Ptr != *nullData.Uint8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint8Ptr))) + } + + if nullData2Get.Uint16Ptr != *nullData.Uint16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint16Ptr))) + } + + if nullData2Get.Uint32Ptr != *nullData.Uint32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint32Ptr))) + } + + if nullData2Get.Uint64Ptr != *nullData.Uint64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint64Ptr))) + } + + if nullData2Get.IntPtr != *nullData.IntPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.IntPtr))) + } + + if nullData2Get.Int8Ptr != *nullData.Int8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int8Ptr))) + } + + if nullData2Get.Int16Ptr != *nullData.Int16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int16Ptr))) + } + + if nullData2Get.Int32Ptr != *nullData.Int32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int32Ptr))) + } + + if nullData2Get.Int64Ptr != *nullData.Int64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int64Ptr))) + } + + if nullData2Get.RunePtr != *nullData.RunePtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.RunePtr))) + } + + if nullData2Get.Float32Ptr != *nullData.Float32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Float32Ptr))) + } + + if nullData2Get.Float64Ptr != *nullData.Float64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Float64Ptr))) + } + + if nullData2Get.TimePtr.Unix() != (*nullData.TimePtr).Unix() { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr))) + } else { + fmt.Printf("time value: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr) + fmt.Println() + } + // -- } func testNullValue(engine *Engine, t *testing.T) { @@ -2713,7 +2800,7 @@ func testNullValue(engine *Engine, t *testing.T) { nullDataGet := NullData{} - has, err := engine.Table("null_data").Id(nullData.Id).Get(&nullDataGet) + has, err := engine.Id(nullData.Id).Get(&nullDataGet) if err != nil { t.Error(err) panic(err) @@ -2721,8 +2808,6 @@ func testNullValue(engine *Engine, t *testing.T) { t.Error(errors.New("ID not found")) } - fmt.Printf("val: %+v\n", nullDataGet) - if nullDataGet.StringPtr != nil { t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr))) } @@ -2790,6 +2875,141 @@ func testNullValue(engine *Engine, t *testing.T) { if nullDataGet.TimePtr != nil { t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr))) } + + nullDataUpdate := NullData{ + StringPtr: new(string), + StringPtr2: new(string), + BoolPtr: new(bool), + BytePtr: new(byte), + UintPtr: new(uint), + Uint8Ptr: new(uint8), + Uint16Ptr: new(uint16), + Uint32Ptr: new(uint32), + Uint64Ptr: new(uint64), + IntPtr: new(int), + Int8Ptr: new(int8), + Int16Ptr: new(int16), + Int32Ptr: new(int32), + Int64Ptr: new(int64), + RunePtr: new(rune), + Float32Ptr: new(float32), + Float64Ptr: new(float64), + // Complex64Ptr :new(complex64), + // Complex128Ptr :new(complex128), + TimePtr: new(time.Time), + } + + *nullDataUpdate.StringPtr = "abc" + *nullDataUpdate.StringPtr2 = "123" + *nullDataUpdate.BoolPtr = true + *nullDataUpdate.BytePtr = 1 + *nullDataUpdate.UintPtr = 1 + *nullDataUpdate.Uint8Ptr = 1 + *nullDataUpdate.Uint16Ptr = 1 + *nullDataUpdate.Uint32Ptr = 1 + *nullDataUpdate.Uint64Ptr = 1 + *nullDataUpdate.IntPtr = -1 + *nullDataUpdate.Int8Ptr = -1 + *nullDataUpdate.Int16Ptr = -1 + *nullDataUpdate.Int32Ptr = -1 + *nullDataUpdate.Int64Ptr = -1 + *nullDataUpdate.RunePtr = 1 + *nullDataUpdate.Float32Ptr = -1.2 + *nullDataUpdate.Float64Ptr = -1.1 + // *nullDataUpdate.Complex64Ptr :new(complex64), + // *nullDataUpdate.Complex128Ptr :new(complex128), + *nullDataUpdate.TimePtr = time.Now() + + cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate) + if err != nil { + t.Error(err) + panic(err) + } else if cnt != 1 { + t.Error(errors.New("update count == 0, how can this happen!?")) + } + + // verify get values + nullDataGet = NullData{} + has, err = engine.Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + panic(err) + } else if !has { + t.Error(errors.New("ID not found")) + } + + if *nullDataGet.StringPtr != *nullDataUpdate.StringPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr))) + } + + if *nullDataGet.StringPtr2 != *nullDataUpdate.StringPtr2 { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr2))) + } + + if *nullDataGet.BoolPtr != *nullDataUpdate.BoolPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%t]", *nullDataGet.BoolPtr))) + } + + if *nullDataGet.UintPtr != *nullDataUpdate.UintPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.UintPtr))) + } + + if *nullDataGet.Uint8Ptr != *nullDataUpdate.Uint8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint8Ptr))) + } + + if *nullDataGet.Uint16Ptr != *nullDataUpdate.Uint16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint16Ptr))) + } + + if *nullDataGet.Uint32Ptr != *nullDataUpdate.Uint32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint32Ptr))) + } + + if *nullDataGet.Uint64Ptr != *nullDataUpdate.Uint64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint64Ptr))) + } + + if *nullDataGet.IntPtr != *nullDataUpdate.IntPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.IntPtr))) + } + + if *nullDataGet.Int8Ptr != *nullDataUpdate.Int8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int8Ptr))) + } + + if *nullDataGet.Int16Ptr != *nullDataUpdate.Int16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int16Ptr))) + } + + if *nullDataGet.Int32Ptr != *nullDataUpdate.Int32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int32Ptr))) + } + + if *nullDataGet.Int64Ptr != *nullDataUpdate.Int64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int64Ptr))) + } + + if *nullDataGet.RunePtr != *nullDataUpdate.RunePtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.RunePtr))) + } + + if *nullDataGet.Float32Ptr != *nullDataUpdate.Float32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float32Ptr))) + } + + if *nullDataGet.Float64Ptr != *nullDataUpdate.Float64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr))) + } + + if (*nullDataGet.TimePtr).Unix() != (*nullDataUpdate.TimePtr).Unix() { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr))) + } else { + fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr) + fmt.Println() + } + // -- + } func testAll(engine *Engine, t *testing.T) { @@ -2890,10 +3110,14 @@ func testAll2(engine *Engine, t *testing.T) { //testCreatedUpdated(engine, t) fmt.Println("-------------- processors --------------") testProcessors(engine, t) - fmt.Println("-------------- processors TX --------------") - testProcessorsTx(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) +} + +// !nash! the 3rd set of the test is intended for non-cache enabled engine +func testAll3(engine *Engine, t *testing.T) { + fmt.Println("-------------- processors TX --------------") + testProcessorsTx(engine, t) fmt.Println("-------------- insert pointer data --------------") testPointerData(engine, t) fmt.Println("-------------- insert null data --------------") diff --git a/mymysql_test.go b/mymysql_test.go index 872fc16c..6b6b4dda 100644 --- a/mymysql_test.go +++ b/mymysql_test.go @@ -34,6 +34,7 @@ func TestMyMysql(t *testing.T) { testAll(engine, t) testAll2(engine, t) + testAll3(engine, t) } func TestMyMysqlWithCache(t *testing.T) { diff --git a/mysql_test.go b/mysql_test.go index a8d0c803..58a2374a 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -32,6 +32,7 @@ func TestMysql(t *testing.T) { testAll(engine, t) testAll2(engine, t) + testAll3(engine, t) } func TestMysqlWithCache(t *testing.T) { diff --git a/postgres_test.go b/postgres_test.go index 3cea129a..b59507a6 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -21,6 +21,7 @@ func TestPostgres(t *testing.T) { testAll(engine, t) testAll2(engine, t) + testAll3(engine, t) } func TestPostgresWithCache(t *testing.T) { diff --git a/sqlite3_test.go b/sqlite3_test.go index 0a1c33a8..f0b9c364 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -18,6 +18,7 @@ func TestSqlite3(t *testing.T) { testAll(engine, t) testAll2(engine, t) + testAll3(engine, t) } func BenchmarkSqlite3NoCache(t *testing.B) { From 95eccb04eb5fca8f1f9dde76f54eb0c0cfdd43ec Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 6 Dec 2013 20:13:10 +0800 Subject: [PATCH 08/13] add null value update record testing --- base_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/base_test.go b/base_test.go index dadbe304..f6a83c78 100644 --- a/base_test.go +++ b/base_test.go @@ -3116,8 +3116,8 @@ func testAll2(engine *Engine, t *testing.T) { // !nash! the 3rd set of the test is intended for non-cache enabled engine func testAll3(engine *Engine, t *testing.T) { - fmt.Println("-------------- processors TX --------------") - testProcessorsTx(engine, t) + // fmt.Println("-------------- processors TX --------------") + // testProcessorsTx(engine, t) fmt.Println("-------------- insert pointer data --------------") testPointerData(engine, t) fmt.Println("-------------- insert null data --------------") From 0c9b7b274f219d9be3886f29e5737794dcb8c6a7 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 6 Dec 2013 20:49:11 +0800 Subject: [PATCH 09/13] get elem value() for Ptr type --- statement.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/statement.go b/statement.go index 060c64ae..d1be6e9a 100644 --- a/statement.go +++ b/statement.go @@ -347,6 +347,9 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers case reflect.Ptr: if fieldValue.IsNil() || !fieldValue.IsValid() { continue + } else { + // TODO need to filter support types + val = fieldValue.Elem() } default: val = fieldValue.Interface() From 68862870a99c7bd31e3b8ca10d8c27215c3d0ca7 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 6 Dec 2013 21:04:27 +0800 Subject: [PATCH 10/13] code tidy up for Type2SQLType and buildConditions --- statement.go | 13 +++++++++++-- table.go | 22 ++++++++++++---------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/statement.go b/statement.go index d1be6e9a..0e51c6ee 100644 --- a/statement.go +++ b/statement.go @@ -348,8 +348,17 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers if fieldValue.IsNil() || !fieldValue.IsValid() { continue } else { - // TODO need to filter support types - val = fieldValue.Elem() + typeStr := fieldType.String() + switch typeStr { + case "*string", "*bool", "*float32", "*float64", "*int64", "*uint64", "*int", "*int16", "*int32 ", "*int8 ", "*uint", "*uint16", "*uint32", "*uint8": + val = fieldValue.Elem() + case "*complex64", "*complex128": + continue // TODO + case "*time.Time": + continue // TODO + default: + continue // TODO + } } default: val = fieldValue.Interface() diff --git a/table.go b/table.go index 5091d2b4..901c9ecc 100644 --- a/table.go +++ b/table.go @@ -111,7 +111,7 @@ var ( BigSerial: true, } - intTypes = sort.StringSlice{"*int", "*int16", "*int32 ", "*int8 "} + intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"} uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} ) @@ -157,23 +157,25 @@ func Type2SQLType(t reflect.Type) (st SQLType) { func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) { typeStr := t.String() has = true - if typeStr == "*string" { + + switch typeStr { + case "*string": st = SQLType{Varchar, 255, 0} - } else if typeStr == "*bool" { + case "*bool": st = SQLType{Bool, 0, 0} - } else if typeStr == "*complex64" || typeStr == "*complex128" { + case "*complex64", "*complex128": st = SQLType{Varchar, 64, 0} - } else if typeStr == "*float32" { + case "*float32": st = SQLType{Float, 0, 0} - } else if typeStr == "*float64" { + case "*float64": st = SQLType{Varchar, 64, 0} - } else if typeStr == "*int64" || typeStr == "*uint64" { + case "*int64", "*uint64": st = SQLType{BigInt, 0, 0} - } else if typeStr == "*time.Time" { + case "*time.Time": st = SQLType{DateTime, 0, 0} - } else if intTypes.Search(typeStr) < len(intTypes) || uintTypes.Search(typeStr) < len(uintTypes) { + case "*int", "*int16", "*int32", "*int8", "*uint", "*uint16", "*uint32", "*uint8": st = SQLType{Int, 0, 0} - } else { + default: has = false } return From 772a2c7baa225e73b6beb1f8578feb093927bee3 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 6 Dec 2013 21:04:56 +0800 Subject: [PATCH 11/13] uncomment testProcessorsTx --- base_test.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/base_test.go b/base_test.go index f6a83c78..e95983aa 100644 --- a/base_test.go +++ b/base_test.go @@ -2926,6 +2926,7 @@ func testNullValue(engine *Engine, t *testing.T) { panic(err) } else if cnt != 1 { t.Error(errors.New("update count == 0, how can this happen!?")) + return } // verify get values @@ -2933,9 +2934,10 @@ func testNullValue(engine *Engine, t *testing.T) { has, err = engine.Id(nullData.Id).Get(&nullDataGet) if err != nil { t.Error(err) - panic(err) + return } else if !has { t.Error(errors.New("ID not found")) + return } if *nullDataGet.StringPtr != *nullDataUpdate.StringPtr { @@ -3116,8 +3118,8 @@ func testAll2(engine *Engine, t *testing.T) { // !nash! the 3rd set of the test is intended for non-cache enabled engine func testAll3(engine *Engine, t *testing.T) { - // fmt.Println("-------------- processors TX --------------") - // testProcessorsTx(engine, t) + fmt.Println("-------------- processors TX --------------") + testProcessorsTx(engine, t) fmt.Println("-------------- insert pointer data --------------") testPointerData(engine, t) fmt.Println("-------------- insert null data --------------") From a74f8db2326dc0c6ab55cfaa3f58f9068c70965d Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Sun, 8 Dec 2013 04:27:48 +0800 Subject: [PATCH 12/13] update statement.buildConditions method to support pointer values update --- base_test.go | 162 +++++++++++++++++++++++++++++++++++++++++++++++---- session.go | 12 ++-- statement.go | 56 ++++++++++-------- 3 files changed, 188 insertions(+), 42 deletions(-) diff --git a/base_test.go b/base_test.go index e95983aa..8ad6b7c3 100644 --- a/base_test.go +++ b/base_test.go @@ -2488,8 +2488,8 @@ type NullData struct { RunePtr *rune Float32Ptr *float32 Float64Ptr *float64 - // Complex64Ptr *complex64 - // Complex128Ptr *complex128 + // Complex64Ptr *complex64 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' + // Complex128Ptr *complex128 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' TimePtr *time.Time } @@ -2512,8 +2512,8 @@ type NullData2 struct { RunePtr rune Float32Ptr float32 Float64Ptr float64 - //Complex64Ptr complex64 - //Complex128Ptr complex128 + // Complex64Ptr complex64 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' + // Complex128Ptr complex128 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' TimePtr time.Time } @@ -2554,8 +2554,8 @@ func testPointerData(engine *Engine, t *testing.T) { RunePtr: new(rune), Float32Ptr: new(float32), Float64Ptr: new(float64), - // Complex64Ptr :new(complex64), - // Complex128Ptr :new(complex128), + // Complex64Ptr: new(complex64), + // Complex128Ptr: new(complex128), TimePtr: new(time.Time), } @@ -2576,8 +2576,8 @@ func testPointerData(engine *Engine, t *testing.T) { *nullData.RunePtr = 1 *nullData.Float32Ptr = -1.2 *nullData.Float64Ptr = -1.1 - // *nullData.Complex64Ptr :new(complex64), - // *nullData.Complex128Ptr :new(complex128), + // *nullData.Complex64Ptr = 123456789012345678901234567890 + // *nullData.Complex128Ptr = 123456789012345678901234567890123456789012345678901234567890 *nullData.TimePtr = time.Now() cnt, err := engine.Insert(&nullData) @@ -2672,9 +2672,18 @@ func testPointerData(engine *Engine, t *testing.T) { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr))) } + // if *nullDataGet.Complex64Ptr != *nullData.Complex64Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex64Ptr))) + // } + + // if *nullDataGet.Complex128Ptr != *nullData.Complex128Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr))) + // } + if (*nullDataGet.TimePtr).Unix() != (*nullData.TimePtr).Unix() { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr))) } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr) fmt.Println() } @@ -2755,9 +2764,18 @@ func testPointerData(engine *Engine, t *testing.T) { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Float64Ptr))) } + // if nullData2Get.Complex64Ptr != *nullData.Complex64Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Complex64Ptr))) + // } + + // if nullData2Get.Complex128Ptr != *nullData.Complex128Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Complex128Ptr))) + // } + if nullData2Get.TimePtr.Unix() != (*nullData.TimePtr).Unix() { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr))) } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver fmt.Printf("time value: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr) fmt.Println() } @@ -2872,6 +2890,14 @@ func testNullValue(engine *Engine, t *testing.T) { t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) } + // if nullDataGet.Complex64Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Complex64Ptr))) + // } + + // if nullDataGet.Complex128Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Complex128Ptr))) + // } + if nullDataGet.TimePtr != nil { t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr))) } @@ -2894,8 +2920,8 @@ func testNullValue(engine *Engine, t *testing.T) { RunePtr: new(rune), Float32Ptr: new(float32), Float64Ptr: new(float64), - // Complex64Ptr :new(complex64), - // Complex128Ptr :new(complex128), + // Complex64Ptr: new(complex64), + // Complex128Ptr: new(complex128), TimePtr: new(time.Time), } @@ -2916,8 +2942,8 @@ func testNullValue(engine *Engine, t *testing.T) { *nullDataUpdate.RunePtr = 1 *nullDataUpdate.Float32Ptr = -1.2 *nullDataUpdate.Float64Ptr = -1.1 - // *nullDataUpdate.Complex64Ptr :new(complex64), - // *nullDataUpdate.Complex128Ptr :new(complex128), + // *nullDataUpdate.Complex64Ptr = 123456789012345678901234567890 + // *nullDataUpdate.Complex128Ptr = 123456789012345678901234567890123456789012345678901234567890 *nullDataUpdate.TimePtr = time.Now() cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate) @@ -3004,14 +3030,126 @@ func testNullValue(engine *Engine, t *testing.T) { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr))) } + // if *nullDataGet.Complex64Ptr != *nullDataUpdate.Complex64Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex64Ptr))) + // } + + // if *nullDataGet.Complex128Ptr != *nullDataUpdate.Complex128Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr))) + // } + if (*nullDataGet.TimePtr).Unix() != (*nullDataUpdate.TimePtr).Unix() { t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr))) } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr) fmt.Println() } // -- + // update to null values + nullDataUpdate = NullData{} + + cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate) + if err != nil { + t.Error(err) + panic(err) + } else if cnt != 1 { + t.Error(errors.New("update count == 0, how can this happen!?")) + return + } + + // verify get values + nullDataGet = NullData{} + has, err = engine.Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + return + } else if !has { + t.Error(errors.New("ID not found")) + return + } + + fmt.Printf("%+v", nullDataGet) + fmt.Println() + + if nullDataGet.StringPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr))) + } + + if nullDataGet.StringPtr2 != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr2))) + } + + if nullDataGet.BoolPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%t]", *nullDataGet.BoolPtr))) + } + + if nullDataGet.UintPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.UintPtr))) + } + + if nullDataGet.Uint8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint8Ptr))) + } + + if nullDataGet.Uint16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint16Ptr))) + } + + if nullDataGet.Uint32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint32Ptr))) + } + + if nullDataGet.Uint64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint64Ptr))) + } + + if nullDataGet.IntPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.IntPtr))) + } + + if nullDataGet.Int8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int8Ptr))) + } + + if nullDataGet.Int16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int16Ptr))) + } + + if nullDataGet.Int32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int32Ptr))) + } + + if nullDataGet.Int64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int64Ptr))) + } + + if nullDataGet.RunePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.RunePtr))) + } + + if nullDataGet.Float32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float32Ptr))) + } + + if nullDataGet.Float64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + } + + // if nullDataGet.Complex64Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + // } + + // if nullDataGet.Complex128Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + // } + + if nullDataGet.TimePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr))) + } + // -- + } func testAll(engine *Engine, t *testing.T) { diff --git a/session.go b/session.go index eef9f5f6..be024a44 100644 --- a/session.go +++ b/session.go @@ -875,6 +875,7 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { // get retrieve one record from database, bean's non-empty fields // will be as conditions func (session *Session) Get(bean interface{}) (bool, error) { + err := session.newDb() if err != nil { return false, err @@ -889,6 +890,7 @@ func (session *Session) Get(bean interface{}) (bool, error) { var sql string var args []interface{} session.Statement.RefTable = session.Engine.autoMap(bean) + if session.Statement.RawSQL == "" { sql, args = session.Statement.genGetSql(bean) } else { @@ -1000,7 +1002,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if len(condiBean) > 0 { colNames, args := buildConditions(session.Engine, table, condiBean[0], true, - session.Statement.allUseBool, session.Statement.boolColumnMap) + session.Statement.allUseBool, false, session.Statement.boolColumnMap) session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args } @@ -1724,7 +1726,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data } } case reflect.Ptr: - // TODO merge duplicated codes above + // !nashtsai! TODO merge duplicated codes above typeStr := fieldType.String() switch typeStr { case "*string": @@ -2498,7 +2500,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.Statement.ColumnStr == "" { colNames, args = buildConditions(session.Engine, table, bean, false, - session.Statement.allUseBool, session.Statement.boolColumnMap) + session.Statement.allUseBool, true, session.Statement.boolColumnMap) } else { colNames, args, err = table.genCols(session, bean, true, true) if err != nil { @@ -2532,7 +2534,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condiBean) > 0 { condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, - session.Statement.allUseBool, session.Statement.boolColumnMap) + session.Statement.allUseBool, false, session.Statement.boolColumnMap) } var condition = "" @@ -2698,7 +2700,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { table := session.Engine.autoMap(bean) session.Statement.RefTable = table colNames, args := buildConditions(session.Engine, table, bean, true, - session.Statement.allUseBool, session.Statement.boolColumnMap) + session.Statement.allUseBool, false, session.Statement.boolColumnMap) var condition = "" if session.Statement.WhereStr != "" { diff --git a/statement.go b/statement.go index 0e51c6ee..b48c0f1e 100644 --- a/statement.go +++ b/statement.go @@ -233,7 +233,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { }*/ // Auto generating conditions according a struct -func buildConditions(engine *Engine, table *Table, bean interface{}, includeVersion bool, allUseBool bool, boolColumnMap map[string]bool) ([]string, []interface{}) { +func buildConditions(engine *Engine, table *Table, bean interface{}, includeVersion bool, allUseBool bool, includeNil bool, boolColumnMap map[string]bool) ([]string, []interface{}) { colNames := make([]string, 0) var args = make([]interface{}, 0) for _, col := range table.Columns { @@ -242,10 +242,29 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers } fieldValue := col.ValueOf(bean) fieldType := reflect.TypeOf(fieldValue.Interface()) + + requiredField := false + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + args = append(args, nil) + colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + var val interface{} switch fieldType.Kind() { case reflect.Bool: - if allUseBool { + if allUseBool || requiredField { val = fieldValue.Interface() } else if _, ok := boolColumnMap[col.Name]; ok { val = fieldValue.Interface() @@ -255,7 +274,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers continue } case reflect.String: - if fieldValue.String() == "" { + if !requiredField && fieldValue.String() == "" { continue } // for MyString, should convert to string or panic @@ -265,24 +284,24 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers val = fieldValue.Interface() } case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if fieldValue.Int() == 0 { + if !requiredField && fieldValue.Int() == 0 { continue } val = fieldValue.Interface() case reflect.Float32, reflect.Float64: - if fieldValue.Float() == 0.0 { + if !requiredField && fieldValue.Float() == 0.0 { continue } val = fieldValue.Interface() case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if fieldValue.Uint() == 0 { + if !requiredField && fieldValue.Uint() == 0 { continue } val = fieldValue.Interface() case reflect.Struct: if fieldType == reflect.TypeOf(time.Now()) { t := fieldValue.Interface().(time.Time) - if t.IsZero() || !fieldValue.IsValid() { + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { continue } var str string @@ -344,22 +363,6 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers } else { continue } - case reflect.Ptr: - if fieldValue.IsNil() || !fieldValue.IsValid() { - continue - } else { - typeStr := fieldType.String() - switch typeStr { - case "*string", "*bool", "*float32", "*float64", "*int64", "*uint64", "*int", "*int16", "*int32 ", "*int8 ", "*uint", "*uint16", "*uint32", "*uint8": - val = fieldValue.Elem() - case "*complex64", "*complex128": - continue // TODO - case "*time.Time": - continue // TODO - default: - continue // TODO - } - } default: val = fieldValue.Interface() } @@ -598,12 +601,14 @@ func (s *Statement) genDropSQL() string { return sql } +// !nashtsai! REVIEW, Statement is a huge struct why is this method not passing *Statement? func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) { table := statement.Engine.autoMap(bean) statement.RefTable = table colNames, args := buildConditions(statement.Engine, table, bean, true, - statement.allUseBool, statement.boolColumnMap) + statement.allUseBool, false, statement.boolColumnMap) + statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args @@ -640,7 +645,8 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{}) table := statement.Engine.autoMap(bean) statement.RefTable = table - colNames, args := buildConditions(statement.Engine, table, bean, true, statement.allUseBool, statement.boolColumnMap) + colNames, args := buildConditions(statement.Engine, table, bean, true, + statement.allUseBool, false, statement.boolColumnMap) statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args var id string = "*" From 1f54b910769ef015d58401acd0c5fe424e8e2272 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Sun, 8 Dec 2013 12:56:10 +0800 Subject: [PATCH 13/13] code tidy up --- session.go | 75 ++++++++++-------------------------------------------- 1 file changed, 14 insertions(+), 61 deletions(-) diff --git a/session.go b/session.go index be024a44..22b1951d 100644 --- a/session.go +++ b/session.go @@ -2005,6 +2005,20 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( } fieldType := fieldValue.Type() k := fieldType.Kind() + if k == reflect.Ptr { + if fieldValue.IsNil() { + return nil, nil + } else if !fieldValue.IsValid() { + session.Engine.LogWarn("the field[", col.FieldName, "] is invalid") + return nil, nil + } else { + // !nashtsai! deference pointer type to instance type + fieldValue = fieldValue.Elem() + fieldType = fieldValue.Type() + k = fieldType.Kind() + } + } + switch k { case reflect.Bool: if fieldValue.Bool() { @@ -2073,67 +2087,6 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( } else { return nil, ErrUnSupportedType } - case reflect.Ptr: - typeStr := fieldType.String() - if typeStr == "*string" { - if fieldValue.IsNil() { - return nil, nil - } else { - return fieldValue.Elem().String(), nil - } - } else if typeStr == "*bool" { - if fieldValue.IsNil() { - return nil, nil - } else { - return fieldValue.Elem().Bool(), nil - } - } else if typeStr == "*complex64" || typeStr == "*complex128" { - if fieldValue.IsNil() { - return nil, nil - } else { - bytes, err := json.Marshal(fieldValue.Elem().Complex()) - if err != nil { - session.Engine.LogSQL(err) - return 0, err - } - return string(bytes), nil - } - } else if typeStr == "*float32" || typeStr == "*float64" { - if fieldValue.IsNil() { - return nil, nil - } else { - return fieldValue.Elem().Float(), nil - } - } else if typeStr == "*time.Time" { - if fieldValue.IsNil() { - return nil, nil - } else { - if col.SQLType.Name == Time { - //s := fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700") - s := fieldValue.Elem().Interface().(time.Time).Format(time.RFC3339) - return s[11:19], nil - } else if col.SQLType.Name == Date { - return fieldValue.Elem().Interface().(time.Time).Format("2006-01-02"), nil - } else if col.SQLType.Name == TimeStampz { - return fieldValue.Elem().Interface().(time.Time).Format(time.RFC3339Nano), nil - } - return fieldValue.Elem().Interface(), nil - } - } else if typeStr == "*int64" || intTypes.Search(typeStr) < len(intTypes) { - if fieldValue.IsNil() { - return nil, nil - } else { - return fieldValue.Elem().Int(), nil - } - } else if typeStr == "*uint64" || uintTypes.Search(typeStr) < len(uintTypes) { - if fieldValue.IsNil() { - return nil, nil - } else { - return fieldValue.Elem().Uint(), nil - } - } - - fallthrough default: return fieldValue.Interface(), nil }