Some improvements
This commit is contained in:
parent
439784b33b
commit
c8b4ea56bc
|
@ -14,10 +14,8 @@ import (
|
|||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
type DBType string
|
||||
|
||||
type URI struct {
|
||||
DBType DBType
|
||||
DBType schemas.DBType
|
||||
Proto string
|
||||
Host string
|
||||
Port string
|
||||
|
@ -31,12 +29,12 @@ type URI struct {
|
|||
Schema string
|
||||
}
|
||||
|
||||
// a dialect is a driver's wrapper
|
||||
// Dialect represents a kind of database
|
||||
type Dialect interface {
|
||||
Init(*core.DB, *URI, string, string) error
|
||||
URI() *URI
|
||||
DB() *core.DB
|
||||
DBType() DBType
|
||||
DBType() schemas.DBType
|
||||
SQLType(*schemas.Column) string
|
||||
FormatBytes(b []byte) string
|
||||
DefaultSchema() string
|
||||
|
@ -111,7 +109,7 @@ func (b *Base) URI() *URI {
|
|||
return b.uri
|
||||
}
|
||||
|
||||
func (b *Base) DBType() DBType {
|
||||
func (b *Base) DBType() schemas.DBType {
|
||||
return b.uri.DBType
|
||||
}
|
||||
|
||||
|
@ -221,13 +219,8 @@ func (db *Base) IsColumnExist(ctx context.Context, tableName, colName string) (b
|
|||
}
|
||||
|
||||
func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
|
||||
quoter := db.dialect.Quoter()
|
||||
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName),
|
||||
return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName),
|
||||
db.String(col))
|
||||
if db.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
|
||||
sql += " COMMENT '" + col.Comment + "'"
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
|
||||
|
@ -323,7 +316,7 @@ var (
|
|||
)
|
||||
|
||||
// RegisterDialect register database dialect
|
||||
func RegisterDialect(dbName DBType, dialectFunc func() Dialect) {
|
||||
func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) {
|
||||
if dialectFunc == nil {
|
||||
panic("core: Register dialect is nil")
|
||||
}
|
||||
|
@ -331,7 +324,7 @@ func RegisterDialect(dbName DBType, dialectFunc func() Dialect) {
|
|||
}
|
||||
|
||||
// QueryDialect query if registered database dialect
|
||||
func QueryDialect(dbName DBType) Dialect {
|
||||
func QueryDialect(dbName schemas.DBType) Dialect {
|
||||
if d, ok := dialects[strings.ToLower(string(dbName))]; ok {
|
||||
return d()
|
||||
}
|
||||
|
@ -340,7 +333,7 @@ func QueryDialect(dbName DBType) Dialect {
|
|||
|
||||
func regDrvsNDialects() bool {
|
||||
providedDrvsNDialects := map[string]struct {
|
||||
dbType DBType
|
||||
dbType schemas.DBType
|
||||
getDriver func() Driver
|
||||
getDialect func() Dialect
|
||||
}{
|
||||
|
|
|
@ -303,18 +303,22 @@ func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{}
|
|||
return sql, args
|
||||
}
|
||||
|
||||
/*func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
|
||||
args := []interface{}{db.DbName, tableName, colName}
|
||||
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
|
||||
return sql, args
|
||||
}*/
|
||||
|
||||
func (db *mysql) TableCheckSQL(tableName string) (string, []interface{}) {
|
||||
args := []interface{}{db.uri.DBName, tableName}
|
||||
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
|
||||
return sql, args
|
||||
}
|
||||
|
||||
func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
|
||||
quoter := db.dialect.Quoter()
|
||||
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName),
|
||||
db.String(col))
|
||||
if len(col.Comment) > 0 {
|
||||
sql += " COMMENT '" + col.Comment + "'"
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func (db *mysql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
args := []interface{}{db.uri.DBName, tableName}
|
||||
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
|
||||
|
|
34
engine.go
34
engine.go
|
@ -118,12 +118,12 @@ func (engine *Engine) SetMapper(mapper names.Mapper) {
|
|||
|
||||
// SetTableMapper set the table name mapping rule
|
||||
func (engine *Engine) SetTableMapper(mapper names.Mapper) {
|
||||
engine.tagParser.TableMapper = mapper
|
||||
engine.tagParser.SetTableMapper(mapper)
|
||||
}
|
||||
|
||||
// SetColumnMapper set the column name mapping rule
|
||||
func (engine *Engine) SetColumnMapper(mapper names.Mapper) {
|
||||
engine.tagParser.ColumnMapper = mapper
|
||||
engine.tagParser.SetColumnMapper(mapper)
|
||||
}
|
||||
|
||||
// SupportInsertMany If engine's database support batch insert records like
|
||||
|
@ -320,7 +320,7 @@ func (engine *Engine) DBMetas() ([]*schemas.Table, error) {
|
|||
}
|
||||
|
||||
// DumpAllToFile dump database all table structs and data to a file
|
||||
func (engine *Engine) DumpAllToFile(fp string, tp ...dialects.DBType) error {
|
||||
func (engine *Engine) DumpAllToFile(fp string, tp ...schemas.DBType) error {
|
||||
f, err := os.Create(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -330,7 +330,7 @@ func (engine *Engine) DumpAllToFile(fp string, tp ...dialects.DBType) error {
|
|||
}
|
||||
|
||||
// DumpAll dump database all table structs and data to w
|
||||
func (engine *Engine) DumpAll(w io.Writer, tp ...dialects.DBType) error {
|
||||
func (engine *Engine) DumpAll(w io.Writer, tp ...schemas.DBType) error {
|
||||
tables, err := engine.DBMetas()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -339,7 +339,7 @@ func (engine *Engine) DumpAll(w io.Writer, tp ...dialects.DBType) error {
|
|||
}
|
||||
|
||||
// DumpTablesToFile dump specified tables to SQL file.
|
||||
func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ...dialects.DBType) error {
|
||||
func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ...schemas.DBType) error {
|
||||
f, err := os.Create(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -349,12 +349,12 @@ func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ..
|
|||
}
|
||||
|
||||
// DumpTables dump specify tables to io.Writer
|
||||
func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...dialects.DBType) error {
|
||||
func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
|
||||
return engine.dumpTables(tables, w, tp...)
|
||||
}
|
||||
|
||||
// dumpTables dump database all table structs and data to w with specify db type
|
||||
func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dialects.DBType) error {
|
||||
func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
|
||||
var dialect dialects.Dialect
|
||||
var distDBName string
|
||||
if len(tp) == 0 {
|
||||
|
@ -480,7 +480,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dia
|
|||
}
|
||||
|
||||
// FIXME: Hack for postgres
|
||||
if string(dialect.DBType()) == schemas.POSTGRES && table.AutoIncrColumn() != nil {
|
||||
if dialect.DBType() == schemas.POSTGRES && table.AutoIncrColumn() != nil {
|
||||
_, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n")
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -723,13 +723,9 @@ func (t *Table) IsValid() bool {
|
|||
}
|
||||
|
||||
// TableInfo get table info according to bean's content
|
||||
func (engine *Engine) TableInfo(bean interface{}) (*Table, error) {
|
||||
func (engine *Engine) TableInfo(bean interface{}) (*schemas.Table, error) {
|
||||
v := utils.ReflectValue(bean)
|
||||
tb, err := engine.tagParser.MapType(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Table{tb, dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)}, nil
|
||||
return engine.tagParser.ParseWithCache(v)
|
||||
}
|
||||
|
||||
// IsTableEmpty if a table has any reocrd
|
||||
|
@ -763,7 +759,7 @@ func (engine *Engine) IDOfV(rv reflect.Value) (schemas.PK, error) {
|
|||
|
||||
func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) {
|
||||
v := reflect.Indirect(rv)
|
||||
table, err := engine.tagParser.MapType(v)
|
||||
table, err := engine.tagParser.ParseWithCache(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -861,7 +857,7 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
|
|||
|
||||
// UnMapType remove table from tables cache
|
||||
func (engine *Engine) UnMapType(t reflect.Type) {
|
||||
engine.tagParser.ClearTable(t)
|
||||
engine.tagParser.ClearCacheTable(t)
|
||||
}
|
||||
|
||||
// Sync the new struct changes to database, this method will automatically add
|
||||
|
@ -874,7 +870,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
|
|||
for _, bean := range beans {
|
||||
v := utils.ReflectValue(bean)
|
||||
tableNameNoSchema := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
|
||||
table, err := engine.tagParser.MapType(v)
|
||||
table, err := engine.tagParser.ParseWithCache(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1222,12 +1218,12 @@ func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interfa
|
|||
|
||||
// GetColumnMapper returns the column name mapper
|
||||
func (engine *Engine) GetColumnMapper() names.Mapper {
|
||||
return engine.tagParser.ColumnMapper
|
||||
return engine.tagParser.GetColumnMapper()
|
||||
}
|
||||
|
||||
// GetTableMapper returns the table name mapper
|
||||
func (engine *Engine) GetTableMapper() names.Mapper {
|
||||
return engine.tagParser.TableMapper
|
||||
return engine.tagParser.GetTableMapper()
|
||||
}
|
||||
|
||||
// GetTZLocation returns time zone of the application
|
||||
|
|
|
@ -188,14 +188,6 @@ func (eg *EngineGroup) SetTableMapper(mapper names.Mapper) {
|
|||
}
|
||||
}
|
||||
|
||||
// ShowExecTime show SQL statement and execute time or not on logger if log level is great than INFO
|
||||
/*func (eg *EngineGroup) ShowExecTime(show ...bool) {
|
||||
eg.Engine.ShowExecTime(show...)
|
||||
for i := 0; i < len(eg.slaves); i++ {
|
||||
eg.slaves[i].ShowExecTime(show...)
|
||||
}
|
||||
}*/
|
||||
|
||||
// ShowSQL show SQL statement or not on logger if log level is great than INFO
|
||||
func (eg *EngineGroup) ShowSQL(show ...bool) {
|
||||
eg.Engine.ShowSQL(show...)
|
||||
|
|
|
@ -83,7 +83,7 @@ type EngineInterface interface {
|
|||
DBMetas() ([]*schemas.Table, error)
|
||||
Dialect() dialects.Dialect
|
||||
DropTables(...interface{}) error
|
||||
DumpAllToFile(fp string, tp ...dialects.DBType) error
|
||||
DumpAllToFile(fp string, tp ...schemas.DBType) error
|
||||
GetCacher(string) caches.Cacher
|
||||
GetColumnMapper() names.Mapper
|
||||
GetDefaultCacher() caches.Cacher
|
||||
|
@ -107,12 +107,11 @@ type EngineInterface interface {
|
|||
SetTableMapper(names.Mapper)
|
||||
SetTZDatabase(tz *time.Location)
|
||||
SetTZLocation(tz *time.Location)
|
||||
//ShowExecTime(...bool)
|
||||
ShowSQL(show ...bool)
|
||||
Sync(...interface{}) error
|
||||
Sync2(...interface{}) error
|
||||
StoreEngine(storeEngine string) *Session
|
||||
TableInfo(bean interface{}) (*Table, error)
|
||||
TableInfo(bean interface{}) (*schemas.Table, error)
|
||||
TableName(interface{}, ...bool) string
|
||||
UnMapType(reflect.Type)
|
||||
}
|
||||
|
|
|
@ -253,11 +253,11 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement
|
|||
|
||||
func (statement *Statement) SetRefValue(v reflect.Value) error {
|
||||
var err error
|
||||
statement.RefTable, err = statement.tagParser.MapType(reflect.Indirect(v))
|
||||
statement.RefTable, err = statement.tagParser.ParseWithCache(reflect.Indirect(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, v, true)
|
||||
statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), v, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -267,11 +267,11 @@ func rValue(bean interface{}) reflect.Value {
|
|||
|
||||
func (statement *Statement) SetRefBean(bean interface{}) error {
|
||||
var err error
|
||||
statement.RefTable, err = statement.tagParser.MapType(rValue(bean))
|
||||
statement.RefTable, err = statement.tagParser.ParseWithCache(rValue(bean))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, bean, true)
|
||||
statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), bean, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -507,13 +507,13 @@ func (statement *Statement) SetTable(tableNameOrBean interface{}) error {
|
|||
t := v.Type()
|
||||
if t.Kind() == reflect.Struct {
|
||||
var err error
|
||||
statement.RefTable, err = statement.tagParser.MapType(v)
|
||||
statement.RefTable, err = statement.tagParser.ParseWithCache(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, tableNameOrBean, true)
|
||||
statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tableNameOrBean, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -554,7 +554,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
|||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||
default:
|
||||
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, tablename, true)
|
||||
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
|
||||
if !utils.IsSubQuery(tbName) {
|
||||
var buf strings.Builder
|
||||
statement.dialect.Quoter().QuoteTo(&buf, tbName)
|
||||
|
@ -689,7 +689,7 @@ func (statement *Statement) GenDelIndexSQL() []string {
|
|||
} else if index.Type == schemas.IndexType {
|
||||
rIdxName = utils.IndexName(idxPrefixName, idxName)
|
||||
}
|
||||
sql := fmt.Sprintf("DROP INDEX %v", statement.quote(dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, rIdxName, true)))
|
||||
sql := fmt.Sprintf("DROP INDEX %v", statement.quote(dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), rIdxName, true)))
|
||||
if statement.dialect.IndexOnTable() {
|
||||
sql += fmt.Sprintf(" ON %v", statement.quote(tbName))
|
||||
}
|
||||
|
@ -844,7 +844,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
|
|||
val = bytes
|
||||
}
|
||||
} else {
|
||||
table, err := statement.tagParser.MapType(fieldValue)
|
||||
table, err := statement.tagParser.ParseWithCache(fieldValue)
|
||||
if err != nil {
|
||||
val = fieldValue.Interface()
|
||||
} else {
|
||||
|
|
|
@ -187,7 +187,7 @@ func (statement *Statement) BuildUpdates(bean interface{},
|
|||
val, _ = nulType.Value()
|
||||
} else {
|
||||
if !col.SQLType.IsJson() {
|
||||
table, err := statement.tagParser.MapType(fieldValue)
|
||||
table, err := statement.tagParser.ParseWithCache(fieldValue)
|
||||
if err != nil {
|
||||
val = fieldValue.Interface()
|
||||
} else {
|
||||
|
|
|
@ -7,7 +7,6 @@ package schemas
|
|||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
//"xorm.io/xorm/cache"
|
||||
)
|
||||
|
||||
// Table represents a database table
|
||||
|
@ -24,10 +23,9 @@ type Table struct {
|
|||
Updated string
|
||||
Deleted string
|
||||
Version string
|
||||
//Cacher caches.Cacher
|
||||
StoreEngine string
|
||||
Charset string
|
||||
Comment string
|
||||
StoreEngine string
|
||||
Charset string
|
||||
Comment string
|
||||
}
|
||||
|
||||
func NewEmptyTable() *Table {
|
||||
|
|
|
@ -11,12 +11,14 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
type DBType string
|
||||
|
||||
const (
|
||||
POSTGRES = "postgres"
|
||||
SQLITE = "sqlite3"
|
||||
MYSQL = "mysql"
|
||||
MSSQL = "mssql"
|
||||
ORACLE = "oracle"
|
||||
POSTGRES DBType = "postgres"
|
||||
SQLITE DBType = "sqlite3"
|
||||
MYSQL DBType = "mysql"
|
||||
MSSQL DBType = "mssql"
|
||||
ORACLE DBType = "oracle"
|
||||
)
|
||||
|
||||
// SQLType represents SQL types
|
||||
|
|
|
@ -698,7 +698,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
|
|||
}
|
||||
}
|
||||
} else if session.statement.UseCascade {
|
||||
table, err := session.engine.tagParser.MapType(*fieldValue)
|
||||
table, err := session.engine.tagParser.ParseWithCache(*fieldValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -207,7 +207,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
|
|||
v = x
|
||||
fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
|
||||
} else if session.statement.UseCascade {
|
||||
table, err := session.engine.tagParser.MapType(*fieldValue)
|
||||
table, err := session.engine.tagParser.ParseWithCache(*fieldValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -488,7 +488,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
|
|||
default:
|
||||
if session.statement.UseCascade {
|
||||
structInter := reflect.New(fieldType.Elem())
|
||||
table, err := session.engine.tagParser.MapType(structInter.Elem())
|
||||
table, err := session.engine.tagParser.ParseWithCache(structInter.Elem())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -599,7 +599,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
|
|||
return v.Value()
|
||||
}
|
||||
|
||||
fieldTable, err := session.engine.tagParser.MapType(fieldValue)
|
||||
fieldTable, err := session.engine.tagParser.ParseWithCache(fieldValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -225,7 +225,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
|
|||
if elemType.Kind() == reflect.Struct {
|
||||
var newValue = newElemFunc(fields)
|
||||
dataStruct := utils.ReflectValue(newValue.Interface())
|
||||
tb, err := session.engine.tagParser.MapType(dataStruct)
|
||||
tb, err := session.engine.tagParser.ParseWithCache(dataStruct)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -242,7 +242,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
|
|||
|
||||
for _, bean := range beans {
|
||||
v := utils.ReflectValue(bean)
|
||||
table, err := engine.tagParser.MapType(v)
|
||||
table, err := engine.tagParser.ParseWithCache(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -20,11 +20,15 @@ import (
|
|||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnsupportedType = errors.New("Unsupported type")
|
||||
)
|
||||
|
||||
type Parser struct {
|
||||
identifier string
|
||||
dialect dialects.Dialect
|
||||
ColumnMapper names.Mapper
|
||||
TableMapper names.Mapper
|
||||
columnMapper names.Mapper
|
||||
tableMapper names.Mapper
|
||||
handlers map[string]Handler
|
||||
cacherMgr *caches.Manager
|
||||
tableCache sync.Map // map[reflect.Type]*schemas.Table
|
||||
|
@ -34,33 +38,39 @@ func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnM
|
|||
return &Parser{
|
||||
identifier: identifier,
|
||||
dialect: dialect,
|
||||
TableMapper: tableMapper,
|
||||
ColumnMapper: columnMapper,
|
||||
tableMapper: tableMapper,
|
||||
columnMapper: columnMapper,
|
||||
handlers: defaultTagHandlers,
|
||||
cacherMgr: cacherMgr,
|
||||
}
|
||||
}
|
||||
|
||||
func addIndex(indexName string, table *schemas.Table, col *schemas.Column, indexType int) {
|
||||
if index, ok := table.Indexes[indexName]; ok {
|
||||
index.AddColumn(col.Name)
|
||||
col.Indexes[index.Name] = indexType
|
||||
} else {
|
||||
index := schemas.NewIndex(indexName, indexType)
|
||||
index.AddColumn(col.Name)
|
||||
table.AddIndex(index)
|
||||
col.Indexes[index.Name] = indexType
|
||||
}
|
||||
func (parser *Parser) GetTableMapper() names.Mapper {
|
||||
return parser.tableMapper
|
||||
}
|
||||
|
||||
func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) {
|
||||
func (parser *Parser) SetTableMapper(mapper names.Mapper) {
|
||||
parser.ClearCaches()
|
||||
parser.tableMapper = mapper
|
||||
}
|
||||
|
||||
func (parser *Parser) GetColumnMapper() names.Mapper {
|
||||
return parser.columnMapper
|
||||
}
|
||||
|
||||
func (parser *Parser) SetColumnMapper(mapper names.Mapper) {
|
||||
parser.ClearCaches()
|
||||
parser.columnMapper = mapper
|
||||
}
|
||||
|
||||
func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) {
|
||||
t := v.Type()
|
||||
tableI, ok := parser.tableCache.Load(t)
|
||||
if ok {
|
||||
return tableI.(*schemas.Table), nil
|
||||
}
|
||||
|
||||
table, err := parser.mapType(v)
|
||||
table, err := parser.Parse(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -78,16 +88,41 @@ func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) {
|
|||
return table, nil
|
||||
}
|
||||
|
||||
// ClearTable removes the database mapper of a type from the cache
|
||||
func (parser *Parser) ClearTable(t reflect.Type) {
|
||||
// ClearCacheTable removes the database mapper of a type from the cache
|
||||
func (parser *Parser) ClearCacheTable(t reflect.Type) {
|
||||
parser.tableCache.Delete(t)
|
||||
}
|
||||
|
||||
func (parser *Parser) mapType(v reflect.Value) (*schemas.Table, error) {
|
||||
// ClearCaches removes all the cached table information parsed by structs
|
||||
func (parser *Parser) ClearCaches() {
|
||||
parser.tableCache = sync.Map{}
|
||||
}
|
||||
|
||||
func addIndex(indexName string, table *schemas.Table, col *schemas.Column, indexType int) {
|
||||
if index, ok := table.Indexes[indexName]; ok {
|
||||
index.AddColumn(col.Name)
|
||||
col.Indexes[index.Name] = indexType
|
||||
} else {
|
||||
index := schemas.NewIndex(indexName, indexType)
|
||||
index.AddColumn(col.Name)
|
||||
table.AddIndex(index)
|
||||
col.Indexes[index.Name] = indexType
|
||||
}
|
||||
}
|
||||
|
||||
// Parse parses a struct as a table information
|
||||
func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
|
||||
t := v.Type()
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return nil, ErrUnsupportedType
|
||||
}
|
||||
|
||||
table := schemas.NewEmptyTable()
|
||||
table.Type = t
|
||||
table.Name = names.GetTableName(parser.TableMapper, v)
|
||||
table.Name = names.GetTableName(parser.tableMapper, v)
|
||||
|
||||
var idFieldColName string
|
||||
var hasCacheTag, hasNoCacheTag bool
|
||||
|
@ -204,7 +239,7 @@ func (parser *Parser) mapType(v reflect.Value) (*schemas.Table, error) {
|
|||
col.Length2 = col.SQLType.DefaultLength2
|
||||
}
|
||||
if col.Name == "" {
|
||||
col.Name = parser.ColumnMapper.Obj2Table(t.Field(i).Name)
|
||||
col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name)
|
||||
}
|
||||
|
||||
if ctx.isUnique {
|
||||
|
@ -229,7 +264,7 @@ func (parser *Parser) mapType(v reflect.Value) (*schemas.Table, error) {
|
|||
} else {
|
||||
sqlType = schemas.Type2SQLType(fieldType)
|
||||
}
|
||||
col = schemas.NewColumn(parser.ColumnMapper.Obj2Table(t.Field(i).Name),
|
||||
col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name),
|
||||
t.Field(i).Name, sqlType, sqlType.DefaultLength,
|
||||
sqlType.DefaultLength2, true)
|
||||
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright 2017 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tags
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"xorm.io/xorm/caches"
|
||||
"xorm.io/xorm/dialects"
|
||||
"xorm.io/xorm/names"
|
||||
)
|
||||
|
||||
type ParseTableName1 struct{}
|
||||
|
||||
type ParseTableName2 struct{}
|
||||
|
||||
func (p ParseTableName2) TableName() string {
|
||||
return "p_parseTableName"
|
||||
}
|
||||
|
||||
func TestParseTableName(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"xorm",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.SnakeMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
table, err := parser.Parse(reflect.ValueOf(new(ParseTableName1)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "parse_table_name1", table.Name)
|
||||
|
||||
table, err = parser.Parse(reflect.ValueOf(new(ParseTableName2)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "p_parseTableName", table.Name)
|
||||
}
|
|
@ -280,7 +280,7 @@ func ExtendsTagHandler(ctx *Context) error {
|
|||
isPtr = true
|
||||
fallthrough
|
||||
case reflect.Struct:
|
||||
parentTable, err := ctx.parser.mapType(fieldValue)
|
||||
parentTable, err := ctx.parser.Parse(fieldValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -871,7 +871,7 @@ func TestAutoIncrTag(t *testing.T) {
|
|||
func TestTagComment(t *testing.T) {
|
||||
assert.NoError(t, prepareEngine())
|
||||
// FIXME: only support mysql
|
||||
if testEngine.Dialect().DriverName() != schemas.MYSQL {
|
||||
if testEngine.Dialect().DBType() != schemas.MYSQL {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ func createEngine(dbType, connStr string) error {
|
|||
var err error
|
||||
|
||||
if !*cluster {
|
||||
switch strings.ToLower(dbType) {
|
||||
switch schemas.DBType(strings.ToLower(dbType)) {
|
||||
case schemas.MSSQL:
|
||||
db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1))
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue