From de95ea6bb0cf71b2c4ef42b45653a1be86840335 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 5 Aug 2021 11:05:49 +0800 Subject: [PATCH] Fix tests --- integrations/session_update_test.go | 4 +- schemas/type.go | 130 ++++++++++------------------ tags/parser.go | 32 ++++++- 3 files changed, 78 insertions(+), 88 deletions(-) diff --git a/integrations/session_update_test.go b/integrations/session_update_test.go index cc1042b6..bbcc7600 100644 --- a/integrations/session_update_test.go +++ b/integrations/session_update_test.go @@ -349,7 +349,7 @@ func TestUpdate1(t *testing.T) { And("height = ?", user.Height). And("departname = ?", ""). And("detail_id = ?", 0). - And("is_man = ?", 0). + And("is_man = ?", false). Get(&Userinfo{}) assert.NoError(t, err) assert.True(t, has, "cannot insert properly") @@ -825,7 +825,7 @@ func TestNewUpdate(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 0, af) - af, err = testEngine.Table(new(TbUserInfo)).Where("phone=?", 13126564922).Update(&changeUsr) + af, err = testEngine.Table(new(TbUserInfo)).Where("phone=?", "13126564922").Update(&changeUsr) assert.NoError(t, err) assert.EqualValues(t, 0, af) } diff --git a/schemas/type.go b/schemas/type.go index d799db08..d065abb0 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -5,9 +5,9 @@ package schemas import ( + "database/sql" "math/big" "reflect" - "sort" "strings" "time" ) @@ -229,88 +229,40 @@ var ( Array: ARRAY_TYPE, } - - intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"} - uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} -) - -// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison -var ( - emptyString string - boolDefault bool - byteDefault byte - complex64Default complex64 - complex128Default complex128 - float32Default float32 - float64Default float64 - int64Default int64 - uint64Default uint64 - int32Default int32 - uint32Default uint32 - int16Default int16 - uint16Default uint16 - int8Default int8 - uint8Default uint8 - intDefault int - uintDefault uint - timeDefault time.Time - bigFloatDefault big.Float ) // enumerates all types var ( - IntType = reflect.TypeOf(intDefault) - Int8Type = reflect.TypeOf(int8Default) - Int16Type = reflect.TypeOf(int16Default) - Int32Type = reflect.TypeOf(int32Default) - Int64Type = reflect.TypeOf(int64Default) + IntType = reflect.TypeOf((*int)(nil)).Elem() + Int8Type = reflect.TypeOf((*int8)(nil)).Elem() + Int16Type = reflect.TypeOf((*int16)(nil)).Elem() + Int32Type = reflect.TypeOf((*int32)(nil)).Elem() + Int64Type = reflect.TypeOf((*int64)(nil)).Elem() - UintType = reflect.TypeOf(uintDefault) - Uint8Type = reflect.TypeOf(uint8Default) - Uint16Type = reflect.TypeOf(uint16Default) - Uint32Type = reflect.TypeOf(uint32Default) - Uint64Type = reflect.TypeOf(uint64Default) + UintType = reflect.TypeOf((*uint)(nil)).Elem() + Uint8Type = reflect.TypeOf((*uint8)(nil)).Elem() + Uint16Type = reflect.TypeOf((*uint16)(nil)).Elem() + Uint32Type = reflect.TypeOf((*uint32)(nil)).Elem() + Uint64Type = reflect.TypeOf((*uint64)(nil)).Elem() - Float32Type = reflect.TypeOf(float32Default) - Float64Type = reflect.TypeOf(float64Default) + Float32Type = reflect.TypeOf((*float32)(nil)).Elem() + Float64Type = reflect.TypeOf((*float64)(nil)).Elem() - Complex64Type = reflect.TypeOf(complex64Default) - Complex128Type = reflect.TypeOf(complex128Default) + Complex64Type = reflect.TypeOf((*complex64)(nil)).Elem() + Complex128Type = reflect.TypeOf((*complex128)(nil)).Elem() - StringType = reflect.TypeOf(emptyString) - BoolType = reflect.TypeOf(boolDefault) - ByteType = reflect.TypeOf(byteDefault) + StringType = reflect.TypeOf((*string)(nil)).Elem() + BoolType = reflect.TypeOf((*bool)(nil)).Elem() + ByteType = reflect.TypeOf((*byte)(nil)).Elem() BytesType = reflect.SliceOf(ByteType) - TimeType = reflect.TypeOf(timeDefault) - BigFloatType = reflect.TypeOf(bigFloatDefault) -) - -// enumerates all types -var ( - PtrIntType = reflect.PtrTo(IntType) - PtrInt8Type = reflect.PtrTo(Int8Type) - PtrInt16Type = reflect.PtrTo(Int16Type) - PtrInt32Type = reflect.PtrTo(Int32Type) - PtrInt64Type = reflect.PtrTo(Int64Type) - - PtrUintType = reflect.PtrTo(UintType) - PtrUint8Type = reflect.PtrTo(Uint8Type) - PtrUint16Type = reflect.PtrTo(Uint16Type) - PtrUint32Type = reflect.PtrTo(Uint32Type) - PtrUint64Type = reflect.PtrTo(Uint64Type) - - PtrFloat32Type = reflect.PtrTo(Float32Type) - PtrFloat64Type = reflect.PtrTo(Float64Type) - - PtrComplex64Type = reflect.PtrTo(Complex64Type) - PtrComplex128Type = reflect.PtrTo(Complex128Type) - - PtrStringType = reflect.PtrTo(StringType) - PtrBoolType = reflect.PtrTo(BoolType) - PtrByteType = reflect.PtrTo(ByteType) - - PtrTimeType = reflect.PtrTo(TimeType) + TimeType = reflect.TypeOf((*time.Time)(nil)).Elem() + BigFloatType = reflect.TypeOf((*big.Float)(nil)).Elem() + NullFloat64Type = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() + NullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() + NullInt32Type = reflect.TypeOf((*sql.NullInt32)(nil)).Elem() + NullInt64Type = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() + NullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem() ) // Type2SQLType generate SQLType acorrding Go's type @@ -331,7 +283,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) { case reflect.Complex64, reflect.Complex128: st = SQLType{Varchar, 64, 0} case reflect.Array, reflect.Slice, reflect.Map: - if t.Elem() == reflect.TypeOf(byteDefault) { + if t.Elem() == ByteType { st = SQLType{Blob, 0, 0} } else { st = SQLType{Text, 0, 0} @@ -343,6 +295,16 @@ func Type2SQLType(t reflect.Type) (st SQLType) { case reflect.Struct: if t.ConvertibleTo(TimeType) { st = SQLType{DateTime, 0, 0} + } else if t.ConvertibleTo(NullFloat64Type) { + st = SQLType{Double, 0, 0} + } else if t.ConvertibleTo(NullStringType) { + st = SQLType{Varchar, 255, 0} + } else if t.ConvertibleTo(NullInt32Type) { + st = SQLType{Integer, 0, 0} + } else if t.ConvertibleTo(NullInt64Type) { + st = SQLType{BigInt, 0, 0} + } else if t.ConvertibleTo(NullBoolType) { + st = SQLType{Boolean, 0, 0} } else { // TODO need to handle association struct st = SQLType{Text, 0, 0} @@ -360,25 +322,25 @@ func SQLType2Type(st SQLType) reflect.Type { name := strings.ToUpper(st.Name) switch name { case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: - return reflect.TypeOf(1) + return IntType case BigInt, BigSerial: - return reflect.TypeOf(int64(1)) + return Int64Type case Float, Real: - return reflect.TypeOf(float32(1)) + return Float32Type case Double: - return reflect.TypeOf(float64(1)) + return Float64Type case Char, NChar, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName: - return reflect.TypeOf("") + return StringType case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier: - return reflect.TypeOf([]byte{}) + return BytesType case Bool: - return reflect.TypeOf(true) + return BoolType case DateTime, Date, Time, TimeStamp, TimeStampz, SmallDateTime, Year: - return reflect.TypeOf(timeDefault) + return TimeType case Decimal, Numeric, Money, SmallMoney: - return reflect.TypeOf("") + return StringType default: - return reflect.TypeOf("") + return StringType } } diff --git a/tags/parser.go b/tags/parser.go index 5f9fd528..989ac7f9 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -7,6 +7,7 @@ package tags import ( "encoding/gob" "errors" + "fmt" "reflect" "strings" "sync" @@ -127,6 +128,25 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index // ErrIgnoreField represents an error to ignore field var ErrIgnoreField = errors.New("field will be ignored") +func (parser *Parser) getSQLTypeByType(t reflect.Type) (schemas.SQLType, error) { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + v, ok := parser.tableCache.Load(t) + if ok { + pkCols := v.(*schemas.Table).PKColumns() + if len(pkCols) == 1 { + return pkCols[0].SQLType, nil + } + if len(pkCols) > 1 { + return schemas.SQLType{}, fmt.Errorf("unsupported mulitiple primary key on cascade") + } + } + } + return schemas.Type2SQLType(t), nil +} + func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { var sqlType schemas.SQLType if fieldValue.CanAddr() { @@ -137,7 +157,11 @@ func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructFi if _, ok := fieldValue.Interface().(convert.Conversion); ok { sqlType = schemas.SQLType{Name: schemas.Text} } else { - sqlType = schemas.Type2SQLType(field.Type) + var err error + sqlType, err = parser.getSQLTypeByType(field.Type) + if err != nil { + return nil, err + } } col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name), field.Name, sqlType, sqlType.DefaultLength, @@ -215,7 +239,11 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, f } if col.SQLType.Name == "" { - col.SQLType = schemas.Type2SQLType(field.Type) + var err error + col.SQLType, err = parser.getSQLTypeByType(field.Type) + if err != nil { + return nil, err + } } if ctx.isUnsigned && col.SQLType.IsNumeric() && !strings.HasPrefix(col.SQLType.Name, "UNSIGNED") { col.SQLType.Name = "UNSIGNED " + col.SQLType.Name