From 2ea34841f077d0f736106a9715b17689b2bfc05c Mon Sep 17 00:00:00 2001 From: Oleh Herych Date: Wed, 26 Jul 2017 17:55:11 +0300 Subject: [PATCH] added test for issue with custom type and null --- helpers.go | 484 -------------------------------------------------- types_test.go | 369 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 369 insertions(+), 484 deletions(-) delete mode 100644 helpers.go create mode 100644 types_test.go diff --git a/helpers.go b/helpers.go deleted file mode 100644 index b80d1e53..00000000 --- a/helpers.go +++ /dev/null @@ -1,484 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "fmt" - "reflect" - "sort" - "strconv" - "strings" - "time" - - "github.com/go-xorm/core" -) - -// str2PK convert string value to primary key value according to tp -func str2PKValue(s string, tp reflect.Type) (reflect.Value, error) { - var err error - var result interface{} - var defReturn = reflect.Zero(tp) - - switch tp.Kind() { - case reflect.Int: - result, err = strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int: %s", s, err.Error()) - } - case reflect.Int8: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int8: %s", s, err.Error()) - } - result = int8(x) - case reflect.Int16: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int16: %s", s, err.Error()) - } - result = int16(x) - case reflect.Int32: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int32: %s", s, err.Error()) - } - result = int32(x) - case reflect.Int64: - result, err = strconv.ParseInt(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int64: %s", s, err.Error()) - } - case reflect.Uint: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint: %s", s, err.Error()) - } - result = uint(x) - case reflect.Uint8: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint8: %s", s, err.Error()) - } - result = uint8(x) - case reflect.Uint16: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint16: %s", s, err.Error()) - } - result = uint16(x) - case reflect.Uint32: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint32: %s", s, err.Error()) - } - result = uint32(x) - case reflect.Uint64: - result, err = strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint64: %s", s, err.Error()) - } - case reflect.String: - result = s - default: - return defReturn, errors.New("unsupported convert type") - } - return reflect.ValueOf(result).Convert(tp), nil -} - -func str2PK(s string, tp reflect.Type) (interface{}, error) { - v, err := str2PKValue(s, tp) - if err != nil { - return nil, err - } - return v.Interface(), nil -} - -func splitTag(tag string) (tags []string) { - tag = strings.TrimSpace(tag) - var hasQuote = false - var lastIdx = 0 - for i, t := range tag { - if t == '\'' { - hasQuote = !hasQuote - } else if t == ' ' { - if lastIdx < i && !hasQuote { - tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) - lastIdx = i + 1 - } - } - } - if lastIdx < len(tag) { - tags = append(tags, strings.TrimSpace(tag[lastIdx:])) - } - return -} - -type zeroable interface { - IsZero() bool -} - -func isZero(k interface{}) bool { - switch k.(type) { - case int: - return k.(int) == 0 - case int8: - return k.(int8) == 0 - case int16: - return k.(int16) == 0 - case int32: - return k.(int32) == 0 - case int64: - return k.(int64) == 0 - case uint: - return k.(uint) == 0 - case uint8: - return k.(uint8) == 0 - case uint16: - return k.(uint16) == 0 - case uint32: - return k.(uint32) == 0 - case uint64: - return k.(uint64) == 0 - case float32: - return k.(float32) == 0 - case float64: - return k.(float64) == 0 - case bool: - return k.(bool) == false - case string: - return k.(string) == "" - case zeroable: - return k.(zeroable).IsZero() - } - return false -} - -func isZeroValue(v reflect.Value) bool { - if isZero(v.Interface()) { - return true - } - switch v.Kind() { - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: - return v.IsNil() - } - return false -} - -func isStructZero(v reflect.Value) bool { - if !v.IsValid() { - return true - } - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - switch field.Kind() { - case reflect.Ptr: - field = field.Elem() - fallthrough - case reflect.Struct: - if !isStructZero(field) { - return false - } - default: - if field.CanInterface() && !isZero(field.Interface()) { - return false - } - } - } - return true -} - -func isArrayValueZero(v reflect.Value) bool { - if !v.IsValid() || v.Len() == 0 { - return true - } - - for i := 0; i < v.Len(); i++ { - if !isZero(v.Index(i).Interface()) { - return false - } - } - - return true -} - -func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { - var v interface{} - kind := tp.Kind() - - if kind == reflect.Ptr { - kind = tp.Elem().Kind() - } - - switch kind { - case reflect.Int16: - temp := int16(id) - v = &temp - case reflect.Int32: - temp := int32(id) - v = &temp - case reflect.Int: - temp := int(id) - v = &temp - case reflect.Int64: - temp := id - v = &temp - case reflect.Uint16: - temp := uint16(id) - v = &temp - case reflect.Uint32: - temp := uint32(id) - v = &temp - case reflect.Uint64: - temp := uint64(id) - v = &temp - case reflect.Uint: - temp := uint(id) - v = &temp - } - - if tp.Kind() == reflect.Ptr { - return reflect.ValueOf(v).Convert(tp) - } - return reflect.ValueOf(v).Elem().Convert(tp) -} - -func int64ToInt(id int64, tp reflect.Type) interface{} { - return int64ToIntValue(id, tp).Interface() -} - -func isPKZero(pk core.PK) bool { - for _, k := range pk { - if isZero(k) { - return true - } - } - return false -} - -func indexNoCase(s, sep string) int { - return strings.Index(strings.ToLower(s), strings.ToLower(sep)) -} - -func splitNoCase(s, sep string) []string { - idx := indexNoCase(s, sep) - if idx < 0 { - return []string{s} - } - return strings.Split(s, s[idx:idx+len(sep)]) -} - -func splitNNoCase(s, sep string, n int) []string { - idx := indexNoCase(s, sep) - if idx < 0 { - return []string{s} - } - return strings.SplitN(s, s[idx:idx+len(sep)], n) -} - -func makeArray(elem string, count int) []string { - res := make([]string, count) - for i := 0; i < count; i++ { - res[i] = elem - } - return res -} - -func rValue(bean interface{}) reflect.Value { - return reflect.Indirect(reflect.ValueOf(bean)) -} - -func rType(bean interface{}) reflect.Type { - sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - //return reflect.TypeOf(sliceValue.Interface()) - return sliceValue.Type() -} - -func structName(v reflect.Type) string { - for v.Kind() == reflect.Ptr { - v = v.Elem() - } - return v.Name() -} - -func col2NewCols(columns ...string) []string { - newColumns := make([]string, 0, len(columns)) - for _, col := range columns { - col = strings.Replace(col, "`", "", -1) - col = strings.Replace(col, `"`, "", -1) - ccols := strings.Split(col, ",") - for _, c := range ccols { - newColumns = append(newColumns, strings.TrimSpace(c)) - } - } - return newColumns -} - -func sliceEq(left, right []string) bool { - if len(left) != len(right) { - return false - } - sort.Sort(sort.StringSlice(left)) - sort.Sort(sort.StringSlice(right)) - for i := 0; i < len(left); i++ { - if left[i] != right[i] { - return false - } - } - return true -} - -func setColumnInt(bean interface{}, col *core.Column, t int64) { - v, err := col.ValueOf(bean) - if err != nil { - return - } - if v.CanSet() { - switch v.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32: - v.SetInt(t) - case reflect.Uint, reflect.Uint64, reflect.Uint32: - v.SetUint(uint64(t)) - } - } -} - -func setColumnTime(bean interface{}, col *core.Column, t time.Time) { - v, err := col.ValueOf(bean) - if err != nil { - return - } - if v.CanSet() { - switch v.Type().Kind() { - case reflect.Struct: - v.Set(reflect.ValueOf(t).Convert(v.Type())) - case reflect.Int, reflect.Int64, reflect.Int32: - v.SetInt(t.Unix()) - case reflect.Uint, reflect.Uint64, reflect.Uint32: - v.SetUint(uint64(t.Unix())) - } - } -} - -func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) { - colNames := make([]string, 0, len(table.ColumnsSeq())) - args := make([]interface{}, 0, len(table.ColumnsSeq())) - - for _, col := range table.Columns() { - if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { - if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok { - continue - } - } - if col.MapType == core.ONLYFROMDB { - continue - } - - fieldValuePtr, err := col.ValueOf(bean) - if err != nil { - return nil, nil, err - } - fieldValue := *fieldValuePtr - - if col.IsAutoIncrement { - switch fieldValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - case reflect.String: - if len(fieldValue.String()) == 0 { - continue - } - case reflect.Ptr: - if fieldValue.Pointer() == 0 { - continue - } - } - } - - if col.IsDeleted { - continue - } - - if session.Statement.ColumnStr != "" { - if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok { - continue - } else if _, ok := session.Statement.incrColumns[col.Name]; ok { - continue - } else if _, ok := session.Statement.decrColumns[col.Name]; ok { - continue - } - } - if session.Statement.OmitStr != "" { - if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok { - continue - } - } - - // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if _, ok := getFlagForColumn(session.Statement.nullableMap, col); ok { - if col.Nullable && isZeroValue(fieldValue) { - var nilValue *int - fieldValue = reflect.ValueOf(nilValue) - } - } - - if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { - // if time is non-empty, then set to auto time - val, t := session.Engine.NowTime2(col.SQLType.Name) - args = append(args, val) - - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) - } else if col.IsVersion && session.Statement.checkVersion { - args = append(args, 1) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return colNames, args, err - } - args = append(args, arg) - } - - if includeQuote { - colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?") - } else { - colNames = append(colNames, col.Name) - } - } - return colNames, args, nil -} - -func indexName(tableName, idxName string) string { - return fmt.Sprintf("IDX_%v_%v", tableName, idxName) -} - -func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) { - if len(m) == 0 { - return false, false - } - - n := len(col.Name) - - for mk := range m { - if len(mk) != n { - continue - } - if strings.EqualFold(mk, col.Name) { - return m[mk], true - } - } - - return false, false -} diff --git a/types_test.go b/types_test.go new file mode 100644 index 00000000..21931762 --- /dev/null +++ b/types_test.go @@ -0,0 +1,369 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "encoding/json" + "errors" + "fmt" + "testing" + + "github.com/go-xorm/core" + "github.com/stretchr/testify/assert" +) + +func TestArrayField(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type ArrayStruct struct { + Id int64 + Name [20]byte `xorm:"char(80)"` + } + + assert.NoError(t, testEngine.Sync2(new(ArrayStruct))) + + var as = ArrayStruct{ + Name: [20]byte{ + 96, 96, 96, 96, 96, + 96, 96, 96, 96, 96, + 96, 96, 96, 96, 96, + 96, 96, 96, 96, 96, + }, + } + cnt, err := testEngine.Insert(&as) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var arr ArrayStruct + has, err := testEngine.Id(1).Get(&arr) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, as.Name, arr.Name) + + var arrs []ArrayStruct + err = testEngine.Find(&arrs) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(arrs)) + assert.Equal(t, as.Name, arrs[0].Name) + + var newName = [20]byte{ + 90, 96, 96, 96, 96, + 96, 96, 96, 96, 96, + 96, 96, 96, 96, 96, + 96, 96, 96, 96, 96, + } + + cnt, err = testEngine.ID(1).Update(&ArrayStruct{ + Name: newName, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var newArr ArrayStruct + has, err = testEngine.ID(1).Get(&newArr) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, newName, newArr.Name) + + cnt, err = testEngine.ID(1).Delete(new(ArrayStruct)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var cfgArr ArrayStruct + has, err = testEngine.ID(1).Get(&cfgArr) + assert.NoError(t, err) + assert.Equal(t, false, has) +} + +func TestGetBytes(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Varbinary struct { + Data []byte `xorm:"VARBINARY(250)"` + } + + err := testEngine.Sync2(new(Varbinary)) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&Varbinary{ + Data: []byte("test"), + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var b Varbinary + has, err := testEngine.Get(&b) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, "test", string(b.Data)) +} + +type ConvString string + +func (s *ConvString) FromDB(data []byte) error { + *s = ConvString("prefix---" + string(data)) + return nil +} + +func (s *ConvString) ToDB() ([]byte, error) { + return []byte(string(*s)), nil +} + +type ConvConfig struct { + Name string + Id int64 +} + +func (s *ConvConfig) FromDB(data []byte) error { + return json.Unmarshal(data, s) +} + +func (s *ConvConfig) ToDB() ([]byte, error) { + return json.Marshal(s) +} + +type SliceType []*ConvConfig + +func (s *SliceType) FromDB(data []byte) error { + return json.Unmarshal(data, s) +} + +func (s *SliceType) ToDB() ([]byte, error) { + return json.Marshal(s) +} + +type Nullable struct { + Data string +} + +func (s *Nullable) FromDB(data []byte) error { + + if data == nil { + return nil + } + + fmt.Println("--sd", data) + + *s = Nullable{ + Data: string(data), + } + + return nil +} + +func (s *Nullable) ToDB() ([]byte, error) { + if s == nil { + return nil, nil + } + + return []byte(s.Data), nil +} + +type ConvStruct struct { + Conv ConvString + Conv2 *ConvString + Cfg1 ConvConfig + Cfg2 *ConvConfig `xorm:"TEXT"` + Cfg3 core.Conversion `xorm:"BLOB"` + Slice SliceType + Nullable1 *Nullable `xorm:"null"` + Nullable2 *Nullable `xorm:"null"` +} + +func (c *ConvStruct) BeforeSet(name string, cell Cell) { + if name == "cfg3" || name == "Cfg3" { + c.Cfg3 = new(ConvConfig) + } +} + +func TestConversion(t *testing.T) { + assert.NoError(t, prepareEngine()) + + c := new(ConvStruct) + assert.NoError(t, testEngine.DropTables(c)) + assert.NoError(t, testEngine.Sync(c)) + + var s ConvString = "sssss" + c.Conv = "tttt" + c.Conv2 = &s + c.Cfg1 = ConvConfig{"mm", 1} + c.Cfg2 = &ConvConfig{"xx", 2} + c.Cfg3 = &ConvConfig{"zz", 3} + c.Slice = []*ConvConfig{{"yy", 4}, {"ff", 5}} + c.Nullable1 = &Nullable{Data: "test"} + c.Nullable2 = nil + + _, err := testEngine.Nullable("nullable2").Insert(c) + assert.NoError(t, err) + + c1 := new(ConvStruct) + has, err := testEngine.Get(c1) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "prefix---tttt", string(c1.Conv)) + assert.NotNil(t, c1.Conv2) + assert.EqualValues(t, "prefix---"+s, *c1.Conv2) + assert.EqualValues(t, c.Cfg1, c1.Cfg1) + assert.NotNil(t, c1.Cfg2) + assert.EqualValues(t, *c.Cfg2, *c1.Cfg2) + assert.NotNil(t, c1.Cfg3) + assert.EqualValues(t, *c.Cfg3.(*ConvConfig), *c1.Cfg3.(*ConvConfig)) + assert.EqualValues(t, 2, len(c1.Slice)) + assert.EqualValues(t, *c.Slice[0], *c1.Slice[0]) + assert.EqualValues(t, *c.Slice[1], *c1.Slice[1]) + assert.NotNil(t, c1.Nullable1) + assert.Equal(t, c1.Nullable1.Data, "test") + assert.Nil(t, c1.Nullable2) +} + +type MyInt int +type MyUInt uint +type MyFloat float64 + +type MyStruct struct { + Type MyInt + U MyUInt + F MyFloat + S MyString + IA []MyInt + UA []MyUInt + FA []MyFloat + SA []MyString + NameArray []string + Name string + UIA []uint + UIA8 []uint8 + UIA16 []uint16 + UIA32 []uint32 + UIA64 []uint64 + UI uint + //C64 complex64 + MSS map[string]string +} + +func TestCustomType1(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(&MyStruct{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&MyStruct{}) + assert.NoError(t, err) + + i := MyStruct{Name: "Test", Type: MyInt(1)} + i.U = 23 + i.F = 1.34 + i.S = "fafdsafdsaf" + i.UI = 2 + i.IA = []MyInt{1, 3, 5} + i.UIA = []uint{1, 3} + i.UIA16 = []uint16{2} + i.UIA32 = []uint32{4, 5} + i.UIA64 = []uint64{6, 7, 9} + i.UIA8 = []uint8{1, 2, 3, 4} + i.NameArray = []string{"ssss", "fsdf", "lllll, ss"} + i.MSS = map[string]string{"s": "sfds,ss", "x": "lfjljsl"} + + cnt, err := testEngine.Insert(&i) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + fmt.Println(i) + i.NameArray = []string{} + i.MSS = map[string]string{} + i.F = 0 + has, err := testEngine.Get(&i) + assert.NoError(t, err) + assert.True(t, has) + + ss := []MyStruct{} + err = testEngine.Find(&ss) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(ss)) + assert.EqualValues(t, i, ss[0]) + + sss := MyStruct{} + has, err = testEngine.Get(&sss) + assert.NoError(t, err) + assert.True(t, has) + + sss.NameArray = []string{} + sss.MSS = map[string]string{} + cnt, err = testEngine.Delete(&sss) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +type Status struct { + Name string + Color string +} + +var ( + _ core.Conversion = &Status{} + Registed Status = Status{"Registed", "white"} + Approved Status = Status{"Approved", "green"} + Removed Status = Status{"Removed", "red"} + Statuses map[string]Status = map[string]Status{ + Registed.Name: Registed, + Approved.Name: Approved, + Removed.Name: Removed, + } +) + +func (s *Status) FromDB(bytes []byte) error { + if r, ok := Statuses[string(bytes)]; ok { + *s = r + return nil + } else { + return errors.New("no this data") + } +} + +func (s *Status) ToDB() ([]byte, error) { + return []byte(s.Name), nil +} + +type UserCus struct { + Id int64 + Name string + Status Status `xorm:"varchar(40)"` +} + +func TestCustomType2(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.CreateTables(&UserCus{}) + assert.NoError(t, err) + + tableName := testEngine.TableMapper.Obj2Table("UserCus") + _, err = testEngine.Exec("delete from " + testEngine.Quote(tableName)) + assert.NoError(t, err) + + if testEngine.Dialect().DBType() == core.MSSQL { + return + /*_, err = engine.Exec("set IDENTITY_INSERT " + tableName + " on") + if err != nil { + t.Fatal(err) + }*/ + } + + _, err = testEngine.Insert(&UserCus{1, "xlw", Registed}) + assert.NoError(t, err) + + user := UserCus{} + exist, err := testEngine.Id(1).Get(&user) + assert.NoError(t, err) + assert.True(t, exist) + + fmt.Println(user) + + users := make([]UserCus, 0) + err = testEngine.Where("`"+testEngine.ColumnMapper.Obj2Table("Status")+"` = ?", "Registed").Find(&users) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) + + fmt.Println(users) +}