diff --git a/base_test.go b/base_test.go index 6b3040ea..995b5773 100644 --- a/base_test.go +++ b/base_test.go @@ -2469,6 +2469,109 @@ func testProcessorsTx(engine *Engine, t *testing.T) { // -- } +type NullData struct { + Id int64 + StringPtr *string + StringPtr2 *string `xorm:"text"` + BoolPtr *bool + BytePtr *byte + UintPtr *uint + Uint8Ptr *uint8 + Uint16Ptr *uint16 + Uint32Ptr *uint32 + UInt64Ptr *uint64 + IntPtr *int + Int8Ptr *int8 + Int16Ptr *int16 + Int32Ptr *int32 + Int64Ptr *int64 + RunePtr *rune + Float32Ptr *float32 + Float64Ptr *float64 + // Complex64Ptr *complex64 + // Complex128Ptr *complex128 + TimePtr *time.Time +} + +type NullData2 struct { + Id int64 + StringPtr string + StringPtr2 string `xorm:"text"` + BoolPtr bool + BytePtr byte + UintPtr uint + Uint8Ptr uint8 + Uint16Ptr uint16 + Uint32Ptr uint32 + UInt64Ptr uint64 + IntPtr int + Int8Ptr int8 + Int16Ptr int16 + Int32Ptr int32 + Int64Ptr int64 + RunePtr rune + Float32Ptr float32 + Float64Ptr float64 + //Complex64Ptr complex64 + //Complex128Ptr complex128 + TimePtr time.Time +} + +type NullData3 struct { + Id int64 + StringPtr *string +} + +func insertNullData(engine *Engine, t *testing.T) { + + err := engine.DropTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } + + nullData := NullData{BoolPtr: new(bool)} + *nullData.BoolPtr = true + cnt, err := engine.Insert(&nullData) + fmt.Println(nullData.Id) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + if nullData.Id <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } + + nullDataGet := NullData{} + + has, err := engine.Table("null_data").Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + panic(err) + } else if !has { + t.Error(errors.New("ID not found")) + } + + // if nullData2.BoolPtr == nil || !*(nullData2.BoolPtr) { + // t.Error(errors.New("BoolPtr wrong value")) + // } + +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -2571,4 +2674,6 @@ func testAll2(engine *Engine, t *testing.T) { testProcessorsTx(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) + fmt.Println("-------------- insert null data --------------") + insertNullData(engine, t) } diff --git a/session.go b/session.go index 197773db..d56843e3 100644 --- a/session.go +++ b/session.go @@ -1660,7 +1660,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data fieldValue.SetUint(x) //Now only support Time type case reflect.Struct: - if fieldValue.Type().String() == "time.Time" { + if fieldType.String() == "time.Time" { sdata := strings.TrimSpace(string(data)) var x time.Time var err error @@ -1723,6 +1723,264 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return errors.New("unsupported struct type in Scan: " + fieldValue.Type().String()) } } + case reflect.Ptr: + // TODO merge duplicated codes above + typeStr := fieldType.String() + switch typeStr { + case "*string": + x := string(data) + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*bool": + d := string(data) + v, err := strconv.ParseBool(d) + if err != nil { + return errors.New("arg " + key + " as bool: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(v).Addr()) + case "*complex64": + var x complex64 + err := json.Unmarshal(data, &x) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*complex128": + var x complex128 + err := json.Unmarshal(data, &x) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*float64": + x, err := strconv.ParseFloat(string(data), 64) + if err != nil { + return errors.New("arg " + key + " as float64: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*float32": + var x float32 + x1, err := strconv.ParseFloat(string(data), 32) + if err != nil { + return errors.New("arg " + key + " as float32: " + err.Error()) + } + x = float32(x1) + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*time.Time": + sdata := strings.TrimSpace(string(data)) + var x time.Time + var err error + + if sdata == "0000-00-00 00:00:00" || + sdata == "0001-01-01 00:00:00" { + } else if !strings.ContainsAny(sdata, "- :") { + // time stamp + sd, err := strconv.ParseInt(sdata, 10, 64) + if err == nil { + x = time.Unix(0, sd) + } + } else if len(sdata) > 19 { + x, err = time.Parse(time.RFC3339Nano, sdata) + if err != nil { + x, err = time.Parse("2006-01-02 15:04:05.999999999", sdata) + } + } else if len(sdata) == 19 { + x, err = time.Parse("2006-01-02 15:04:05", sdata) + } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { + x, err = time.Parse("2006-01-02", sdata) + } else if col.SQLType.Name == Time { + if len(sdata) > 8 { + sdata = sdata[len(sdata)-8:] + } + st := fmt.Sprintf("2006-01-02 %v", sdata) + x, err = time.Parse("2006-01-02 15:04:05", st) + } else { + return errors.New(fmt.Sprintf("unsupported time format %v", string(data))) + } + if err != nil { + return errors.New(fmt.Sprintf("unsupported time format %v: %v", string(data), err)) + } + + v = x + fieldValue.Set(reflect.ValueOf(v).Addr()) + case "*uint64": + var x uint64 + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*uint": + var x uint + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint(x1) + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*uint32": + var x uint32 + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint32(x1) + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*uint8": + var x uint8 + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint8(x1) + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*uint16": + var x uint16 + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint16(x1) + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*int64": + sdata := string(data) + var x int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int64(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x, err = strconv.ParseInt(sdata, 16, 64) + } else if strings.HasPrefix(sdata, "0") { + x, err = strconv.ParseInt(sdata, 8, 64) + } else { + x, err = strconv.ParseInt(sdata, 10, 64) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*int": + sdata := string(data) + var x int + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*int32": + sdata := string(data) + var x int32 + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int32(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int32(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int32(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int32(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*int8": + sdata := string(data) + var x int8 + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int8(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int8(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int8(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int8(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + case "*int16": + sdata := string(data) + var x int16 + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int16(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int16(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int16(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int16(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(x).Addr()) + } + fallthrough default: return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) } @@ -1742,8 +2000,8 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( } } } - - k := fieldValue.Type().Kind() + fieldType := fieldValue.Type() + k := fieldType.Kind() switch k { case reflect.Bool: if fieldValue.Bool() { @@ -1754,7 +2012,7 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( case reflect.String: return fieldValue.String(), nil case reflect.Struct: - if fieldValue.Type().String() == "time.Time" { + if fieldType.String() == "time.Time" { if col.SQLType.Name == Time { //s := fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700") s := fieldValue.Interface().(time.Time).Format(time.RFC3339) @@ -1812,6 +2070,60 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( } else { return nil, ErrUnSupportedType } + case reflect.Ptr: + typeStr := fieldType.String() + if typeStr == "*string" { + if fieldValue.IsNil() { + return nil, nil + } else { + return fieldValue.Elem().String(), nil + } + } else if typeStr == "*bool" { + if fieldValue.IsNil() { + return nil, nil + } else { + return fieldValue.Elem().Bool(), nil + } + } else if typeStr == "*complex64" || typeStr == "*complex128" { + if fieldValue.IsNil() { + return nil, nil + } else { + bytes, err := json.Marshal(fieldValue.Elem().Complex()) + if err != nil { + session.Engine.LogSQL(err) + return 0, err + } + return string(bytes), nil + } + } else if typeStr == "*float32" || typeStr == "*float64" { + if fieldValue.IsNil() { + return nil, nil + } else { + return fieldValue.Elem().Float(), nil + } + } else if typeStr == "*time.Time" { + if fieldValue.IsNil() { + return nil, nil + } else { + if col.SQLType.Name == Time { + //s := fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700") + s := fieldValue.Elem().Interface().(time.Time).Format(time.RFC3339) + return s[11:19], nil + } else if col.SQLType.Name == Date { + return fieldValue.Elem().Interface().(time.Time).Format("2006-01-02"), nil + } else if col.SQLType.Name == TimeStampz { + return fieldValue.Elem().Interface().(time.Time).Format(time.RFC3339Nano), nil + } + return fieldValue.Elem().Interface(), nil + } + } else if typeStr == "*int64" || typeStr == "*uint64" || intTypes.Search(typeStr) < len(intTypes) { + if fieldValue.IsNil() { + return nil, nil + } else { + return fieldValue.Elem().Int(), nil + } + } + fallthrough default: return fieldValue.Interface(), nil } diff --git a/table.go b/table.go index 012c4030..479e060d 100644 --- a/table.go +++ b/table.go @@ -2,6 +2,7 @@ package xorm import ( "reflect" + "sort" "strings" "time" ) @@ -24,6 +25,8 @@ func (s *SQLType) IsBlob() bool { s.Name == Binary || s.Name == VarBinary || s.Name == Bytea } +const () + var ( Bit = "BIT" TinyInt = "TINYINT" @@ -107,6 +110,8 @@ var ( Serial: true, BigSerial: true, } + + intTypes = sort.StringSlice{"*int", "*int16", "*int32 ", "*int8 ", "*uint", "*uint16", "*uint32", "*uint8"} ) var b byte @@ -140,12 +145,39 @@ func Type2SQLType(t reflect.Type) (st SQLType) { } else { st = SQLType{Text, 0, 0} } + case reflect.Ptr: + st, _ = ptrType2SQLType(t) default: st = SQLType{Text, 0, 0} } return } +func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) { + typeStr := t.String() + has = true + if typeStr == "*string" { + st = SQLType{Varchar, 255, 0} + } else if typeStr == "*bool" { + st = SQLType{Bool, 0, 0} + } else if typeStr == "*complex64" || typeStr == "*complex128" { + st = SQLType{Varchar, 64, 0} + } else if typeStr == "*float32" { + st = SQLType{Float, 0, 0} + } else if typeStr == "*float64" { + st = SQLType{Varchar, 64, 0} + } else if typeStr == "*int64" || typeStr == "*uint64" { + st = SQLType{BigInt, 0, 0} + } else if typeStr == "*time.Time" { + st = SQLType{DateTime, 0, 0} + } else if intTypes.Search(typeStr) < len(intTypes) { + st = SQLType{Int, 0, 0} + } else { + has = false + } + return +} + // default sql type change to go types func SQLType2Type(st SQLType) reflect.Type { name := strings.ToUpper(st.Name)