From c8b4ea56bc8a9738185a439fd5c7aa74b85ae504 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 29 Feb 2020 14:14:28 +0800 Subject: [PATCH] Some improvements --- dialects/dialect.go | 23 ++++------ dialects/mysql.go | 16 ++++--- engine.go | 34 ++++++-------- engine_group.go | 8 ---- interface.go | 5 +- internal/statements/statement.go | 18 ++++---- internal/statements/update.go | 2 +- schemas/table.go | 8 ++-- schemas/type.go | 12 +++-- session.go | 2 +- session_convert.go | 6 +-- session_find.go | 2 +- session_schema.go | 2 +- tags/parser.go | 79 +++++++++++++++++++++++--------- tags/parser_test.go | 40 ++++++++++++++++ tags/tag.go | 2 +- tags_test.go | 2 +- xorm_test.go | 2 +- 18 files changed, 161 insertions(+), 102 deletions(-) create mode 100644 tags/parser_test.go diff --git a/dialects/dialect.go b/dialects/dialect.go index 186f94be..5efd1da4 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -14,10 +14,8 @@ import ( "xorm.io/xorm/schemas" ) -type DBType string - type URI struct { - DBType DBType + DBType schemas.DBType Proto string Host string Port string @@ -31,12 +29,12 @@ type URI struct { Schema string } -// a dialect is a driver's wrapper +// Dialect represents a kind of database type Dialect interface { Init(*core.DB, *URI, string, string) error URI() *URI DB() *core.DB - DBType() DBType + DBType() schemas.DBType SQLType(*schemas.Column) string FormatBytes(b []byte) string DefaultSchema() string @@ -111,7 +109,7 @@ func (b *Base) URI() *URI { return b.uri } -func (b *Base) DBType() DBType { +func (b *Base) DBType() schemas.DBType { return b.uri.DBType } @@ -221,13 +219,8 @@ func (db *Base) IsColumnExist(ctx context.Context, tableName, colName string) (b } func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { - quoter := db.dialect.Quoter() - sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), + return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName), db.String(col)) - if db.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 { - sql += " COMMENT '" + col.Comment + "'" - } - return sql } func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { @@ -323,7 +316,7 @@ var ( ) // RegisterDialect register database dialect -func RegisterDialect(dbName DBType, dialectFunc func() Dialect) { +func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) { if dialectFunc == nil { panic("core: Register dialect is nil") } @@ -331,7 +324,7 @@ func RegisterDialect(dbName DBType, dialectFunc func() Dialect) { } // QueryDialect query if registered database dialect -func QueryDialect(dbName DBType) Dialect { +func QueryDialect(dbName schemas.DBType) Dialect { if d, ok := dialects[strings.ToLower(string(dbName))]; ok { return d() } @@ -340,7 +333,7 @@ func QueryDialect(dbName DBType) Dialect { func regDrvsNDialects() bool { providedDrvsNDialects := map[string]struct { - dbType DBType + dbType schemas.DBType getDriver func() Driver getDialect func() Dialect }{ diff --git a/dialects/mysql.go b/dialects/mysql.go index 09384e89..5ed2d8f1 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -303,18 +303,22 @@ func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{} return sql, args } -/*func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{db.DbName, tableName, colName} - sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" - return sql, args -}*/ - func (db *mysql) TableCheckSQL(tableName string) (string, []interface{}) { args := []interface{}{db.uri.DBName, tableName} sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" return sql, args } +func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { + quoter := db.dialect.Quoter() + sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), + db.String(col)) + if len(col.Comment) > 0 { + sql += " COMMENT '" + col.Comment + "'" + } + return sql +} + func (db *mysql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{db.uri.DBName, tableName} s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + diff --git a/engine.go b/engine.go index f40ce2e4..b34f0716 100644 --- a/engine.go +++ b/engine.go @@ -118,12 +118,12 @@ func (engine *Engine) SetMapper(mapper names.Mapper) { // SetTableMapper set the table name mapping rule func (engine *Engine) SetTableMapper(mapper names.Mapper) { - engine.tagParser.TableMapper = mapper + engine.tagParser.SetTableMapper(mapper) } // SetColumnMapper set the column name mapping rule func (engine *Engine) SetColumnMapper(mapper names.Mapper) { - engine.tagParser.ColumnMapper = mapper + engine.tagParser.SetColumnMapper(mapper) } // SupportInsertMany If engine's database support batch insert records like @@ -320,7 +320,7 @@ func (engine *Engine) DBMetas() ([]*schemas.Table, error) { } // DumpAllToFile dump database all table structs and data to a file -func (engine *Engine) DumpAllToFile(fp string, tp ...dialects.DBType) error { +func (engine *Engine) DumpAllToFile(fp string, tp ...schemas.DBType) error { f, err := os.Create(fp) if err != nil { return err @@ -330,7 +330,7 @@ func (engine *Engine) DumpAllToFile(fp string, tp ...dialects.DBType) error { } // DumpAll dump database all table structs and data to w -func (engine *Engine) DumpAll(w io.Writer, tp ...dialects.DBType) error { +func (engine *Engine) DumpAll(w io.Writer, tp ...schemas.DBType) error { tables, err := engine.DBMetas() if err != nil { return err @@ -339,7 +339,7 @@ func (engine *Engine) DumpAll(w io.Writer, tp ...dialects.DBType) error { } // DumpTablesToFile dump specified tables to SQL file. -func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ...dialects.DBType) error { +func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ...schemas.DBType) error { f, err := os.Create(fp) if err != nil { return err @@ -349,12 +349,12 @@ func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp .. } // DumpTables dump specify tables to io.Writer -func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...dialects.DBType) error { +func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { return engine.dumpTables(tables, w, tp...) } // dumpTables dump database all table structs and data to w with specify db type -func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dialects.DBType) error { +func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { var dialect dialects.Dialect var distDBName string if len(tp) == 0 { @@ -480,7 +480,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dia } // FIXME: Hack for postgres - if string(dialect.DBType()) == schemas.POSTGRES && table.AutoIncrColumn() != nil { + if dialect.DBType() == schemas.POSTGRES && table.AutoIncrColumn() != nil { _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n") if err != nil { return err @@ -723,13 +723,9 @@ func (t *Table) IsValid() bool { } // TableInfo get table info according to bean's content -func (engine *Engine) TableInfo(bean interface{}) (*Table, error) { +func (engine *Engine) TableInfo(bean interface{}) (*schemas.Table, error) { v := utils.ReflectValue(bean) - tb, err := engine.tagParser.MapType(v) - if err != nil { - return nil, err - } - return &Table{tb, dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)}, nil + return engine.tagParser.ParseWithCache(v) } // IsTableEmpty if a table has any reocrd @@ -763,7 +759,7 @@ func (engine *Engine) IDOfV(rv reflect.Value) (schemas.PK, error) { func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) { v := reflect.Indirect(rv) - table, err := engine.tagParser.MapType(v) + table, err := engine.tagParser.ParseWithCache(v) if err != nil { return nil, err } @@ -861,7 +857,7 @@ func (engine *Engine) ClearCache(beans ...interface{}) error { // UnMapType remove table from tables cache func (engine *Engine) UnMapType(t reflect.Type) { - engine.tagParser.ClearTable(t) + engine.tagParser.ClearCacheTable(t) } // Sync the new struct changes to database, this method will automatically add @@ -874,7 +870,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { for _, bean := range beans { v := utils.ReflectValue(bean) tableNameNoSchema := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) - table, err := engine.tagParser.MapType(v) + table, err := engine.tagParser.ParseWithCache(v) if err != nil { return err } @@ -1222,12 +1218,12 @@ func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interfa // GetColumnMapper returns the column name mapper func (engine *Engine) GetColumnMapper() names.Mapper { - return engine.tagParser.ColumnMapper + return engine.tagParser.GetColumnMapper() } // GetTableMapper returns the table name mapper func (engine *Engine) GetTableMapper() names.Mapper { - return engine.tagParser.TableMapper + return engine.tagParser.GetTableMapper() } // GetTZLocation returns time zone of the application diff --git a/engine_group.go b/engine_group.go index 71095a91..8177697e 100644 --- a/engine_group.go +++ b/engine_group.go @@ -188,14 +188,6 @@ func (eg *EngineGroup) SetTableMapper(mapper names.Mapper) { } } -// ShowExecTime show SQL statement and execute time or not on logger if log level is great than INFO -/*func (eg *EngineGroup) ShowExecTime(show ...bool) { - eg.Engine.ShowExecTime(show...) - for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].ShowExecTime(show...) - } -}*/ - // ShowSQL show SQL statement or not on logger if log level is great than INFO func (eg *EngineGroup) ShowSQL(show ...bool) { eg.Engine.ShowSQL(show...) diff --git a/interface.go b/interface.go index 694cddf9..13f1e12a 100644 --- a/interface.go +++ b/interface.go @@ -83,7 +83,7 @@ type EngineInterface interface { DBMetas() ([]*schemas.Table, error) Dialect() dialects.Dialect DropTables(...interface{}) error - DumpAllToFile(fp string, tp ...dialects.DBType) error + DumpAllToFile(fp string, tp ...schemas.DBType) error GetCacher(string) caches.Cacher GetColumnMapper() names.Mapper GetDefaultCacher() caches.Cacher @@ -107,12 +107,11 @@ type EngineInterface interface { SetTableMapper(names.Mapper) SetTZDatabase(tz *time.Location) SetTZLocation(tz *time.Location) - //ShowExecTime(...bool) ShowSQL(show ...bool) Sync(...interface{}) error Sync2(...interface{}) error StoreEngine(storeEngine string) *Session - TableInfo(bean interface{}) (*Table, error) + TableInfo(bean interface{}) (*schemas.Table, error) TableName(interface{}, ...bool) string UnMapType(reflect.Type) } diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 92b1809a..68738b90 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -253,11 +253,11 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement func (statement *Statement) SetRefValue(v reflect.Value) error { var err error - statement.RefTable, err = statement.tagParser.MapType(reflect.Indirect(v)) + statement.RefTable, err = statement.tagParser.ParseWithCache(reflect.Indirect(v)) if err != nil { return err } - statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, v, true) + statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), v, true) return nil } @@ -267,11 +267,11 @@ func rValue(bean interface{}) reflect.Value { func (statement *Statement) SetRefBean(bean interface{}) error { var err error - statement.RefTable, err = statement.tagParser.MapType(rValue(bean)) + statement.RefTable, err = statement.tagParser.ParseWithCache(rValue(bean)) if err != nil { return err } - statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, bean, true) + statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), bean, true) return nil } @@ -507,13 +507,13 @@ func (statement *Statement) SetTable(tableNameOrBean interface{}) error { t := v.Type() if t.Kind() == reflect.Struct { var err error - statement.RefTable, err = statement.tagParser.MapType(v) + statement.RefTable, err = statement.tagParser.ParseWithCache(v) if err != nil { return err } } - statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, tableNameOrBean, true) + statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tableNameOrBean, true) return nil } @@ -554,7 +554,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) default: - tbName := dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, tablename, true) + tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) if !utils.IsSubQuery(tbName) { var buf strings.Builder statement.dialect.Quoter().QuoteTo(&buf, tbName) @@ -689,7 +689,7 @@ func (statement *Statement) GenDelIndexSQL() []string { } else if index.Type == schemas.IndexType { rIdxName = utils.IndexName(idxPrefixName, idxName) } - sql := fmt.Sprintf("DROP INDEX %v", statement.quote(dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, rIdxName, true))) + sql := fmt.Sprintf("DROP INDEX %v", statement.quote(dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), rIdxName, true))) if statement.dialect.IndexOnTable() { sql += fmt.Sprintf(" ON %v", statement.quote(tbName)) } @@ -844,7 +844,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, val = bytes } } else { - table, err := statement.tagParser.MapType(fieldValue) + table, err := statement.tagParser.ParseWithCache(fieldValue) if err != nil { val = fieldValue.Interface() } else { diff --git a/internal/statements/update.go b/internal/statements/update.go index a5d7ec5a..e9cdd98c 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -187,7 +187,7 @@ func (statement *Statement) BuildUpdates(bean interface{}, val, _ = nulType.Value() } else { if !col.SQLType.IsJson() { - table, err := statement.tagParser.MapType(fieldValue) + table, err := statement.tagParser.ParseWithCache(fieldValue) if err != nil { val = fieldValue.Interface() } else { diff --git a/schemas/table.go b/schemas/table.go index 44aa8152..2dac3ea2 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -7,7 +7,6 @@ package schemas import ( "reflect" "strings" - //"xorm.io/xorm/cache" ) // Table represents a database table @@ -24,10 +23,9 @@ type Table struct { Updated string Deleted string Version string - //Cacher caches.Cacher - StoreEngine string - Charset string - Comment string + StoreEngine string + Charset string + Comment string } func NewEmptyTable() *Table { diff --git a/schemas/type.go b/schemas/type.go index 2aaa2a44..39f1bf4e 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -11,12 +11,14 @@ import ( "time" ) +type DBType string + const ( - POSTGRES = "postgres" - SQLITE = "sqlite3" - MYSQL = "mysql" - MSSQL = "mssql" - ORACLE = "oracle" + POSTGRES DBType = "postgres" + SQLITE DBType = "sqlite3" + MYSQL DBType = "mysql" + MSSQL DBType = "mssql" + ORACLE DBType = "oracle" ) // SQLType represents SQL types diff --git a/session.go b/session.go index 0f9099da..287465ca 100644 --- a/session.go +++ b/session.go @@ -698,7 +698,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } } } else if session.statement.UseCascade { - table, err := session.engine.tagParser.MapType(*fieldValue) + table, err := session.engine.tagParser.ParseWithCache(*fieldValue) if err != nil { return nil, err } diff --git a/session_convert.go b/session_convert.go index 41ab75a9..1cd00627 100644 --- a/session_convert.go +++ b/session_convert.go @@ -207,7 +207,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val v = x fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) } else if session.statement.UseCascade { - table, err := session.engine.tagParser.MapType(*fieldValue) + table, err := session.engine.tagParser.ParseWithCache(*fieldValue) if err != nil { return err } @@ -488,7 +488,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val default: if session.statement.UseCascade { structInter := reflect.New(fieldType.Elem()) - table, err := session.engine.tagParser.MapType(structInter.Elem()) + table, err := session.engine.tagParser.ParseWithCache(structInter.Elem()) if err != nil { return err } @@ -599,7 +599,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. return v.Value() } - fieldTable, err := session.engine.tagParser.MapType(fieldValue) + fieldTable, err := session.engine.tagParser.ParseWithCache(fieldValue) if err != nil { return nil, err } diff --git a/session_find.go b/session_find.go index 97273428..9551b767 100644 --- a/session_find.go +++ b/session_find.go @@ -225,7 +225,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect if elemType.Kind() == reflect.Struct { var newValue = newElemFunc(fields) dataStruct := utils.ReflectValue(newValue.Interface()) - tb, err := session.engine.tagParser.MapType(dataStruct) + tb, err := session.engine.tagParser.ParseWithCache(dataStruct) if err != nil { return err } diff --git a/session_schema.go b/session_schema.go index 0279ced7..3617a6b8 100644 --- a/session_schema.go +++ b/session_schema.go @@ -242,7 +242,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for _, bean := range beans { v := utils.ReflectValue(bean) - table, err := engine.tagParser.MapType(v) + table, err := engine.tagParser.ParseWithCache(v) if err != nil { return err } diff --git a/tags/parser.go b/tags/parser.go index 5c94c55b..236d2d46 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -20,11 +20,15 @@ import ( "xorm.io/xorm/schemas" ) +var ( + ErrUnsupportedType = errors.New("Unsupported type") +) + type Parser struct { identifier string dialect dialects.Dialect - ColumnMapper names.Mapper - TableMapper names.Mapper + columnMapper names.Mapper + tableMapper names.Mapper handlers map[string]Handler cacherMgr *caches.Manager tableCache sync.Map // map[reflect.Type]*schemas.Table @@ -34,33 +38,39 @@ func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnM return &Parser{ identifier: identifier, dialect: dialect, - TableMapper: tableMapper, - ColumnMapper: columnMapper, + tableMapper: tableMapper, + columnMapper: columnMapper, handlers: defaultTagHandlers, cacherMgr: cacherMgr, } } -func addIndex(indexName string, table *schemas.Table, col *schemas.Column, indexType int) { - if index, ok := table.Indexes[indexName]; ok { - index.AddColumn(col.Name) - col.Indexes[index.Name] = indexType - } else { - index := schemas.NewIndex(indexName, indexType) - index.AddColumn(col.Name) - table.AddIndex(index) - col.Indexes[index.Name] = indexType - } +func (parser *Parser) GetTableMapper() names.Mapper { + return parser.tableMapper } -func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) { +func (parser *Parser) SetTableMapper(mapper names.Mapper) { + parser.ClearCaches() + parser.tableMapper = mapper +} + +func (parser *Parser) GetColumnMapper() names.Mapper { + return parser.columnMapper +} + +func (parser *Parser) SetColumnMapper(mapper names.Mapper) { + parser.ClearCaches() + parser.columnMapper = mapper +} + +func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) { t := v.Type() tableI, ok := parser.tableCache.Load(t) if ok { return tableI.(*schemas.Table), nil } - table, err := parser.mapType(v) + table, err := parser.Parse(v) if err != nil { return nil, err } @@ -78,16 +88,41 @@ func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) { return table, nil } -// ClearTable removes the database mapper of a type from the cache -func (parser *Parser) ClearTable(t reflect.Type) { +// ClearCacheTable removes the database mapper of a type from the cache +func (parser *Parser) ClearCacheTable(t reflect.Type) { parser.tableCache.Delete(t) } -func (parser *Parser) mapType(v reflect.Value) (*schemas.Table, error) { +// ClearCaches removes all the cached table information parsed by structs +func (parser *Parser) ClearCaches() { + parser.tableCache = sync.Map{} +} + +func addIndex(indexName string, table *schemas.Table, col *schemas.Column, indexType int) { + if index, ok := table.Indexes[indexName]; ok { + index.AddColumn(col.Name) + col.Indexes[index.Name] = indexType + } else { + index := schemas.NewIndex(indexName, indexType) + index.AddColumn(col.Name) + table.AddIndex(index) + col.Indexes[index.Name] = indexType + } +} + +// Parse parses a struct as a table information +func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { t := v.Type() + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, ErrUnsupportedType + } + table := schemas.NewEmptyTable() table.Type = t - table.Name = names.GetTableName(parser.TableMapper, v) + table.Name = names.GetTableName(parser.tableMapper, v) var idFieldColName string var hasCacheTag, hasNoCacheTag bool @@ -204,7 +239,7 @@ func (parser *Parser) mapType(v reflect.Value) (*schemas.Table, error) { col.Length2 = col.SQLType.DefaultLength2 } if col.Name == "" { - col.Name = parser.ColumnMapper.Obj2Table(t.Field(i).Name) + col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name) } if ctx.isUnique { @@ -229,7 +264,7 @@ func (parser *Parser) mapType(v reflect.Value) (*schemas.Table, error) { } else { sqlType = schemas.Type2SQLType(fieldType) } - col = schemas.NewColumn(parser.ColumnMapper.Obj2Table(t.Field(i).Name), + col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true) diff --git a/tags/parser_test.go b/tags/parser_test.go new file mode 100644 index 00000000..929b3718 --- /dev/null +++ b/tags/parser_test.go @@ -0,0 +1,40 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tags + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "xorm.io/xorm/caches" + "xorm.io/xorm/dialects" + "xorm.io/xorm/names" +) + +type ParseTableName1 struct{} + +type ParseTableName2 struct{} + +func (p ParseTableName2) TableName() string { + return "p_parseTableName" +} + +func TestParseTableName(t *testing.T) { + parser := NewParser( + "xorm", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + table, err := parser.Parse(reflect.ValueOf(new(ParseTableName1))) + assert.NoError(t, err) + assert.EqualValues(t, "parse_table_name1", table.Name) + + table, err = parser.Parse(reflect.ValueOf(new(ParseTableName2))) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableName", table.Name) +} diff --git a/tags/tag.go b/tags/tag.go index a043ed77..ee3f1e82 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -280,7 +280,7 @@ func ExtendsTagHandler(ctx *Context) error { isPtr = true fallthrough case reflect.Struct: - parentTable, err := ctx.parser.mapType(fieldValue) + parentTable, err := ctx.parser.Parse(fieldValue) if err != nil { return err } diff --git a/tags_test.go b/tags_test.go index 9d41a5fa..775fcf60 100644 --- a/tags_test.go +++ b/tags_test.go @@ -871,7 +871,7 @@ func TestAutoIncrTag(t *testing.T) { func TestTagComment(t *testing.T) { assert.NoError(t, prepareEngine()) // FIXME: only support mysql - if testEngine.Dialect().DriverName() != schemas.MYSQL { + if testEngine.Dialect().DBType() != schemas.MYSQL { return } diff --git a/xorm_test.go b/xorm_test.go index 2a24edb3..59f6c1a9 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -47,7 +47,7 @@ func createEngine(dbType, connStr string) error { var err error if !*cluster { - switch strings.ToLower(dbType) { + switch schemas.DBType(strings.ToLower(dbType)) { case schemas.MSSQL: db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1)) if err != nil {