diff --git a/.drone.yml b/.drone.yml index 9b4ffe9a..4f84d7fa 100644 --- a/.drone.yml +++ b/.drone.yml @@ -249,11 +249,11 @@ volumes: services: - name: mssql pull: always - image: microsoft/mssql-server-linux:latest + image: mcr.microsoft.com/mssql/server:latest environment: ACCEPT_EULA: Y SA_PASSWORD: yourStrong(!)Password - MSSQL_PID: Developer + MSSQL_PID: Standard --- kind: pipeline diff --git a/.gitignore b/.gitignore index a3fbadd4..a183a295 100644 --- a/.gitignore +++ b/.gitignore @@ -36,4 +36,5 @@ test.db.sql *coverage.out test.db integrations/*.sql -integrations/test_sqlite* \ No newline at end of file +integrations/test_sqlite* +cover.out \ No newline at end of file diff --git a/.revive.toml b/.revive.toml index 6dec7465..9e3b629d 100644 --- a/.revive.toml +++ b/.revive.toml @@ -8,20 +8,22 @@ warningCode = 1 [rule.context-as-argument] [rule.context-keys-type] [rule.dot-imports] +[rule.empty-lines] +[rule.errorf] [rule.error-return] [rule.error-strings] [rule.error-naming] [rule.exported] [rule.if-return] [rule.increment-decrement] -[rule.var-naming] - arguments = [["ID", "UID", "UUID", "URL", "JSON"], []] -[rule.var-declaration] +[rule.indent-error-flow] [rule.package-comments] [rule.range] [rule.receiver-naming] +[rule.struct-tag] [rule.time-naming] [rule.unexported-return] -[rule.indent-error-flow] -[rule.errorf] -[rule.struct-tag] \ No newline at end of file +[rule.unnecessary-stmt] +[rule.var-declaration] +[rule.var-naming] + arguments = [["ID", "UID", "UUID", "URL", "JSON"], []] \ No newline at end of file diff --git a/convert.go b/convert.go index c19d30e0..b7f30cad 100644 --- a/convert.go +++ b/convert.go @@ -175,7 +175,10 @@ func convertAssign(dest, src interface{}) error { return nil } - dpv := reflect.ValueOf(dest) + return convertAssignV(reflect.ValueOf(dest), src) +} + +func convertAssignV(dpv reflect.Value, src interface{}) error { if dpv.Kind() != reflect.Ptr { return errors.New("destination not a pointer") } @@ -183,9 +186,7 @@ func convertAssign(dest, src interface{}) error { return errNilPtr } - if !sv.IsValid() { - sv = reflect.ValueOf(src) - } + var sv = reflect.ValueOf(src) dv := reflect.Indirect(dpv) if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { @@ -244,7 +245,7 @@ func convertAssign(dest, src interface{}) error { return nil } - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dpv.Interface()) } func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { @@ -375,48 +376,3 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) { } return v.Interface(), nil } - -func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { - var v interface{} - kind := tp.Kind() - - if kind == reflect.Ptr { - kind = tp.Elem().Kind() - } - - switch kind { - case reflect.Int16: - temp := int16(id) - v = &temp - case reflect.Int32: - temp := int32(id) - v = &temp - case reflect.Int: - temp := int(id) - v = &temp - case reflect.Int64: - temp := id - v = &temp - case reflect.Uint16: - temp := uint16(id) - v = &temp - case reflect.Uint32: - temp := uint32(id) - v = &temp - case reflect.Uint64: - temp := uint64(id) - v = &temp - case reflect.Uint: - temp := uint(id) - v = &temp - } - - if tp.Kind() == reflect.Ptr { - return reflect.ValueOf(v).Convert(tp) - } - return reflect.ValueOf(v).Elem().Convert(tp) -} - -func int64ToInt(id int64, tp reflect.Type) interface{} { - return int64ToIntValue(id, tp).Interface() -} diff --git a/dialects/filter.go b/dialects/filter.go index 2a36a731..bfe2e93e 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -23,13 +23,45 @@ type SeqFilter struct { func convertQuestionMark(sql, prefix string, start int) string { var buf strings.Builder var beginSingleQuote bool + var isLineComment bool + var isComment bool + var isMaybeLineComment bool + var isMaybeComment bool + var isMaybeCommentEnd bool var index = start for _, c := range sql { - if !beginSingleQuote && c == '?' { + if !beginSingleQuote && !isLineComment && !isComment && c == '?' { buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) index++ } else { - if c == '\'' { + if isMaybeLineComment { + if c == '-' { + isLineComment = true + } + isMaybeLineComment = false + } else if isMaybeComment { + if c == '*' { + isComment = true + } + isMaybeComment = false + } else if isMaybeCommentEnd { + if c == '/' { + isComment = false + } + isMaybeCommentEnd = false + } else if isLineComment { + if c == '\n' { + isLineComment = false + } + } else if isComment { + if c == '*' { + isMaybeCommentEnd = true + } + } else if !beginSingleQuote && c == '-' { + isMaybeLineComment = true + } else if !beginSingleQuote && c == '/' { + isMaybeComment = true + } else if c == '\'' { beginSingleQuote = !beginSingleQuote } buf.WriteRune(c) diff --git a/dialects/filter_test.go b/dialects/filter_test.go index 7e2ef0a2..15050656 100644 --- a/dialects/filter_test.go +++ b/dialects/filter_test.go @@ -19,3 +19,60 @@ func TestSeqFilter(t *testing.T) { assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) } } + +func TestSeqFilterLineComment(t *testing.T) { + var kases = map[string]string{ + `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=? -- it's a comment + AND b=?`: `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=$1 -- it's a comment + AND b=$2`, + `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=? -- it's a comment? + AND b=?`: `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=$1 -- it's a comment? + AND b=$2`, + `SELECT * + FROM TABLE1 + WHERE a=? -- it's a comment? and that's okay? + AND b=?`: `SELECT * + FROM TABLE1 + WHERE a=$1 -- it's a comment? and that's okay? + AND b=$2`, + } + for sql, result := range kases { + assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + } +} + +func TestSeqFilterComment(t *testing.T) { + var kases = map[string]string{ + `SELECT * + FROM TABLE1 + WHERE a=? /* it's a comment */ + AND b=?`: `SELECT * + FROM TABLE1 + WHERE a=$1 /* it's a comment */ + AND b=$2`, + `SELECT /* it's a comment * ? + More comment on the next line! */ * + FROM TABLE1 + WHERE a=? /**/ + AND b=?`: `SELECT /* it's a comment * ? + More comment on the next line! */ * + FROM TABLE1 + WHERE a=$1 /**/ + AND b=$2`, + } + for sql, result := range kases { + assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + } +} diff --git a/dialects/postgres.go b/dialects/postgres.go index 52c88567..9acf763a 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1044,12 +1044,13 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} - s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, + s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, description, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey FROM pg_attribute f JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum + LEFT JOIN pg_description de ON f.attrelid=de.objoid AND f.attnum=de.objsubid LEFT JOIN pg_namespace n ON n.oid = c.relnamespace LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid @@ -1078,9 +1079,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A col.Indexes = make(map[string]int) var colName, isNullable, dataType string - var maxLenStr, colDefault *string + var maxLenStr, colDefault, description *string var isPK, isUnique bool - err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &isPK, &isUnique) + err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &description, &isPK, &isUnique) if err != nil { return nil, nil, err } @@ -1126,6 +1127,10 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A col.DefaultIsEmpty = true } + if description != nil { + col.Comment = *description + } + if isPK { col.IsPrimaryKey = true } diff --git a/engine.go b/engine.go index 649ec1a2..76ce8f1a 100644 --- a/engine.go +++ b/engine.go @@ -652,11 +652,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return errors.New("unknown column error") } - fields := strings.Split(col.FieldName, ".") - field := dataStruct - for _, fieldName := range fields { - field = field.FieldByName(fieldName) - } + field := dataStruct.FieldByIndex(col.FieldIndex) temp += "," + formatColumnValue(dstDialect, field.Interface(), col) } _, err = io.WriteString(w, temp[1:]+");\n") diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 9b70f9b5..a06d91aa 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -176,6 +176,23 @@ func TestDumpTables(t *testing.T) { } } +func TestDumpTables2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestDumpTableStruct2 struct { + Id int64 + Created time.Time `xorm:"Default CURRENT_TIMESTAMP"` + } + + assertSync(t, new(TestDumpTableStruct2)) + + fp := fmt.Sprintf("./dump2-%v-table.sql", testEngine.Dialect().URI().DBType) + os.Remove(fp) + tb, err := testEngine.TableInfo(new(TestDumpTableStruct2)) + assert.NoError(t, err) + assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, fp)) +} + func TestSetSchema(t *testing.T) { assert.NoError(t, PrepareEngine()) @@ -209,3 +226,39 @@ func TestDBVersion(t *testing.T) { fmt.Println(testEngine.Dialect().URI().DBType, "version is", version) } + +func TestGetColumns(t *testing.T) { + if testEngine.Dialect().URI().DBType != schemas.POSTGRES { + t.Skip() + return + } + type TestCommentStruct struct { + HasComment int + NoComment int + } + + assertSync(t, new(TestCommentStruct)) + + comment := "this is a comment" + sql := fmt.Sprintf("comment on column %s.%s is '%s'", testEngine.TableName(new(TestCommentStruct), true), "has_comment", comment) + _, err := testEngine.Exec(sql) + assert.NoError(t, err) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + tableName := testEngine.GetColumnMapper().Obj2Table("TestCommentStruct") + var hasComment, noComment string + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("has_comment") + assert.NotNil(t, col) + hasComment = col.Comment + col2 := table.GetColumn("no_comment") + assert.NotNil(t, col2) + noComment = col2.Comment + break + } + } + assert.Equal(t, comment, hasComment) + assert.Zero(t, noComment) +} diff --git a/integrations/session_insert_test.go b/integrations/session_insert_test.go index eaa1b2c7..e5d880ae 100644 --- a/integrations/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -32,7 +32,6 @@ func TestInsertOne(t *testing.T) { } func TestInsertMulti(t *testing.T) { - assert.NoError(t, PrepareEngine()) type TestMulti struct { Id int64 `xorm:"int(11) pk"` @@ -78,7 +77,6 @@ func insertMultiDatas(step int, datas interface{}) (num int64, err error) { } func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) error) (err error) { - sliceValue := reflect.Indirect(reflect.ValueOf(datas)) if sliceValue.Kind() != reflect.Slice { return fmt.Errorf("not slice") diff --git a/integrations/session_update_test.go b/integrations/session_update_test.go index 15d2f694..796bfa0a 100644 --- a/integrations/session_update_test.go +++ b/integrations/session_update_test.go @@ -472,6 +472,11 @@ func TestUpdateIncrDecr(t *testing.T) { cnt, err = testEngine.ID(col1.Id).Cols(colName).Incr(colName).Update(col1) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + + testEngine.SetColumnMapper(testEngine.GetColumnMapper()) + cnt, err = testEngine.Cols(colName).Decr(colName, 2).ID(col1.Id).Update(new(UpdateIncr)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) } type UpdatedUpdate struct { diff --git a/internal/statements/expr.go b/internal/statements/expr.go index b44c96ca..c2a2e1cc 100644 --- a/internal/statements/expr.go +++ b/internal/statements/expr.go @@ -27,6 +27,7 @@ type Expr struct { Arg interface{} } +// WriteArgs writes args to the writer func (expr *Expr) WriteArgs(w *builder.BytesWriter) error { switch arg := expr.Arg.(type) { case *builder.Builder: diff --git a/internal/statements/query.go b/internal/statements/query.go index e1091e9f..a972a8e0 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -343,7 +343,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac var args []interface{} var joinStr string var err error - var b interface{} = nil + var b interface{} if len(bean) > 0 { b = bean[0] beanValue := reflect.ValueOf(bean[0]) diff --git a/internal/statements/statement.go b/internal/statements/statement.go index a52c6ca2..b1a5ed3c 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -208,20 +208,18 @@ func (statement *Statement) quote(s string) string { // And add Where & and statement func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { - switch query.(type) { + switch qr := query.(type) { case string: - cond := builder.Expr(query.(string), args...) + cond := builder.Expr(qr, args...) statement.cond = statement.cond.And(cond) case map[string]interface{}: - queryMap := query.(map[string]interface{}) - newMap := make(map[string]interface{}) - for k, v := range queryMap { - newMap[statement.quote(k)] = v + cond := make(builder.Eq) + for k, v := range qr { + cond[statement.quote(k)] = v } - statement.cond = statement.cond.And(builder.Eq(newMap)) - case builder.Cond: - cond := query.(builder.Cond) statement.cond = statement.cond.And(cond) + case builder.Cond: + statement.cond = statement.cond.And(qr) for _, v := range args { if vv, ok := v.(builder.Cond); ok { statement.cond = statement.cond.And(vv) @@ -236,23 +234,25 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme // Or add Where & Or statement func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { - switch query.(type) { + switch qr := query.(type) { case string: - cond := builder.Expr(query.(string), args...) + cond := builder.Expr(qr, args...) statement.cond = statement.cond.Or(cond) case map[string]interface{}: - cond := builder.Eq(query.(map[string]interface{})) + cond := make(builder.Eq) + for k, v := range qr { + cond[statement.quote(k)] = v + } statement.cond = statement.cond.Or(cond) case builder.Cond: - cond := query.(builder.Cond) - statement.cond = statement.cond.Or(cond) + statement.cond = statement.cond.Or(qr) for _, v := range args { if vv, ok := v.(builder.Cond); ok { statement.cond = statement.cond.Or(vv) } } default: - // TODO: not support condition type + statement.LastError = ErrConditionType } return statement } @@ -734,6 +734,8 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, //engine.logger.Warn(err) } continue + } else if fieldValuePtr == nil { + continue } if col.IsDeleted && !unscoped { // tag "deleted" is enabled diff --git a/internal/statements/statement_test.go b/internal/statements/statement_test.go index 15f446f4..ba92330e 100644 --- a/internal/statements/statement_test.go +++ b/internal/statements/statement_test.go @@ -78,7 +78,6 @@ func TestColumnsStringGeneration(t *testing.T) { } func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { - b.StopTimer() mapCols := make(map[string]bool) @@ -101,9 +100,7 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - for _, col := range cols { - if _, ok := getFlagForColumn(mapCols, col); !ok { b.Fatal("Unexpected result") } @@ -112,7 +109,6 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { } func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { - b.StopTimer() mapCols := make(map[string]bool) @@ -131,9 +127,7 @@ func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - for _, col := range cols { - if _, ok := getFlagForColumn(mapCols, col); ok { b.Fatal("Unexpected result") } diff --git a/internal/statements/update.go b/internal/statements/update.go index 251880b2..06cf0689 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -88,6 +88,9 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, if err != nil { return nil, nil, err } + if fieldValuePtr == nil { + continue + } fieldValue := *fieldValuePtr fieldType := reflect.TypeOf(fieldValue.Interface()) diff --git a/log/logger.go b/log/logger.go index eeb63693..3b6db34e 100644 --- a/log/logger.go +++ b/log/logger.go @@ -132,7 +132,6 @@ func (s *SimpleLogger) Error(v ...interface{}) { if s.level <= LOG_ERR { s.ERR.Output(2, fmt.Sprintln(v...)) } - return } // Errorf implement ILogger @@ -140,7 +139,6 @@ func (s *SimpleLogger) Errorf(format string, v ...interface{}) { if s.level <= LOG_ERR { s.ERR.Output(2, fmt.Sprintf(format, v...)) } - return } // Debug implement ILogger @@ -148,7 +146,6 @@ func (s *SimpleLogger) Debug(v ...interface{}) { if s.level <= LOG_DEBUG { s.DEBUG.Output(2, fmt.Sprintln(v...)) } - return } // Debugf implement ILogger @@ -156,7 +153,6 @@ func (s *SimpleLogger) Debugf(format string, v ...interface{}) { if s.level <= LOG_DEBUG { s.DEBUG.Output(2, fmt.Sprintf(format, v...)) } - return } // Info implement ILogger @@ -164,7 +160,6 @@ func (s *SimpleLogger) Info(v ...interface{}) { if s.level <= LOG_INFO { s.INFO.Output(2, fmt.Sprintln(v...)) } - return } // Infof implement ILogger @@ -172,7 +167,6 @@ func (s *SimpleLogger) Infof(format string, v ...interface{}) { if s.level <= LOG_INFO { s.INFO.Output(2, fmt.Sprintf(format, v...)) } - return } // Warn implement ILogger @@ -180,7 +174,6 @@ func (s *SimpleLogger) Warn(v ...interface{}) { if s.level <= LOG_WARNING { s.WARN.Output(2, fmt.Sprintln(v...)) } - return } // Warnf implement ILogger @@ -188,7 +181,6 @@ func (s *SimpleLogger) Warnf(format string, v ...interface{}) { if s.level <= LOG_WARNING { s.WARN.Output(2, fmt.Sprintf(format, v...)) } - return } // Level implement ILogger @@ -199,7 +191,6 @@ func (s *SimpleLogger) Level() LogLevel { // SetLevel implement ILogger func (s *SimpleLogger) SetLevel(l LogLevel) { s.level = l - return } // ShowSQL implement ILogger diff --git a/schemas/column.go b/schemas/column.go index 24b53802..4bbb6c2d 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -6,10 +6,8 @@ package schemas import ( "errors" - "fmt" "reflect" "strconv" - "strings" "time" ) @@ -25,6 +23,7 @@ type Column struct { Name string TableName string FieldName string // Available only when parsed from a struct + FieldIndex []int // Available only when parsed from a struct SQLType SQLType IsJSON bool Length int @@ -83,41 +82,17 @@ func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) { // ValueOfV returns column's filed of struct's value accept reflevt value func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { - var fieldValue reflect.Value - fieldPath := strings.Split(col.FieldName, ".") - - if dataStruct.Type().Kind() == reflect.Map { - keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1]) - fieldValue = dataStruct.MapIndex(keyValue) - return &fieldValue, nil - } else if dataStruct.Type().Kind() == reflect.Interface { - structValue := reflect.ValueOf(dataStruct.Interface()) - dataStruct = &structValue - } - - level := len(fieldPath) - fieldValue = dataStruct.FieldByName(fieldPath[0]) - for i := 0; i < level-1; i++ { - if !fieldValue.IsValid() { - break - } - if fieldValue.Kind() == reflect.Struct { - fieldValue = fieldValue.FieldByName(fieldPath[i+1]) - } else if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + var v = *dataStruct + for _, i := range col.FieldIndex { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) } - fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) - } else { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) + v = v.Elem() } + v = v.FieldByIndex([]int{i}) } - - if !fieldValue.IsValid() { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) - } - - return &fieldValue, nil + return &v, nil } // ConvertID converts id content to suitable type according column type diff --git a/schemas/table.go b/schemas/table.go index bfa517aa..91b33e06 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -5,7 +5,6 @@ package schemas import ( - "fmt" "reflect" "strconv" "strings" @@ -159,24 +158,8 @@ func (table *Table) IDOfV(rv reflect.Value) (PK, error) { for i, col := range table.PKColumns() { var err error - fieldName := col.FieldName - for { - parts := strings.SplitN(fieldName, ".", 2) - if len(parts) == 1 { - break - } + pkField := v.FieldByIndex(col.FieldIndex) - v = v.FieldByName(parts[0]) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - if v.Kind() != reflect.Struct { - return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName) - } - fieldName = parts[1] - } - - pkField := v.FieldByName(fieldName) switch pkField.Kind() { case reflect.String: pk[i], err = col.ConvertID(pkField.String()) diff --git a/schemas/table_test.go b/schemas/table_test.go index 9bf10e33..0e35193f 100644 --- a/schemas/table_test.go +++ b/schemas/table_test.go @@ -27,7 +27,6 @@ var testsGetColumn = []struct { var table *Table func init() { - table = NewEmptyTable() var name string @@ -41,7 +40,6 @@ func init() { } func TestGetColumn(t *testing.T) { - for _, test := range testsGetColumn { if table.GetColumn(test.name) == nil { t.Error("Column not found!") @@ -50,7 +48,6 @@ func TestGetColumn(t *testing.T) { } func TestGetColumnIdx(t *testing.T) { - for _, test := range testsGetColumn { if table.GetColumnIdx(test.name, test.idx) == nil { t.Errorf("Column %s with idx %d not found!", test.name, test.idx) @@ -59,7 +56,6 @@ func TestGetColumnIdx(t *testing.T) { } func BenchmarkGetColumnWithToLower(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { @@ -71,7 +67,6 @@ func BenchmarkGetColumnWithToLower(b *testing.B) { } func BenchmarkGetColumnIdxWithToLower(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { @@ -89,7 +84,6 @@ func BenchmarkGetColumnIdxWithToLower(b *testing.B) { } func BenchmarkGetColumn(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { if table.GetColumn(test.name) == nil { @@ -100,7 +94,6 @@ func BenchmarkGetColumn(b *testing.B) { } func BenchmarkGetColumnIdx(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { if table.GetColumnIdx(test.name, test.idx) == nil { diff --git a/session.go b/session.go index d5ccb6dc..6df9e20d 100644 --- a/session.go +++ b/session.go @@ -375,6 +375,9 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s if err != nil { return nil, err } + if fieldValue == nil { + return nil, ErrFieldIsNotValid{key, table.Name} + } if !fieldValue.IsValid() || !fieldValue.CanSet() { return nil, ErrFieldIsNotValid{key, table.Name} diff --git a/session_convert.go b/session_convert.go index a6839947..b8218a77 100644 --- a/session_convert.go +++ b/session_convert.go @@ -35,27 +35,20 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time sd, err := strconv.ParseInt(sdata, 10, 64) if err == nil { x = time.Unix(sd, 0) - //session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } else { - //session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } } else if len(sdata) > 19 && strings.Contains(sdata, "-") { x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc) - session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) + session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.Name, x, sdata) if err != nil { x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc) - //session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } if err != nil { x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc) - //session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } } else if len(sdata) == 19 && strings.Contains(sdata, "-") { x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc) - //session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) - //session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else if col.SQLType.Name == schemas.Time { if strings.Contains(sdata, " ") { ssd := strings.Split(sdata, " ") @@ -69,7 +62,6 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time st := fmt.Sprintf("2006-01-02 %v", sdata) x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc) - //session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else { outErr = fmt.Errorf("unsupported time format %v", sdata) return diff --git a/session_insert.go b/session_insert.go index 5f968151..82d91969 100644 --- a/session_insert.go +++ b/session_insert.go @@ -374,9 +374,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) - - return 1, nil + return 1, convertAssignV(aiValue.Addr(), id) } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) @@ -416,9 +414,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) - - return 1, nil + return 1, convertAssignV(aiValue.Addr(), id) } res, err := session.exec(sqlStr, args...) @@ -458,7 +454,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) + if err := convertAssignV(aiValue.Addr(), id); err != nil { + return 0, err + } return res.RowsAffected() } diff --git a/session_update.go b/session_update.go index f791bb2d..78907e43 100644 --- a/session_update.go +++ b/session_update.go @@ -280,15 +280,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 k = ct.Elem().Kind() } if k == reflect.Struct { - var refTable = session.statement.RefTable - if refTable == nil { - refTable, err = session.engine.TableInfo(condiBean[0]) - if err != nil { - return 0, err - } + condTable, err := session.engine.TableInfo(condiBean[0]) + if err != nil { + return 0, err } - var err error - autoCond, err = session.statement.BuildConds(refTable, condiBean[0], true, true, false, true, false) + + autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false) if err != nil { return 0, err } @@ -457,7 +454,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 // FIXME: if bean is a map type, it will panic because map cannot be as map key session.afterUpdateBeans[bean] = &afterClosures } - } else { if _, ok := interface{}(bean).(AfterUpdateProcessor); ok { session.afterUpdateBeans[bean] = nil diff --git a/tags/parser.go b/tags/parser.go index ff329daa..d701e316 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -7,7 +7,6 @@ package tags import ( "encoding/gob" "errors" - "fmt" "reflect" "strings" "sync" @@ -23,7 +22,7 @@ import ( var ( // ErrUnsupportedType represents an unsupported type error - ErrUnsupportedType = errors.New("Unsupported type") + ErrUnsupportedType = errors.New("unsupported type") ) // Parser represents a parser for xorm tag @@ -125,6 +124,147 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index } } +var ErrIgnoreField = errors.New("field will be ignored") + +func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { + var sqlType schemas.SQLType + if fieldValue.CanAddr() { + if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } + } + if _, ok := fieldValue.Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } else { + sqlType = schemas.Type2SQLType(field.Type) + } + col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name), + field.Name, sqlType, sqlType.DefaultLength, + sqlType.DefaultLength2, true) + col.FieldIndex = []int{fieldIndex} + + if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { + col.IsAutoIncrement = true + col.IsPrimaryKey = true + col.Nullable = false + } + return col, nil +} + +func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) { + var col = &schemas.Column{ + FieldName: field.Name, + FieldIndex: []int{fieldIndex}, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: schemas.TWOSIDES, + Indexes: make(map[string]int), + DefaultIsEmpty: true, + } + + var ctx = Context{ + table: table, + col: col, + fieldValue: fieldValue, + indexNames: make(map[string]int), + parser: parser, + } + + for j, tag := range tags { + if ctx.ignoreNext { + ctx.ignoreNext = false + continue + } + + ctx.tag = tag + ctx.tagUname = strings.ToUpper(tag.name) + + if j > 0 { + ctx.preTag = strings.ToUpper(tags[j-1].name) + } + if j < len(tags)-1 { + ctx.nextTag = tags[j+1].name + } else { + ctx.nextTag = "" + } + + if h, ok := parser.handlers[ctx.tagUname]; ok { + if err := h(&ctx); err != nil { + return nil, err + } + } else { + if strings.HasPrefix(ctx.tag.name, "'") && strings.HasSuffix(ctx.tag.name, "'") { + col.Name = ctx.tag.name[1 : len(ctx.tag.name)-1] + } else { + col.Name = ctx.tag.name + } + } + + if ctx.hasCacheTag { + if parser.cacherMgr.GetDefaultCacher() != nil { + parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher()) + } else { + parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)) + } + } + if ctx.hasNoCacheTag { + parser.cacherMgr.SetCacher(table.Name, nil) + } + } + + if col.SQLType.Name == "" { + col.SQLType = schemas.Type2SQLType(field.Type) + } + parser.dialect.SQLType(col) + if col.Length == 0 { + col.Length = col.SQLType.DefaultLength + } + if col.Length2 == 0 { + col.Length2 = col.SQLType.DefaultLength2 + } + if col.Name == "" { + col.Name = parser.columnMapper.Obj2Table(field.Name) + } + + if ctx.isUnique { + ctx.indexNames[col.Name] = schemas.UniqueType + } else if ctx.isIndex { + ctx.indexNames[col.Name] = schemas.IndexType + } + + for indexName, indexType := range ctx.indexNames { + addIndex(indexName, table, col, indexType) + } + + return col, nil +} + +func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { + var ( + tag = field.Tag + ormTagStr = strings.TrimSpace(tag.Get(parser.identifier)) + ) + if ormTagStr == "-" { + return nil, ErrIgnoreField + } + if ormTagStr == "" { + return parser.parseFieldWithNoTag(fieldIndex, field, fieldValue) + } + tags, err := splitTag(ormTagStr) + if err != nil { + return nil, err + } + return parser.parseFieldWithTags(table, fieldIndex, field, fieldValue, tags) +} + +func isNotTitle(n string) bool { + for _, c := range n { + return unicode.IsLower(c) + } + return true +} + // Parse parses a struct as a table information func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { t := v.Type() @@ -140,193 +280,21 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { table.Type = t table.Name = names.GetTableName(parser.tableMapper, v) - var idFieldColName string - var hasCacheTag, hasNoCacheTag bool - for i := 0; i < t.NumField(); i++ { - var isUnexportField bool - for _, c := range t.Field(i).Name { - if unicode.IsLower(c) { - isUnexportField = true - } - break - } - if isUnexportField { + var field = t.Field(i) + if isNotTitle(field.Name) { continue } - tag := t.Field(i).Tag - ormTagStr := tag.Get(parser.identifier) - var col *schemas.Column - fieldValue := v.Field(i) - fieldType := fieldValue.Type() - - if ormTagStr != "" { - col = &schemas.Column{ - FieldName: t.Field(i).Name, - Nullable: true, - IsPrimaryKey: false, - IsAutoIncrement: false, - MapType: schemas.TWOSIDES, - Indexes: make(map[string]int), - DefaultIsEmpty: true, - } - tags := splitTag(ormTagStr) - - if len(tags) > 0 { - if tags[0] == "-" { - continue - } - - var ctx = Context{ - table: table, - col: col, - fieldValue: fieldValue, - indexNames: make(map[string]int), - parser: parser, - } - - if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") { - pStart := strings.Index(tags[0], "(") - if pStart > -1 && strings.HasSuffix(tags[0], ")") { - var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool { - return r == '\'' || r == '"' - }) - - ctx.params = []string{tagPrefix} - } - - if err := ExtendsTagHandler(&ctx); err != nil { - return nil, err - } - continue - } - - for j, key := range tags { - if ctx.ignoreNext { - ctx.ignoreNext = false - continue - } - - k := strings.ToUpper(key) - ctx.tagName = k - ctx.params = []string{} - - pStart := strings.Index(k, "(") - if pStart == 0 { - return nil, errors.New("( could not be the first character") - } - if pStart > -1 { - if !strings.HasSuffix(k, ")") { - return nil, fmt.Errorf("field %s tag %s cannot match ) character", col.FieldName, key) - } - - ctx.tagName = k[:pStart] - ctx.params = strings.Split(key[pStart+1:len(k)-1], ",") - } - - if j > 0 { - ctx.preTag = strings.ToUpper(tags[j-1]) - } - if j < len(tags)-1 { - ctx.nextTag = tags[j+1] - } else { - ctx.nextTag = "" - } - - if h, ok := parser.handlers[ctx.tagName]; ok { - if err := h(&ctx); err != nil { - return nil, err - } - } else { - if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") { - col.Name = key[1 : len(key)-1] - } else { - col.Name = key - } - } - - if ctx.hasCacheTag { - hasCacheTag = true - } - if ctx.hasNoCacheTag { - hasNoCacheTag = true - } - } - - if col.SQLType.Name == "" { - col.SQLType = schemas.Type2SQLType(fieldType) - } - parser.dialect.SQLType(col) - if col.Length == 0 { - col.Length = col.SQLType.DefaultLength - } - if col.Length2 == 0 { - col.Length2 = col.SQLType.DefaultLength2 - } - if col.Name == "" { - col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name) - } - - if ctx.isUnique { - ctx.indexNames[col.Name] = schemas.UniqueType - } else if ctx.isIndex { - ctx.indexNames[col.Name] = schemas.IndexType - } - - for indexName, indexType := range ctx.indexNames { - addIndex(indexName, table, col, indexType) - } - } - } else { - var sqlType schemas.SQLType - if fieldValue.CanAddr() { - if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - sqlType = schemas.SQLType{Name: schemas.Text} - } - } - if _, ok := fieldValue.Interface().(convert.Conversion); ok { - sqlType = schemas.SQLType{Name: schemas.Text} - } else { - sqlType = schemas.Type2SQLType(fieldType) - } - col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name), - t.Field(i).Name, sqlType, sqlType.DefaultLength, - sqlType.DefaultLength2, true) - - if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { - idFieldColName = col.Name - } - } - if col.IsAutoIncrement { - col.Nullable = false + col, err := parser.parseField(table, i, field, v.Field(i)) + if err == ErrIgnoreField { + continue + } else if err != nil { + return nil, err } table.AddColumn(col) } // end for - if idFieldColName != "" && len(table.PrimaryKeys) == 0 { - col := table.GetColumn(idFieldColName) - col.IsPrimaryKey = true - col.IsAutoIncrement = true - col.Nullable = false - table.PrimaryKeys = append(table.PrimaryKeys, col.Name) - table.AutoIncrement = col.Name - } - - if hasCacheTag { - if parser.cacherMgr.GetDefaultCacher() != nil { // !nash! use engine's cacher if provided - //engine.logger.Info("enable cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher()) - } else { - //engine.logger.Info("enable LRU cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)) - } - } - if hasNoCacheTag { - //engine.logger.Info("disable cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, nil) - } - return table, nil } diff --git a/tags/parser_test.go b/tags/parser_test.go index 5add1e13..70c57692 100644 --- a/tags/parser_test.go +++ b/tags/parser_test.go @@ -6,12 +6,16 @@ package tags import ( "reflect" + "strings" "testing" + "time" - "github.com/stretchr/testify/assert" "xorm.io/xorm/caches" "xorm.io/xorm/dialects" "xorm.io/xorm/names" + "xorm.io/xorm/schemas" + + "github.com/stretchr/testify/assert" ) type ParseTableName1 struct{} @@ -80,7 +84,7 @@ func TestParseWithOtherIdentifier(t *testing.T) { parser := NewParser( "xorm", dialects.QueryDialect("mysql"), - names.GonicMapper{}, + names.SameMapper{}, names.SnakeMapper{}, caches.NewManager(), ) @@ -88,13 +92,461 @@ func TestParseWithOtherIdentifier(t *testing.T) { type StructWithDBTag struct { FieldFoo string `db:"foo"` } + parser.SetIdentifier("db") table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag))) assert.NoError(t, err) - assert.EqualValues(t, "struct_with_db_tag", table.Name) + assert.EqualValues(t, "StructWithDBTag", table.Name) assert.EqualValues(t, 1, len(table.Columns())) for _, col := range table.Columns() { assert.EqualValues(t, "foo", col.Name) } } + +func TestParseWithIgnore(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SameMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithIgnoreTag struct { + FieldFoo string `db:"-"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithIgnoreTag))) + assert.NoError(t, err) + assert.EqualValues(t, "StructWithIgnoreTag", table.Name) + assert.EqualValues(t, 0, len(table.Columns())) +} + +func TestParseWithAutoincrement(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithAutoIncrement struct { + ID int64 + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_auto_increment", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "id", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsAutoIncrement) + assert.True(t, table.Columns()[0].IsPrimaryKey) +} + +func TestParseWithAutoincrement2(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithAutoIncrement2 struct { + ID int64 `db:"pk autoincr"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement2))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_auto_increment2", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "id", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsAutoIncrement) + assert.True(t, table.Columns()[0].IsPrimaryKey) + assert.False(t, table.Columns()[0].Nullable) +} + +func TestParseWithNullable(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithNullable struct { + Name string `db:"notnull"` + FullName string `db:"null comment('column comment,字段注释')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithNullable))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_nullable", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "full_name", table.Columns()[1].Name) + assert.False(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.EqualValues(t, "column comment,字段注释", table.Columns()[1].Comment) +} + +func TestParseWithTimes(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithTimes struct { + Name string `db:"notnull"` + CreatedAt time.Time `db:"created"` + UpdatedAt time.Time `db:"updated"` + DeletedAt time.Time `db:"deleted"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithTimes))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_times", table.Name) + assert.EqualValues(t, 4, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "created_at", table.Columns()[1].Name) + assert.EqualValues(t, "updated_at", table.Columns()[2].Name) + assert.EqualValues(t, "deleted_at", table.Columns()[3].Name) + assert.False(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsCreated) + assert.True(t, table.Columns()[2].Nullable) + assert.True(t, table.Columns()[2].IsUpdated) + assert.True(t, table.Columns()[3].Nullable) + assert.True(t, table.Columns()[3].IsDeleted) +} + +func TestParseWithExtends(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithEmbed struct { + Name string + CreatedAt time.Time `db:"created"` + UpdatedAt time.Time `db:"updated"` + DeletedAt time.Time `db:"deleted"` + } + + type StructWithExtends struct { + SW StructWithEmbed `db:"extends"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithExtends))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_extends", table.Name) + assert.EqualValues(t, 4, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "created_at", table.Columns()[1].Name) + assert.EqualValues(t, "updated_at", table.Columns()[2].Name) + assert.EqualValues(t, "deleted_at", table.Columns()[3].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsCreated) + assert.True(t, table.Columns()[2].Nullable) + assert.True(t, table.Columns()[2].IsUpdated) + assert.True(t, table.Columns()[3].Nullable) + assert.True(t, table.Columns()[3].IsDeleted) +} + +func TestParseWithCache(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithCache struct { + Name string `db:"cache"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithCache))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_cache", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + cacher := parser.cacherMgr.GetCacher(table.Name) + assert.NotNil(t, cacher) +} + +func TestParseWithNoCache(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithNoCache struct { + Name string `db:"nocache"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithNoCache))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_no_cache", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + cacher := parser.cacherMgr.GetCacher(table.Name) + assert.Nil(t, cacher) +} + +func TestParseWithEnum(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithEnum struct { + Name string `db:"enum('alice', 'bob')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithEnum))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_enum", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.EqualValues(t, schemas.Enum, strings.ToUpper(table.Columns()[0].SQLType.Name)) + assert.EqualValues(t, map[string]int{ + "alice": 0, + "bob": 1, + }, table.Columns()[0].EnumOptions) +} + +func TestParseWithSet(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithSet struct { + Name string `db:"set('alice', 'bob')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithSet))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_set", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.EqualValues(t, schemas.Set, strings.ToUpper(table.Columns()[0].SQLType.Name)) + assert.EqualValues(t, map[string]int{ + "alice": 0, + "bob": 1, + }, table.Columns()[0].SetOptions) +} + +func TestParseWithIndex(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithIndex struct { + Name string `db:"index"` + Name2 string `db:"index(s)"` + Name3 string `db:"unique"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithIndex))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_index", table.Name) + assert.EqualValues(t, 3, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "name2", table.Columns()[1].Name) + assert.EqualValues(t, "name3", table.Columns()[2].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[2].Nullable) + assert.EqualValues(t, 1, len(table.Columns()[0].Indexes)) + assert.EqualValues(t, 1, len(table.Columns()[1].Indexes)) + assert.EqualValues(t, 1, len(table.Columns()[2].Indexes)) +} + +func TestParseWithVersion(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithVersion struct { + Name string + Version int `db:"version"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithVersion))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_version", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "version", table.Columns()[1].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsVersion) +} + +func TestParseWithLocale(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithLocale struct { + UTCLocale time.Time `db:"utc"` + LocalLocale time.Time `db:"local"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithLocale))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_locale", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "utc_locale", table.Columns()[0].Name) + assert.EqualValues(t, "local_locale", table.Columns()[1].Name) + assert.EqualValues(t, time.UTC, table.Columns()[0].TimeZone) + assert.EqualValues(t, time.Local, table.Columns()[1].TimeZone) +} + +func TestParseWithDefault(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithDefault struct { + Default1 time.Time `db:"default '1970-01-01 00:00:00'"` + Default2 time.Time `db:"default(CURRENT_TIMESTAMP)"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithDefault))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_default", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.EqualValues(t, "default2", table.Columns()[1].Name) + assert.EqualValues(t, "'1970-01-01 00:00:00'", table.Columns()[0].Default) + assert.EqualValues(t, "CURRENT_TIMESTAMP", table.Columns()[1].Default) +} + +func TestParseWithOnlyToDB(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "DB": true, + }, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithOnlyToDB struct { + Default1 time.Time `db:"->"` + Default2 time.Time `db:"<-"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithOnlyToDB))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_only_to_db", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.EqualValues(t, "default2", table.Columns()[1].Name) + assert.EqualValues(t, schemas.ONLYTODB, table.Columns()[0].MapType) + assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType) +} + +func TestParseWithJSON(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "JSON": true, + }, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithJSON struct { + Default1 []string `db:"json"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithJSON))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_json", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsJSON) +} + +func TestParseWithSQLType(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "SQL": true, + }, + names.GonicMapper{ + "UUID": true, + }, + caches.NewManager(), + ) + + type StructWithSQLType struct { + Col1 string `db:"varchar(32)"` + Col2 string `db:"char(32)"` + Int int64 `db:"bigint"` + DateTime time.Time `db:"datetime"` + UUID string `db:"uuid"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithSQLType))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_sql_type", table.Name) + assert.EqualValues(t, 5, len(table.Columns())) + assert.EqualValues(t, "col1", table.Columns()[0].Name) + assert.EqualValues(t, "col2", table.Columns()[1].Name) + assert.EqualValues(t, "int", table.Columns()[2].Name) + assert.EqualValues(t, "date_time", table.Columns()[3].Name) + assert.EqualValues(t, "uuid", table.Columns()[4].Name) + + assert.EqualValues(t, "VARCHAR", table.Columns()[0].SQLType.Name) + assert.EqualValues(t, "CHAR", table.Columns()[1].SQLType.Name) + assert.EqualValues(t, "BIGINT", table.Columns()[2].SQLType.Name) + assert.EqualValues(t, "DATETIME", table.Columns()[3].SQLType.Name) + assert.EqualValues(t, "UUID", table.Columns()[4].SQLType.Name) +} diff --git a/tags/tag.go b/tags/tag.go index bb5b5838..4a39ba54 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -14,30 +14,74 @@ import ( "xorm.io/xorm/schemas" ) -func splitTag(tag string) (tags []string) { - tag = strings.TrimSpace(tag) - var hasQuote = false - var lastIdx = 0 - for i, t := range tag { - if t == '\'' { - hasQuote = !hasQuote - } else if t == ' ' { - if lastIdx < i && !hasQuote { - tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) - lastIdx = i + 1 +type tag struct { + name string + params []string +} + +func splitTag(tagStr string) ([]tag, error) { + tagStr = strings.TrimSpace(tagStr) + var ( + inQuote bool + inBigQuote bool + lastIdx int + curTag tag + paramStart int + tags []tag + ) + for i, t := range tagStr { + switch t { + case '\'': + inQuote = !inQuote + case ' ': + if !inQuote && !inBigQuote { + if lastIdx < i { + if curTag.name == "" { + curTag.name = tagStr[lastIdx:i] + } + tags = append(tags, curTag) + lastIdx = i + 1 + curTag = tag{} + } else if lastIdx == i { + lastIdx = i + 1 + } + } else if inBigQuote && !inQuote { + paramStart = i + 1 + } + case ',': + if !inQuote && !inBigQuote { + return nil, fmt.Errorf("comma[%d] of %s should be in quote or big quote", i, tagStr) + } + if !inQuote && inBigQuote { + curTag.params = append(curTag.params, strings.TrimSpace(tagStr[paramStart:i])) + paramStart = i + 1 + } + case '(': + inBigQuote = true + if !inQuote { + curTag.name = tagStr[lastIdx:i] + paramStart = i + 1 + } + case ')': + inBigQuote = false + if !inQuote { + curTag.params = append(curTag.params, tagStr[paramStart:i]) } } } - if lastIdx < len(tag) { - tags = append(tags, strings.TrimSpace(tag[lastIdx:])) + if lastIdx < len(tagStr) { + if curTag.name == "" { + curTag.name = tagStr[lastIdx:] + } + tags = append(tags, curTag) } - return + return tags, nil } // Context represents a context for xorm tag parse. type Context struct { - tagName string - params []string + tag + tagUname string preTag, nextTag string table *schemas.Table col *schemas.Column @@ -76,6 +120,7 @@ var ( "CACHE": CacheTagHandler, "NOCACHE": NoCacheTagHandler, "COMMENT": CommentTagHandler, + "EXTENDS": ExtendsTagHandler, } ) @@ -124,6 +169,7 @@ func NotNullTagHandler(ctx *Context) error { // AutoIncrTagHandler describes autoincr tag handler func AutoIncrTagHandler(ctx *Context) error { ctx.col.IsAutoIncrement = true + ctx.col.Nullable = false /* if len(ctx.params) > 0 { autoStartInt, err := strconv.Atoi(ctx.params[0]) @@ -225,41 +271,44 @@ func CommentTagHandler(ctx *Context) error { // SQLTypeTagHandler describes SQL Type tag handler func SQLTypeTagHandler(ctx *Context) error { - ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName} - if strings.EqualFold(ctx.tagName, "JSON") { + ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} + if ctx.tagUname == "JSON" { ctx.col.IsJSON = true } - if len(ctx.params) > 0 { - if ctx.tagName == schemas.Enum { - ctx.col.EnumOptions = make(map[string]int) - for k, v := range ctx.params { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - ctx.col.EnumOptions[v] = k + if len(ctx.params) == 0 { + return nil + } + + switch ctx.tagUname { + case schemas.Enum: + ctx.col.EnumOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.EnumOptions[v] = k + } + case schemas.Set: + ctx.col.SetOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.SetOptions[v] = k + } + default: + var err error + if len(ctx.params) == 2 { + ctx.col.Length, err = strconv.Atoi(ctx.params[0]) + if err != nil { + return err } - } else if ctx.tagName == schemas.Set { - ctx.col.SetOptions = make(map[string]int) - for k, v := range ctx.params { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - ctx.col.SetOptions[v] = k + ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) + if err != nil { + return err } - } else { - var err error - if len(ctx.params) == 2 { - ctx.col.Length, err = strconv.Atoi(ctx.params[0]) - if err != nil { - return err - } - ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) - if err != nil { - return err - } - } else if len(ctx.params) == 1 { - ctx.col.Length, err = strconv.Atoi(ctx.params[0]) - if err != nil { - return err - } + } else if len(ctx.params) == 1 { + ctx.col.Length, err = strconv.Atoi(ctx.params[0]) + if err != nil { + return err } } } @@ -289,11 +338,12 @@ func ExtendsTagHandler(ctx *Context) error { } for _, col := range parentTable.Columns() { col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) + col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...) var tagPrefix = ctx.col.FieldName if len(ctx.params) > 0 { col.Nullable = isPtr - tagPrefix = ctx.params[0] + tagPrefix = strings.Trim(ctx.params[0], "'") if col.IsPrimaryKey { col.Name = ctx.col.FieldName col.IsPrimaryKey = false @@ -315,7 +365,7 @@ func ExtendsTagHandler(ctx *Context) error { default: //TODO: warning } - return nil + return ErrIgnoreField } // CacheTagHandler describes cache tag handler diff --git a/tags/tag_test.go b/tags/tag_test.go index 5775b40a..3ceeefd1 100644 --- a/tags/tag_test.go +++ b/tags/tag_test.go @@ -7,24 +7,83 @@ package tags import ( "testing" - "xorm.io/xorm/internal/utils" + "github.com/stretchr/testify/assert" ) func TestSplitTag(t *testing.T) { var cases = []struct { tag string - tags []string + tags []tag }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ + { + name: "not", + }, + { + name: "null", + }, + { + name: "default", + }, + { + name: "'2000-01-01 00:00:00'", + }, + { + name: "TIMESTAMP", + }, + }, + }, + {"TEXT", []tag{ + { + name: "TEXT", + }, + }, + }, + {"default('2000-01-01 00:00:00')", []tag{ + { + name: "default", + params: []string{ + "'2000-01-01 00:00:00'", + }, + }, + }, + }, + {"json binary", []tag{ + { + name: "json", + }, + { + name: "binary", + }, + }, + }, + {"numeric(10, 2)", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + }, + }, + {"numeric(10, 2) notnull", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + { + name: "notnull", + }, + }, + }, } for _, kase := range cases { - tags := splitTag(kase.tag) - if !utils.SliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } + t.Run(kase.tag, func(t *testing.T) { + tags, err := splitTag(kase.tag) + assert.NoError(t, err) + assert.EqualValues(t, len(tags), len(kase.tags)) + for i := 0; i < len(tags); i++ { + assert.Equal(t, tags[i], kase.tags[i]) + } + }) } }