From ec2c685583e4c4fd985e0a08b44fa94ac08d8241 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 27 Nov 2013 15:53:05 +0800 Subject: [PATCH] add PrefixMapper & SuffixMapper --- base_test.go | 132 +++++++++++++++++++++++++++++++++++++++------ docs/QuickStart.md | 2 +- engine.go | 23 ++++++-- mapper.go | 36 +++++++++++++ session.go | 42 +++++++++++++-- statement.go | 4 +- table.go | 22 ++++++-- xorm.go | 3 +- 8 files changed, 232 insertions(+), 32 deletions(-) diff --git a/base_test.go b/base_test.go index e7353aec..17cfedf1 100644 --- a/base_test.go +++ b/base_test.go @@ -429,6 +429,13 @@ func find(engine *Engine, t *testing.T) { for _, user := range users { fmt.Println(user) } + + users2 := make([]Userinfo, 0) + err = engine.Sql("select * from userinfo").Find(&users2) + if err != nil { + t.Error(err) + panic(err) + } } func find2(engine *Engine, t *testing.T) { @@ -1439,57 +1446,76 @@ type Version struct { } func testVersion(engine *Engine, t *testing.T) { - err := engine.DropTables(new(Version)) + /*err := engine.DropTables(new(Version)) if err != nil { t.Error(err) - return + panic(err) } err = engine.CreateTables(new(Version)) if err != nil { t.Error(err) - return + panic(err) } ver := &Version{Name: "sfsfdsfds"} - _, err = engine.Cols("name").Insert(ver) + _, err = engine.Insert(ver) if err != nil { t.Error(err) - return + panic(err) + } + fmt.Println(ver) + if ver.Ver != 1 { + err = errors.New("insert error") + t.Error(err) + panic(err) } newVer := new(Version) has, err := engine.Id(ver.Id).Get(newVer) if err != nil { t.Error(err) - return + panic(err) } if !has { t.Error(errors.New(fmt.Sprintf("no version id is %v", ver.Id))) - return + panic(err) + } + fmt.Println(newVer) + if newVer.Ver != 1 { + err = errors.New("insert error") + t.Error(err) + panic(err) } newVer.Name = "-------" - _, err = engine.Id(ver.Id).Update(newVer, &Version{Ver: newVer.Ver}) + _, err = engine.Id(ver.Id).Update(newVer) if err != nil { t.Error(err) - return + panic(err) } + newVer = new(Version) has, err = engine.Id(ver.Id).Get(newVer) if err != nil { t.Error(err) - return + panic(err) } - fmt.Println(ver) - - newVer.Name = "-------" - _, err = engine.Id(ver.Id).Update(newVer, &Version{Ver: newVer.Ver}) - if err != nil { + fmt.Println(newVer) + if newVer.Ver != 2 { + err = errors.New("insert error") t.Error(err) - return - } + panic(err) + }*/ + + /* + newVer.Name = "-------" + _, err = engine.Id(ver.Id).Update(newVer) + if err != nil { + t.Error(err) + return + }*/ } func testDistinct(engine *Engine, t *testing.T) { @@ -1603,6 +1629,74 @@ func testBool(engine *Engine, t *testing.T) { } } +type TTime struct { + Id int64 + T time.Time +} + +func testTime(engine *Engine, t *testing.T) { + err := engine.Sync(&TTime{}) + if err != nil { + t.Error(err) + panic(err) + } + + tt := &TTime{} + _, err = engine.Insert(tt) + if err != nil { + t.Error(err) + panic(err) + } + + tt2 := &TTime{Id: tt.Id} + has, err := engine.Get(tt2) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("no record error") + t.Error(err) + panic(err) + } + + tt3 := &TTime{T: time.Now()} + _, err = engine.Insert(tt3) + if err != nil { + t.Error(err) + panic(err) + } +} + +func testPrefixTableName(engine *Engine, t *testing.T) { + tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) + if err != nil { + t.Error(err) + panic(err) + } + tempEngine.ShowSQL = true + mapper := NewPrefixMapper(SnakeMapper{}, "xlw_") + //tempEngine.SetMapper(mapper) + tempEngine.SetTableMapper(mapper) + exist, err := tempEngine.IsTableExist(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } + if exist { + err = tempEngine.DropTables(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } + } + err = tempEngine.CreateTables(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -1693,6 +1787,10 @@ func testAll2(engine *Engine, t *testing.T) { testUseBool(engine, t) fmt.Println("-------------- testBool --------------") testBool(engine, t) + fmt.Println("-------------- testTime --------------") + testTime(engine, t) + fmt.Println("-------------- testPrefixTableName --------------") + testPrefixTableName(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) } diff --git a/docs/QuickStart.md b/docs/QuickStart.md index 63cb6dfd..8a58c8a4 100644 --- a/docs/QuickStart.md +++ b/docs/QuickStart.md @@ -120,7 +120,7 @@ engine.Mapper = SameMapper{} 如果所有的命名都是按照IMapper的映射来操作的,那当然是最理想的。但是如果碰到某个表名或者某个字段名跟映射规则不匹配时,我们就需要别的机制来改变。 -通过`engine.Table()`方法可以改变struct对应的数据库表的名称,通过sturct中field对应的Tag中使用`xorm:"'table_name'"`可以使该field对应的Column名称为指定名称。这里使用两个单引号将Column名称括起来是为了防止名称冲突,因为我们在Tag中还可以对这个Column进行更多的定义。如果名称不冲突的情况,单引号也可以不使用。 +通过`engine.Table()`方法可以改变struct对应的数据库表的名称,通过sturct中field对应的Tag中使用`xorm:"'column_name'"`可以使该field对应的Column名称为指定名称。这里使用两个单引号将Column名称括起来是为了防止名称冲突,因为我们在Tag中还可以对这个Column进行更多的定义。如果名称不冲突的情况,单引号也可以不使用。 ### 2.3.Column属性定义 diff --git a/engine.go b/engine.go index 580d76af..570f221d 100644 --- a/engine.go +++ b/engine.go @@ -40,7 +40,8 @@ type dialect interface { // Engine is the major struct of xorm, it means a database manager. // Commonly, an application only need one engine type Engine struct { - Mapper IMapper + columnMapper IMapper + tableMapper IMapper TagIdentifier string DriverName string DataSourceName string @@ -58,6 +59,19 @@ type Engine struct { UseCache bool } +func (engine *Engine) SetMapper(mapper IMapper) { + engine.SetTableMapper(mapper) + engine.SetColumnMapper(mapper) +} + +func (engine *Engine) SetTableMapper(mapper IMapper) { + engine.tableMapper = mapper +} + +func (engine *Engine) SetColumnMapper(mapper IMapper) { + engine.columnMapper = mapper +} + // If engine's database support batch insert records like // "insert into user values (name, age), (name, age)". // When the return is ture, then engine.Insert(&users) will @@ -396,13 +410,14 @@ func (engine *Engine) newTable() *Table { table.Indexes = make(map[string]*Index) table.Columns = make(map[string]*Column) table.ColumnsSeq = make([]string, 0) + table.Created = make(map[string]bool) table.Cacher = engine.Cacher return table } func (engine *Engine) mapType(t reflect.Type) *Table { table := engine.newTable() - table.Name = engine.Mapper.Obj2Table(t.Name()) + table.Name = engine.tableMapper.Obj2Table(t.Name()) table.Type = t var idFieldColName string @@ -510,7 +525,7 @@ func (engine *Engine) mapType(t reflect.Type) *Table { col.Length2 = col.SQLType.DefaultLength2 } if col.Name == "" { - col.Name = engine.Mapper.Obj2Table(t.Field(i).Name) + col.Name = engine.columnMapper.Obj2Table(t.Field(i).Name) } if indexType == IndexType { if indexName == "" { @@ -542,7 +557,7 @@ func (engine *Engine) mapType(t reflect.Type) *Table { } } else { sqlType := Type2SQLType(fieldType) - col = &Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, + col = &Column{engine.columnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true, "", make(map[string]bool), false, false, TWOSIDES, false, false, false, false} } diff --git a/mapper.go b/mapper.go index 9bcb544b..f98f0f9a 100644 --- a/mapper.go +++ b/mapper.go @@ -92,3 +92,39 @@ func titleCasedName(name string) string { func (mapper SnakeMapper) Table2Obj(name string) string { return titleCasedName(name) } + +// provide prefix table name support +type PrefixMapper struct { + Mapper IMapper + Prefix string +} + +func (mapper PrefixMapper) Obj2Table(name string) string { + return mapper.Prefix + mapper.Mapper.Obj2Table(name) +} + +func (mapper PrefixMapper) Table2Obj(name string) string { + return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) +} + +func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper { + return PrefixMapper{mapper, prefix} +} + +// provide suffix table name support +type SuffixMapper struct { + Mapper IMapper + Suffix string +} + +func (mapper SuffixMapper) Obj2Table(name string) string { + return mapper.Suffix + mapper.Mapper.Obj2Table(name) +} + +func (mapper SuffixMapper) Table2Obj(name string) string { + return mapper.Mapper.Table2Obj(name[len(mapper.Suffix):]) +} + +func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper { + return SuffixMapper{mapper, suffix} +} diff --git a/session.go b/session.go index 8edabf41..ddb3e2e9 100644 --- a/session.go +++ b/session.go @@ -1706,6 +1706,13 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.cacheInsert(session.Statement.TableName()) } + if table.Version != "" && session.Statement.checkVersion { + verValue := table.VersionColumn().ValueOf(bean) + if verValue.IsValid() && verValue.CanSet() { + verValue.SetInt(1) + } + } + if table.PrimaryKey == "" { return res.RowsAffected() } @@ -1742,6 +1749,13 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.cacheInsert(session.Statement.TableName()) } + if table.Version != "" && session.Statement.checkVersion { + verValue := table.VersionColumn().ValueOf(bean) + if verValue.IsValid() && verValue.CanSet() { + verValue.SetInt(1) + } + } + if len(res) < 1 { return 0, errors.New("insert no error but not returned id") } @@ -1916,7 +1930,13 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { if col, ok := table.Columns[colName]; ok { fieldValue := col.ValueOf(bean) session.Engine.LogDebug("[xorm:cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) - fieldValue.Set(reflect.ValueOf(args[idx])) + if col.IsVersion && session.Statement.checkVersion { + fieldValue.SetInt(fieldValue.Int() + 1) + fmt.Println("-----", fieldValue) + } else { + fieldValue.Set(reflect.ValueOf(args[idx])) + fmt.Println("xxxxxx", fieldValue) + } } else { session.Engine.LogError("[xorm:cacheUpdate] ERROR: column %v is not table %v's", colName, table.Name) @@ -2001,27 +2021,38 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 st := session.Statement defer session.Statement.Init() if st.WhereStr != "" { - condition = fmt.Sprintf("WHERE %v", st.WhereStr) + condition = fmt.Sprintf("%v", st.WhereStr) } if condition == "" { if len(condiColNames) > 0 { - condition = fmt.Sprintf("WHERE %v ", strings.Join(condiColNames, " and ")) + condition = fmt.Sprintf("%v", strings.Join(condiColNames, " AND ")) } } else { if len(condiColNames) > 0 { - condition = fmt.Sprintf("%v and %v", condition, strings.Join(condiColNames, " and ")) + condition = fmt.Sprintf("(%v) AND (%v)", condition, strings.Join(condiColNames, " AND ")) } } var sql string - if table.Version != "" { + if table.Version != "" && session.Statement.checkVersion { + if condition != "" { + condition = fmt.Sprintf("WHERE (%v) AND %v = ?", condition, + session.Engine.Quote(table.Version)) + } else { + condition = fmt.Sprintf("WHERE %v = ?", session.Engine.Quote(table.Version)) + } sql = fmt.Sprintf("UPDATE %v SET %v, %v %v", session.Engine.Quote(session.Statement.TableName()), strings.Join(colNames, ", "), session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1", condition) + + condiArgs = append(condiArgs, table.VersionColumn().ValueOf(bean).Interface()) } else { + if condition != "" { + condition = "WHERE " + condition + } sql = fmt.Sprintf("UPDATE %v SET %v %v", session.Engine.Quote(session.Statement.TableName()), strings.Join(colNames, ", "), @@ -2034,6 +2065,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if err != nil { return 0, err } + if table.Cacher != nil && session.Statement.UseCache { session.cacheUpdate(sql, args...) } diff --git a/statement.go b/statement.go index 4dd66084..1dc06dc6 100644 --- a/statement.go +++ b/statement.go @@ -37,6 +37,7 @@ type Statement struct { UseAutoTime bool IsDistinct bool allUseBool bool + checkVersion bool boolColumnMap map[string]bool } @@ -65,6 +66,7 @@ func (statement *Statement) Init() { statement.IsDistinct = false statement.allUseBool = false statement.boolColumnMap = make(map[string]bool) + statement.checkVersion = true } // add the raw sql statement @@ -164,7 +166,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers case reflect.Struct: if fieldType == reflect.TypeOf(time.Now()) { t := fieldValue.Interface().(time.Time) - if t.IsZero() { + if t.IsZero() || !fieldValue.IsValid() { continue } var str string diff --git a/table.go b/table.go index beee9912..fdd62bfd 100644 --- a/table.go +++ b/table.go @@ -272,17 +272,31 @@ type Table struct { Columns map[string]*Column Indexes map[string]*Index PrimaryKey string - Created string + Created map[string]bool Updated string Version string Cacher Cacher } +/* +func NewTable(name string, t reflect.Type) *Table { + return &Table{Name: name, Type: t, + ColumnsSeq: make([]string, 0), + Columns: make(map[string]*Column), + Indexes: make(map[string]*Index), + Created: make(map[string]bool), + } +}*/ + // if has primary key, return column func (table *Table) PKColumn() *Column { return table.Columns[table.PrimaryKey] } +func (table *Table) VersionColumn() *Column { + return table.Columns[table.Version] +} + // add a column to table func (table *Table) AddColumn(col *Column) { table.ColumnsSeq = append(table.ColumnsSeq, col.Name) @@ -291,7 +305,7 @@ func (table *Table) AddColumn(col *Column) { table.PrimaryKey = col.Name } if col.IsCreated { - table.Created = col.Name + table.Created[col.Name] = true } if col.IsUpdated { table.Updated = col.Name @@ -311,7 +325,7 @@ func (table *Table) genCols(session *Session, bean interface{}, useCol bool, inc args := make([]interface{}, 0) for _, col := range table.Columns { - if useCol { + if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { if _, ok := session.Statement.columnMap[col.Name]; !ok { continue } @@ -338,6 +352,8 @@ func (table *Table) genCols(session *Session, bean interface{}, useCol bool, inc if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { args = append(args, time.Now()) + } else if col.IsVersion && session.Statement.checkVersion { + args = append(args, 1) } else { arg, err := session.value2Interface(col, fieldValue) if err != nil { diff --git a/xorm.go b/xorm.go index 80d17fe5..9f11c45b 100644 --- a/xorm.go +++ b/xorm.go @@ -20,8 +20,9 @@ func close(engine *Engine) { // new a db manager according to the parameter. Currently support four // drivers func NewEngine(driverName string, dataSourceName string) (*Engine, error) { - engine := &Engine{DriverName: driverName, Mapper: SnakeMapper{}, + engine := &Engine{DriverName: driverName, DataSourceName: dataSourceName, Filters: make([]Filter, 0)} + engine.SetMapper(SnakeMapper{}) if driverName == SQLITE { engine.dialect = &sqlite3{}