diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..9ac0c261 --- /dev/null +++ b/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2013 - 2014 +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the {organization} nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index c5fb4ca7..4f85368b 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,8 @@ Or # Cases +* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs) + * [Gowalker](http://gowalker.org) - [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker) * [Gobuild.io](http://gobuild.io) - [github.com/shxsun/gobuild](http://github.com/shxsun/gobuild) @@ -86,11 +88,7 @@ Or * [Very Hour](http://veryhour.com/) -* [GoCMS](https://github.com/zzdboy/GoCMS) - -# Todo - -[Todo List](https://trello.com/b/IHsuAnhk/xorm) +* [GoCMS - github.com/zzboy/GoCMS](https://github.com/zzdboy/GoCMS) # Discuss @@ -106,4 +104,4 @@ If you want to pull request, please see [CONTRIBUTING](https://github.com/lunny/ # LICENSE BSD License - [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) + [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) \ No newline at end of file diff --git a/README_CN.md b/README_CN.md index d7e8de80..35b7ed70 100644 --- a/README_CN.md +++ b/README_CN.md @@ -79,6 +79,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 ## 案例 +* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs) + * [Gowalker](http://gowalker.org) - [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker) * [Gobuild.io](http://gobuild.io) - [github.com/shxsun/gobuild](http://github.com/shxsun/gobuild) @@ -89,12 +91,7 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 * [Very Hour](http://veryhour.com/) -* [GoCMS](https://github.com/zzdboy/GoCMS) - - -## Todo - -[开发计划](https://trello.com/b/IHsuAnhk/xorm) +* [GoCMS - github.com/zzboy/GoCMS](https://github.com/zzdboy/GoCMS) ## 讨论 diff --git a/base_test.go b/base_test.go index 92eb6290..f8d9d794 100644 --- a/base_test.go +++ b/base_test.go @@ -370,6 +370,104 @@ func update(engine *Engine, t *testing.T) { panic(err) return } + + type UpdateAllCols struct { + Id int64 + Bool bool + String string + } + + col1 := &UpdateAllCols{} + err = engine.Sync(col1) + if err != nil { + t.Error(err) + panic(err) + } + + _, err = engine.Insert(col1) + if err != nil { + t.Error(err) + panic(err) + } + + col2 := &UpdateAllCols{col1.Id, true, ""} + _, err = engine.Id(col2.Id).AllCols().Update(col2) + if err != nil { + t.Error(err) + panic(err) + } + + col3 := &UpdateAllCols{} + has, err := engine.Id(col2.Id).Get(col3) + if err != nil { + t.Error(err) + panic(err) + } + + if !has { + err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) + t.Error(err) + panic(err) + return + } + + if *col2 != *col3 { + err = errors.New(fmt.Sprintf("col2 should eq col3")) + t.Error(err) + panic(err) + return + } + + { + type UpdateMustCols struct { + Id int64 + Bool bool + String string + } + + col1 := &UpdateMustCols{} + err = engine.Sync(col1) + if err != nil { + t.Error(err) + panic(err) + } + + _, err = engine.Insert(col1) + if err != nil { + t.Error(err) + panic(err) + } + + col2 := &UpdateMustCols{col1.Id, true, ""} + boolStr := engine.columnMapper.Obj2Table("Bool") + stringStr := engine.columnMapper.Obj2Table("String") + _, err = engine.Id(col2.Id).MustCols(boolStr, stringStr).Update(col2) + if err != nil { + t.Error(err) + panic(err) + } + + col3 := &UpdateMustCols{} + has, err := engine.Id(col2.Id).Get(col3) + if err != nil { + t.Error(err) + panic(err) + } + + if !has { + err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) + t.Error(err) + panic(err) + return + } + + if *col2 != *col3 { + err = errors.New(fmt.Sprintf("col2 should eq col3")) + t.Error(err) + panic(err) + return + } + } } func updateSameMapper(engine *Engine, t *testing.T) { @@ -3892,6 +3990,28 @@ func testCompositeKey2(engine *Engine, t *testing.T) { } } +type CustomTableName struct { + Id int64 + Name string +} + +func (c *CustomTableName) TableName() string { + return "customtablename" +} + +func testCustomTableName(engine *Engine, t *testing.T) { + c := new(CustomTableName) + err := engine.DropTables(c) + if err != nil { + t.Error(err) + } + + err = engine.CreateTables(c) + if err != nil { + t.Error(err) + } +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -4002,6 +4122,8 @@ func testAll2(engine *Engine, t *testing.T) { testProcessors(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) + fmt.Println("-------------- testCustomTableName --------------") + testCustomTableName(engine, t) } // !nash! the 3rd set of the test is intended for non-cache enabled engine diff --git a/docs/QuickStart.md b/docs/QuickStart.md index dc2f0cd3..c00e9727 100644 --- a/docs/QuickStart.md +++ b/docs/QuickStart.md @@ -4,7 +4,7 @@ xorm 快速入门 * [1.创建Orm引擎](#10) * [2.定义表结构体](#20) * [2.1.名称映射规则](#21) - * [2.2.前缀映射规则和后缀映射规则](#22) + * [2.2.前缀映射,后缀映射和缓存映射](#22) * [2.3.使用Table和Tag改变名称映射](#23) * [2.4.Column属性定义](#24) * [2.5.Go与字段类型对应表](#25) @@ -29,12 +29,13 @@ xorm 快速入门 * [9.执行SQL命令](#100) * [10.事务处理](#110) * [11.缓存](#120) -* [12.xorm工具](#130) - * [12.1.反转命令](#131) -* [13.Examples](#140) -* [14.案例](#150) -* [15.那些年我们踩过的坑](#160) -* [16.讨论](#170) +* [12.事件](#125) +* [13.xorm工具](#130) + * [13.1.反转命令](#131) +* [14.Examples](#140) +* [15.案例](#150) +* [16.那些年我们踩过的坑](#160) +* [17.讨论](#170) ## 1.创建Orm引擎 @@ -129,18 +130,20 @@ engine.SetColumnMapper(SnakeMapper{}) ``` -### 2.2.前缀映射规则和后缀映射规则 +### 2.2.前缀映射,后缀映射和缓存映射 * 通过`engine.NewPrefixMapper(SnakeMapper{}, "prefix")`可以在SnakeMapper的基础上在命名中添加统一的前缀,当然也可以把SnakeMapper{}换成SameMapper或者你自定义的Mapper。 * 通过`engine.NewSufffixMapper(SnakeMapper{}, "suffix")`可以在SnakeMapper的基础上在命名中添加统一的后缀,当然也可以把SnakeMapper{}换成SameMapper或者你自定义的Mapper。 -* +* 通过`eneing.NewCacheMapper(SnakeMapper{})`可以组合其它的映射规则,起到在内存中缓存曾经映射过的命名映射。 ### 2.3.使用Table和Tag改变名称映射 如果所有的命名都是按照IMapper的映射来操作的,那当然是最理想的。但是如果碰到某个表名或者某个字段名跟映射规则不匹配时,我们就需要别的机制来改变。 -通过`engine.Table()`方法可以改变struct对应的数据库表的名称,通过sturct中field对应的Tag中使用`xorm:"'column_name'"`可以使该field对应的Column名称为指定名称。这里使用两个单引号将Column名称括起来是为了防止名称冲突,因为我们在Tag中还可以对这个Column进行更多的定义。如果名称不冲突的情况,单引号也可以不使用。 +* 如果struct拥有`Tablename() string`的成员方法,那么此方法的返回值即是该struct默认对应的数据库表名。 + +* 通过`engine.Table()`方法可以改变struct对应的数据库表的名称,通过sturct中field对应的Tag中使用`xorm:"'column_name'"`可以使该field对应的Column名称为指定名称。这里使用两个单引号将Column名称括起来是为了防止名称冲突,因为我们在Tag中还可以对这个Column进行更多的定义。如果名称不冲突的情况,单引号也可以不使用。 ### 2.4.Column属性定义 @@ -153,7 +156,7 @@ type User struct { } ``` -对于不同的数据库系统,数据类型其实是有些差异的。因此xorm中对数据类型有自己的定义,基本的原则是尽量兼容各种数据库的字段类型,具体的字段对应关系可以查看[字段类型对应表](https://github.com/lunny/xorm/blob/master/docs/COLUMNTYPE.md)。 +对于不同的数据库系统,数据类型其实是有些差异的。因此xorm中对数据类型有自己的定义,基本的原则是尽量兼容各种数据库的字段类型,具体的字段对应关系可以查看[字段类型对应表](https://github.com/lunny/xorm/blob/master/docs/COLUMNTYPE.md)。对于使用者,一般只要使用自己熟悉的数据库字段定义即可。 具体的映射规则如下,另Tag中的关键字均不区分大小写,字段名区分大小写: @@ -407,7 +410,11 @@ engine.Cols("age", "name").Update(&user) // UPDATE user SET age=? AND name=? ``` -其中的参数"age", "name"也可以写成"age, name",两种写法均可 +* AllCols() +查询或更新所有字段。 + +* MustCols(…string) +某些字段必须更新。 * Omit(...string) 和cols相反,此函数指定排除某些指定的字段。注意:此方法和Cols方法不可同时使用 @@ -729,20 +736,46 @@ engine.ClearCache(new(User)) ![cache design](https://raw.github.com/lunny/xorm/master/docs/cache_design.png) + +## 12.事件 +xorm支持两种方式的事件,一种是在Struct中的特定方法来作为事件的方法,一种是在执行语句的过程中执行事件。 + +在Struct中作为成员方法的事件如下: + +* BeforeInsert() + +* BeforeUpdate() + +* BeforeDelete() + +* AfterInsert() + +* AfterUpdate() + +* AfterDelete() + +在语句执行过程中的事件方法为: + +* Before(beforeFunc interface{}) + +* After(afterFunc interface{}) + +其中beforeFunc和afterFunc的原型为func(bean interface{}). + -## 12.xorm工具 +## 13.xorm工具 xorm工具提供了xorm命令,能够帮助做很多事情。 -### 12.1.反转命令 +### 13.1.反转命令 参见 [xorm工具](https://github.com/lunny/xorm/tree/master/xorm) -## 13.Examples +## 14.Examples 请访问[https://github.com/lunny/xorm/tree/master/examples](https://github.com/lunny/xorm/tree/master/examples) -## 14.案例 +## 15.案例 * [Gowalker](http://gowalker.org),源代码 [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker) @@ -753,7 +786,7 @@ xorm工具提供了xorm命令,能够帮助做很多事情。 * [VeryHour](http://veryhour.com) -## 15.那些年我们踩过的坑 +## 16.那些年我们踩过的坑 * 怎么同时使用xorm的tag和json的tag? 答:使用空格 @@ -797,5 +830,5 @@ money float64 `xorm:"Numeric"` -## 16.讨论 +## 17.讨论 请加入QQ群:280360085 进行讨论。 diff --git a/docs/QuickStartEn.md b/docs/QuickStartEn.md index 5249c81f..7cc08e2e 100644 --- a/docs/QuickStartEn.md +++ b/docs/QuickStartEn.md @@ -100,30 +100,47 @@ engine.Logger = f ## 2.Define struct -xorm支持将一个struct映射为数据库中对应的一张表。映射规则如下: +xorm map a struct to a database table, the rule is below. -### 2.1.名称映射规则 +### 2.1.name mapping rule -名称映射规则主要负责结构体名称到表名和结构体field到表字段的名称映射。由xorm.IMapper接口的实现者来管理,xorm内置了两种IMapper实现:`SnakeMapper` 和 `SameMapper`。SnakeMapper支持struct为驼峰式命名,表结构为下划线命名之间的转换;SameMapper支持相同的命名。 +use xorm.IMapper interface to implement. There are two IMapper implemented: `SnakeMapper` and `SameMapper`. SnakeMapper means struct name is word by word and table name or column name as 下划线. SameMapper means same name between struct and table. -当前SnakeMapper为默认值,如果需要改变时,在engine创建完成后使用 +SnakeMapper is the default. ```Go engine.Mapper = SameMapper{} ``` +同时需要注意的是: + +* 如果你使用了别的命名规则映射方案,也可以自己实现一个IMapper。 +* 表名称和字段名称的映射规则默认是相同的,当然也可以设置为不同,如: + +```Go +engine.SetTableMapper(SameMapper{}) +engine.SetColumnMapper(SnakeMapper{}) +``` + + +### 2.2.前缀映射规则,后缀映射规则和缓存映射规则 + +* 通过`engine.NewPrefixMapper(SnakeMapper{}, "prefix")`可以在SnakeMapper的基础上在命名中添加统一的前缀,当然也可以把SnakeMapper{}换成SameMapper或者你自定义的Mapper。 +* 通过`engine.NewSufffixMapper(SnakeMapper{}, "suffix")`可以在SnakeMapper的基础上在命名中添加统一的后缀,当然也可以把SnakeMapper{}换成SameMapper或者你自定义的Mapper。 +* 通过`eneing.NewCacheMapper(SnakeMapper{})`可以起到在内存中缓存曾经映射过的命名映射。 + 当然,如果你使用了别的命名规则映射方案,也可以自己实现一个IMapper。 -### 2.2.使用Table和Tag改变名称映射 +### 2.3.使用Table和Tag改变名称映射 如果所有的命名都是按照IMapper的映射来操作的,那当然是最理想的。但是如果碰到某个表名或者某个字段名跟映射规则不匹配时,我们就需要别的机制来改变。 通过`engine.Table()`方法可以改变struct对应的数据库表的名称,通过sturct中field对应的Tag中使用`xorm:"'table_name'"`可以使该field对应的Column名称为指定名称。这里使用两个单引号将Column名称括起来是为了防止名称冲突,因为我们在Tag中还可以对这个Column进行更多的定义。如果名称不冲突的情况,单引号也可以不使用。 -### 2.3.Column属性定义 +### 2.4.Column属性定义 我们在field对应的Tag中对Column的一些属性进行定义,定义的方法基本和我们写SQL定义表结构类似,比如: ``` diff --git a/engine.go b/engine.go index 6e2cbad5..2499d494 100644 --- a/engine.go +++ b/engine.go @@ -23,6 +23,7 @@ const ( MSSQL = "mssql" ORACLE_OCI = "oci8" + QL = "ql" ) // a dialect is a driver's wrapper @@ -33,6 +34,8 @@ type dialect interface { SqlType(t *Column) string SupportInsertMany() bool QuoteStr() string + RollBackStr() string + DropTableSql(tableName string) string AutoIncrStr() string SupportEngine() bool SupportCharset() bool @@ -154,9 +157,9 @@ func (engine *Engine) NoCascade() *Session { // Set a table use a special cacher func (engine *Engine) MapCacher(bean interface{}, cacher Cacher) { - t := rType(bean) - engine.autoMapType(t) - engine.Tables[t].Cacher = cacher + v := rValue(bean) + engine.autoMapType(v) + engine.Tables[v.Type()].Cacher = cacher } // OpenDB provides a interface to operate database directly. @@ -333,6 +336,18 @@ func (engine *Engine) Cols(columns ...string) *Session { return session.Cols(columns...) } +func (engine *Engine) AllCols() *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.AllCols() +} + +func (engine *Engine) MustCols(columns ...string) *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.MustCols(columns...) +} + // Xorm automatically retrieve condition according struct, but // if struct has bool field, it will ignore them. So use UseBool // to tell system to do not ignore them. @@ -420,12 +435,13 @@ func (engine *Engine) Having(conditions string) *Session { return session.Having(conditions) } -func (engine *Engine) autoMapType(t reflect.Type) *Table { +func (engine *Engine) autoMapType(v reflect.Value) *Table { + t := v.Type() engine.mutex.RLock() table, ok := engine.Tables[t] engine.mutex.RUnlock() if !ok { - table = engine.mapType(t) + table = engine.mapType(v) engine.mutex.Lock() engine.Tables[t] = table engine.mutex.Unlock() @@ -434,8 +450,8 @@ func (engine *Engine) autoMapType(t reflect.Type) *Table { } func (engine *Engine) autoMap(bean interface{}) *Table { - t := rType(bean) - return engine.autoMapType(t) + v := rValue(bean) + return engine.autoMapType(v) } func (engine *Engine) newTable() *Table { @@ -448,9 +464,36 @@ func (engine *Engine) newTable() *Table { return table } -func (engine *Engine) mapType(t reflect.Type) *Table { +func addIndex(indexName string, table *Table, col *Column, indexType int) { + if index, ok := table.Indexes[indexName]; ok { + index.AddColumn(col.Name) + col.Indexes[index.Name] = true + } else { + index := NewIndex(indexName, indexType) + index.AddColumn(col.Name) + table.AddIndex(index) + col.Indexes[index.Name] = true + } +} + +func (engine *Engine) mapType(v reflect.Value) *Table { + t := v.Type() table := engine.newTable() - table.Name = engine.tableMapper.Obj2Table(t.Name()) + method := v.MethodByName("TableName") + if !method.IsValid() { + method = v.Addr().MethodByName("TableName") + } + if method.IsValid() { + params := []reflect.Value{} + results := method.Call(params) + if len(results) == 1 { + table.Name = results[0].Interface().(string) + } + } + + if table.Name == "" { + table.Name = engine.tableMapper.Obj2Table(t.Name()) + } table.Type = t var idFieldColName string @@ -460,7 +503,8 @@ func (engine *Engine) mapType(t reflect.Type) *Table { tag := t.Field(i).Tag ormTagStr := tag.Get(engine.TagIdentifier) var col *Column - fieldType := t.Field(i).Type + fieldValue := v.Field(i) + fieldType := fieldValue.Type() if ormTagStr != "" { col = &Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, @@ -473,7 +517,7 @@ func (engine *Engine) mapType(t reflect.Type) *Table { } if (strings.ToUpper(tags[0]) == "EXTENDS") && (fieldType.Kind() == reflect.Struct) { - parentTable := engine.mapType(fieldType) + parentTable := engine.mapType(fieldValue) for name, col := range parentTable.Columns { col.FieldName = fmt.Sprintf("%v.%v", fieldType.Name(), col.FieldName) table.Columns[strings.ToLower(name)] = col @@ -483,8 +527,9 @@ func (engine *Engine) mapType(t reflect.Type) *Table { table.PrimaryKeys = parentTable.PrimaryKeys continue } - var indexType int - var indexName string + + indexNames := make(map[string]int) + var isIndex, isUnique bool var preKey string for j, key := range tags { k := strings.ToUpper(key) @@ -520,15 +565,15 @@ func (engine *Engine) mapType(t reflect.Type) *Table { case k == "UPDATED": col.IsUpdated = true case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"): - indexType = IndexType - indexName = k[len("INDEX")+1 : len(k)-1] + indexName := k[len("INDEX")+1 : len(k)-1] + indexNames[indexName] = IndexType case k == "INDEX": - indexType = IndexType + isIndex = true case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"): - indexName = k[len("UNIQUE")+1 : len(k)-1] - indexType = UniqueType + indexName := k[len("UNIQUE")+1 : len(k)-1] + indexNames[indexName] = UniqueType case k == "UNIQUE": - indexType = UniqueType + isUnique = true case k == "NOTNULL": col.Nullable = false case k == "NOT": @@ -583,32 +628,15 @@ func (engine *Engine) mapType(t reflect.Type) *Table { if col.Name == "" { col.Name = engine.columnMapper.Obj2Table(t.Field(i).Name) } - if indexType == IndexType { - if indexName == "" { - indexName = col.Name - } - if index, ok := table.Indexes[indexName]; ok { - index.AddColumn(col.Name) - col.Indexes[index.Name] = true - } else { - index := NewIndex(indexName, IndexType) - index.AddColumn(col.Name) - table.AddIndex(index) - col.Indexes[index.Name] = true - } - } else if indexType == UniqueType { - if indexName == "" { - indexName = col.Name - } - if index, ok := table.Indexes[indexName]; ok { - index.AddColumn(col.Name) - col.Indexes[index.Name] = true - } else { - index := NewIndex(indexName, UniqueType) - index.AddColumn(col.Name) - table.AddIndex(index) - col.Indexes[index.Name] = true - } + + if isUnique { + indexNames[col.Name] = UniqueType + } else if isIndex { + indexNames[col.Name] = IndexType + } + + for indexName, indexType := range indexNames { + addIndex(indexName, table, col, indexType) } } } else { @@ -660,19 +688,20 @@ func (engine *Engine) mapping(beans ...interface{}) (e error) { engine.mutex.Lock() defer engine.mutex.Unlock() for _, bean := range beans { - t := rType(bean) - engine.Tables[t] = engine.mapType(t) + v := rValue(bean) + engine.Tables[v.Type()] = engine.mapType(v) } return } // If a table has any reocrd func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) { - t := rType(bean) + v := rValue(bean) + t := v.Type() if t.Kind() != reflect.Struct { return false, errors.New("bean should be a struct or struct's point") } - engine.autoMapType(t) + engine.autoMapType(v) session := engine.NewSession() defer session.Close() rows, err := session.Count(bean) @@ -681,11 +710,11 @@ func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) { // If a table is exist func (engine *Engine) IsTableExist(bean interface{}) (bool, error) { - t := rType(bean) - if t.Kind() != reflect.Struct { + v := rValue(bean) + if v.Type().Kind() != reflect.Struct { return false, errors.New("bean should be a struct or struct's point") } - table := engine.autoMapType(t) + table := engine.autoMapType(v) session := engine.NewSession() defer session.Close() has, err := session.isTableExist(table.Name) diff --git a/helpers.go b/helpers.go index 96f118f2..25b6ddc8 100644 --- a/helpers.go +++ b/helpers.go @@ -37,9 +37,14 @@ func makeArray(elem string, count int) []string { return res } +func rValue(bean interface{}) reflect.Value { + return reflect.Indirect(reflect.ValueOf(bean)) +} + func rType(bean interface{}) reflect.Type { sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - return reflect.TypeOf(sliceValue.Interface()) + //return reflect.TypeOf(sliceValue.Interface()) + return sliceValue.Type() } func structName(v reflect.Type) string { diff --git a/mssql.go b/mssql.go index 6e9776d2..54c93e71 100644 --- a/mssql.go +++ b/mssql.go @@ -108,6 +108,12 @@ func (db *mssql) AutoIncrStr() string { return "IDENTITY" } +func (db *mssql) DropTableSql(tableName string) string { + return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ + "object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+ + "DROP TABLE \"%s\"", tableName, tableName) +} + func (db *mssql) SupportCharset() bool { return false } @@ -187,7 +193,7 @@ where a.object_id=object_id('` + tableName + `')` if col.SQLType.IsText() { if col.Default != "" { col.Default = "'" + col.Default + "'" - }else{ + } else { if col.DefaultIsEmpty { col.Default = "''" } diff --git a/mysql.go b/mysql.go index 23b53641..8d0cfaa3 100644 --- a/mysql.go +++ b/mysql.go @@ -89,6 +89,14 @@ func (b *base) DBType() string { return b.uri.dbType } +func (db *base) RollBackStr() string { + return "ROLL BACK" +} + +func (db *base) DropTableSql(tableName string) string { + return fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName) +} + type mysql struct { base net string diff --git a/session.go b/session.go index 825acc46..2298e839 100644 --- a/session.go +++ b/session.go @@ -126,6 +126,16 @@ func (session *Session) Cols(columns ...string) *Session { return session } +func (session *Session) AllCols() *Session { + session.Statement.AllCols() + return session +} + +func (session *Session) MustCols(columns ...string) *Session { + session.Statement.MustCols(columns...) + return session +} + func (session *Session) NoCascade() *Session { session.Statement.UseCascade = false return session @@ -281,7 +291,7 @@ func (session *Session) Begin() error { // When using transaction, you can rollback if any error func (session *Session) Rollback() error { if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { - session.Engine.LogSQL("ROLL BACK") + session.Engine.LogSQL(session.Engine.dialect.RollBackStr()) session.IsCommitedOrRollbacked = true return session.Tx.Rollback() } @@ -348,12 +358,12 @@ func cleanupProcessorsClosures(slices *[]func(interface{})) { } func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { - dataStruct := reflect.Indirect(reflect.ValueOf(obj)) + dataStruct := rValue(obj) if dataStruct.Kind() != reflect.Struct { return errors.New("Expected a pointer to a struct") } - table := session.Engine.autoMapType(rType(obj)) + table := session.Engine.autoMapType(dataStruct) for key, data := range objMap { key = strings.ToLower(key) @@ -1007,12 +1017,14 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if session.Statement.RefTable == nil { if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Elem().Kind() == reflect.Struct { - table = session.Engine.autoMapType(sliceElementType.Elem()) + pv := reflect.New(sliceElementType.Elem()) + table = session.Engine.autoMapType(pv.Elem()) } else { return errors.New("slice type") } } else if sliceElementType.Kind() == reflect.Struct { - table = session.Engine.autoMapType(sliceElementType) + pv := reflect.New(sliceElementType) + table = session.Engine.autoMapType(pv.Elem()) } else { return errors.New("slice type") } @@ -1023,7 +1035,8 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if len(condiBean) > 0 { colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true, - false, true, session.Statement.allUseBool, session.Statement.boolColumnMap) + false, true, session.Statement.allUseBool, session.Statement.useAllCols, + session.Statement.mustColumnMap) session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args } @@ -1375,13 +1388,12 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *T } func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount int, bean interface{}) error { - - dataStruct := reflect.Indirect(reflect.ValueOf(bean)) + dataStruct := rValue(bean) if dataStruct.Kind() != reflect.Struct { return errors.New("Expected a pointer to a struct") } - table := session.Engine.autoMapType(rType(bean)) + table := session.Engine.autoMapType(dataStruct) var scanResultContainers []interface{} for i := 0; i < fieldsCount; i++ { @@ -1483,7 +1495,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in fieldValue.Set(vv) } } else if session.Statement.UseCascade { - table := session.Engine.autoMapType(fieldValue.Type()) + table := session.Engine.autoMapType(*fieldValue) if table != nil { var x int64 if rawValueType.Kind() == reflect.Int64 { @@ -1752,9 +1764,10 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } bean := sliceValue.Index(0).Interface() - sliceElementType := rType(bean) + elementValue := rValue(bean) + //sliceElementType := elementValue.Type() - table := session.Engine.autoMapType(sliceElementType) + table := session.Engine.autoMapType(elementValue) session.Statement.RefTable = table size := sliceValue.Len() @@ -2062,7 +2075,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data v = x fieldValue.Set(reflect.ValueOf(v)) } else if session.Statement.UseCascade { - table := session.Engine.autoMapType(fieldValue.Type()) + table := session.Engine.autoMapType(*fieldValue) if table != nil { x, err := strconv.ParseInt(string(data), 10, 64) if err != nil { @@ -2838,7 +2851,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.Statement.ColumnStr == "" { colNames, args = buildConditions(session.Engine, table, bean, false, false, - false, false, session.Statement.allUseBool, session.Statement.boolColumnMap) + false, false, session.Statement.allUseBool, session.Statement.useAllCols, + session.Statement.mustColumnMap) } else { colNames, args, err = table.genCols(session, bean, true, true) if err != nil { @@ -2872,7 +2886,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condiBean) > 0 { condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, true, - false, true, session.Statement.allUseBool, session.Statement.boolColumnMap) + false, true, session.Statement.allUseBool, session.Statement.useAllCols, + session.Statement.mustColumnMap) } var condition = "" @@ -2895,6 +2910,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var sqlStr, inSql string var inArgs []interface{} + doIncVer := false + var verValue reflect.Value if table.Version != "" && session.Statement.checkVersion { if condition != "" { condition = fmt.Sprintf("WHERE (%v) AND %v = ?", condition, @@ -2917,7 +2934,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1", condition) - condiArgs = append(condiArgs, table.VersionColumn().ValueOf(bean).Interface()) + verValue = table.VersionColumn().ValueOf(bean) + //if err != nil { + // return 0, err + //} + + condiArgs = append(condiArgs, verValue.Interface()) + doIncVer = true } else { if condition != "" { condition = "WHERE " + condition @@ -2944,6 +2967,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 res, err := session.exec(sqlStr, args...) if err != nil { return 0, err + } else if doIncVer { + verValue.SetInt(verValue.Int() + 1) } if table.Cacher != nil && session.Statement.UseCache { @@ -3060,7 +3085,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) { table := session.Engine.autoMap(bean) session.Statement.RefTable = table colNames, args := buildConditions(session.Engine, table, bean, true, true, - false, true, session.Statement.allUseBool, session.Statement.boolColumnMap) + false, true, session.Statement.allUseBool, session.Statement.useAllCols, + session.Statement.mustColumnMap) var condition = "" diff --git a/statement.go b/statement.go index 4bde5c7b..773dd378 100644 --- a/statement.go +++ b/statement.go @@ -25,6 +25,7 @@ type Statement struct { HavingStr string ColumnStr string columnMap map[string]bool + useAllCols bool OmitStr string ConditionStr string AltTableName string @@ -40,7 +41,7 @@ type Statement struct { IsDistinct bool allUseBool bool checkVersion bool - boolColumnMap map[string]bool + mustColumnMap map[string]bool inColumns map[string][]interface{} } @@ -69,7 +70,7 @@ func (statement *Statement) Init() { statement.UseAutoTime = true statement.IsDistinct = false statement.allUseBool = false - statement.boolColumnMap = make(map[string]bool) + statement.mustColumnMap = make(map[string]bool) statement.checkVersion = true statement.inColumns = make(map[string][]interface{}) } @@ -112,11 +113,12 @@ func (statement *Statement) Or(querystring string, args ...interface{}) *Stateme // tempororily set table name func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { - t := rType(tableNameOrBean) + v := rValue(tableNameOrBean) + t := v.Type() if t.Kind() == reflect.String { statement.AltTableName = tableNameOrBean.(string) } else if t.Kind() == reflect.Struct { - statement.RefTable = statement.Engine.autoMapType(t) + statement.RefTable = statement.Engine.autoMapType(v) } return statement } @@ -239,8 +241,9 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { // Auto generating conditions according a struct func buildConditions(engine *Engine, table *Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, - boolColumnMap map[string]bool) ([]string, []interface{}) { + includeVersion bool, includeUpdated bool, includeNil bool, + includeAutoIncr bool, allUseBool bool, useAllCols bool, + mustColumnMap map[string]bool) ([]string, []interface{}) { colNames := make([]string, 0) var args = make([]interface{}, 0) @@ -262,7 +265,15 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, fieldValue := col.ValueOf(bean) fieldType := reflect.TypeOf(fieldValue.Interface()) - requiredField := false + requiredField := useAllCols + if b, ok := mustColumnMap[strings.ToLower(col.Name)]; ok { + if b { + requiredField = true + } else { + continue + } + } + if fieldType.Kind() == reflect.Ptr { if fieldValue.IsNil() { if includeNil { @@ -285,8 +296,6 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, case reflect.Bool: if allUseBool || requiredField { val = fieldValue.Interface() - } else if _, ok := boolColumnMap[col.Name]; ok { - val = fieldValue.Interface() } else { // if a bool in a struct, it will not be as a condition because it default is false, // please use Where() instead @@ -334,7 +343,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, val = t } } else { - engine.autoMapType(fieldValue.Type()) + engine.autoMapType(fieldValue) if table, ok := engine.Tables[fieldValue.Type()]; ok { if len(table.PrimaryKeys) == 1 { pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) @@ -517,13 +526,34 @@ func (statement *Statement) Cols(columns ...string) *Statement { return statement } +// Update use only: update all columns +func (statement *Statement) AllCols() *Statement { + statement.useAllCols = true + return statement +} + +// Update use only: must update columns +func (statement *Statement) MustCols(columns ...string) *Statement { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.mustColumnMap[strings.ToLower(nc)] = true + } + return statement +} + +// Update use only: not update columns +/*func (statement *Statement) NotCols(columns ...string) *Statement { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.mustColumnMap[strings.ToLower(nc)] = false + } + return statement +}*/ + // indicates that use bool fields as update contents and query contiditions func (statement *Statement) UseBool(columns ...string) *Statement { if len(columns) > 0 { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.boolColumnMap[strings.ToLower(nc)] = true - } + statement.MustCols(columns...) } else { statement.allUseBool = true } @@ -705,13 +735,7 @@ func (s *Statement) genDelIndexSQL() []string { } func (s *Statement) genDropSQL() string { - if s.Engine.dialect.DBType() == MSSQL { - return "IF EXISTS (SELECT * FROM sysobjects WHERE id = object_id(N'" + - s.TableName() + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1) " + - "DROP TABLE " + s.Engine.Quote(s.TableName()) + ";" - } else { - return "DROP TABLE IF EXISTS " + s.Engine.Quote(s.TableName()) + ";" - } + return s.Engine.dialect.DropTableSql(s.TableName()) + ";" } func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) { @@ -719,7 +743,8 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) statement.RefTable = table colNames, args := buildConditions(statement.Engine, table, bean, true, true, - false, true, statement.allUseBool, statement.boolColumnMap) + false, true, statement.allUseBool, statement.useAllCols, + statement.mustColumnMap) statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args @@ -758,7 +783,7 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} statement.RefTable = table colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, - true, statement.allUseBool, statement.boolColumnMap) + true, statement.allUseBool, statement.useAllCols, statement.mustColumnMap) statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args