diff --git a/engine.go b/engine.go index d0a9dfbe..134e6b14 100644 --- a/engine.go +++ b/engine.go @@ -44,6 +44,8 @@ type Engine struct { DatabaseTZ *time.Location // The timezone of the database disableGlobalCache bool + + tagHandlers map[string]tagHandler } // ShowSQL show SQL statement or not on logger if log level is great than INFO @@ -780,13 +782,18 @@ func (engine *Engine) autoMapType(v reflect.Value) *core.Table { defer engine.mutex.Unlock() table, ok := engine.Tables[t] if !ok { - table = engine.mapType(v) - engine.Tables[t] = table - if engine.Cacher != nil { - if v.CanAddr() { - engine.GobRegister(v.Addr().Interface()) - } else { - engine.GobRegister(v.Interface()) + var err error + table, err = engine.mapType(v) + if err != nil { + engine.logger.Error(err) + } else { + engine.Tables[t] = table + if engine.Cacher != nil { + if v.CanAddr() { + engine.GobRegister(v.Addr().Interface()) + } else { + engine.GobRegister(v.Interface()) + } } } } @@ -842,7 +849,7 @@ var ( tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() ) -func (engine *Engine) mapType(v reflect.Value) *core.Table { +func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { t := v.Type() table := engine.newTable() if tb, ok := v.Interface().(TableName); ok { @@ -861,7 +868,6 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { table.Type = t var idFieldColName string - var err error var hasCacheTag, hasNoCacheTag bool for i := 0; i < t.NumField(); i++ { @@ -881,186 +887,94 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { if tags[0] == "-" { continue } + + var ctx = tagContext{ + table: table, + col: col, + fieldValue: fieldValue, + indexNames: make(map[string]int), + engine: engine, + } + if strings.ToUpper(tags[0]) == "EXTENDS" { - switch fieldValue.Kind() { - case reflect.Ptr: - f := fieldValue.Type().Elem() - if f.Kind() == reflect.Struct { - fieldPtr := fieldValue - fieldValue = fieldValue.Elem() - if !fieldValue.IsValid() || fieldPtr.IsNil() { - fieldValue = reflect.New(f).Elem() - } - } - fallthrough - case reflect.Struct: - parentTable := engine.mapType(fieldValue) - for _, col := range parentTable.Columns() { - col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName) - table.AddColumn(col) - for indexName, indexType := range col.Indexes { - addIndex(indexName, table, col, indexType) - } - } - continue - default: - //TODO: warning + if err := ExtendsTagHandler(&ctx); err != nil { + return nil, err } + continue } - indexNames := make(map[string]int) - var isIndex, isUnique bool - var preKey string for j, key := range tags { + if ctx.ignoreNext { + ctx.ignoreNext = false + continue + } + k := strings.ToUpper(key) - switch { - case k == "<-": - col.MapType = core.ONLYFROMDB - case k == "->": - col.MapType = core.ONLYTODB - case k == "PK": - col.IsPrimaryKey = true - col.Nullable = false - case k == "NULL": - if j == 0 { - col.Nullable = true - } else { - col.Nullable = (strings.ToUpper(tags[j-1]) != "NOT") - } - // TODO: for postgres how add autoincr? - /*case strings.HasPrefix(k, "AUTOINCR(") && strings.HasSuffix(k, ")"): - col.IsAutoIncrement = true + ctx.tagName = k - autoStart := k[len("AUTOINCR")+1 : len(k)-1] - autoStartInt, err := strconv.Atoi(autoStart) - if err != nil { - engine.LogError(err) + pStart := strings.Index(k, "(") + if pStart == 0 { + return nil, errors.New("( could not be the first charactor") } - col.AutoIncrStart = autoStartInt*/ - case k == "AUTOINCR": - col.IsAutoIncrement = true - //col.AutoIncrStart = 1 - case k == "DEFAULT": - col.Default = tags[j+1] - case k == "CREATED": - col.IsCreated = true - case k == "VERSION": - col.IsVersion = true - col.Default = "1" - case k == "UTC": - col.TimeZone = time.UTC - case k == "LOCAL": - col.TimeZone = time.Local - case strings.HasPrefix(k, "LOCALE(") && strings.HasSuffix(k, ")"): - location := k[len("LOCALE")+1 : len(k)-1] - col.TimeZone, err = time.LoadLocation(location) - if err != nil { - engine.logger.Error(err) + if pStart > -1 { + if !strings.HasSuffix(k, ")") { + return nil, errors.New("cannot match ) charactor") } - case k == "UPDATED": - col.IsUpdated = true - case k == "DELETED": - col.IsDeleted = true - case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"): - indexName := k[len("INDEX")+1 : len(k)-1] - indexNames[indexName] = core.IndexType - case k == "INDEX": - isIndex = true - case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"): - indexName := k[len("UNIQUE")+1 : len(k)-1] - indexNames[indexName] = core.UniqueType - case k == "UNIQUE": - isUnique = true - case k == "NOTNULL": - col.Nullable = false - case k == "CACHE": - if !hasCacheTag { - hasCacheTag = true - } - case k == "NOCACHE": - if !hasNoCacheTag { - hasNoCacheTag = true - } - case k == "NOT": - default: - if strings.HasPrefix(k, "'") && strings.HasSuffix(k, "'") { - if preKey != "DEFAULT" { - col.Name = key[1 : len(key)-1] - } - } else if strings.Contains(k, "(") && strings.HasSuffix(k, ")") { - fs := strings.Split(k, "(") - if _, ok := core.SqlTypes[fs[0]]; !ok { - preKey = k - continue - } - col.SQLType = core.SQLType{Name: fs[0]} - if fs[0] == core.Enum && fs[1][0] == '\'' { //enum - options := strings.Split(fs[1][0:len(fs[1])-1], ",") - col.EnumOptions = make(map[string]int) - for k, v := range options { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - col.EnumOptions[v] = k - } - } else if fs[0] == core.Set && fs[1][0] == '\'' { //set - options := strings.Split(fs[1][0:len(fs[1])-1], ",") - col.SetOptions = make(map[string]int) - for k, v := range options { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - col.SetOptions[v] = k - } - } else { - fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",") - if len(fs2) == 2 { - col.Length, err = strconv.Atoi(fs2[0]) - if err != nil { - engine.logger.Error(err) - } - col.Length2, err = strconv.Atoi(fs2[1]) - if err != nil { - engine.logger.Error(err) - } - } else if len(fs2) == 1 { - col.Length, err = strconv.Atoi(fs2[0]) - if err != nil { - engine.logger.Error(err) - } - } - } - } else { - if _, ok := core.SqlTypes[k]; ok { - col.SQLType = core.SQLType{Name: k} - } else if key != col.Default { - col.Name = key - } - } - engine.dialect.SqlType(col) + ctx.tagName = k[:pStart] + ctx.params = strings.Split(k[pStart+1:len(k)-1], ",") + } + + if j > 0 { + ctx.preTag = strings.ToUpper(tags[j-1]) + } + if j < len(tags)-1 { + ctx.nextTag = strings.ToUpper(tags[j+1]) + } else { + ctx.nextTag = "" + } + + if h, ok := engine.tagHandlers[ctx.tagName]; ok { + if err := h(&ctx); err != nil { + return nil, err + } + } else { + if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") { + col.Name = key[1 : len(key)-1] + } else { + col.Name = key + } + } + + if ctx.hasCacheTag { + hasCacheTag = true + } + if ctx.hasNoCacheTag { + hasNoCacheTag = true } - preKey = k } + if col.SQLType.Name == "" { col.SQLType = core.Type2SQLType(fieldType) } + engine.dialect.SqlType(col) if col.Length == 0 { col.Length = col.SQLType.DefaultLength } if col.Length2 == 0 { col.Length2 = col.SQLType.DefaultLength2 } - if col.Name == "" { col.Name = engine.ColumnMapper.Obj2Table(t.Field(i).Name) } - if isUnique { - indexNames[col.Name] = core.UniqueType - } else if isIndex { - indexNames[col.Name] = core.IndexType + if ctx.isUnique { + ctx.indexNames[col.Name] = core.UniqueType + } else if ctx.isIndex { + ctx.indexNames[col.Name] = core.IndexType } - for indexName, indexType := range indexNames { + for indexName, indexType := range ctx.indexNames { addIndex(indexName, table, col, indexType) } } @@ -1114,7 +1028,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { table.Cacher = nil } - return table + return table, nil } // IsTableEmpty if a table has any reocrd diff --git a/session_schema.go b/session_schema.go index 9011adad..21fa2996 100644 --- a/session_schema.go +++ b/session_schema.go @@ -306,7 +306,10 @@ func (session *Session) Sync2(beans ...interface{}) error { for _, bean := range beans { v := rValue(bean) - table := engine.mapType(v) + table, err := engine.mapType(v) + if err != nil { + return err + } structTables = append(structTables, table) var tbName = session.tbNameNoSchema(table) diff --git a/statement_test.go b/statement_test.go index eba4f698..cb3730ef 100644 --- a/statement_test.go +++ b/statement_test.go @@ -26,11 +26,9 @@ var colStrTests = []struct { } func TestColumnsStringGeneration(t *testing.T) { - var statement *Statement for ndx, testCase := range colStrTests { - statement = createTestStatement() if testCase.omitColumn != "" { @@ -54,7 +52,6 @@ func TestColumnsStringGeneration(t *testing.T) { } func BenchmarkColumnsStringGeneration(b *testing.B) { - b.StopTimer() statement := createTestStatement() diff --git a/tag.go b/tag.go new file mode 100644 index 00000000..4b0e3f54 --- /dev/null +++ b/tag.go @@ -0,0 +1,281 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "time" + + "github.com/go-xorm/core" +) + +type tagContext struct { + tagName string + params []string + preTag, nextTag string + table *core.Table + col *core.Column + fieldValue reflect.Value + isIndex bool + isUnique bool + indexNames map[string]int + engine *Engine + hasCacheTag bool + hasNoCacheTag bool + ignoreNext bool +} + +// tagHandler describes tag handler for XORM +type tagHandler func(ctx *tagContext) error + +var ( + // defaultTagHandlers enumerates all the default tag handler + defaultTagHandlers = map[string]tagHandler{ + "<-": OnlyFromDBTagHandler, + "->": OnlyToDBTagHandler, + "PK": PKTagHandler, + "NULL": NULLTagHandler, + "NOT": IgnoreTagHandler, + "AUTOINCR": AutoIncrTagHandler, + "DEFAULT": DefaultTagHandler, + "CREATED": CreatedTagHandler, + "UPDATED": UpdatedTagHandler, + "DELETED": DeletedTagHandler, + "VERSION": VersionTagHandler, + "UTC": UTCTagHandler, + "LOCAL": LocalTagHandler, + "NOTNULL": NotNullTagHandler, + "INDEX": IndexTagHandler, + "UNIQUE": UniqueTagHandler, + "CACHE": CacheTagHandler, + "NOCACHE": NoCacheTagHandler, + } +) + +func init() { + for k := range core.SqlTypes { + defaultTagHandlers[k] = SQLTypeTagHandler + } +} + +// IgnoreTagHandler describes ignored tag handler +func IgnoreTagHandler(ctx *tagContext) error { + return nil +} + +// OnlyFromDBTagHandler describes mapping direction tag handler +func OnlyFromDBTagHandler(ctx *tagContext) error { + ctx.col.MapType = core.ONLYFROMDB + return nil +} + +// OnlyToDBTagHandler describes mapping direction tag handler +func OnlyToDBTagHandler(ctx *tagContext) error { + ctx.col.MapType = core.ONLYTODB + return nil +} + +// PKTagHandler decribes primary key tag handler +func PKTagHandler(ctx *tagContext) error { + ctx.col.IsPrimaryKey = true + ctx.col.Nullable = false + return nil +} + +// NULLTagHandler describes null tag handler +func NULLTagHandler(ctx *tagContext) error { + ctx.col.Nullable = (strings.ToUpper(ctx.preTag) != "NOT") + return nil +} + +// NotNullTagHandler describes notnull tag handler +func NotNullTagHandler(ctx *tagContext) error { + ctx.col.Nullable = false + return nil +} + +// AutoIncrTagHandler describes autoincr tag handler +func AutoIncrTagHandler(ctx *tagContext) error { + ctx.col.IsAutoIncrement = true + /* + if len(ctx.params) > 0 { + autoStartInt, err := strconv.Atoi(ctx.params[0]) + if err != nil { + return err + } + ctx.col.AutoIncrStart = autoStartInt + } else { + ctx.col.AutoIncrStart = 1 + } + */ + return nil +} + +// DefaultTagHandler describes default tag handler +func DefaultTagHandler(ctx *tagContext) error { + if len(ctx.params) > 0 { + ctx.col.Default = ctx.params[0] + } else { + ctx.col.Default = ctx.nextTag + ctx.ignoreNext = true + } + return nil +} + +// CreatedTagHandler describes created tag handler +func CreatedTagHandler(ctx *tagContext) error { + ctx.col.IsCreated = true + return nil +} + +// VersionTagHandler describes version tag handler +func VersionTagHandler(ctx *tagContext) error { + ctx.col.IsVersion = true + ctx.col.Default = "1" + return nil +} + +// UTCTagHandler describes utc tag handler +func UTCTagHandler(ctx *tagContext) error { + ctx.col.TimeZone = time.UTC + return nil +} + +// LocalTagHandler describes local tag handler +func LocalTagHandler(ctx *tagContext) error { + if len(ctx.params) == 0 { + ctx.col.TimeZone = time.Local + } else { + var err error + ctx.col.TimeZone, err = time.LoadLocation(ctx.params[0]) + if err != nil { + return err + } + } + return nil +} + +// UpdatedTagHandler describes updated tag handler +func UpdatedTagHandler(ctx *tagContext) error { + ctx.col.IsUpdated = true + return nil +} + +// DeletedTagHandler describes deleted tag handler +func DeletedTagHandler(ctx *tagContext) error { + ctx.col.IsDeleted = true + return nil +} + +// IndexTagHandler describes index tag handler +func IndexTagHandler(ctx *tagContext) error { + if len(ctx.params) > 0 { + ctx.indexNames[ctx.params[0]] = core.IndexType + } else { + ctx.isIndex = true + } + return nil +} + +// UniqueTagHandler describes unique tag handler +func UniqueTagHandler(ctx *tagContext) error { + if len(ctx.params) > 0 { + ctx.indexNames[ctx.params[0]] = core.UniqueType + } else { + ctx.isUnique = true + } + return nil +} + +// SQLTypeTagHandler describes SQL Type tag handler +func SQLTypeTagHandler(ctx *tagContext) error { + ctx.col.SQLType = core.SQLType{Name: ctx.tagName} + if len(ctx.params) > 0 { + if ctx.tagName == core.Enum { + ctx.col.EnumOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.EnumOptions[v] = k + } + } else if ctx.tagName == core.Set { + ctx.col.SetOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.SetOptions[v] = k + } + } else { + var err error + if len(ctx.params) == 2 { + ctx.col.Length, err = strconv.Atoi(ctx.params[0]) + if err != nil { + return err + } + ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) + if err != nil { + return err + } + } else if len(ctx.params) == 1 { + ctx.col.Length, err = strconv.Atoi(ctx.params[0]) + if err != nil { + return err + } + } + } + } + return nil +} + +// ExtendsTagHandler describes extends tag handler +func ExtendsTagHandler(ctx *tagContext) error { + var fieldValue = ctx.fieldValue + switch fieldValue.Kind() { + case reflect.Ptr: + f := fieldValue.Type().Elem() + if f.Kind() == reflect.Struct { + fieldPtr := fieldValue + fieldValue = fieldValue.Elem() + if !fieldValue.IsValid() || fieldPtr.IsNil() { + fieldValue = reflect.New(f).Elem() + } + } + fallthrough + case reflect.Struct: + parentTable, err := ctx.engine.mapType(fieldValue) + if err != nil { + return err + } + for _, col := range parentTable.Columns() { + col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) + ctx.table.AddColumn(col) + for indexName, indexType := range col.Indexes { + addIndex(indexName, ctx.table, col, indexType) + } + } + default: + //TODO: warning + } + return nil +} + +// CacheTagHandler describes cache tag handler +func CacheTagHandler(ctx *tagContext) error { + if !ctx.hasCacheTag { + ctx.hasCacheTag = true + } + return nil +} + +// NoCacheTagHandler describes nocache tag handler +func NoCacheTagHandler(ctx *tagContext) error { + if !ctx.hasNoCacheTag { + ctx.hasNoCacheTag = true + } + return nil +} diff --git a/xorm.go b/xorm.go index 0d690804..2cfbe9ec 100644 --- a/xorm.go +++ b/xorm.go @@ -86,6 +86,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { mutex: &sync.RWMutex{}, TagIdentifier: "xorm", TZLocation: time.Local, + tagHandlers: defaultTagHandlers, } logger := NewSimpleLogger(os.Stdout)