From 0f191f3e28c3a18ef6bbe21d585cc8ca62f77a0a Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 26 Feb 2020 08:51:37 +0800 Subject: [PATCH] fix tests --- engine.go | 12 +++++------- engine_group.go | 8 ++++---- engine_table.go | 6 +++--- tags/parser.go | 21 +++++++++++---------- tags/tag_test.go | 26 ++++++++++++++++++++++++++ tags_test.go | 19 ------------------- types_test.go | 18 ++++++++++-------- xorm.go | 4 ++-- 8 files changed, 61 insertions(+), 53 deletions(-) create mode 100644 tags/tag_test.go diff --git a/engine.go b/engine.go index 0de233f5..b97d1c06 100644 --- a/engine.go +++ b/engine.go @@ -37,9 +37,7 @@ type Engine struct { db *core.DB dialect dialects.Dialect - ColumnMapper names.Mapper - TableMapper names.Mapper - Tables map[reflect.Type]*schemas.Table + Tables map[reflect.Type]*schemas.Table mutex *sync.RWMutex @@ -151,12 +149,12 @@ func (engine *Engine) SetMapper(mapper names.Mapper) { // SetTableMapper set the table name mapping rule func (engine *Engine) SetTableMapper(mapper names.Mapper) { - engine.TableMapper = mapper + engine.tagParser.TableMapper = mapper } // SetColumnMapper set the column name mapping rule func (engine *Engine) SetColumnMapper(mapper names.Mapper) { - engine.ColumnMapper = mapper + engine.tagParser.ColumnMapper = mapper } // SupportInsertMany If engine's database support batch insert records like @@ -1333,12 +1331,12 @@ func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{} // GetColumnMapper returns the column name mapper func (engine *Engine) GetColumnMapper() names.Mapper { - return engine.ColumnMapper + return engine.tagParser.ColumnMapper } // GetTableMapper returns the table name mapper func (engine *Engine) GetTableMapper() names.Mapper { - return engine.TableMapper + return engine.tagParser.TableMapper } // GetTZLocation returns time zone of the application diff --git a/engine_group.go b/engine_group.go index 24dc0103..55159d55 100644 --- a/engine_group.go +++ b/engine_group.go @@ -112,9 +112,9 @@ func (eg *EngineGroup) Ping() error { // SetColumnMapper set the column name mapping rule func (eg *EngineGroup) SetColumnMapper(mapper names.Mapper) { - eg.Engine.ColumnMapper = mapper + eg.Engine.SetColumnMapper(mapper) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].ColumnMapper = mapper + eg.slaves[i].SetColumnMapper(mapper) } } @@ -182,9 +182,9 @@ func (eg *EngineGroup) SetPolicy(policy GroupPolicy) *EngineGroup { // SetTableMapper set the table name mapping rule func (eg *EngineGroup) SetTableMapper(mapper names.Mapper) { - eg.Engine.TableMapper = mapper + eg.Engine.SetTableMapper(mapper) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].TableMapper = mapper + eg.slaves[i].SetTableMapper(mapper) } } diff --git a/engine_table.go b/engine_table.go index 33da67bb..0954b2d3 100644 --- a/engine_table.go +++ b/engine_table.go @@ -78,7 +78,7 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string { v := rValue(f) t := v.Type() if t.Kind() == reflect.Struct { - table = names.GetTableName(engine.TableMapper, v) + table = names.GetTableName(engine.GetTableMapper(), v) } else { table = engine.Quote(fmt.Sprintf("%v", f)) } @@ -96,12 +96,12 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string { return tablename.(string) case reflect.Value: v := tablename.(reflect.Value) - return names.GetTableName(engine.TableMapper, v) + return names.GetTableName(engine.GetTableMapper(), v) default: v := rValue(tablename) t := v.Type() if t.Kind() == reflect.Struct { - return names.GetTableName(engine.TableMapper, v) + return names.GetTableName(engine.GetTableMapper(), v) } return engine.Quote(fmt.Sprintf("%v", tablename)) } diff --git a/tags/parser.go b/tags/parser.go index 2602a916..15dcaa30 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -19,19 +19,20 @@ import ( ) type Parser struct { - identifier string - dialect dialects.Dialect - tableMapper, columnMapper names.Mapper - handlers map[string]Handler - cacherMgr *caches.Manager + identifier string + dialect dialects.Dialect + ColumnMapper names.Mapper + TableMapper names.Mapper + handlers map[string]Handler + cacherMgr *caches.Manager } func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnMapper names.Mapper, cacherMgr *caches.Manager) *Parser { return &Parser{ identifier: identifier, dialect: dialect, - tableMapper: tableMapper, - columnMapper: columnMapper, + TableMapper: tableMapper, + ColumnMapper: columnMapper, handlers: defaultTagHandlers, cacherMgr: cacherMgr, } @@ -53,7 +54,7 @@ func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) { t := v.Type() table := schemas.NewEmptyTable() table.Type = t - table.Name = names.GetTableName(parser.tableMapper, v) + table.Name = names.GetTableName(parser.TableMapper, v) var idFieldColName string var hasCacheTag, hasNoCacheTag bool @@ -170,7 +171,7 @@ func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) { col.Length2 = col.SQLType.DefaultLength2 } if col.Name == "" { - col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name) + col.Name = parser.ColumnMapper.Obj2Table(t.Field(i).Name) } if ctx.isUnique { @@ -195,7 +196,7 @@ func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) { } else { sqlType = schemas.Type2SQLType(fieldType) } - col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name), + col = schemas.NewColumn(parser.ColumnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true) diff --git a/tags/tag_test.go b/tags/tag_test.go new file mode 100644 index 00000000..f4a79379 --- /dev/null +++ b/tags/tag_test.go @@ -0,0 +1,26 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tags + +import "testing" + +func TestSplitTag(t *testing.T) { + var cases = []struct { + tag string + tags []string + }{ + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, + {"TEXT", []string{"TEXT"}}, + {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, + {"json binary", []string{"json", "binary"}}, + } + + for _, kase := range cases { + tags := splitTag(kase.tag) + if !sliceEq(tags, kase.tags) { + t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) + } + } +} diff --git a/tags_test.go b/tags_test.go index 9d9bf83e..2d90948b 100644 --- a/tags_test.go +++ b/tags_test.go @@ -1223,25 +1223,6 @@ func TestTagTime(t *testing.T) { strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) } -func TestSplitTag(t *testing.T) { - var cases = []struct { - tag string - tags []string - }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, - } - - for _, kase := range cases { - tags := splitTag(kase.tag) - if !sliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } - } -} - func TestTagAutoIncr(t *testing.T) { assert.NoError(t, prepareEngine()) diff --git a/types_test.go b/types_test.go index 1e21907c..53872372 100644 --- a/types_test.go +++ b/types_test.go @@ -9,8 +9,10 @@ import ( "fmt" "testing" - "github.com/stretchr/testify/assert" + "xorm.io/xorm/convert" "xorm.io/xorm/schemas" + + "github.com/stretchr/testify/assert" ) func TestArrayField(t *testing.T) { @@ -137,8 +139,8 @@ type ConvStruct struct { Conv ConvString Conv2 *ConvString Cfg1 ConvConfig - Cfg2 *ConvConfig `xorm:"TEXT"` - Cfg3 Conversion `xorm:"BLOB"` + Cfg2 *ConvConfig `xorm:"TEXT"` + Cfg3 convert.Conversion `xorm:"BLOB"` Slice SliceType } @@ -267,11 +269,11 @@ type Status struct { } var ( - _ Conversion = &Status{} - Registered Status = Status{"Registered", "white"} - Approved Status = Status{"Approved", "green"} - Removed Status = Status{"Removed", "red"} - Statuses map[string]Status = map[string]Status{ + _ convert.Conversion = &Status{} + Registered Status = Status{"Registered", "white"} + Approved Status = Status{"Approved", "green"} + Removed Status = Status{"Removed", "red"} + Statuses map[string]Status = map[string]Status{ Registered.Name: Registered, Approved.Name: Approved, Removed.Name: Removed, diff --git a/xorm.go b/xorm.go index b5d701c0..f3230aa1 100644 --- a/xorm.go +++ b/xorm.go @@ -80,8 +80,8 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { logger := log.NewSimpleLogger(os.Stdout) logger.SetLevel(log.LOG_INFO) engine.SetLogger(logger) - engine.SetMapper(names.NewCacheMapper(new(names.SnakeMapper))) - engine.tagParser = tags.NewParser("xorm", dialect, engine.TableMapper, engine.ColumnMapper, engine.cacherMgr) + mapper := names.NewCacheMapper(new(names.SnakeMapper)) + engine.tagParser = tags.NewParser("xorm", dialect, mapper, mapper, engine.cacherMgr) runtime.SetFinalizer(engine, close)