From 00ee06fdd5b83e68eff442a2e2a211a8bc68b9f2 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 12 Jun 2021 20:27:49 +0800 Subject: [PATCH 1/8] Add test for dump table with default value (#1950) Confirm #1391 resolved. Reviewed-on: https://gitea.com/xorm/xorm/pulls/1950 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- integrations/engine_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 9b70f9b5..344e95a8 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -176,6 +176,23 @@ func TestDumpTables(t *testing.T) { } } +func TestDumpTables2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestDumpTableStruct2 struct { + Id int64 + Created time.Time `xorm:"Default CURRENT_TIMESTAMP"` + } + + assertSync(t, new(TestDumpTableStruct2)) + + fp := fmt.Sprintf("./dump2-%v-table.sql", testEngine.Dialect().URI().DBType) + os.Remove(fp) + tb, err := testEngine.TableInfo(new(TestDumpTableStruct2)) + assert.NoError(t, err) + assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, fp)) +} + func TestSetSchema(t *testing.T) { assert.NoError(t, PrepareEngine()) From e1422f183c89cb984b11db108911c170c2ed4ffb Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 12 Jun 2021 20:35:22 +0800 Subject: [PATCH 2/8] Add test to confirm #1247 resolved (#1951) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1951 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- integrations/session_update_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/integrations/session_update_test.go b/integrations/session_update_test.go index 15d2f694..796bfa0a 100644 --- a/integrations/session_update_test.go +++ b/integrations/session_update_test.go @@ -472,6 +472,11 @@ func TestUpdateIncrDecr(t *testing.T) { cnt, err = testEngine.ID(col1.Id).Cols(colName).Incr(colName).Update(col1) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + + testEngine.SetColumnMapper(testEngine.GetColumnMapper()) + cnt, err = testEngine.Cols(colName).Decr(colName, 2).ID(col1.Id).Update(new(UpdateIncr)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) } type UpdatedUpdate struct { From 94614619671583f31a1ab6e1710723cd0ac1ae0d Mon Sep 17 00:00:00 2001 From: knice88 Date: Sat, 12 Jun 2021 22:47:15 +0800 Subject: [PATCH 3/8] fix pg GetColumns missing comment (#1949) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit xorm reverse生成的结构体缺少备注信息 Co-authored-by: chad Reviewed-on: https://gitea.com/xorm/xorm/pulls/1949 Reviewed-by: Lunny Xiao Co-authored-by: knice88 Co-committed-by: knice88 --- dialects/postgres.go | 11 ++++++++--- integrations/engine_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/dialects/postgres.go b/dialects/postgres.go index 52c88567..9acf763a 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1044,12 +1044,13 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} - s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, + s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, description, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey FROM pg_attribute f JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum + LEFT JOIN pg_description de ON f.attrelid=de.objoid AND f.attnum=de.objsubid LEFT JOIN pg_namespace n ON n.oid = c.relnamespace LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid @@ -1078,9 +1079,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A col.Indexes = make(map[string]int) var colName, isNullable, dataType string - var maxLenStr, colDefault *string + var maxLenStr, colDefault, description *string var isPK, isUnique bool - err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &isPK, &isUnique) + err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &description, &isPK, &isUnique) if err != nil { return nil, nil, err } @@ -1126,6 +1127,10 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A col.DefaultIsEmpty = true } + if description != nil { + col.Comment = *description + } + if isPK { col.IsPrimaryKey = true } diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 344e95a8..a06d91aa 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -226,3 +226,39 @@ func TestDBVersion(t *testing.T) { fmt.Println(testEngine.Dialect().URI().DBType, "version is", version) } + +func TestGetColumns(t *testing.T) { + if testEngine.Dialect().URI().DBType != schemas.POSTGRES { + t.Skip() + return + } + type TestCommentStruct struct { + HasComment int + NoComment int + } + + assertSync(t, new(TestCommentStruct)) + + comment := "this is a comment" + sql := fmt.Sprintf("comment on column %s.%s is '%s'", testEngine.TableName(new(TestCommentStruct), true), "has_comment", comment) + _, err := testEngine.Exec(sql) + assert.NoError(t, err) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + tableName := testEngine.GetColumnMapper().Obj2Table("TestCommentStruct") + var hasComment, noComment string + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("has_comment") + assert.NotNil(t, col) + hasComment = col.Comment + col2 := table.GetColumn("no_comment") + assert.NotNil(t, col2) + noComment = col2.Comment + break + } + } + assert.Equal(t, comment, hasComment) + assert.Zero(t, noComment) +} From bc25b4128bc7396d34d1414e6355bec7d1d15a52 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 14 Jun 2021 11:23:05 +0800 Subject: [PATCH 4/8] Fix #1663 (#1952) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1952 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- internal/statements/statement.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/internal/statements/statement.go b/internal/statements/statement.go index a52c6ca2..ca59817b 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -208,20 +208,18 @@ func (statement *Statement) quote(s string) string { // And add Where & and statement func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { - switch query.(type) { + switch qr := query.(type) { case string: - cond := builder.Expr(query.(string), args...) + cond := builder.Expr(qr, args...) statement.cond = statement.cond.And(cond) case map[string]interface{}: - queryMap := query.(map[string]interface{}) - newMap := make(map[string]interface{}) - for k, v := range queryMap { - newMap[statement.quote(k)] = v + cond := make(builder.Eq) + for k, v := range qr { + cond[statement.quote(k)] = v } - statement.cond = statement.cond.And(builder.Eq(newMap)) - case builder.Cond: - cond := query.(builder.Cond) statement.cond = statement.cond.And(cond) + case builder.Cond: + statement.cond = statement.cond.And(qr) for _, v := range args { if vv, ok := v.(builder.Cond); ok { statement.cond = statement.cond.And(vv) @@ -236,23 +234,25 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme // Or add Where & Or statement func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { - switch query.(type) { + switch qr := query.(type) { case string: - cond := builder.Expr(query.(string), args...) + cond := builder.Expr(qr, args...) statement.cond = statement.cond.Or(cond) case map[string]interface{}: - cond := builder.Eq(query.(map[string]interface{})) + cond := make(builder.Eq) + for k, v := range qr { + cond[statement.quote(k)] = v + } statement.cond = statement.cond.Or(cond) case builder.Cond: - cond := query.(builder.Cond) - statement.cond = statement.cond.Or(cond) + statement.cond = statement.cond.Or(qr) for _, v := range args { if vv, ok := v.(builder.Cond); ok { statement.cond = statement.cond.Or(vv) } } default: - // TODO: not support condition type + statement.LastError = ErrConditionType } return statement } From 5a58a272bc86d9281a1766f5681df49d67526bee Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 15 Jun 2021 20:28:49 +0800 Subject: [PATCH 5/8] fix lint (#1953) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1953 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- .revive.toml | 14 ++++++++------ convert.go | 4 ---- integrations/session_insert_test.go | 2 -- internal/statements/expr.go | 1 + internal/statements/query.go | 2 +- internal/statements/statement_test.go | 6 ------ log/logger.go | 9 --------- schemas/table_test.go | 7 ------- session_update.go | 1 - 9 files changed, 10 insertions(+), 36 deletions(-) diff --git a/.revive.toml b/.revive.toml index 6dec7465..9e3b629d 100644 --- a/.revive.toml +++ b/.revive.toml @@ -8,20 +8,22 @@ warningCode = 1 [rule.context-as-argument] [rule.context-keys-type] [rule.dot-imports] +[rule.empty-lines] +[rule.errorf] [rule.error-return] [rule.error-strings] [rule.error-naming] [rule.exported] [rule.if-return] [rule.increment-decrement] -[rule.var-naming] - arguments = [["ID", "UID", "UUID", "URL", "JSON"], []] -[rule.var-declaration] +[rule.indent-error-flow] [rule.package-comments] [rule.range] [rule.receiver-naming] +[rule.struct-tag] [rule.time-naming] [rule.unexported-return] -[rule.indent-error-flow] -[rule.errorf] -[rule.struct-tag] \ No newline at end of file +[rule.unnecessary-stmt] +[rule.var-declaration] +[rule.var-naming] + arguments = [["ID", "UID", "UUID", "URL", "JSON"], []] \ No newline at end of file diff --git a/convert.go b/convert.go index c19d30e0..ee5b6029 100644 --- a/convert.go +++ b/convert.go @@ -416,7 +416,3 @@ func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { } return reflect.ValueOf(v).Elem().Convert(tp) } - -func int64ToInt(id int64, tp reflect.Type) interface{} { - return int64ToIntValue(id, tp).Interface() -} diff --git a/integrations/session_insert_test.go b/integrations/session_insert_test.go index eaa1b2c7..e5d880ae 100644 --- a/integrations/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -32,7 +32,6 @@ func TestInsertOne(t *testing.T) { } func TestInsertMulti(t *testing.T) { - assert.NoError(t, PrepareEngine()) type TestMulti struct { Id int64 `xorm:"int(11) pk"` @@ -78,7 +77,6 @@ func insertMultiDatas(step int, datas interface{}) (num int64, err error) { } func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) error) (err error) { - sliceValue := reflect.Indirect(reflect.ValueOf(datas)) if sliceValue.Kind() != reflect.Slice { return fmt.Errorf("not slice") diff --git a/internal/statements/expr.go b/internal/statements/expr.go index b44c96ca..c2a2e1cc 100644 --- a/internal/statements/expr.go +++ b/internal/statements/expr.go @@ -27,6 +27,7 @@ type Expr struct { Arg interface{} } +// WriteArgs writes args to the writer func (expr *Expr) WriteArgs(w *builder.BytesWriter) error { switch arg := expr.Arg.(type) { case *builder.Builder: diff --git a/internal/statements/query.go b/internal/statements/query.go index e1091e9f..a972a8e0 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -343,7 +343,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac var args []interface{} var joinStr string var err error - var b interface{} = nil + var b interface{} if len(bean) > 0 { b = bean[0] beanValue := reflect.ValueOf(bean[0]) diff --git a/internal/statements/statement_test.go b/internal/statements/statement_test.go index 15f446f4..ba92330e 100644 --- a/internal/statements/statement_test.go +++ b/internal/statements/statement_test.go @@ -78,7 +78,6 @@ func TestColumnsStringGeneration(t *testing.T) { } func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { - b.StopTimer() mapCols := make(map[string]bool) @@ -101,9 +100,7 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - for _, col := range cols { - if _, ok := getFlagForColumn(mapCols, col); !ok { b.Fatal("Unexpected result") } @@ -112,7 +109,6 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { } func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { - b.StopTimer() mapCols := make(map[string]bool) @@ -131,9 +127,7 @@ func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - for _, col := range cols { - if _, ok := getFlagForColumn(mapCols, col); ok { b.Fatal("Unexpected result") } diff --git a/log/logger.go b/log/logger.go index eeb63693..3b6db34e 100644 --- a/log/logger.go +++ b/log/logger.go @@ -132,7 +132,6 @@ func (s *SimpleLogger) Error(v ...interface{}) { if s.level <= LOG_ERR { s.ERR.Output(2, fmt.Sprintln(v...)) } - return } // Errorf implement ILogger @@ -140,7 +139,6 @@ func (s *SimpleLogger) Errorf(format string, v ...interface{}) { if s.level <= LOG_ERR { s.ERR.Output(2, fmt.Sprintf(format, v...)) } - return } // Debug implement ILogger @@ -148,7 +146,6 @@ func (s *SimpleLogger) Debug(v ...interface{}) { if s.level <= LOG_DEBUG { s.DEBUG.Output(2, fmt.Sprintln(v...)) } - return } // Debugf implement ILogger @@ -156,7 +153,6 @@ func (s *SimpleLogger) Debugf(format string, v ...interface{}) { if s.level <= LOG_DEBUG { s.DEBUG.Output(2, fmt.Sprintf(format, v...)) } - return } // Info implement ILogger @@ -164,7 +160,6 @@ func (s *SimpleLogger) Info(v ...interface{}) { if s.level <= LOG_INFO { s.INFO.Output(2, fmt.Sprintln(v...)) } - return } // Infof implement ILogger @@ -172,7 +167,6 @@ func (s *SimpleLogger) Infof(format string, v ...interface{}) { if s.level <= LOG_INFO { s.INFO.Output(2, fmt.Sprintf(format, v...)) } - return } // Warn implement ILogger @@ -180,7 +174,6 @@ func (s *SimpleLogger) Warn(v ...interface{}) { if s.level <= LOG_WARNING { s.WARN.Output(2, fmt.Sprintln(v...)) } - return } // Warnf implement ILogger @@ -188,7 +181,6 @@ func (s *SimpleLogger) Warnf(format string, v ...interface{}) { if s.level <= LOG_WARNING { s.WARN.Output(2, fmt.Sprintf(format, v...)) } - return } // Level implement ILogger @@ -199,7 +191,6 @@ func (s *SimpleLogger) Level() LogLevel { // SetLevel implement ILogger func (s *SimpleLogger) SetLevel(l LogLevel) { s.level = l - return } // ShowSQL implement ILogger diff --git a/schemas/table_test.go b/schemas/table_test.go index 9bf10e33..0e35193f 100644 --- a/schemas/table_test.go +++ b/schemas/table_test.go @@ -27,7 +27,6 @@ var testsGetColumn = []struct { var table *Table func init() { - table = NewEmptyTable() var name string @@ -41,7 +40,6 @@ func init() { } func TestGetColumn(t *testing.T) { - for _, test := range testsGetColumn { if table.GetColumn(test.name) == nil { t.Error("Column not found!") @@ -50,7 +48,6 @@ func TestGetColumn(t *testing.T) { } func TestGetColumnIdx(t *testing.T) { - for _, test := range testsGetColumn { if table.GetColumnIdx(test.name, test.idx) == nil { t.Errorf("Column %s with idx %d not found!", test.name, test.idx) @@ -59,7 +56,6 @@ func TestGetColumnIdx(t *testing.T) { } func BenchmarkGetColumnWithToLower(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { @@ -71,7 +67,6 @@ func BenchmarkGetColumnWithToLower(b *testing.B) { } func BenchmarkGetColumnIdxWithToLower(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { @@ -89,7 +84,6 @@ func BenchmarkGetColumnIdxWithToLower(b *testing.B) { } func BenchmarkGetColumn(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { if table.GetColumn(test.name) == nil { @@ -100,7 +94,6 @@ func BenchmarkGetColumn(b *testing.B) { } func BenchmarkGetColumnIdx(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { if table.GetColumnIdx(test.name, test.idx) == nil { diff --git a/session_update.go b/session_update.go index f791bb2d..d96226da 100644 --- a/session_update.go +++ b/session_update.go @@ -457,7 +457,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 // FIXME: if bean is a map type, it will panic because map cannot be as map key session.afterUpdateBeans[bean] = &afterClosures } - } else { if _, ok := interface{}(bean).(AfterUpdateProcessor); ok { session.afterUpdateBeans[bean] = nil From 44f892fddca72e496e13e947cf4c28e2348bd2ba Mon Sep 17 00:00:00 2001 From: antialiasis Date: Sat, 26 Jun 2021 19:19:13 +0800 Subject: [PATCH 6/8] Ignore comments when deciding when to replace question marks. #1954 (#1955) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This should solve #1954 and adds some tests for it. I will note I'm not 100% clear on whether there are other edge cases that should be covered here. From what I understand the only standard SQL way to escape single quotes is to double them, which shouldn't cause any problems with this, but if some SQL flavors allow other kinds of escaping, for instance, that would probably need to be covered too for ideal results. Co-authored-by: Hlín Önnudóttir Co-authored-by: Lunny Xiao Reviewed-on: https://gitea.com/xorm/xorm/pulls/1955 Reviewed-by: Lunny Xiao Co-authored-by: antialiasis Co-committed-by: antialiasis --- dialects/filter.go | 36 ++++++++++++++++++++++++-- dialects/filter_test.go | 57 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/dialects/filter.go b/dialects/filter.go index 2a36a731..bfe2e93e 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -23,13 +23,45 @@ type SeqFilter struct { func convertQuestionMark(sql, prefix string, start int) string { var buf strings.Builder var beginSingleQuote bool + var isLineComment bool + var isComment bool + var isMaybeLineComment bool + var isMaybeComment bool + var isMaybeCommentEnd bool var index = start for _, c := range sql { - if !beginSingleQuote && c == '?' { + if !beginSingleQuote && !isLineComment && !isComment && c == '?' { buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) index++ } else { - if c == '\'' { + if isMaybeLineComment { + if c == '-' { + isLineComment = true + } + isMaybeLineComment = false + } else if isMaybeComment { + if c == '*' { + isComment = true + } + isMaybeComment = false + } else if isMaybeCommentEnd { + if c == '/' { + isComment = false + } + isMaybeCommentEnd = false + } else if isLineComment { + if c == '\n' { + isLineComment = false + } + } else if isComment { + if c == '*' { + isMaybeCommentEnd = true + } + } else if !beginSingleQuote && c == '-' { + isMaybeLineComment = true + } else if !beginSingleQuote && c == '/' { + isMaybeComment = true + } else if c == '\'' { beginSingleQuote = !beginSingleQuote } buf.WriteRune(c) diff --git a/dialects/filter_test.go b/dialects/filter_test.go index 7e2ef0a2..15050656 100644 --- a/dialects/filter_test.go +++ b/dialects/filter_test.go @@ -19,3 +19,60 @@ func TestSeqFilter(t *testing.T) { assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) } } + +func TestSeqFilterLineComment(t *testing.T) { + var kases = map[string]string{ + `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=? -- it's a comment + AND b=?`: `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=$1 -- it's a comment + AND b=$2`, + `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=? -- it's a comment? + AND b=?`: `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=$1 -- it's a comment? + AND b=$2`, + `SELECT * + FROM TABLE1 + WHERE a=? -- it's a comment? and that's okay? + AND b=?`: `SELECT * + FROM TABLE1 + WHERE a=$1 -- it's a comment? and that's okay? + AND b=$2`, + } + for sql, result := range kases { + assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + } +} + +func TestSeqFilterComment(t *testing.T) { + var kases = map[string]string{ + `SELECT * + FROM TABLE1 + WHERE a=? /* it's a comment */ + AND b=?`: `SELECT * + FROM TABLE1 + WHERE a=$1 /* it's a comment */ + AND b=$2`, + `SELECT /* it's a comment * ? + More comment on the next line! */ * + FROM TABLE1 + WHERE a=? /**/ + AND b=?`: `SELECT /* it's a comment * ? + More comment on the next line! */ * + FROM TABLE1 + WHERE a=$1 /**/ + AND b=$2`, + } + for sql, result := range kases { + assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + } +} From 053a0947404ffb51ad35184676e461c29b06e06d Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 28 Jun 2021 22:41:54 +0800 Subject: [PATCH 7/8] refactor splitTag function (#1960) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1960 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- .gitignore | 3 +- tags/parser.go | 330 +++++++++++++++---------------- tags/parser_test.go | 458 +++++++++++++++++++++++++++++++++++++++++++- tags/tag.go | 147 +++++++++----- tags/tag_test.go | 79 +++++++- 5 files changed, 774 insertions(+), 243 deletions(-) diff --git a/.gitignore b/.gitignore index a3fbadd4..a183a295 100644 --- a/.gitignore +++ b/.gitignore @@ -36,4 +36,5 @@ test.db.sql *coverage.out test.db integrations/*.sql -integrations/test_sqlite* \ No newline at end of file +integrations/test_sqlite* +cover.out \ No newline at end of file diff --git a/tags/parser.go b/tags/parser.go index ff329daa..599e9e0e 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -7,7 +7,6 @@ package tags import ( "encoding/gob" "errors" - "fmt" "reflect" "strings" "sync" @@ -23,7 +22,7 @@ import ( var ( // ErrUnsupportedType represents an unsupported type error - ErrUnsupportedType = errors.New("Unsupported type") + ErrUnsupportedType = errors.New("unsupported type") ) // Parser represents a parser for xorm tag @@ -125,6 +124,145 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index } } +var ErrIgnoreField = errors.New("field will be ignored") + +func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { + var sqlType schemas.SQLType + if fieldValue.CanAddr() { + if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } + } + if _, ok := fieldValue.Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } else { + sqlType = schemas.Type2SQLType(field.Type) + } + col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name), + field.Name, sqlType, sqlType.DefaultLength, + sqlType.DefaultLength2, true) + + if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { + col.IsAutoIncrement = true + col.IsPrimaryKey = true + col.Nullable = false + } + return col, nil +} + +func (parser *Parser) parseFieldWithTags(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) { + var col = &schemas.Column{ + FieldName: field.Name, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: schemas.TWOSIDES, + Indexes: make(map[string]int), + DefaultIsEmpty: true, + } + + var ctx = Context{ + table: table, + col: col, + fieldValue: fieldValue, + indexNames: make(map[string]int), + parser: parser, + } + + for j, tag := range tags { + if ctx.ignoreNext { + ctx.ignoreNext = false + continue + } + + ctx.tag = tag + ctx.tagUname = strings.ToUpper(tag.name) + + if j > 0 { + ctx.preTag = strings.ToUpper(tags[j-1].name) + } + if j < len(tags)-1 { + ctx.nextTag = tags[j+1].name + } else { + ctx.nextTag = "" + } + + if h, ok := parser.handlers[ctx.tagUname]; ok { + if err := h(&ctx); err != nil { + return nil, err + } + } else { + if strings.HasPrefix(ctx.tag.name, "'") && strings.HasSuffix(ctx.tag.name, "'") { + col.Name = ctx.tag.name[1 : len(ctx.tag.name)-1] + } else { + col.Name = ctx.tag.name + } + } + + if ctx.hasCacheTag { + if parser.cacherMgr.GetDefaultCacher() != nil { + parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher()) + } else { + parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)) + } + } + if ctx.hasNoCacheTag { + parser.cacherMgr.SetCacher(table.Name, nil) + } + } + + if col.SQLType.Name == "" { + col.SQLType = schemas.Type2SQLType(field.Type) + } + parser.dialect.SQLType(col) + if col.Length == 0 { + col.Length = col.SQLType.DefaultLength + } + if col.Length2 == 0 { + col.Length2 = col.SQLType.DefaultLength2 + } + if col.Name == "" { + col.Name = parser.columnMapper.Obj2Table(field.Name) + } + + if ctx.isUnique { + ctx.indexNames[col.Name] = schemas.UniqueType + } else if ctx.isIndex { + ctx.indexNames[col.Name] = schemas.IndexType + } + + for indexName, indexType := range ctx.indexNames { + addIndex(indexName, table, col, indexType) + } + + return col, nil +} + +func (parser *Parser) parseField(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { + var ( + tag = field.Tag + ormTagStr = strings.TrimSpace(tag.Get(parser.identifier)) + ) + if ormTagStr == "-" { + return nil, ErrIgnoreField + } + if ormTagStr == "" { + return parser.parseFieldWithNoTag(field, fieldValue) + } + tags, err := splitTag(ormTagStr) + if err != nil { + return nil, err + } + return parser.parseFieldWithTags(table, field, fieldValue, tags) +} + +func isNotTitle(n string) bool { + for _, c := range n { + return unicode.IsLower(c) + } + return true +} + // Parse parses a struct as a table information func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { t := v.Type() @@ -140,193 +278,25 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { table.Type = t table.Name = names.GetTableName(parser.tableMapper, v) - var idFieldColName string - var hasCacheTag, hasNoCacheTag bool - for i := 0; i < t.NumField(); i++ { - var isUnexportField bool - for _, c := range t.Field(i).Name { - if unicode.IsLower(c) { - isUnexportField = true - } - break - } - if isUnexportField { + if isNotTitle(t.Field(i).Name) { continue } - tag := t.Field(i).Tag - ormTagStr := tag.Get(parser.identifier) - var col *schemas.Column - fieldValue := v.Field(i) - fieldType := fieldValue.Type() + var ( + field = t.Field(i) + fieldValue = v.Field(i) + ) - if ormTagStr != "" { - col = &schemas.Column{ - FieldName: t.Field(i).Name, - Nullable: true, - IsPrimaryKey: false, - IsAutoIncrement: false, - MapType: schemas.TWOSIDES, - Indexes: make(map[string]int), - DefaultIsEmpty: true, - } - tags := splitTag(ormTagStr) - - if len(tags) > 0 { - if tags[0] == "-" { - continue - } - - var ctx = Context{ - table: table, - col: col, - fieldValue: fieldValue, - indexNames: make(map[string]int), - parser: parser, - } - - if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") { - pStart := strings.Index(tags[0], "(") - if pStart > -1 && strings.HasSuffix(tags[0], ")") { - var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool { - return r == '\'' || r == '"' - }) - - ctx.params = []string{tagPrefix} - } - - if err := ExtendsTagHandler(&ctx); err != nil { - return nil, err - } - continue - } - - for j, key := range tags { - if ctx.ignoreNext { - ctx.ignoreNext = false - continue - } - - k := strings.ToUpper(key) - ctx.tagName = k - ctx.params = []string{} - - pStart := strings.Index(k, "(") - if pStart == 0 { - return nil, errors.New("( could not be the first character") - } - if pStart > -1 { - if !strings.HasSuffix(k, ")") { - return nil, fmt.Errorf("field %s tag %s cannot match ) character", col.FieldName, key) - } - - ctx.tagName = k[:pStart] - ctx.params = strings.Split(key[pStart+1:len(k)-1], ",") - } - - if j > 0 { - ctx.preTag = strings.ToUpper(tags[j-1]) - } - if j < len(tags)-1 { - ctx.nextTag = tags[j+1] - } else { - ctx.nextTag = "" - } - - if h, ok := parser.handlers[ctx.tagName]; ok { - if err := h(&ctx); err != nil { - return nil, err - } - } else { - if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") { - col.Name = key[1 : len(key)-1] - } else { - col.Name = key - } - } - - if ctx.hasCacheTag { - hasCacheTag = true - } - if ctx.hasNoCacheTag { - hasNoCacheTag = true - } - } - - if col.SQLType.Name == "" { - col.SQLType = schemas.Type2SQLType(fieldType) - } - parser.dialect.SQLType(col) - if col.Length == 0 { - col.Length = col.SQLType.DefaultLength - } - if col.Length2 == 0 { - col.Length2 = col.SQLType.DefaultLength2 - } - if col.Name == "" { - col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name) - } - - if ctx.isUnique { - ctx.indexNames[col.Name] = schemas.UniqueType - } else if ctx.isIndex { - ctx.indexNames[col.Name] = schemas.IndexType - } - - for indexName, indexType := range ctx.indexNames { - addIndex(indexName, table, col, indexType) - } - } - } else { - var sqlType schemas.SQLType - if fieldValue.CanAddr() { - if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - sqlType = schemas.SQLType{Name: schemas.Text} - } - } - if _, ok := fieldValue.Interface().(convert.Conversion); ok { - sqlType = schemas.SQLType{Name: schemas.Text} - } else { - sqlType = schemas.Type2SQLType(fieldType) - } - col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name), - t.Field(i).Name, sqlType, sqlType.DefaultLength, - sqlType.DefaultLength2, true) - - if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { - idFieldColName = col.Name - } - } - if col.IsAutoIncrement { - col.Nullable = false + col, err := parser.parseField(table, field, fieldValue) + if err == ErrIgnoreField { + continue + } else if err != nil { + return nil, err } table.AddColumn(col) } // end for - if idFieldColName != "" && len(table.PrimaryKeys) == 0 { - col := table.GetColumn(idFieldColName) - col.IsPrimaryKey = true - col.IsAutoIncrement = true - col.Nullable = false - table.PrimaryKeys = append(table.PrimaryKeys, col.Name) - table.AutoIncrement = col.Name - } - - if hasCacheTag { - if parser.cacherMgr.GetDefaultCacher() != nil { // !nash! use engine's cacher if provided - //engine.logger.Info("enable cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher()) - } else { - //engine.logger.Info("enable LRU cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)) - } - } - if hasNoCacheTag { - //engine.logger.Info("disable cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, nil) - } - return table, nil } diff --git a/tags/parser_test.go b/tags/parser_test.go index 5add1e13..70c57692 100644 --- a/tags/parser_test.go +++ b/tags/parser_test.go @@ -6,12 +6,16 @@ package tags import ( "reflect" + "strings" "testing" + "time" - "github.com/stretchr/testify/assert" "xorm.io/xorm/caches" "xorm.io/xorm/dialects" "xorm.io/xorm/names" + "xorm.io/xorm/schemas" + + "github.com/stretchr/testify/assert" ) type ParseTableName1 struct{} @@ -80,7 +84,7 @@ func TestParseWithOtherIdentifier(t *testing.T) { parser := NewParser( "xorm", dialects.QueryDialect("mysql"), - names.GonicMapper{}, + names.SameMapper{}, names.SnakeMapper{}, caches.NewManager(), ) @@ -88,13 +92,461 @@ func TestParseWithOtherIdentifier(t *testing.T) { type StructWithDBTag struct { FieldFoo string `db:"foo"` } + parser.SetIdentifier("db") table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag))) assert.NoError(t, err) - assert.EqualValues(t, "struct_with_db_tag", table.Name) + assert.EqualValues(t, "StructWithDBTag", table.Name) assert.EqualValues(t, 1, len(table.Columns())) for _, col := range table.Columns() { assert.EqualValues(t, "foo", col.Name) } } + +func TestParseWithIgnore(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SameMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithIgnoreTag struct { + FieldFoo string `db:"-"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithIgnoreTag))) + assert.NoError(t, err) + assert.EqualValues(t, "StructWithIgnoreTag", table.Name) + assert.EqualValues(t, 0, len(table.Columns())) +} + +func TestParseWithAutoincrement(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithAutoIncrement struct { + ID int64 + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_auto_increment", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "id", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsAutoIncrement) + assert.True(t, table.Columns()[0].IsPrimaryKey) +} + +func TestParseWithAutoincrement2(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithAutoIncrement2 struct { + ID int64 `db:"pk autoincr"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement2))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_auto_increment2", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "id", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsAutoIncrement) + assert.True(t, table.Columns()[0].IsPrimaryKey) + assert.False(t, table.Columns()[0].Nullable) +} + +func TestParseWithNullable(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithNullable struct { + Name string `db:"notnull"` + FullName string `db:"null comment('column comment,字段注释')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithNullable))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_nullable", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "full_name", table.Columns()[1].Name) + assert.False(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.EqualValues(t, "column comment,字段注释", table.Columns()[1].Comment) +} + +func TestParseWithTimes(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithTimes struct { + Name string `db:"notnull"` + CreatedAt time.Time `db:"created"` + UpdatedAt time.Time `db:"updated"` + DeletedAt time.Time `db:"deleted"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithTimes))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_times", table.Name) + assert.EqualValues(t, 4, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "created_at", table.Columns()[1].Name) + assert.EqualValues(t, "updated_at", table.Columns()[2].Name) + assert.EqualValues(t, "deleted_at", table.Columns()[3].Name) + assert.False(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsCreated) + assert.True(t, table.Columns()[2].Nullable) + assert.True(t, table.Columns()[2].IsUpdated) + assert.True(t, table.Columns()[3].Nullable) + assert.True(t, table.Columns()[3].IsDeleted) +} + +func TestParseWithExtends(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithEmbed struct { + Name string + CreatedAt time.Time `db:"created"` + UpdatedAt time.Time `db:"updated"` + DeletedAt time.Time `db:"deleted"` + } + + type StructWithExtends struct { + SW StructWithEmbed `db:"extends"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithExtends))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_extends", table.Name) + assert.EqualValues(t, 4, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "created_at", table.Columns()[1].Name) + assert.EqualValues(t, "updated_at", table.Columns()[2].Name) + assert.EqualValues(t, "deleted_at", table.Columns()[3].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsCreated) + assert.True(t, table.Columns()[2].Nullable) + assert.True(t, table.Columns()[2].IsUpdated) + assert.True(t, table.Columns()[3].Nullable) + assert.True(t, table.Columns()[3].IsDeleted) +} + +func TestParseWithCache(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithCache struct { + Name string `db:"cache"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithCache))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_cache", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + cacher := parser.cacherMgr.GetCacher(table.Name) + assert.NotNil(t, cacher) +} + +func TestParseWithNoCache(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithNoCache struct { + Name string `db:"nocache"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithNoCache))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_no_cache", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + cacher := parser.cacherMgr.GetCacher(table.Name) + assert.Nil(t, cacher) +} + +func TestParseWithEnum(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithEnum struct { + Name string `db:"enum('alice', 'bob')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithEnum))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_enum", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.EqualValues(t, schemas.Enum, strings.ToUpper(table.Columns()[0].SQLType.Name)) + assert.EqualValues(t, map[string]int{ + "alice": 0, + "bob": 1, + }, table.Columns()[0].EnumOptions) +} + +func TestParseWithSet(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithSet struct { + Name string `db:"set('alice', 'bob')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithSet))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_set", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.EqualValues(t, schemas.Set, strings.ToUpper(table.Columns()[0].SQLType.Name)) + assert.EqualValues(t, map[string]int{ + "alice": 0, + "bob": 1, + }, table.Columns()[0].SetOptions) +} + +func TestParseWithIndex(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithIndex struct { + Name string `db:"index"` + Name2 string `db:"index(s)"` + Name3 string `db:"unique"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithIndex))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_index", table.Name) + assert.EqualValues(t, 3, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "name2", table.Columns()[1].Name) + assert.EqualValues(t, "name3", table.Columns()[2].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[2].Nullable) + assert.EqualValues(t, 1, len(table.Columns()[0].Indexes)) + assert.EqualValues(t, 1, len(table.Columns()[1].Indexes)) + assert.EqualValues(t, 1, len(table.Columns()[2].Indexes)) +} + +func TestParseWithVersion(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithVersion struct { + Name string + Version int `db:"version"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithVersion))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_version", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "version", table.Columns()[1].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsVersion) +} + +func TestParseWithLocale(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithLocale struct { + UTCLocale time.Time `db:"utc"` + LocalLocale time.Time `db:"local"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithLocale))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_locale", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "utc_locale", table.Columns()[0].Name) + assert.EqualValues(t, "local_locale", table.Columns()[1].Name) + assert.EqualValues(t, time.UTC, table.Columns()[0].TimeZone) + assert.EqualValues(t, time.Local, table.Columns()[1].TimeZone) +} + +func TestParseWithDefault(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithDefault struct { + Default1 time.Time `db:"default '1970-01-01 00:00:00'"` + Default2 time.Time `db:"default(CURRENT_TIMESTAMP)"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithDefault))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_default", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.EqualValues(t, "default2", table.Columns()[1].Name) + assert.EqualValues(t, "'1970-01-01 00:00:00'", table.Columns()[0].Default) + assert.EqualValues(t, "CURRENT_TIMESTAMP", table.Columns()[1].Default) +} + +func TestParseWithOnlyToDB(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "DB": true, + }, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithOnlyToDB struct { + Default1 time.Time `db:"->"` + Default2 time.Time `db:"<-"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithOnlyToDB))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_only_to_db", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.EqualValues(t, "default2", table.Columns()[1].Name) + assert.EqualValues(t, schemas.ONLYTODB, table.Columns()[0].MapType) + assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType) +} + +func TestParseWithJSON(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "JSON": true, + }, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithJSON struct { + Default1 []string `db:"json"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithJSON))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_json", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsJSON) +} + +func TestParseWithSQLType(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "SQL": true, + }, + names.GonicMapper{ + "UUID": true, + }, + caches.NewManager(), + ) + + type StructWithSQLType struct { + Col1 string `db:"varchar(32)"` + Col2 string `db:"char(32)"` + Int int64 `db:"bigint"` + DateTime time.Time `db:"datetime"` + UUID string `db:"uuid"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithSQLType))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_sql_type", table.Name) + assert.EqualValues(t, 5, len(table.Columns())) + assert.EqualValues(t, "col1", table.Columns()[0].Name) + assert.EqualValues(t, "col2", table.Columns()[1].Name) + assert.EqualValues(t, "int", table.Columns()[2].Name) + assert.EqualValues(t, "date_time", table.Columns()[3].Name) + assert.EqualValues(t, "uuid", table.Columns()[4].Name) + + assert.EqualValues(t, "VARCHAR", table.Columns()[0].SQLType.Name) + assert.EqualValues(t, "CHAR", table.Columns()[1].SQLType.Name) + assert.EqualValues(t, "BIGINT", table.Columns()[2].SQLType.Name) + assert.EqualValues(t, "DATETIME", table.Columns()[3].SQLType.Name) + assert.EqualValues(t, "UUID", table.Columns()[4].SQLType.Name) +} diff --git a/tags/tag.go b/tags/tag.go index bb5b5838..d8d9bb46 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -14,30 +14,74 @@ import ( "xorm.io/xorm/schemas" ) -func splitTag(tag string) (tags []string) { - tag = strings.TrimSpace(tag) - var hasQuote = false - var lastIdx = 0 - for i, t := range tag { - if t == '\'' { - hasQuote = !hasQuote - } else if t == ' ' { - if lastIdx < i && !hasQuote { - tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) - lastIdx = i + 1 +type tag struct { + name string + params []string +} + +func splitTag(tagStr string) ([]tag, error) { + tagStr = strings.TrimSpace(tagStr) + var ( + inQuote bool + inBigQuote bool + lastIdx int + curTag tag + paramStart int + tags []tag + ) + for i, t := range tagStr { + switch t { + case '\'': + inQuote = !inQuote + case ' ': + if !inQuote && !inBigQuote { + if lastIdx < i { + if curTag.name == "" { + curTag.name = tagStr[lastIdx:i] + } + tags = append(tags, curTag) + lastIdx = i + 1 + curTag = tag{} + } else if lastIdx == i { + lastIdx = i + 1 + } + } else if inBigQuote && !inQuote { + paramStart = i + 1 + } + case ',': + if !inQuote && !inBigQuote { + return nil, fmt.Errorf("comma[%d] of %s should be in quote or big quote", i, tagStr) + } + if !inQuote && inBigQuote { + curTag.params = append(curTag.params, strings.TrimSpace(tagStr[paramStart:i])) + paramStart = i + 1 + } + case '(': + inBigQuote = true + if !inQuote { + curTag.name = tagStr[lastIdx:i] + paramStart = i + 1 + } + case ')': + inBigQuote = false + if !inQuote { + curTag.params = append(curTag.params, tagStr[paramStart:i]) } } } - if lastIdx < len(tag) { - tags = append(tags, strings.TrimSpace(tag[lastIdx:])) + if lastIdx < len(tagStr) { + if curTag.name == "" { + curTag.name = tagStr[lastIdx:] + } + tags = append(tags, curTag) } - return + return tags, nil } // Context represents a context for xorm tag parse. type Context struct { - tagName string - params []string + tag + tagUname string preTag, nextTag string table *schemas.Table col *schemas.Column @@ -76,6 +120,7 @@ var ( "CACHE": CacheTagHandler, "NOCACHE": NoCacheTagHandler, "COMMENT": CommentTagHandler, + "EXTENDS": ExtendsTagHandler, } ) @@ -124,6 +169,7 @@ func NotNullTagHandler(ctx *Context) error { // AutoIncrTagHandler describes autoincr tag handler func AutoIncrTagHandler(ctx *Context) error { ctx.col.IsAutoIncrement = true + ctx.col.Nullable = false /* if len(ctx.params) > 0 { autoStartInt, err := strconv.Atoi(ctx.params[0]) @@ -225,41 +271,44 @@ func CommentTagHandler(ctx *Context) error { // SQLTypeTagHandler describes SQL Type tag handler func SQLTypeTagHandler(ctx *Context) error { - ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName} - if strings.EqualFold(ctx.tagName, "JSON") { + ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} + if ctx.tagUname == "JSON" { ctx.col.IsJSON = true } - if len(ctx.params) > 0 { - if ctx.tagName == schemas.Enum { - ctx.col.EnumOptions = make(map[string]int) - for k, v := range ctx.params { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - ctx.col.EnumOptions[v] = k + if len(ctx.params) == 0 { + return nil + } + + switch ctx.tagUname { + case schemas.Enum: + ctx.col.EnumOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.EnumOptions[v] = k + } + case schemas.Set: + ctx.col.SetOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.SetOptions[v] = k + } + default: + var err error + if len(ctx.params) == 2 { + ctx.col.Length, err = strconv.Atoi(ctx.params[0]) + if err != nil { + return err } - } else if ctx.tagName == schemas.Set { - ctx.col.SetOptions = make(map[string]int) - for k, v := range ctx.params { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - ctx.col.SetOptions[v] = k + ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) + if err != nil { + return err } - } else { - var err error - if len(ctx.params) == 2 { - ctx.col.Length, err = strconv.Atoi(ctx.params[0]) - if err != nil { - return err - } - ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) - if err != nil { - return err - } - } else if len(ctx.params) == 1 { - ctx.col.Length, err = strconv.Atoi(ctx.params[0]) - if err != nil { - return err - } + } else if len(ctx.params) == 1 { + ctx.col.Length, err = strconv.Atoi(ctx.params[0]) + if err != nil { + return err } } } @@ -293,7 +342,7 @@ func ExtendsTagHandler(ctx *Context) error { var tagPrefix = ctx.col.FieldName if len(ctx.params) > 0 { col.Nullable = isPtr - tagPrefix = ctx.params[0] + tagPrefix = strings.Trim(ctx.params[0], "'") if col.IsPrimaryKey { col.Name = ctx.col.FieldName col.IsPrimaryKey = false @@ -315,7 +364,7 @@ func ExtendsTagHandler(ctx *Context) error { default: //TODO: warning } - return nil + return ErrIgnoreField } // CacheTagHandler describes cache tag handler diff --git a/tags/tag_test.go b/tags/tag_test.go index 5775b40a..3ceeefd1 100644 --- a/tags/tag_test.go +++ b/tags/tag_test.go @@ -7,24 +7,83 @@ package tags import ( "testing" - "xorm.io/xorm/internal/utils" + "github.com/stretchr/testify/assert" ) func TestSplitTag(t *testing.T) { var cases = []struct { tag string - tags []string + tags []tag }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ + { + name: "not", + }, + { + name: "null", + }, + { + name: "default", + }, + { + name: "'2000-01-01 00:00:00'", + }, + { + name: "TIMESTAMP", + }, + }, + }, + {"TEXT", []tag{ + { + name: "TEXT", + }, + }, + }, + {"default('2000-01-01 00:00:00')", []tag{ + { + name: "default", + params: []string{ + "'2000-01-01 00:00:00'", + }, + }, + }, + }, + {"json binary", []tag{ + { + name: "json", + }, + { + name: "binary", + }, + }, + }, + {"numeric(10, 2)", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + }, + }, + {"numeric(10, 2) notnull", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + { + name: "notnull", + }, + }, + }, } for _, kase := range cases { - tags := splitTag(kase.tag) - if !utils.SliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } + t.Run(kase.tag, func(t *testing.T) { + tags, err := splitTag(kase.tag) + assert.NoError(t, err) + assert.EqualValues(t, len(tags), len(kase.tags)) + for i := 0; i < len(tags); i++ { + assert.Equal(t, tags[i], kase.tags[i]) + } + }) } } From 8f8195a86b7ae503a7250b0a92a47b3d35b32488 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 29 Jun 2021 14:32:29 +0800 Subject: [PATCH 8/8] Improve get field value of bean (#1961) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1961 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- .drone.yml | 4 +-- convert.go | 52 ++++---------------------------- engine.go | 6 +--- internal/statements/statement.go | 2 ++ internal/statements/update.go | 3 ++ schemas/column.go | 43 ++++++-------------------- schemas/table.go | 19 +----------- session.go | 3 ++ session_convert.go | 10 +----- session_insert.go | 12 +++----- session_update.go | 13 +++----- tags/parser.go | 22 ++++++-------- tags/tag.go | 1 + 13 files changed, 49 insertions(+), 141 deletions(-) diff --git a/.drone.yml b/.drone.yml index 9b4ffe9a..4f84d7fa 100644 --- a/.drone.yml +++ b/.drone.yml @@ -249,11 +249,11 @@ volumes: services: - name: mssql pull: always - image: microsoft/mssql-server-linux:latest + image: mcr.microsoft.com/mssql/server:latest environment: ACCEPT_EULA: Y SA_PASSWORD: yourStrong(!)Password - MSSQL_PID: Developer + MSSQL_PID: Standard --- kind: pipeline diff --git a/convert.go b/convert.go index ee5b6029..b7f30cad 100644 --- a/convert.go +++ b/convert.go @@ -175,7 +175,10 @@ func convertAssign(dest, src interface{}) error { return nil } - dpv := reflect.ValueOf(dest) + return convertAssignV(reflect.ValueOf(dest), src) +} + +func convertAssignV(dpv reflect.Value, src interface{}) error { if dpv.Kind() != reflect.Ptr { return errors.New("destination not a pointer") } @@ -183,9 +186,7 @@ func convertAssign(dest, src interface{}) error { return errNilPtr } - if !sv.IsValid() { - sv = reflect.ValueOf(src) - } + var sv = reflect.ValueOf(src) dv := reflect.Indirect(dpv) if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { @@ -244,7 +245,7 @@ func convertAssign(dest, src interface{}) error { return nil } - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dpv.Interface()) } func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { @@ -375,44 +376,3 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) { } return v.Interface(), nil } - -func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { - var v interface{} - kind := tp.Kind() - - if kind == reflect.Ptr { - kind = tp.Elem().Kind() - } - - switch kind { - case reflect.Int16: - temp := int16(id) - v = &temp - case reflect.Int32: - temp := int32(id) - v = &temp - case reflect.Int: - temp := int(id) - v = &temp - case reflect.Int64: - temp := id - v = &temp - case reflect.Uint16: - temp := uint16(id) - v = &temp - case reflect.Uint32: - temp := uint32(id) - v = &temp - case reflect.Uint64: - temp := uint64(id) - v = &temp - case reflect.Uint: - temp := uint(id) - v = &temp - } - - if tp.Kind() == reflect.Ptr { - return reflect.ValueOf(v).Convert(tp) - } - return reflect.ValueOf(v).Elem().Convert(tp) -} diff --git a/engine.go b/engine.go index 649ec1a2..76ce8f1a 100644 --- a/engine.go +++ b/engine.go @@ -652,11 +652,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return errors.New("unknown column error") } - fields := strings.Split(col.FieldName, ".") - field := dataStruct - for _, fieldName := range fields { - field = field.FieldByName(fieldName) - } + field := dataStruct.FieldByIndex(col.FieldIndex) temp += "," + formatColumnValue(dstDialect, field.Interface(), col) } _, err = io.WriteString(w, temp[1:]+");\n") diff --git a/internal/statements/statement.go b/internal/statements/statement.go index ca59817b..b1a5ed3c 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -734,6 +734,8 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, //engine.logger.Warn(err) } continue + } else if fieldValuePtr == nil { + continue } if col.IsDeleted && !unscoped { // tag "deleted" is enabled diff --git a/internal/statements/update.go b/internal/statements/update.go index 251880b2..06cf0689 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -88,6 +88,9 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, if err != nil { return nil, nil, err } + if fieldValuePtr == nil { + continue + } fieldValue := *fieldValuePtr fieldType := reflect.TypeOf(fieldValue.Interface()) diff --git a/schemas/column.go b/schemas/column.go index 24b53802..4bbb6c2d 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -6,10 +6,8 @@ package schemas import ( "errors" - "fmt" "reflect" "strconv" - "strings" "time" ) @@ -25,6 +23,7 @@ type Column struct { Name string TableName string FieldName string // Available only when parsed from a struct + FieldIndex []int // Available only when parsed from a struct SQLType SQLType IsJSON bool Length int @@ -83,41 +82,17 @@ func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) { // ValueOfV returns column's filed of struct's value accept reflevt value func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { - var fieldValue reflect.Value - fieldPath := strings.Split(col.FieldName, ".") - - if dataStruct.Type().Kind() == reflect.Map { - keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1]) - fieldValue = dataStruct.MapIndex(keyValue) - return &fieldValue, nil - } else if dataStruct.Type().Kind() == reflect.Interface { - structValue := reflect.ValueOf(dataStruct.Interface()) - dataStruct = &structValue - } - - level := len(fieldPath) - fieldValue = dataStruct.FieldByName(fieldPath[0]) - for i := 0; i < level-1; i++ { - if !fieldValue.IsValid() { - break - } - if fieldValue.Kind() == reflect.Struct { - fieldValue = fieldValue.FieldByName(fieldPath[i+1]) - } else if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + var v = *dataStruct + for _, i := range col.FieldIndex { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) } - fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) - } else { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) + v = v.Elem() } + v = v.FieldByIndex([]int{i}) } - - if !fieldValue.IsValid() { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) - } - - return &fieldValue, nil + return &v, nil } // ConvertID converts id content to suitable type according column type diff --git a/schemas/table.go b/schemas/table.go index bfa517aa..91b33e06 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -5,7 +5,6 @@ package schemas import ( - "fmt" "reflect" "strconv" "strings" @@ -159,24 +158,8 @@ func (table *Table) IDOfV(rv reflect.Value) (PK, error) { for i, col := range table.PKColumns() { var err error - fieldName := col.FieldName - for { - parts := strings.SplitN(fieldName, ".", 2) - if len(parts) == 1 { - break - } + pkField := v.FieldByIndex(col.FieldIndex) - v = v.FieldByName(parts[0]) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - if v.Kind() != reflect.Struct { - return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName) - } - fieldName = parts[1] - } - - pkField := v.FieldByName(fieldName) switch pkField.Kind() { case reflect.String: pk[i], err = col.ConvertID(pkField.String()) diff --git a/session.go b/session.go index d5ccb6dc..6df9e20d 100644 --- a/session.go +++ b/session.go @@ -375,6 +375,9 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s if err != nil { return nil, err } + if fieldValue == nil { + return nil, ErrFieldIsNotValid{key, table.Name} + } if !fieldValue.IsValid() || !fieldValue.CanSet() { return nil, ErrFieldIsNotValid{key, table.Name} diff --git a/session_convert.go b/session_convert.go index a6839947..b8218a77 100644 --- a/session_convert.go +++ b/session_convert.go @@ -35,27 +35,20 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time sd, err := strconv.ParseInt(sdata, 10, 64) if err == nil { x = time.Unix(sd, 0) - //session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } else { - //session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } } else if len(sdata) > 19 && strings.Contains(sdata, "-") { x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc) - session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) + session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.Name, x, sdata) if err != nil { x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc) - //session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } if err != nil { x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc) - //session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } } else if len(sdata) == 19 && strings.Contains(sdata, "-") { x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc) - //session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) - //session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else if col.SQLType.Name == schemas.Time { if strings.Contains(sdata, " ") { ssd := strings.Split(sdata, " ") @@ -69,7 +62,6 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time st := fmt.Sprintf("2006-01-02 %v", sdata) x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc) - //session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else { outErr = fmt.Errorf("unsupported time format %v", sdata) return diff --git a/session_insert.go b/session_insert.go index 5f968151..82d91969 100644 --- a/session_insert.go +++ b/session_insert.go @@ -374,9 +374,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) - - return 1, nil + return 1, convertAssignV(aiValue.Addr(), id) } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) @@ -416,9 +414,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) - - return 1, nil + return 1, convertAssignV(aiValue.Addr(), id) } res, err := session.exec(sqlStr, args...) @@ -458,7 +454,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) + if err := convertAssignV(aiValue.Addr(), id); err != nil { + return 0, err + } return res.RowsAffected() } diff --git a/session_update.go b/session_update.go index d96226da..78907e43 100644 --- a/session_update.go +++ b/session_update.go @@ -280,15 +280,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 k = ct.Elem().Kind() } if k == reflect.Struct { - var refTable = session.statement.RefTable - if refTable == nil { - refTable, err = session.engine.TableInfo(condiBean[0]) - if err != nil { - return 0, err - } + condTable, err := session.engine.TableInfo(condiBean[0]) + if err != nil { + return 0, err } - var err error - autoCond, err = session.statement.BuildConds(refTable, condiBean[0], true, true, false, true, false) + + autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false) if err != nil { return 0, err } diff --git a/tags/parser.go b/tags/parser.go index 599e9e0e..d701e316 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -126,7 +126,7 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index var ErrIgnoreField = errors.New("field will be ignored") -func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { +func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { var sqlType schemas.SQLType if fieldValue.CanAddr() { if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { @@ -141,6 +141,7 @@ func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name), field.Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true) + col.FieldIndex = []int{fieldIndex} if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { col.IsAutoIncrement = true @@ -150,9 +151,10 @@ func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue return col, nil } -func (parser *Parser) parseFieldWithTags(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) { +func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) { var col = &schemas.Column{ FieldName: field.Name, + FieldIndex: []int{fieldIndex}, Nullable: true, IsPrimaryKey: false, IsAutoIncrement: false, @@ -238,7 +240,7 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, field reflect.Str return col, nil } -func (parser *Parser) parseField(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { +func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { var ( tag = field.Tag ormTagStr = strings.TrimSpace(tag.Get(parser.identifier)) @@ -247,13 +249,13 @@ func (parser *Parser) parseField(table *schemas.Table, field reflect.StructField return nil, ErrIgnoreField } if ormTagStr == "" { - return parser.parseFieldWithNoTag(field, fieldValue) + return parser.parseFieldWithNoTag(fieldIndex, field, fieldValue) } tags, err := splitTag(ormTagStr) if err != nil { return nil, err } - return parser.parseFieldWithTags(table, field, fieldValue, tags) + return parser.parseFieldWithTags(table, fieldIndex, field, fieldValue, tags) } func isNotTitle(n string) bool { @@ -279,16 +281,12 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { table.Name = names.GetTableName(parser.tableMapper, v) for i := 0; i < t.NumField(); i++ { - if isNotTitle(t.Field(i).Name) { + var field = t.Field(i) + if isNotTitle(field.Name) { continue } - var ( - field = t.Field(i) - fieldValue = v.Field(i) - ) - - col, err := parser.parseField(table, field, fieldValue) + col, err := parser.parseField(table, i, field, v.Field(i)) if err == ErrIgnoreField { continue } else if err != nil { diff --git a/tags/tag.go b/tags/tag.go index d8d9bb46..4a39ba54 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -338,6 +338,7 @@ func ExtendsTagHandler(ctx *Context) error { } for _, col := range parentTable.Columns() { col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) + col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...) var tagPrefix = ctx.col.FieldName if len(ctx.params) > 0 {