diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..c7fde071 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,33 @@ +## Contributing to xorm + +`xorm` has a backlog of pull requests, but contributions are still very +much welcome. You can help with patch review, submitting bug reports, +or adding new functionality. There is no formal style guide, but +please conform to the style of existing code and general Go formatting +conventions when submitting patches. + +### Patch review + +Help review existing open pull requests by commenting on the code or +proposed functionality. + +### Bug reports + +We appreciate any bug reports, but especially ones with self-contained +(doesn't depend on code outside of pq), minimal (can't be simplified +further) test cases. It's especially helpful if you can submit a pull +request with just the failing test case (you'll probably want to +pattern it after the tests in +[base_test.go](https://github.com/lunny/xorm/blob/master/base_test.go) AND +[benchmark_base_test.go](https://github.com/lunny/xorm/blob/master/benchmark_base_test.go). + +If you implements a new database interface, you maybe need to add a _test.go file. +For example, [mysql_test.go](https://github.com/lunny/xorm/blob/master/mysql_test.go) + +### New functionality + +There are a number of pending patches for new functionality, so +additional feature patches will take a while to merge. Still, patches +are generally reviewed based on usefulness and complexity in addition +to time-in-queue, so if you have a knockout idea, take a shot. Feel +free to open an issue discussion your proposed patch beforehand. diff --git a/README.md b/README.md index d67951f9..ecdfa721 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,17 @@ -[中文](https://github.com/lunny/xorm/blob/master/README_CN.md) - -Xorm is a simple and powerful ORM for Go. - +[中文](https://github.com/lunny/xorm/blob/master/README_CN.md) + +Xorm is a simple and powerful ORM for Go. + [![Build Status](https://drone.io/github.com/lunny/xorm/status.png)](https://drone.io/github.com/lunny/xorm/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/lunny/xorm) [![Bitdeli Badge](https://d2weczhvl823v0.cloudfront.net/lunny/xorm/trend.png)](https://bitdeli.com/free "Bitdeli Badge") -# Features - +# Features + * Struct <-> Table Mapping Support * Chainable APIs * Transaction Support - + * Both ORM and raw SQL operation Support * Sync database sechmea Support @@ -24,44 +24,52 @@ Xorm is a simple and powerful ORM for Go. * Optimistic Locking support - -# Drivers Support - -Drivers for Go's sql package which currently support database/sql includes: - + +# Drivers Support + +Drivers for Go's sql package which currently support database/sql includes: + * Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) - -* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) - -* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) - -* Postgres: [github.com/lib/pq](https://github.com/lib/pq) - - + +* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) + +* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) + +* Postgres: [github.com/lib/pq](https://github.com/lib/pq) + +* MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc) + # Changelog -* **v0.2.3** : Improved documents; Optimistic Locking support; Timestamp with time zone support; Mapper change to tableMapper and columnMapper & added PrefixMapper & SuffixMapper support custom table or column name's prefix and suffix;Insert now return affected, err instead of id, err; Added UseBool & Distinct; -* **v0.2.2** : Postgres drivers now support lib/pq; Added method Iterate for record by record to handler;Added SetMaxConns(go1.2+) support; some bugs fixed. -* **v0.2.1** : Added database reverse tool, now support generate go & c++ codes, see [Xorm Tool README](https://github.com/lunny/xorm/blob/master/xorm/README.md); some bug fixed. -* **v0.2.0** : Added Cache supported, select is speeder up 3~5x; Added SameMapper for same name between struct and table; Added Sync method for auto added tables, columns, indexes; +* **v0.3.1** + + Features: + * Support MSSQL DB via ODBC driver ([github.com/lunny/godbc](https://github.com/lunny/godbc)); + * Composite Key, using multiple pk xorm tag + * Added Row() API as alternative to Iterate() API for traversing result set, provide similar usages to sql.Rows type + * ORM struct allowed declaration of pointer builtin type as members to allow null DB fields + * Before and After Event processors + + Improvements: + * Allowed int/int32/int64/uint/uint32/uint64/string as Primary Key type + * Performance improvement for Get()/Find()/Iterate() + +[More changelogs ...](https://github.com/lunny/xorm/blob/master/docs/Changelog.md) -[More changelogs ...](https://github.com/lunny/xorm/blob/master/docs/Changelog.md) - - # Installation If you have [gopm](https://github.com/gpmgo/gopm) installed, gopm get github.com/lunny/xorm -Or - - go get github.com/lunny/xorm - +Or + + go get github.com/lunny/xorm + # Documents -* [GoDoc](http://godoc.org/github.com/lunny/xorm) - +* [GoDoc](http://godoc.org/github.com/lunny/xorm) + * [GoWalker](http://gowalker.org/github.com/lunny/xorm) * [Quick Start](https://github.com/lunny/xorm/blob/master/docs/QuickStartEn.md) @@ -74,18 +82,24 @@ Or * [Godaily](http://godaily.org) - [github.com/govc/godaily](http://github.com/govc/godaily) -* [Very Hour](http://veryhour.com/) - +* [Very Hour](http://veryhour.com/) + +# Todo + +[Todo List](https://trello.com/b/IHsuAnhk/xorm) + # Discuss Please visit [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm) # Contributors +If you want to pull request, please see [CONTRIBUTING](https://github.com/lunny/xorm/blob/master/CONTRIBUTING.md) + * [Lunny](https://github.com/lunny) -* [Nashtsai](https://github.com/nashtsai) - -# LICENSE - - BSD License - [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) +* [Nashtsai](https://github.com/nashtsai) + +# LICENSE + + BSD License + [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) diff --git a/README_CN.md b/README_CN.md index febc65f9..d1b88e67 100644 --- a/README_CN.md +++ b/README_CN.md @@ -1,7 +1,7 @@ # xorm - -[English](https://github.com/lunny/xorm/blob/master/README.md) - + +[English](https://github.com/lunny/xorm/blob/master/README.md) + xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。 [![Build Status](https://drone.io/github.com/lunny/xorm/status.png)](https://drone.io/github.com/lunny/xorm/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/lunny/xorm) @@ -27,27 +27,37 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 * 支持记录版本(即乐观锁) ## 驱动支持 - -目前支持的Go数据库驱动如下: - -* Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) - + +目前支持的Go数据库驱动和对应的数据库如下: + +* Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) + * MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) * SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) * Postgres: [github.com/lib/pq](https://github.com/lib/pq) -* Postgres: [github.com/bylevel/pq](https://github.com/bylevel/pq) +* MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc) ## 更新日志 -* **v0.2.3** : 改善了文档;提供了乐观锁支持;添加了带时区时间字段支持;Mapper现在分成表名Mapper和字段名Mapper,同时实现了表或字段的自定义前缀后缀;Insert方法的返回值含义从id, err更改为 affected, err,请大家注意;添加了UseBool 和 Distinct函数。 -* **v0.2.2** : Postgres驱动新增了对lib/pq的支持;新增了逐条遍历方法Iterate;新增了SetMaxConns(go1.2+)支持,修复了bug若干; -* **v0.2.1** : 新增数据库反转工具,当前支持go和c++代码的生成,详见 [Xorm Tool README](https://github.com/lunny/xorm/blob/master/xorm/README.md); 修复了一些bug. -* **v0.2.0** : 新增 [缓存](https://github.com/lunny/xorm/blob/master/docs/QuickStart.md#120)支持,查询速度提升3-5倍; 新增数据库表和Struct同名的映射方式; 新增Sync同步表结构; + +* **v0.3.1** + + 新特性: + * 支持 MSSQL DB 通过 ODBC 驱动 ([github.com/lunny/godbc](https://github.com/lunny/godbc)); + * 通过多个pk标记支持联合主键; + * 新增 Rows() API 用来遍历查询结果,该函数提供了类似sql.Rows的相似用法,可作为 Iterate() API 的可选替代; + * ORM 结构体现在允许内建类型的指针作为成员,使得数据库为null成为可能; + * Before 和 After 支持 + + 改进: + * 允许 int/int32/int64/uint/uint32/uint64/string 作为主键类型 + * 查询函数 Get()/Find()/Iterate() 在性能上的改进 + [更多更新日志...](https://github.com/lunny/xorm/blob/master/docs/ChangelogCN.md) - + ## 安装 推荐使用 [gopm](https://github.com/gpmgo/gopm) 进行安装: @@ -56,10 +66,10 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 或者您也可以使用go工具进行安装: - go get github.com/lunny/xorm - + go get github.com/lunny/xorm + ## 文档 - + * [快速开始](https://github.com/lunny/xorm/blob/master/docs/QuickStart.md) * [GoWalker代码文档](http://gowalker.org/github.com/lunny/xorm) @@ -77,16 +87,22 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 * [Very Hour](http://veryhour.com/) +## Todo + +[开发计划](https://trello.com/b/IHsuAnhk/xorm) + ## 讨论 请加入QQ群:280360085 进行讨论。 # 贡献者 +如果您也想为Xorm贡献您的力量,请查看 [CONTRIBUTING](https://github.com/lunny/xorm/blob/master/CONTRIBUTING.md) + * [Lunny](https://github.com/lunny) * [Nashtsai](https://github.com/nashtsai) - -## LICENSE - -BSD License -[http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) + +## LICENSE + +BSD License +[http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) diff --git a/VERSION b/VERSION index 6a4bbfd8..a67c0014 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -xorm v0.2.3 +xorm v0.3.1 diff --git a/base_test.go b/base_test.go index b10643a4..91da58c6 100644 --- a/base_test.go +++ b/base_test.go @@ -267,6 +267,16 @@ func insertTwoTable(engine *Engine, t *testing.T) { } } +type Article struct { + Id int32 `xorm:"pk INT autoincr"` + Name string `xorm:"VARCHAR(45)"` + Img string `xorm:"VARCHAR(100)"` + Aside string `xorm:"VARCHAR(200)"` + Desc string `xorm:"VARCHAR(200)"` + Content string `xorm:"TEXT"` + Status int8 `xorm:"TINYINT(4)"` +} + type Condi map[string]interface{} func update(engine *Engine, t *testing.T) { @@ -314,6 +324,44 @@ func update(engine *Engine, t *testing.T) { panic(err) return } + + err = engine.Sync(&Article{}) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err = engine.Insert(&Article{0, "1", "2", "3", "4", "5", 2}) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + + cnt, err = engine.Id(1).Update(&Article{Name: "6"}) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New("update not returned 1") + t.Error(err) + panic(err) + return + } + + err = engine.DropTables(&Article{}) + if err != nil { + t.Error(err) + panic(err) + } } func updateSameMapper(engine *Engine, t *testing.T) { @@ -359,7 +407,7 @@ func updateSameMapper(engine *Engine, t *testing.T) { } } -func testdelete(engine *Engine, t *testing.T) { +func testDelete(engine *Engine, t *testing.T) { user := Userinfo{Uid: 1} cnt, err := engine.Delete(&user) if err != nil { @@ -557,20 +605,48 @@ func where(engine *Engine, t *testing.T) { func in(engine *Engine, t *testing.T) { users := make([]Userinfo, 0) - err := engine.In("(id)", 1, 2, 3).Find(&users) + err := engine.In("(id)", 7, 8, 9).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) + if len(users) != 3 { + err = errors.New("in uses should be 7,8,9 total 3") + t.Error(err) + panic(err) + } + + for _, user := range users { + if user.Uid != 7 && user.Uid != 8 && user.Uid != 9 { + err = errors.New("in uses should be 7,8,9 total 3") + t.Error(err) + panic(err) + } + } + + users = make([]Userinfo, 0) + ids := []interface{}{7, 8, 9} + err = engine.Where("departname = ?", "dev").In("(id)", ids...).Find(&users) if err != nil { t.Error(err) panic(err) } fmt.Println(users) - ids := []interface{}{1, 2, 3} - err = engine.Where("(id) > ?", 2).In("(id)", ids...).Find(&users) - if err != nil { + if len(users) != 3 { + err = errors.New("in uses should be 7,8,9 total 3") t.Error(err) panic(err) } - fmt.Println(users) + + for _, user := range users { + if user.Uid != 7 && user.Uid != 8 && user.Uid != 9 { + err = errors.New("in uses should be 7,8,9 total 3") + t.Error(err) + panic(err) + } + } err = engine.In("(id)", 1).In("(id)", 2).In("departname", "dev").Find(&users) if err != nil { @@ -1448,12 +1524,32 @@ func testIndexAndUnique(engine *Engine, t *testing.T) { } type IntId struct { - Id int + Id int `xorm:"pk autoincr"` Name string } type Int32Id struct { - Id int32 + Id int32 `xorm:"pk autoincr"` + Name string +} + +type UintId struct { + Id uint `xorm:"pk autoincr"` + Name string +} + +type Uint32Id struct { + Id uint32 `xorm:"pk autoincr"` + Name string +} + +type Uint64Id struct { + Id uint64 `xorm:"pk autoincr"` + Name string +} + +type StringPK struct { + Id string `xorm:"pk notnull"` Name string } @@ -1470,11 +1566,51 @@ func testIntId(engine *Engine, t *testing.T) { panic(err) } - _, err = engine.Insert(&IntId{Name: "test"}) + cnt, err := engine.Insert(&IntId{Name: "test"}) if err != nil { t.Error(err) panic(err) } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } + + bean := new(IntId) + has, err := engine.Get(bean) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + beans := make([]IntId, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + panic(err) + } + if len(beans) != 1 { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + cnt, err = engine.Id(bean.Id).Delete(&IntId{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } } func testInt32Id(engine *Engine, t *testing.T) { @@ -1490,11 +1626,295 @@ func testInt32Id(engine *Engine, t *testing.T) { panic(err) } - _, err = engine.Insert(&Int32Id{Name: "test"}) + cnt, err := engine.Insert(&Int32Id{Name: "test"}) if err != nil { t.Error(err) panic(err) } + + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } + + bean := new(Int32Id) + has, err := engine.Get(bean) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + beans := make([]Int32Id, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + panic(err) + } + if len(beans) != 1 { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + cnt, err = engine.Id(bean.Id).Delete(&Int32Id{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } +} + +func testUintId(engine *Engine, t *testing.T) { + err := engine.DropTables(&UintId{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&UintId{}) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err := engine.Insert(&UintId{Name: "test"}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } + + bean := new(UintId) + has, err := engine.Get(bean) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + beans := make([]UintId, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + panic(err) + } + if len(beans) != 1 { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + cnt, err = engine.Id(bean.Id).Delete(&UintId{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } +} + +func testUint32Id(engine *Engine, t *testing.T) { + err := engine.DropTables(&Uint32Id{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&Uint32Id{}) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err := engine.Insert(&Uint32Id{Name: "test"}) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } + + bean := new(Uint32Id) + has, err := engine.Get(bean) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + beans := make([]Uint32Id, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + panic(err) + } + if len(beans) != 1 { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + cnt, err = engine.Id(bean.Id).Delete(&Uint32Id{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } +} + +func testUint64Id(engine *Engine, t *testing.T) { + err := engine.DropTables(&Uint64Id{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&Uint64Id{}) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err := engine.Insert(&Uint64Id{Name: "test"}) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } + + bean := new(Uint64Id) + has, err := engine.Get(bean) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + beans := make([]Uint64Id, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + panic(err) + } + if len(beans) != 1 { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + cnt, err = engine.Id(bean.Id).Delete(&Uint64Id{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } +} + +func testStringPK(engine *Engine, t *testing.T) { + err := engine.DropTables(&StringPK{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&StringPK{}) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err := engine.Insert(&StringPK{Name: "test"}) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } + + bean := new(StringPK) + has, err := engine.Get(bean) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + beans := make([]StringPK, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + panic(err) + } + if len(beans) != 1 { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + cnt, err = engine.Id(bean.Id).Delete(&StringPK{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } } func testMetaInfo(engine *Engine, t *testing.T) { @@ -1529,6 +1949,27 @@ func testIterate(engine *Engine, t *testing.T) { } } +func testRows(engine *Engine, t *testing.T) { + rows, err := engine.Omit("is_man").Rows(new(Userinfo)) + if err != nil { + t.Error(err) + panic(err) + } + defer rows.Close() + + idx := 0 + user := new(Userinfo) + for rows.Next() { + err = rows.Scan(user) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(idx, "--", user) + idx++ + } +} + type StrangeName struct { Id_t int64 `xorm:"pk autoincr"` Name string @@ -2810,7 +3251,9 @@ func testPointerData(engine *Engine, t *testing.T) { // using instance type should just work too nullData2Get := NullData2{} - has, err = engine.Table("null_data").Id(nullData.Id).Get(&nullData2Get) + tableName := engine.tableMapper.Obj2Table("NullData") + + has, err = engine.Table(tableName).Id(nullData.Id).Get(&nullData2Get) if err != nil { t.Error(err) panic(err) @@ -3156,45 +3599,54 @@ func testNullValue(engine *Engine, t *testing.T) { // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr))) // } - /*if (*nullDataGet.TimePtr).Unix() != (*nullDataUpdate.TimePtr).Unix() { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr))) - } else { - // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver - fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr) - fmt.Println() - }*/ - // -- + // !nashtsai! skipped mymysql test due to driver will round up time caused inaccuracy comparison + // skipped postgres test due to postgres driver doesn't read time.Time's timzezone info when stored in the db + // mysql and sqlite3 seem have done this correctly by storing datatime in UTC timezone, I think postgres driver + // prefer using timestamp with timezone to sovle the issue + if engine.DriverName != POSTGRES && engine.DriverName != MYMYSQL && + engine.DriverName != MYSQL { + if (*nullDataGet.TimePtr).Unix() != (*nullDataUpdate.TimePtr).Unix() { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr))) + } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver + // inserted value unmatch: [2013-12-25 12:12:45 +0800 CST]:[2013-12-25 12:12:44.878903653 +0800 CST] + fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr) + fmt.Println() + } + } // update to null values - /*nullDataUpdate = NullData{} + nullDataUpdate = NullData{} - cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate) - if err != nil { - t.Error(err) - panic(err) - } else if cnt != 1 { - t.Error(errors.New("update count == 0, how can this happen!?")) - return - }*/ + string_ptr := engine.columnMapper.Obj2Table("StringPtr") + + cnt, err = engine.Id(nullData.Id).Cols(string_ptr).Update(&nullDataUpdate) + if err != nil { + t.Error(err) + panic(err) + } else if cnt != 1 { + t.Error(errors.New("update count == 0, how can this happen!?")) + return + } // verify get values - /*nullDataGet = NullData{} - has, err = engine.Id(nullData.Id).Get(&nullDataGet) - if err != nil { - t.Error(err) - return - } else if !has { - t.Error(errors.New("ID not found")) - return - } + nullDataGet = NullData{} + has, err = engine.Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + return + } else if !has { + t.Error(errors.New("ID not found")) + return + } - fmt.Printf("%+v", nullDataGet) - fmt.Println() - - if nullDataGet.StringPtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr))) - } + fmt.Printf("%+v", nullDataGet) + fmt.Println() + if nullDataGet.StringPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr))) + } + /* if nullDataGet.StringPtr2 != nil { t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr2))) } @@ -3342,17 +3794,11 @@ type Lowercase struct { func testLowerCase(engine *Engine, t *testing.T) { err := engine.Sync(&Lowercase{}) - if err != nil { - t.Error(err) - panic(err) - } - _, err = engine.Where("id > 0").Delete(&Lowercase{}) if err != nil { t.Error(err) panic(err) } - _, err = engine.Insert(&Lowercase{ended: 1}) if err != nil { t.Error(err) @@ -3373,6 +3819,71 @@ func testLowerCase(engine *Engine, t *testing.T) { } } +type User struct { + UserId string `xorm:"varchar(19) not null pk"` + NickName string `xorm:"varchar(19) not null"` + GameId uint32 `xorm:"integer pk"` + Score int32 `xorm:"integer"` +} + +func testCompositeKey2(engine *Engine, t *testing.T) { + err := engine.DropTables(&User{}) + + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&User{}) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err := engine.Insert(&User{"11", "nick", 22, 5}) + if err != nil { + t.Error(err) + } else if cnt != 1 { + t.Error(errors.New("failed to insert User{11, 22}")) + } + + cnt, err = engine.Insert(&User{"11", "nick", 22, 6}) + if err == nil || cnt == 1 { + t.Error(errors.New("inserted User{11, 22}")) + } + + var user User + has, err := engine.Id(PK{"11", 22}).Get(&user) + if err != nil { + t.Error(err) + } else if !has { + t.Error(errors.New("can't get User{11, 22}")) + } + + // test passing PK ptr, this test seem failed withCache + has, err = engine.Id(&PK{"11", 22}).Get(&user) + if err != nil { + t.Error(err) + } else if !has { + t.Error(errors.New("can't get User{11, 22}")) + } + + user = User{NickName: "test1"} + cnt, err = engine.Id(PK{"11", 22}).Update(&user) + if err != nil { + t.Error(err) + } else if cnt != 1 { + t.Error(errors.New("can't update User{11, 22}")) + } + + cnt, err = engine.Id(PK{"11", 22}).Delete(&User{}) + if err != nil { + t.Error(err) + } else if cnt != 1 { + t.Error(errors.New("can't delete CompositeKey{11, 22}")) + } +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -3390,8 +3901,8 @@ func testAll(engine *Engine, t *testing.T) { insertTwoTable(engine, t) fmt.Println("-------------- update --------------") update(engine, t) - fmt.Println("-------------- testdelete --------------") - testdelete(engine, t) + fmt.Println("-------------- testDelete --------------") + testDelete(engine, t) fmt.Println("-------------- get --------------") get(engine, t) fmt.Println("-------------- cascadeGet --------------") @@ -3446,13 +3957,21 @@ func testAll2(engine *Engine, t *testing.T) { fmt.Println("-------------- testIndexAndUnique --------------") testIndexAndUnique(engine, t) fmt.Println("-------------- testIntId --------------") - //testIntId(engine, t) + testIntId(engine, t) fmt.Println("-------------- testInt32Id --------------") - //testInt32Id(engine, t) + testInt32Id(engine, t) + fmt.Println("-------------- testUintId --------------") + testUintId(engine, t) + fmt.Println("-------------- testUint32Id --------------") + testUint32Id(engine, t) + fmt.Println("-------------- testUint64Id --------------") + testUint64Id(engine, t) fmt.Println("-------------- testMetaInfo --------------") testMetaInfo(engine, t) fmt.Println("-------------- testIterate --------------") testIterate(engine, t) + fmt.Println("-------------- testRows --------------") + testRows(engine, t) fmt.Println("-------------- testStrangeName --------------") testStrangeName(engine, t) fmt.Println("-------------- testVersion --------------") @@ -3487,5 +4006,16 @@ func testAll3(engine *Engine, t *testing.T) { testNullValue(engine, t) fmt.Println("-------------- testCompositeKey --------------") testCompositeKey(engine, t) + fmt.Println("-------------- testCompositeKey2 --------------") + testCompositeKey2(engine, t) + fmt.Println("-------------- testStringPK --------------") + testStringPK(engine, t) +} + +func testAllSnakeMapper(engine *Engine, t *testing.T) { + +} + +func testAllSameMapper(engine *Engine, t *testing.T) { } diff --git a/benchmark.bat b/benchmark.bat new file mode 100644 index 00000000..d35fe044 --- /dev/null +++ b/benchmark.bat @@ -0,0 +1 @@ +go test -v -bench=. -run=XXX \ No newline at end of file diff --git a/doc.go b/doc.go index 37c00516..765f2dfe 100644 --- a/doc.go +++ b/doc.go @@ -60,7 +60,14 @@ There are 7 major ORM methods and many helpful methods to use to operate databas err := engine.Find(...) // SELECT * FROM user -4. Query multiple records and record by record handle +4. Query multiple records and record by record handle, there two methods, one is Iterate, +another is Raws + + raws, err := engine.Raws(...) + // SELECT * FROM user + for raws.Next() { + raws.Scan(bean) + } err := engine.Iterate(...) // SELECT * FROM user diff --git a/docs/Changelog.md b/docs/Changelog.md index bc737e68..aeeccbbe 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -1,5 +1,19 @@ ## Changelog +* **v0.3.1** + + Features: + * Support MSSQL DB via ODBC driver ([github.com/lunny/godbc](https://github.com/lunny/godbc)); + * Composite Key, using multiple pk xorm tag + * Added Row() API as alternative to Iterate() API for traversing result set, provide similar usages to sql.Rows type + * ORM struct allowed declaration of pointer builtin type as members to allow null DB fields + * Before and After Event processors + + Improvements: + * Allowed int/int32/int64/uint/uint32/uint64/string as Primary Key type + * Performance improvement for Get()/Find()/Iterate() + + * **v0.2.3** : Improved documents; Optimistic Locking support; Timestamp with time zone support; Mapper change to tableMapper and columnMapper & added PrefixMapper & SuffixMapper support custom table or column name's prefix and suffix;Insert now return affected, err instead of id, err; Added UseBool & Distinct; * **v0.2.2** : Postgres drivers now support lib/pq; Added method Iterate for record by record to handler;Added SetMaxConns(go1.2+) support; some bugs fixed. * **v0.2.1** : Added database reverse tool, now support generate go & c++ codes, see [Xorm Tool README](https://github.com/lunny/xorm/blob/master/xorm/README.md); some bug fixed. diff --git a/docs/ChangelogCN.md b/docs/ChangelogCN.md index 19f6022d..79140be4 100644 --- a/docs/ChangelogCN.md +++ b/docs/ChangelogCN.md @@ -1,5 +1,18 @@ ## 更新日志 +* **v0.3.1** + + 新特性: + * 支持 MSSQL DB 通过 ODBC 驱动 ([github.com/lunny/godbc](https://github.com/lunny/godbc)); + * 通过多个pk标记支持联合主键; + * 新增 Rows() API 用来遍历查询结果,该函数提供了类似sql.Rows的相似用法,可作为 Iterate() API 的可选替代; + * ORM 结构体现在允许内建类型的指针作为成员,使得数据库为null成为可能; + * Before 和 After 支持 + + 改进: + * 允许 int/int32/int64/uint/uint32/uint64/string 作为主键类型 + * 查询函数 Get()/Find()/Iterate() 在性能上的改进 + * **v0.2.3** : 改善了文档;提供了乐观锁支持;添加了带时区时间字段支持;Mapper现在分成表名Mapper和字段名Mapper,同时实现了表或字段的自定义前缀后缀;Insert方法的返回值含义从id, err更改为 affected, err,请大家注意;添加了UseBool 和 Distinct函数。 * **v0.2.2** : Postgres驱动新增了对lib/pq的支持;新增了逐条遍历方法Iterate;新增了SetMaxConns(go1.2+)支持,修复了bug若干; * **v0.2.1** : 新增数据库反转工具,当前支持go和c++代码的生成,详见 [Xorm Tool README](https://github.com/lunny/xorm/blob/master/xorm/README.md); 修复了一些bug. diff --git a/docs/QuickStart.md b/docs/QuickStart.md index c3d5768b..6fc55f7b 100644 --- a/docs/QuickStart.md +++ b/docs/QuickStart.md @@ -3,10 +3,11 @@ xorm 快速入门 * [1.创建Orm引擎](#10) * [2.定义表结构体](#20) - * [2.1.名称映射规则](#21) - * [2.2.使用Table和Tag改变名称映射](#22) - * [2.3.Column属性定义](#23) - * [2.4.Go与字段类型对应表](#24) + * [2.1.名称映射规则](#21) + * [2.2.前缀映射规则和后缀映射规则](#22) + * [2.3.使用Table和Tag改变名称映射](#23) + * [2.4.Column属性定义](#24) + * [2.5.Go与字段类型对应表](#25) * [3.表结构操作](#30) * [3.1 获取数据库信息](#31) * [3.2 表操作](#32) @@ -19,7 +20,8 @@ xorm 快速入门 * [5.3.Get方法](#63) * [5.4.Find方法](#64) * [5.5.Iterate方法](#65) - * [5.6.Count方法](#66) + * [5.6.Count方法](#66) + * [5.7.Rows方法](#67) * [6.更新数据](#70) * [6.1.乐观锁](#71) * [7.删除数据](#80) @@ -61,7 +63,7 @@ defer engine.Close() 一般如果只针对一个数据库进行操作,只需要创建一个Engine即可。Engine支持在多GoRutine下使用。 -xorm当前支持四种驱动如下: +xorm当前支持五种驱动四个数据库如下: * Mysql: [github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL) @@ -69,13 +71,15 @@ xorm当前支持四种驱动如下: * SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) -* Postgres: [github.com/lib/pq](https://github.com/lib/pq) +* Postgres: [github.com/lib/pq](https://github.com/lib/pq) + +* MsSql: [github.com/lunny/godbc](https://githubcom/lunny/godbc) NewEngine传入的参数和`sql.Open`传入的参数完全相同,因此,使用哪个驱动前,请查看此驱动中关于传入参数的说明文档。 在engine创建完成后可以进行一些设置,如: -1.设置 +1.错误显示设置,默认如下均为`false` * `engine.ShowSQL = true`,则会在控制台打印出生成的SQL语句; * `engine.ShowDebug = true`,则会在控制台打印调试信息; @@ -93,7 +97,7 @@ f, err := os.Create("sql.log") engine.Logger = f ``` -3.engine内部支持连接池接口,默认使用的Go所实现的连接池,同时自带了另外两种实现:一种是不使用连接池,另一种为一个自实现的连接池。推荐使用Go所实现的连接池。如果要使用自己实现的连接池,可以实现`xorm.IConnectPool`并通过`engine.SetPool`进行设置。 +3.engine内部支持连接池接口,默认使用的Go所实现的连接池,同时自带了另外两种实现:一种是不使用连接池,另一种为一个自实现的连接池。推荐使用Go所实现的连接池。如果要使用自己实现的连接池,可以实现`xorm.IConnectPool`并通过`engine.SetPool`进行设置。推荐使用Go默认的连接池。 * 如果需要设置连接池的空闲数大小,可以使用`engine.SetIdleConns()`来实现。 * 如果需要设置最大打开连接数,则可以使用`engine.SetMaxConns()`来实现。 @@ -106,25 +110,40 @@ xorm支持将一个struct映射为数据库中对应的一张表。映射规则 ### 2.1.名称映射规则 -名称映射规则主要负责结构体名称到表名和结构体field到表字段的名称映射。由xorm.IMapper接口的实现者来管理,xorm内置了两种IMapper实现:`SnakeMapper` 和 `SameMapper`。SnakeMapper支持struct为驼峰式命名,表结构为下划线命名之间的转换;SameMapper支持相同的命名。 +名称映射规则主要负责结构体名称到表名和结构体field到表字段的名称映射。由xorm.IMapper接口的实现者来管理,xorm内置了两种IMapper实现:`SnakeMapper` 和 `SameMapper`。SnakeMapper支持struct为驼峰式命名,表结构为下划线命名之间的转换;SameMapper支持结构体名称和对应的表名称以及结构体field名称与对应的表字段名称相同的命名。 当前SnakeMapper为默认值,如果需要改变时,在engine创建完成后使用 ```Go -engine.Mapper = SameMapper{} +engine.SetMapper(SameMapper{}) ``` -当然,如果你使用了别的命名规则映射方案,也可以自己实现一个IMapper。 +同时需要注意的是: + +* 如果你使用了别的命名规则映射方案,也可以自己实现一个IMapper。 +* 表名称和字段名称的映射规则默认是相同的,当然也可以设置为不同,如: + +```Go +engine.SetTableMapper(SameMapper{}) +engine.SetColumnMapper(SnakeMapper{}) +``` -### 2.2.使用Table和Tag改变名称映射 +### 2.2.前缀映射规则和后缀映射规则 + +* 通过`engine.NewPrefixMapper(SnakeMapper{}, "prefix")`可以在SnakeMapper的基础上在命名中添加统一的前缀,当然也可以把SnakeMapper{}换成SameMapper或者你自定义的Mapper。 +* 通过`engine.NewSufffixMapper(SnakeMapper{}, "suffix")`可以在SnakeMapper的基础上在命名中添加统一的后缀,当然也可以把SnakeMapper{}换成SameMapper或者你自定义的Mapper。 +* + + +### 2.3.使用Table和Tag改变名称映射 如果所有的命名都是按照IMapper的映射来操作的,那当然是最理想的。但是如果碰到某个表名或者某个字段名跟映射规则不匹配时,我们就需要别的机制来改变。 通过`engine.Table()`方法可以改变struct对应的数据库表的名称,通过sturct中field对应的Tag中使用`xorm:"'column_name'"`可以使该field对应的Column名称为指定名称。这里使用两个单引号将Column名称括起来是为了防止名称冲突,因为我们在Tag中还可以对这个Column进行更多的定义。如果名称不冲突的情况,单引号也可以不使用。 -### 2.3.Column属性定义 +### 2.4.Column属性定义 我们在field对应的Tag中对Column的一些属性进行定义,定义的方法基本和我们写SQL定义表结构类似,比如: ``` @@ -140,10 +159,10 @@ type User struct { - + - + @@ -152,7 +171,7 @@ type User struct { - + @@ -188,11 +207,11 @@ type User struct { 另外有如下几条自动映射的规则: -- 1.如果field名称为`Id`而且类型为`int64`的话,会被xorm视为主键,并且拥有自增属性。如果想用`Id`以外的名字做为主键名,可以在对应的Tag上加上`xorm:"pk"`来定义主键。 +- 1.如果field名称为`Id`而且类型为`int64`并且没有定义tag,则会被xorm视为主键,并且拥有自增属性。如果想用`Id`以外的名字或非int64类型做为主键名,必须在对应的Tag上加上`xorm:"pk"`来定义主键,加上`xorm:"autoincr"`作为自增。这里需要注意的是,有些数据库并不允许非主键的自增属性。 - 2.string类型默认映射为varchar(255),如果需要不同的定义,可以在tag中自定义 -- 3.支持`type MyString string`等自定义的field,支持Slice, Map等field成员,这些成员默认存储为Text类型,并且默认将使用Json格式来序列化和反序列化。也支持数据库字段类型为Blob类型,如果是Blob类型,则先使用Json格式序列化再转成[]byte格式。当然[]byte或者[]uint8默认为Blob类型并且都以二进制方式存储。 +- 3.支持`type MyString string`等自定义的field,支持Slice, Map等field成员,这些成员默认存储为Text类型,并且默认将使用Json格式来序列化和反序列化。也支持数据库字段类型为Blob类型,如果是Blob类型,则先使用Json格式序列化再转成[]byte格式。当然[]byte或者[]uint8默认为Blob类型并且都以二进制方式存储。具体参见 [go类型<->数据库类型对应表](https://github.com/lunny/xorm/blob/master/docs/AutoMap.md) - 4.实现了Conversion接口的类型或者结构体,将根据接口的转换方式在类型和数据库记录之间进行相互转换。 ```Go @@ -218,41 +237,50 @@ xorm提供了一些动态获取和修改表结构的方法。对于一般的应 ## 3.1 获取数据库信息 * DBMetas() -xorm支持获取表结构信息,通过调用`engine.DBMetas()`可以获取到所有的表的信息 + +xorm支持获取表结构信息,通过调用`engine.DBMetas()`可以获取到所有的表,字段,索引的信息。 ## 3.2.表操作 * CreateTables() + 创建表使用`engine.CreateTables()`,参数为一个或多个空的对应Struct的指针。同时可用的方法有Charset()和StoreEngine(),如果对应的数据库支持,这两个方法可以在创建表时指定表的字符编码和使用的引擎。当前仅支持Mysql数据库。 * IsTableEmpty() + 判断表是否为空,参数和CreateTables相同 * IsTableExist() + 判断表是否存在 * DropTables() + 删除表使用`engine.DropTables()`,参数为一个或多个空的对应Struct的指针或者表的名字。如果为string传入,则只删除对应的表,如果传入的为Struct,则删除表的同时还会删除对应的索引。 ## 3.3.创建索引和唯一索引 * CreateIndexes + 根据struct中的tag来创建索引 * CreateUniques + 根据struct中的tag来创建唯一索引 ## 3.4.同步数据库结构 同步能够部分智能的根据结构体的变动检测表结构的变动,并自动同步。目前能够实现: -1) 自动检测和创建表,这个检测是根据表的名字 -2)自动检测和新增表中的字段,这个检测是根据字段名 -3)自动检测和创建索引和唯一索引,这个检测是根据一个或多个字段名,而不根据索引名称 + +* 1) 自动检测和创建表,这个检测是根据表的名字 +* 2)自动检测和新增表中的字段,这个检测是根据字段名 +* 3)自动检测和创建索引和唯一索引,这个检测是根据一个或多个字段名,而不根据索引名称 调用方法如下: + ```Go err := engine.Sync(new(User)) ``` @@ -264,18 +292,21 @@ err := engine.Sync(new(User)) 如果传入的是Slice并且当数据库支持批量插入时,Insert会使用批量插入的方式进行插入。 * 插入一条数据 + ```Go user := new(User) user.Name = "myname" affected, err := engine.Insert(user) ``` -在插入成功后,如果该结构体有PK字段,则PK字段会被自动赋值为数据库中的id +在插入单条数据成功后,如果该结构体有自增字段,则自增字段会被自动赋值为数据库中的id + ```Go fmt.Println(user.Id) ``` * 插入同一个表的多条数据 + ```Go users := make([]User, 0) users[0].Name = "name0" @@ -284,6 +315,7 @@ affected, err := engine.Insert(&users) ``` * 使用指针Slice插入多条记录 + ```Go users := make([]*User, 0) users[0] = new(User) @@ -293,6 +325,7 @@ affected, err := engine.Insert(&users) ``` * 插入不同表的一条记录 + ```Go user := new(User) user.Name = "myname" @@ -302,6 +335,7 @@ affected, err := engine.Insert(user, question) ``` * 插入不同表的多条记录 + ```Go users := make([]User, 0) users[0].Name = "name0" @@ -321,25 +355,27 @@ questions[0].Content = "whywhywhwy?" affected, err := engine.Insert(user, &questions) ``` -注意:这里虽然支持同时插入,但这些插入并没有事务关系。因此有可能在中间插入出错后,后面的插入将不会继续。 +这里需要注意以下几点: +* 这里虽然支持同时插入,但这些插入并没有事务关系。因此有可能在中间插入出错后,后面的插入将不会继续。 +* 多条插入会自动生成`Insert into table values (),(),()`的语句,因此这样的语句有一个最大的记录数,根据经验测算在150条左右。大于150条后,生成的sql语句将太长可能导致执行失败。因此在插入大量数据时,目前需要自行分割成每150条插入一次。 ## 5.查询和统计数据 -所有的查询条件不区分调用顺序,但必须在调用Get,Find,Count这三个函数之前调用。同时需要注意的一点是,在调用的参数中,所有的字符字段名均为映射后的数据库的字段名,而不是field的名字。 +所有的查询条件不区分调用顺序,但必须在调用Get,Find,Count, Iterate, Rows这几个函数之前调用。同时需要注意的一点是,在调用的参数中,如果采用默认的`SnakeMapper`所有的字符字段名均为映射后的数据库的字段名,而不是field的名字。 ### 5.1.查询条件方法 -查询和统计主要使用`Get`, `Find`, `Count`三个方法。在进行查询时可以使用多个方法来形成查询条件,条件函数如下: +查询和统计主要使用`Get`, `Find`, `Count`, `Rows`, `Iterate`这几个方法。在进行查询时可以使用多个方法来形成查询条件,条件函数如下: * Id(interface{}) 传入一个PK字段的值,作为查询条件,如果是复合主键,则 `Id(xorm.PK{1, 2})` -传入的两个参数按照struct中定义的顺序赋值。 +传入的两个参数按照struct中pk标记字段出现的顺序赋值。 * Where(string, …interface{}) -和Where语句中的条件基本相同,作为条件 +和SQL中Where语句中的条件基本相同,作为条件 * And(string, …interface{}) 和Where函数中的条件基本相同,作为条件 @@ -360,7 +396,7 @@ affected, err := engine.Insert(user, &questions) 按照指定的顺序进行排序 * In(string, …interface{}) -某字段在一些值中 +某字段在一些值中,这里需要注意必须是[]interface{}才可以展开,由于Go语言的限制,[]int64等均不可以展开。 * Cols(…string) 只查询或更新某些指定的字段,默认是查询所有映射的字段或者根据Update的第一个参数来判断更新的字段。例如: @@ -429,20 +465,28 @@ Having的参数字符串 如: 1) 根据Id来获得单条数据: + ```Go user := new(User) has, err := engine.Id(id).Get(user) +// 复合主键的获取方法 +// has, errr := engine.Id(xorm.PK{1,2}).Get(user) ``` + 2) 根据Where来获得单条数据: + ```Go user := new(User) has, err := engine.Where("name=?", "xlw").Get(user) ``` + 3) 根据user结构体中已有的非空数据来获得单条数据: + ```Go user := &User{Id:1} has, err := engine.Get(user) ``` + 或者其它条件 ```Go @@ -458,6 +502,7 @@ has, err := engine.Get(user) 查询多条数据使用`Find`方法,Find方法的第一个参数为`slice`的指针或`Map`指针,即为查询后返回的结果,第二个参数可选,为查询的条件struct的指针。 1) 传入Slice用于返回数据 + ```Go everyone := make([]Userinfo, 0) err := engine.Find(&everyone) @@ -466,7 +511,8 @@ pEveryOne := make([]*Userinfo, 0) err := engine.Find(&pEveryOne) ``` -2) 传入Map用户返回数据,map必须为`map[int64]Userinfo`的形式,map的key为id +2) 传入Map用户返回数据,map必须为`map[int64]Userinfo`的形式,map的key为id,因此对于复合主键无法使用这种方式。 + ```Go users := make(map[int64]Userinfo) err := engine.Find(&users) @@ -476,6 +522,7 @@ err := engine.Find(&pUsers) ``` 3) 也可以加入各种条件 + ```Go users := make([]Userinfo, 0) err := engine.Where("age > ? or name = ?", 30, "xlw").Limit(20, 10).Find(&users) @@ -485,6 +532,7 @@ err := engine.Where("age > ? or name = ?", 30, "xlw").Limit(20, 10).Find(&users) ### 5.5.Iterate方法 Iterate方法提供逐条执行查询到的记录的方法,他所能使用的条件和Find方法完全相同 + ```Go err := engine.Where("age > ? or name=?)", 30, "xlw").Iterate(new(Userinfo), func(i int, bean interface{})error{ user := bean.(*Userinfo) @@ -501,6 +549,22 @@ user := new(User) total, err := engine.Where("id >?", 1).Count(user) ``` + +### 5.7.Rows方法 + +Rows方法和Iterate方法类似,提供逐条执行查询到的记录的方法,不过Rows更加灵活好用。 +```Go +user := new(User) +rows, err := engine.Where("id >?", 1).Rows(user) +if err != nil { +} +defer rows.Close() +for rows.Next() { + err = rows.Scan(user) + //... +} +``` + ## 6.更新数据 @@ -514,12 +578,14 @@ affected, err := engine.Id(id).Update(user) 这里需要注意,Update会自动从user结构体中提取非0和非nil得值作为需要更新的内容,因此,如果需要更新一个值为0,则此种方法将无法实现,因此有两种选择: -1. 通过添加Cols函数指定需要更新结构体中的哪些值,未指定的将不更新,指定了的即使为0也会更新。 +* 1.通过添加Cols函数指定需要更新结构体中的哪些值,未指定的将不更新,指定了的即使为0也会更新。 + ```Go affected, err := engine.Id(id).Cols("age").Update(&user) ``` -2. 通过传入map[string]interface{}来进行更新,但这时需要额外指定更新到哪个表,因为通过map是无法自动检测更新哪个表的。 +* 2.通过传入map[string]interface{}来进行更新,但这时需要额外指定更新到哪个表,因为通过map是无法自动检测更新哪个表的。 + ```Go affected, err := engine.Table(new(User)).Id(id).Update(map[string]interface{}{"age":0}) ``` @@ -548,6 +614,7 @@ engine.Id(1).Update(&user) ## 7.删除数据 删除数据`Delete`方法,参数为struct的指针并且成为查询条件。 + ```Go user := new(User) affected, err := engine.Id(id).Delete(user) @@ -561,29 +628,33 @@ affected, err := engine.Id(id).Delete(user) ## 8.执行SQL查询 也可以直接执行一个SQL查询,即Select命令。在Postgres中支持原始SQL语句中使用 ` 和 ? 符号。 + ```Go sql := "select * from userinfo" results, err := engine.Query(sql) ``` +当调用`Query`时,第一个返回值`results`为`[]map[string][]byte`的形式。 + ## 9.执行SQL命令 -也可以直接执行一个SQL命令,即执行Insert, Update, Delete 等操作。同样在Postgres中支持原始SQL语句中使用 ` 和 ? 符号。 +也可以直接执行一个SQL命令,即执行Insert, Update, Delete 等操作。此时不管数据库是何种类型,都可以使用 ` 和 ? 符号。 + ```Go -sql = "update userinfo set username=? where id=?" +sql = "update `userinfo` set username=? where id=?" res, err := engine.Exec(sql, "xiaolun", 1) ``` ## 10.事务处理 -当使用事务处理时,需要创建Session对象。 +当使用事务处理时,需要创建Session对象。在进行事物处理时,可以混用ORM方法和RAW方法,如下代码所示: ```Go session := engine.NewSession() defer session.Close() // add Begin() before any action -err := session.Begin() +err := session.Begin() user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} _, err = session.Insert(&user1) if err != nil { @@ -620,14 +691,17 @@ xorm内置了一致性缓存支持,不过默认并没有开启。要开启缓 cacher := xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000) engine.SetDefaultCacher(cacher) ``` + 上述代码采用了LRU算法的一个缓存,缓存方式是存放到内存中,缓存struct的记录数为1000条,缓存针对的范围是所有具有主键的表,没有主键的表中的数据将不会被缓存。 如果只想针对部分表,则: + ```Go cacher := xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000) engine.MapCacher(&user, cacher) ``` 如果要禁用某个表的缓存,则: + ```Go engine.MapCacher(&user, nil) ``` @@ -638,16 +712,17 @@ engine.MapCacher(&user, nil) 不过需要特别注意不适用缓存或者需要手动编码的地方: -1. 在Get或者Find时使用了Cols方法,在开启缓存后此方法无效,系统仍旧会取出这个表中的所有字段。 +1. 当使用了`Distinct`,`Having`,`GroupBy`方法将不会使用缓存 + +2. 在`Get`或者`Find`时使用了`Cols`,`Omit`方法,则在开启缓存后此方法无效,系统仍旧会取出这个表中的所有字段。 + +3. 在使用Exec方法执行了方法之后,可能会导致缓存与数据库不一致的地方。因此如果启用缓存,尽量避免使用Exec。如果必须使用,则需要在使用了Exec之后调用ClearCache手动做缓存清除的工作。比如: -2. 在使用Exec方法执行了方法之后,可能会导致缓存与数据库不一致的地方。因此如果启用缓存,尽量避免使用Exec。如果必须使用,则需要在使用了Exec之后调用ClearCache手动做缓存清除的工作。比如: ```Go engine.Exec("update user set name = ? where id = ?", "xlw", 1) engine.ClearCache(new(User)) ``` -ClearCacheBean - 缓存的实现原理如下图所示: ![cache design](https://raw.github.com/lunny/xorm/master/docs/cache_design.png) diff --git a/engine.go b/engine.go index 65d2f095..93494048 100644 --- a/engine.go +++ b/engine.go @@ -28,6 +28,7 @@ const ( // a dialect is a driver's wrapper type dialect interface { Init(DriverName, DataSourceName string) error + URI() *uri DBType() string SqlType(t *Column) string SupportInsertMany() bool @@ -472,15 +473,16 @@ func (engine *Engine) mapType(t reflect.Type) *Table { parentTable := engine.mapType(fieldType) for name, col := range parentTable.Columns { col.FieldName = fmt.Sprintf("%v.%v", fieldType.Name(), col.FieldName) - table.Columns[name] = col + table.Columns[strings.ToLower(name)] = col table.ColumnsSeq = append(table.ColumnsSeq, name) } - table.PrimaryKey = parentTable.PrimaryKey + table.PrimaryKeys = parentTable.PrimaryKeys continue } var indexType int var indexName string + var preKey string for j, key := range tags { k := strings.ToUpper(key) switch { @@ -519,12 +521,13 @@ func (engine *Engine) mapType(t reflect.Type) *Table { case k == "NOT": default: if strings.HasPrefix(k, "'") && strings.HasSuffix(k, "'") { - if key != col.Default { + if preKey != "DEFAULT" { col.Name = key[1 : len(key)-1] } } else if strings.Contains(k, "(") && strings.HasSuffix(k, ")") { fs := strings.Split(k, "(") if _, ok := sqlTypes[fs[0]]; !ok { + preKey = k continue } col.SQLType = SQLType{fs[0], 0, 0} @@ -538,12 +541,13 @@ func (engine *Engine) mapType(t reflect.Type) *Table { } else { if _, ok := sqlTypes[k]; ok { col.SQLType = SQLType{k, 0, 0} - } else if key != col.Default { + } else if preKey != "DEFAULT" { col.Name = key } } engine.SqlType(col) } + preKey = k } if col.SQLType.Name == "" { col.SQLType = Type2SQLType(fieldType) @@ -602,12 +606,13 @@ func (engine *Engine) mapType(t reflect.Type) *Table { } } - if idFieldColName != "" && table.PrimaryKey == "" { - col := table.Columns[idFieldColName] + if idFieldColName != "" && len(table.PrimaryKeys) == 0 { + col := table.Columns[strings.ToLower(idFieldColName)] col.IsPrimaryKey = true col.IsAutoIncrement = true col.Nullable = false - table.PrimaryKey = col.Name + table.PrimaryKeys = append(table.PrimaryKeys, col.Name) + table.AutoIncrement = col.Name } return table @@ -933,6 +938,13 @@ func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error { return session.Iterate(bean, fun) } +// Return sql.Rows compatible Rows obj, as a forward Iterator object for iterating record by record, bean's non-empty fields +// are conditions. +func (engine *Engine) Rows(bean interface{}) (*Rows, error) { + session := engine.NewSession() + return session.Rows(bean) +} + // Count counts the records. bean's non-empty fields // are conditions. func (engine *Engine) Count(bean interface{}) (int64, error) { @@ -972,18 +984,22 @@ func (engine *Engine) Import(ddlPath string) ([]sql.Result, error) { scanner.Split(semiColSpliter) session := engine.NewSession() - session.IsAutoClose = false + defer session.Close() + err = session.newDb() + if err != nil { + return results, err + } + for scanner.Scan() { query := scanner.Text() query = strings.Trim(query, " \t") if len(query) > 0 { - result, err := session.Exec(query) + result, err := session.Db.Exec(query) results = append(results, result) if err != nil { lastError = err } } } - session.Close() return results, lastError } diff --git a/filter.go b/filter.go index 5fff4c0d..d2b6c468 100644 --- a/filter.go +++ b/filter.go @@ -40,10 +40,10 @@ type IdFilter struct { } func (i *IdFilter) Do(sql string, session *Session) string { - if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { - sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) - sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) - return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) + if session.Statement.RefTable != nil && len(session.Statement.RefTable.PrimaryKeys) == 1 { + sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKeys[0]), -1) + sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKeys[0]), -1) + return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKeys[0]), -1) } return sql } diff --git a/helpers.go b/helpers.go index 307353c2..96f118f2 100644 --- a/helpers.go +++ b/helpers.go @@ -1,8 +1,12 @@ package xorm import ( + "database/sql" + "fmt" "reflect" + "strconv" "strings" + "time" ) func indexNoCase(s, sep string) int { @@ -61,3 +65,72 @@ func sliceEq(left, right []string) bool { return true } + +func value2Bytes(rawValue *reflect.Value) (data []byte, err error) { + + aa := reflect.TypeOf((*rawValue).Interface()) + vv := reflect.ValueOf((*rawValue).Interface()) + + var str string + switch aa.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(vv.Int(), 10) + data = []byte(str) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str = strconv.FormatUint(vv.Uint(), 10) + data = []byte(str) + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) + data = []byte(str) + case reflect.String: + str = vv.String() + data = []byte(str) + case reflect.Array, reflect.Slice: + switch aa.Elem().Kind() { + case reflect.Uint8: + data = rawValue.Interface().([]byte) + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + //时间类型 + case reflect.Struct: + if aa == reflect.TypeOf(c_TIME_DEFAULT) { + str = rawValue.Interface().(time.Time).Format(time.RFC3339Nano) + data = []byte(str) + } else { + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + case reflect.Bool: + str = strconv.FormatBool(vv.Bool()) + data = []byte(str) + case reflect.Complex128, reflect.Complex64: + str = fmt.Sprintf("%v", vv.Complex()) + data = []byte(str) + /* TODO: unsupported types below + case reflect.Map: + case reflect.Ptr: + case reflect.Uintptr: + case reflect.UnsafePointer: + case reflect.Chan, reflect.Func, reflect.Interface: + */ + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + return +} + +func rows2maps(rows *sql.Rows) (resultsSlice []map[string][]byte, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + for rows.Next() { + result, err := row2map(rows, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, result) + } + + return resultsSlice, nil +} diff --git a/mssql.go b/mssql.go index 0f4bf5e4..3606332f 100644 --- a/mssql.go +++ b/mssql.go @@ -22,6 +22,7 @@ type odbcParser struct { func (p *odbcParser) parse(driverName, dataSourceName string) (*uri, error) { kv := strings.Split(dataSourceName, ";") var dbName string + for _, c := range kv { vv := strings.Split(strings.TrimSpace(c), "=") if len(vv) == 2 { @@ -155,6 +156,7 @@ where a.object_id=object_id('` + tableName + `')` for name, content := range record { switch name { case "name": + col.Name = strings.Trim(string(content), "` ") case "ctype": ct := strings.ToUpper(string(content)) @@ -163,11 +165,14 @@ where a.object_id=object_id('` + tableName + `')` col.SQLType = SQLType{TimeStampz, 0, 0} case "NVARCHAR": col.SQLType = SQLType{Varchar, 0, 0} + case "IMAGE": + col.SQLType = SQLType{VarBinary, 0, 0} default: if _, ok := sqlTypes[ct]; ok { col.SQLType = SQLType{ct, 0, 0} } else { - return nil, nil, errors.New(fmt.Sprintf("unknow colType %v for %v", ct, col)) + return nil, nil, errors.New(fmt.Sprintf("unknow colType %v for %v - %v", + ct, tableName, col.Name)) } } diff --git a/mssql_test.go b/mssql_test.go index cde4081a..a9093866 100644 --- a/mssql_test.go +++ b/mssql_test.go @@ -4,18 +4,16 @@ package xorm // +build windows import ( + "database/sql" "testing" _ "github.com/lunny/godbc" ) -/* -CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET -utf8 COLLATE utf8_general_ci; -*/ +const mssqlConnStr = "driver={SQL Server};Server=192.168.20.135;Database=xorm_test; uid=sa; pwd=1234;" func newMssqlEngine() (*Engine, error) { - return NewEngine("odbc", "driver={SQL Server};Server=192.168.20.135;Database=xorm_test; uid=sa; pwd=1234;") + return NewEngine("odbc", mssqlConnStr) } func TestMssql(t *testing.T) { @@ -51,7 +49,41 @@ func TestMssqlWithCache(t *testing.T) { testAll2(engine, t) } -func BenchmarkMssqlNoCache(t *testing.B) { +func newMssqlDriverDB() (*sql.DB, error) { + return sql.Open("odbc", mssqlConnStr) +} + +const ( + createTableMssql = `IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = 'big_struct' ) CREATE TABLE + "big_struct" ("id" BIGINT PRIMARY KEY IDENTITY NOT NULL, "name" VARCHAR(255) NULL, "title" VARCHAR(255) NULL, + "age" VARCHAR(255) NULL, "alias" VARCHAR(255) NULL, "nick_name" VARCHAR(255) NULL); + ` + + dropTableMssql = "IF EXISTS (SELECT * FROM sysobjects WHERE id = object_id(N'big_struct') and OBJECTPROPERTY(id, N'IsUserTable') = 1) DROP TABLE IF EXISTS `big_struct`;" +) + +func BenchmarkMssqlDriverInsert(t *testing.B) { + doBenchDriver(newMssqlDriverDB, createTableMssql, dropTableMssql, + doBenchDriverInsert, t) +} + +func BenchmarkMssqlDriverFind(t *testing.B) { + doBenchDriver(newMssqlDriverDB, createTableMssql, dropTableMssql, + doBenchDriverFind, t) +} + +func BenchmarkMssqlNoCacheInsert(t *testing.B) { + engine, err := newMssqlEngine() + defer engine.Close() + if err != nil { + t.Error(err) + return + } + //engine.ShowSQL = true + doBenchInsert(engine, t) +} + +func BenchmarkMssqlNoCacheFind(t *testing.B) { engine, err := newMssqlEngine() defer engine.Close() if err != nil { @@ -62,7 +94,18 @@ func BenchmarkMssqlNoCache(t *testing.B) { doBenchFind(engine, t) } -func BenchmarkMssqlCache(t *testing.B) { +func BenchmarkMssqlNoCacheFindPtr(t *testing.B) { + engine, err := newMssqlEngine() + defer engine.Close() + if err != nil { + t.Error(err) + return + } + //engine.ShowSQL = true + doBenchFindPtr(engine, t) +} + +func BenchmarkMssqlCacheInsert(t *testing.B) { engine, err := newMssqlEngine() defer engine.Close() if err != nil { @@ -70,5 +113,30 @@ func BenchmarkMssqlCache(t *testing.B) { return } engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) + + doBenchInsert(engine, t) +} + +func BenchmarkMssqlCacheFind(t *testing.B) { + engine, err := newMssqlEngine() + defer engine.Close() + if err != nil { + t.Error(err) + return + } + engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) + doBenchFind(engine, t) } + +func BenchmarkMssqlCacheFindPtr(t *testing.B) { + engine, err := newMssqlEngine() + defer engine.Close() + if err != nil { + t.Error(err) + return + } + engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) + + doBenchFindPtr(engine, t) +} diff --git a/mysql.go b/mysql.go index 6e51780b..4dcde839 100644 --- a/mysql.go +++ b/mysql.go @@ -33,7 +33,6 @@ type mysqlParser struct { } func (p *mysqlParser) parse(driverName, dataSourceName string) (*uri, error) { - //cfg.params = make(map[string]string) dsnPattern := regexp.MustCompile( `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] @@ -49,6 +48,20 @@ func (p *mysqlParser) parse(driverName, dataSourceName string) (*uri, error) { switch names[i] { case "dbname": uri.dbName = match + case "params": + if len(match) > 0 { + kvs := strings.Split(match, "&") + for _, kv := range kvs { + splits := strings.Split(kv, "=") + if len(splits) == 2 { + switch splits[0] { + case "charset": + uri.charset = splits[1] + } + } + } + } + } } return uri, nil @@ -68,6 +81,10 @@ func (b *base) init(parser parser, drivername, dataSourceName string) (err error return } +func (b *base) URI() *uri { + return b.uri +} + func (b *base) DBType() string { return b.uri.dbType } diff --git a/mysql_test.go b/mysql_test.go index 2e106231..e0c3deac 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -12,8 +12,6 @@ CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET utf8 COLLATE utf8_general_ci; */ -var mysqlShowTestSql bool = true - func TestMysql(t *testing.T) { err := mysqlDdlImport() if err != nil { @@ -27,10 +25,34 @@ func TestMysql(t *testing.T) { t.Error(err) return } - engine.ShowSQL = mysqlShowTestSql - engine.ShowErr = mysqlShowTestSql - engine.ShowWarn = mysqlShowTestSql - engine.ShowDebug = mysqlShowTestSql + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql + + testAll(engine, t) + testAll2(engine, t) + testAll3(engine, t) +} + +func TestMysqlSameMapper(t *testing.T) { + err := mysqlDdlImport() + if err != nil { + t.Error(err) + return + } + + engine, err := NewEngine("mysql", "root:@/xorm_test3?charset=utf8") + defer engine.Close() + if err != nil { + t.Error(err) + return + } + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql + engine.SetMapper(SameMapper{}) testAll(engine, t) testAll2(engine, t) @@ -51,10 +73,10 @@ func TestMysqlWithCache(t *testing.T) { return } engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) - engine.ShowSQL = mysqlShowTestSql - engine.ShowErr = mysqlShowTestSql - engine.ShowWarn = mysqlShowTestSql - engine.ShowDebug = mysqlShowTestSql + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql testAll(engine, t) testAll2(engine, t) @@ -69,10 +91,10 @@ func mysqlDdlImport() error { if err != nil { return err } - engine.ShowSQL = mysqlShowTestSql - engine.ShowErr = mysqlShowTestSql - engine.ShowWarn = mysqlShowTestSql - engine.ShowDebug = mysqlShowTestSql + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql sqlResults, _ := engine.Import("tests/mysql_ddl.sql") engine.LogDebug("sql results: %v", sqlResults) diff --git a/oracle.go b/oracle.go index b863b6ab..4e3c6fb6 100644 --- a/oracle.go +++ b/oracle.go @@ -1,258 +1,258 @@ -package xorm - -import ( - "database/sql" - "errors" - "fmt" - "regexp" - "strconv" - "strings" -) - -type oracle struct { - base -} - -type oracleParser struct { -} - -//dataSourceName=user/password@ipv4:port/dbname -//dataSourceName=user/password@[ipv6]:port/dbname -func (p *oracleParser) parse(driverName, dataSourceName string) (*uri, error) { - db := &uri{dbType: ORACLE_OCI} - dsnPattern := regexp.MustCompile( - `^(?P.*)\/(?P.*)@` + // user:password@ - `(?P.*)` + // ip:port - `\/(?P.*)`) // dbname - matches := dsnPattern.FindStringSubmatch(dataSourceName) - names := dsnPattern.SubexpNames() - for i, match := range matches { - switch names[i] { - case "dbname": - db.dbName = match - } - } - if db.dbName == "" { - return nil, errors.New("dbname is empty") - } - return db, nil -} - -func (db *oracle) Init(drivername, uri string) error { - return db.base.init(&oracleParser{}, drivername, uri) -} - -func (db *oracle) SqlType(c *Column) string { - var res string - switch t := c.SQLType.Name; t { - case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool, Serial, BigSerial: - return "NUMBER" - case Binary, VarBinary, Blob, TinyBlob, MediumBlob, LongBlob, Bytea: - return Blob - case Time, DateTime, TimeStamp: - res = TimeStamp - case TimeStampz: - res = "TIMESTAMP WITH TIME ZONE" - case Float, Double, Numeric, Decimal: - res = "NUMBER" - case Text, MediumText, LongText: - res = "CLOB" - case Char, Varchar, TinyText: - return "VARCHAR2" - default: - res = t - } - - var hasLen1 bool = (c.Length > 0) - var hasLen2 bool = (c.Length2 > 0) - if hasLen1 { - res += "(" + strconv.Itoa(c.Length) + ")" - } else if hasLen2 { - res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" - } - return res -} - -func (db *oracle) SupportInsertMany() bool { - return true -} - -func (db *oracle) QuoteStr() string { - return "\"" -} - -func (db *oracle) AutoIncrStr() string { - return "" -} - -func (db *oracle) SupportEngine() bool { - return false -} - -func (db *oracle) SupportCharset() bool { - return false -} - -func (db *oracle) IndexOnTable() bool { - return false -} - -func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(idxName)} - return `SELECT INDEX_NAME FROM USER_INDEXES ` + - `WHERE TABLE_NAME = ? AND INDEX_NAME = ?`, args -} - -func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{strings.ToUpper(tableName)} - return `SELECT table_name FROM user_tables WHERE table_name = ?`, args -} - -func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(colName)} - return "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + - " AND column_name = ?", args -} - -func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, error) { - args := []interface{}{strings.ToUpper(tableName)} - s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + - "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" - - cnn, err := sql.Open(db.driverName, db.dataSourceName) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, nil, err - } - cols := make(map[string]*Column) - colSeq := make([]string, 0) - for _, record := range res { - col := new(Column) - col.Indexes = make(map[string]bool) - for name, content := range record { - switch name { - case "column_name": - col.Name = strings.Trim(string(content), `" `) - case "data_default": - col.Default = string(content) - case "nullable": - if string(content) == "Y" { - col.Nullable = true - } else { - col.Nullable = false - } - case "data_type": - ct := string(content) - switch ct { - case "VARCHAR2": - col.SQLType = SQLType{Varchar, 0, 0} - case "TIMESTAMP WITH TIME ZONE": - col.SQLType = SQLType{TimeStamp, 0, 0} - default: - col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} - } - if _, ok := sqlTypes[col.SQLType.Name]; !ok { - return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) - } - case "data_length": - i, err := strconv.Atoi(string(content)) - if err != nil { - return nil, nil, errors.New("retrieve length error") - } - col.Length = i - case "data_precision": - case "data_scale": - } - } - if col.SQLType.IsText() { - if col.Default != "" { - col.Default = "'" + col.Default + "'" - } - } - cols[col.Name] = col - colSeq = append(colSeq, col.Name) - } - - return colSeq, cols, nil -} - -func (db *oracle) GetTables() ([]*Table, error) { - args := []interface{}{} - s := "SELECT table_name FROM user_tables" - cnn, err := sql.Open(db.driverName, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } - - tables := make([]*Table, 0) - for _, record := range res { - table := new(Table) - for name, content := range record { - switch name { - case "table_name": - table.Name = string(content) - } - } - tables = append(tables, table) - } - return tables, nil -} - -func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) { - args := []interface{}{tableName} - s := "SELECT t.column_name,i.table_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + - "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1" - - cnn, err := sql.Open(db.driverName, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } - - indexes := make(map[string]*Index, 0) - for _, record := range res { - var indexType int - var indexName string - var colName string - - for name, content := range record { - switch name { - case "index_name": - indexName = strings.Trim(string(content), `" `) - case "uniqueness": - c := string(content) - if c == "UNIQUE" { - indexType = UniqueType - } else { - indexType = IndexType - } - case "column_name": - colName = string(content) - } - } - - var index *Index - var ok bool - if index, ok = indexes[indexName]; !ok { - index = new(Index) - index.Type = indexType - index.Name = indexName - indexes[indexName] = index - } - index.AddColumn(colName) - } - return indexes, nil -} +package xorm + +import ( + "database/sql" + "errors" + "fmt" + "regexp" + "strconv" + "strings" +) + +type oracle struct { + base +} + +type oracleParser struct { +} + +//dataSourceName=user/password@ipv4:port/dbname +//dataSourceName=user/password@[ipv6]:port/dbname +func (p *oracleParser) parse(driverName, dataSourceName string) (*uri, error) { + db := &uri{dbType: ORACLE_OCI} + dsnPattern := regexp.MustCompile( + `^(?P.*)\/(?P.*)@` + // user:password@ + `(?P.*)` + // ip:port + `\/(?P.*)`) // dbname + matches := dsnPattern.FindStringSubmatch(dataSourceName) + names := dsnPattern.SubexpNames() + for i, match := range matches { + switch names[i] { + case "dbname": + db.dbName = match + } + } + if db.dbName == "" { + return nil, errors.New("dbname is empty") + } + return db, nil +} + +func (db *oracle) Init(drivername, uri string) error { + return db.base.init(&oracleParser{}, drivername, uri) +} + +func (db *oracle) SqlType(c *Column) string { + var res string + switch t := c.SQLType.Name; t { + case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool, Serial, BigSerial: + return "NUMBER" + case Binary, VarBinary, Blob, TinyBlob, MediumBlob, LongBlob, Bytea: + return Blob + case Time, DateTime, TimeStamp: + res = TimeStamp + case TimeStampz: + res = "TIMESTAMP WITH TIME ZONE" + case Float, Double, Numeric, Decimal: + res = "NUMBER" + case Text, MediumText, LongText: + res = "CLOB" + case Char, Varchar, TinyText: + return "VARCHAR2" + default: + res = t + } + + var hasLen1 bool = (c.Length > 0) + var hasLen2 bool = (c.Length2 > 0) + if hasLen1 { + res += "(" + strconv.Itoa(c.Length) + ")" + } else if hasLen2 { + res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" + } + return res +} + +func (db *oracle) SupportInsertMany() bool { + return true +} + +func (db *oracle) QuoteStr() string { + return "\"" +} + +func (db *oracle) AutoIncrStr() string { + return "" +} + +func (db *oracle) SupportEngine() bool { + return false +} + +func (db *oracle) SupportCharset() bool { + return false +} + +func (db *oracle) IndexOnTable() bool { + return false +} + +func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) { + args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(idxName)} + return `SELECT INDEX_NAME FROM USER_INDEXES ` + + `WHERE TABLE_NAME = ? AND INDEX_NAME = ?`, args +} + +func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) { + args := []interface{}{strings.ToUpper(tableName)} + return `SELECT table_name FROM user_tables WHERE table_name = ?`, args +} + +func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface{}) { + args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(colName)} + return "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + + " AND column_name = ?", args +} + +func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, error) { + args := []interface{}{strings.ToUpper(tableName)} + s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + + "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" + + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, nil, err + } + cols := make(map[string]*Column) + colSeq := make([]string, 0) + for _, record := range res { + col := new(Column) + col.Indexes = make(map[string]bool) + for name, content := range record { + switch name { + case "column_name": + col.Name = strings.Trim(string(content), `" `) + case "data_default": + col.Default = string(content) + case "nullable": + if string(content) == "Y" { + col.Nullable = true + } else { + col.Nullable = false + } + case "data_type": + ct := string(content) + switch ct { + case "VARCHAR2": + col.SQLType = SQLType{Varchar, 0, 0} + case "TIMESTAMP WITH TIME ZONE": + col.SQLType = SQLType{TimeStamp, 0, 0} + default: + col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} + } + if _, ok := sqlTypes[col.SQLType.Name]; !ok { + return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) + } + case "data_length": + i, err := strconv.Atoi(string(content)) + if err != nil { + return nil, nil, errors.New("retrieve length error") + } + col.Length = i + case "data_precision": + case "data_scale": + } + } + if col.SQLType.IsText() { + if col.Default != "" { + col.Default = "'" + col.Default + "'" + } + } + cols[col.Name] = col + colSeq = append(colSeq, col.Name) + } + + return colSeq, cols, nil +} + +func (db *oracle) GetTables() ([]*Table, error) { + args := []interface{}{} + s := "SELECT table_name FROM user_tables" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } + + tables := make([]*Table, 0) + for _, record := range res { + table := new(Table) + for name, content := range record { + switch name { + case "table_name": + table.Name = string(content) + } + } + tables = append(tables, table) + } + return tables, nil +} + +func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) { + args := []interface{}{tableName} + s := "SELECT t.column_name,i.table_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + + "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1" + + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } + + indexes := make(map[string]*Index, 0) + for _, record := range res { + var indexType int + var indexName string + var colName string + + for name, content := range record { + switch name { + case "index_name": + indexName = strings.Trim(string(content), `" `) + case "uniqueness": + c := string(content) + if c == "UNIQUE" { + indexType = UniqueType + } else { + indexType = IndexType + } + case "column_name": + colName = string(content) + } + } + + var index *Index + var ok bool + if index, ok = indexes[indexName]; !ok { + index = new(Index) + index.Type = indexType + index.Name = indexName + indexes[indexName] = index + } + index.AddColumn(colName) + } + return indexes, nil +} diff --git a/postgres.go b/postgres.go index c316f9b5..97550543 100644 --- a/postgres.go +++ b/postgres.go @@ -67,7 +67,11 @@ func (db *postgres) SqlType(c *Column) string { switch t := c.SQLType.Name; t { case TinyInt: res = SmallInt + return res case MediumInt, Int, Integer: + if c.IsAutoIncrement { + return Serial + } return Integer case Serial, BigSerial: c.IsAutoIncrement = true diff --git a/postgres_test.go b/postgres_test.go index e7cc363f..9657cffd 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -7,12 +7,16 @@ import ( _ "github.com/lib/pq" ) +//var connStr string = "dbname=xorm_test user=lunny password=1234 sslmode=disable" + +var connStr string = "dbname=xorm_test sslmode=disable" + func newPostgresEngine() (*Engine, error) { - return NewEngine("postgres", "dbname=xorm_test sslmode=disable") + return NewEngine("postgres", connStr) } func newPostgresDriverDB() (*sql.DB, error) { - return sql.Open("postgres", "dbname=xorm_test sslmode=disable") + return sql.Open("postgres", connStr) } func TestPostgres(t *testing.T) { diff --git a/rows.go b/rows.go new file mode 100644 index 00000000..0ac6c956 --- /dev/null +++ b/rows.go @@ -0,0 +1,145 @@ +package xorm + +import ( + "database/sql" + "fmt" + "reflect" +) + +type Rows struct { + NoTypeCheck bool + + session *Session + stmt *sql.Stmt + rows *sql.Rows + fields []string + fieldsCount int + beanType reflect.Type + lastError error +} + +func newRows(session *Session, bean interface{}) (*Rows, error) { + rows := new(Rows) + rows.session = session + rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type() + + err := rows.session.newDb() + if err != nil { + return nil, err + } + + defer rows.session.Statement.Init() + + var sql string + var args []interface{} + rows.session.Statement.RefTable = rows.session.Engine.autoMap(bean) + if rows.session.Statement.RawSQL == "" { + sql, args = rows.session.Statement.genGetSql(bean) + } else { + sql = rows.session.Statement.RawSQL + args = rows.session.Statement.RawParams + } + + for _, filter := range rows.session.Engine.Filters { + sql = filter.Do(sql, session) + } + + rows.session.Engine.LogSQL(sql) + rows.session.Engine.LogSQL(args) + + rows.stmt, err = rows.session.Db.Prepare(sql) + if err != nil { + rows.lastError = err + defer rows.Close() + return nil, err + } + + rows.rows, err = rows.stmt.Query(args...) + if err != nil { + rows.lastError = err + defer rows.Close() + return nil, err + } + + rows.fields, err = rows.rows.Columns() + if err != nil { + rows.lastError = err + defer rows.Close() + return nil, err + } + rows.fieldsCount = len(rows.fields) + + return rows, nil +} + +// move cursor to next record, return false if end has reached +func (rows *Rows) Next() bool { + if rows.lastError == nil && rows.rows != nil { + hasNext := rows.rows.Next() + if !hasNext { + rows.lastError = sql.ErrNoRows + } + return hasNext + } + return false +} + +// Err returns the error, if any, that was encountered during iteration. Err may be called after an explicit or implicit Close. +func (rows *Rows) Err() error { + return rows.lastError +} + +// scan row record to bean properties +func (rows *Rows) Scan(bean interface{}) error { + if rows.lastError != nil { + return rows.lastError + } + + if !rows.NoTypeCheck && reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { + return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) + } + + return rows.session.row2Bean(rows.rows, rows.fields, rows.fieldsCount, bean) + + // result, err := row2map(rows.rows, rows.fields) // !nashtsai! TODO remove row2map then scanMapIntoStruct conversation for better performance + // if err == nil { + // err = rows.session.scanMapIntoStruct(bean, result) + // } + // return err +} + +// // Columns returns the column names. Columns returns an error if the rows are closed, or if the rows are from QueryRow and there was a deferred error. +// func (rows *Rows) Columns() ([]string, error) { +// if rows.lastError == nil && rows.rows != nil { +// return rows.rows.Columns() +// } +// return nil, rows.lastError +// } + +// close session if session.IsAutoClose is true, and claimed any opened resources +func (rows *Rows) Close() error { + if rows.session.IsAutoClose { + defer rows.session.Close() + } + + if rows.lastError == nil { + if rows.rows != nil { + rows.lastError = rows.rows.Close() + if rows.lastError != nil { + defer rows.stmt.Close() + return rows.lastError + } + } + if rows.stmt != nil { + rows.lastError = rows.stmt.Close() + } + } else { + if rows.stmt != nil { + defer rows.stmt.Close() + } + if rows.rows != nil { + defer rows.rows.Close() + } + } + return rows.lastError +} diff --git a/session.go b/session.go index 9c0c4b22..bd647de6 100644 --- a/session.go +++ b/session.go @@ -24,9 +24,6 @@ type Session struct { IsAutoClose bool // !nashtsai! storing these beans due to yet committed tx - // afterInsertBeans []interface{} - // afterUpdateBeans []interface{} - // afterDeleteBeans []interface{} afterInsertBeans map[interface{}]*[]func(interface{}) afterUpdateBeans map[interface{}]*[]func(interface{}) afterDeleteBeans map[interface{}]*[]func(interface{}) @@ -189,8 +186,8 @@ func (session *Session) Desc(colNames ...string) *Session { session.Statement.OrderStr += ", " } newColNames := col2NewCols(colNames...) - sql := strings.Join(newColNames, session.Engine.Quote(" DESC, ")) - session.Statement.OrderStr += session.Engine.Quote(sql) + " DESC" + sqlStr := strings.Join(newColNames, session.Engine.Quote(" DESC, ")) + session.Statement.OrderStr += session.Engine.Quote(sqlStr) + " DESC" return session } @@ -200,8 +197,8 @@ func (session *Session) Asc(colNames ...string) *Session { session.Statement.OrderStr += ", " } newColNames := col2NewCols(colNames...) - sql := strings.Join(newColNames, session.Engine.Quote(" ASC, ")) - session.Statement.OrderStr += session.Engine.Quote(sql) + " ASC" + sqlStr := strings.Join(newColNames, session.Engine.Quote(" ASC, ")) + session.Statement.OrderStr += session.Engine.Quote(sqlStr) + " ASC" return session } @@ -360,11 +357,13 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b for key, data := range objMap { key = strings.ToLower(key) - if _, ok := table.Columns[key]; !ok { + var col *Column + var ok bool + if col, ok = table.Columns[key]; !ok { session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.ColumnsSeq)) continue } - col := table.Columns[key] + fieldName := col.FieldName fieldPath := strings.Split(fieldName, ".") var fieldValue reflect.Value @@ -395,8 +394,8 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b } //Execute sql -func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, error) { - rs, err := session.Db.Prepare(sql) +func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) { + rs, err := session.Db.Prepare(sqlStr) if err != nil { return nil, err } @@ -409,22 +408,22 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, return res, nil } -func (session *Session) exec(sql string, args ...interface{}) (sql.Result, error) { +func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) { for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - session.Engine.LogSQL(sql) + session.Engine.LogSQL(sqlStr) session.Engine.LogSQL(args) if session.IsAutoCommit { - return session.innerExec(sql, args...) + return session.innerExec(sqlStr, args...) } - return session.Tx.Exec(sql, args...) + return session.Tx.Exec(sqlStr, args...) } // Exec raw sql -func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { +func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { err := session.newDb() if err != nil { return nil, err @@ -434,7 +433,7 @@ func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error defer session.Close() } - return session.exec(sql, args...) + return session.exec(sqlStr, args...) } // this function create a table according a bean @@ -467,8 +466,8 @@ func (session *Session) CreateIndexes(bean interface{}) error { } sqls := session.Statement.genIndexSQL() - for _, sql := range sqls { - _, err = session.exec(sql) + for _, sqlStr := range sqls { + _, err = session.exec(sqlStr) if err != nil { return err } @@ -490,8 +489,8 @@ func (session *Session) CreateUniques(bean interface{}) error { } sqls := session.Statement.genUniqueSQL() - for _, sql := range sqls { - _, err = session.exec(sql) + for _, sqlStr := range sqls { + _, err = session.exec(sqlStr) if err != nil { return err } @@ -500,9 +499,9 @@ func (session *Session) CreateUniques(bean interface{}) error { } func (session *Session) createOneTable() error { - sql := session.Statement.genCreateTableSQL() - session.Engine.LogDebug("create table sql: [", sql, "]") - _, err := session.exec(sql) + sqlStr := session.Statement.genCreateTableSQL() + session.Engine.LogDebug("create table sql: [", sqlStr, "]") + _, err := session.exec(sqlStr) return err } @@ -539,8 +538,8 @@ func (session *Session) DropIndexes(bean interface{}) error { } sqls := session.Statement.genDelIndexSQL() - for _, sql := range sqls { - _, err = session.exec(sql) + for _, sqlStr := range sqls { + _, err = session.exec(sqlStr) if err != nil { return err } @@ -570,16 +569,16 @@ func (session *Session) DropTable(bean interface{}) error { return errors.New("Unsupported type") } - sql := session.Statement.genDropSQL() - _, err = session.exec(sql) + sqlStr := session.Statement.genDropSQL() + _, err = session.exec(sqlStr) return err } -func (statement *Statement) convertIdSql(sql string) string { +func (statement *Statement) convertIdSql(sqlStr string) string { if statement.RefTable != nil { - col := statement.RefTable.PKColumn() + col := statement.RefTable.PKColumns()[0] if col != nil { - sqls := splitNNoCase(sql, "from", 2) + sqls := splitNNoCase(sqlStr, "from", 2) if len(sqls) != 2 { return "" } @@ -591,14 +590,15 @@ func (statement *Statement) convertIdSql(sql string) string { return "" } -func (session *Session) cacheGet(bean interface{}, sql string, args ...interface{}) (has bool, err error) { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { +func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { + // if has no reftable or number of pks is not equal to 1, then don't use cache currently + if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { return false, ErrCacheFailed } for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - newsql := session.Statement.convertIdSql(sql) + newsql := session.Statement.convertIdSql(sqlStr) if newsql == "" { return false, ErrCacheFailed } @@ -617,7 +617,7 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface if len(resultsSlice) > 0 { data := resultsSlice[0] var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + if v, ok := data[session.Statement.RefTable.PrimaryKeys[0]]; !ok { return false, ErrCacheFailed } else { id, err = strconv.ParseInt(string(v), 10, 64) @@ -670,19 +670,19 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface return false, nil } -func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr interface{}, args ...interface{}) (err error) { +func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { if session.Statement.RefTable == nil || - session.Statement.RefTable.PrimaryKey == "" || - indexNoCase(sql, "having") != -1 || - indexNoCase(sql, "group by") != -1 { + len(session.Statement.RefTable.PrimaryKeys) != 1 || + indexNoCase(sqlStr, "having") != -1 || + indexNoCase(sqlStr, "group by") != -1 { return ErrCacheFailed } for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - newsql := session.Statement.convertIdSql(sql) + newsql := session.Statement.convertIdSql(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -708,7 +708,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter for _, data := range resultsSlice { //fmt.Println(data) var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + if v, ok := data[session.Statement.RefTable.PrimaryKeys[0]]; !ok { return errors.New("no id") } else { id, err = strconv.ParseInt(string(v), 10, 64) @@ -729,7 +729,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - pkFieldName := session.Statement.RefTable.PKColumn().FieldName + pkFieldName := session.Statement.RefTable.PKColumns()[0].FieldName ididxes := make(map[int64]int) var ides []interface{} = make([]interface{}, 0) @@ -743,7 +743,18 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter } else { session.Engine.LogDebug("[xorm:cacheFind] cached bean:", tableName, id, bean) - sid := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName).Int() + pkField := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName) + + var sid int64 + switch pkField.Type().Kind() { + case reflect.Int32, reflect.Int, reflect.Int64: + sid = pkField.Int() + case reflect.Uint, reflect.Uint32, reflect.Uint64: + sid = int64(pkField.Uint()) + default: + return ErrCacheFailed + } + if sid != id { session.Engine.LogError("[xorm:cacheFind] error cache", id, sid, bean) return ErrCacheFailed @@ -795,7 +806,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter } } else if sliceValue.Kind() == reflect.Map { var key int64 - if table.PrimaryKey != "" { + if table.PrimaryKeys[0] != "" { key = ids[j] } else { key = int64(j) @@ -821,69 +832,38 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter // IterFunc only use by Iterate type IterFunc func(idx int, bean interface{}) error +// Return sql.Rows compatible Rows obj, as a forward Iterator object for iterating record by record, bean's non-empty fields +// are conditions. +func (session *Session) Rows(bean interface{}) (*Rows, error) { + return newRows(session, bean) +} + // Iterate record by record handle records from table, condiBeans's non-empty fields // are conditions. beans could be []Struct, []*Struct, map[int64]Struct // map[int64]*Struct func (session *Session) Iterate(bean interface{}, fun IterFunc) error { - err := session.newDb() + + rows, err := session.Rows(bean) if err != nil { return err - } - - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } - - var sql string - var args []interface{} - session.Statement.RefTable = session.Engine.autoMap(bean) - if session.Statement.RawSQL == "" { - sql, args = session.Statement.genGetSql(bean) } else { - sql = session.Statement.RawSQL - args = session.Statement.RawParams - } - - for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) - } - - session.Engine.LogSQL(sql) - session.Engine.LogSQL(args) - - s, err := session.Db.Prepare(sql) - if err != nil { - return err - } - defer s.Close() - rows, err := s.Query(args...) - if err != nil { - return err - } - defer rows.Close() - - fields, err := rows.Columns() - if err != nil { - return err - } - t := reflect.Indirect(reflect.ValueOf(bean)).Type() - b := reflect.New(t).Interface() - i := 0 - for rows.Next() { - result, err := row2map(rows, fields) - if err == nil { - err = session.scanMapIntoStruct(b, result) - } - if err == nil { + defer rows.Close() + //b := reflect.New(iterator.beanType).Interface() + i := 0 + for rows.Next() { + b := reflect.New(rows.beanType).Interface() + err = rows.Scan(b) + if err != nil { + return err + } err = fun(i, b) - i = i + 1 - } - if err != nil { - return err + if err != nil { + return err + } + i++ } + return err } - return nil } @@ -901,41 +881,68 @@ func (session *Session) Get(bean interface{}) (bool, error) { } session.Statement.Limit(1) - var sql string + var sqlStr string var args []interface{} session.Statement.RefTable = session.Engine.autoMap(bean) if session.Statement.RawSQL == "" { - sql, args = session.Statement.genGetSql(bean) + sqlStr, args = session.Statement.genGetSql(bean) } else { - sql = session.Statement.RawSQL + sqlStr = session.Statement.RawSQL args = session.Statement.RawParams } if session.Statement.RefTable.Cacher != nil && session.Statement.UseCache { - has, err := session.cacheGet(bean, sql, args...) + has, err := session.cacheGet(bean, sqlStr, args...) if err != ErrCacheFailed { return has, err } } - resultsSlice, err := session.query(sql, args...) + var rawRows *sql.Rows + session.queryPreprocess(&sqlStr, args...) + if session.IsAutoCommit { + stmt, err := session.Db.Prepare(sqlStr) + if err != nil { + return false, err + } + defer stmt.Close() + rawRows, err = stmt.Query(args...) + } else { + rawRows, err = session.Tx.Query(sqlStr, args...) + } if err != nil { return false, err } - if len(resultsSlice) < 1 { + + defer rawRows.Close() + + if rawRows.Next() { + if fields, err := rawRows.Columns(); err == nil { + err = session.row2Bean(rawRows, fields, len(fields), bean) + } + return true, err + } else { return false, nil } - err = session.scanMapIntoStruct(bean, resultsSlice[0]) - if err != nil { - return true, err - } - if len(resultsSlice) == 1 { - return true, nil - } else { - return true, errors.New("More than one record") - } + // resultsSlice, err := session.query(sqlStr, args...) + // if err != nil { + // return false, err + // } + // if len(resultsSlice) < 1 { + // return false, nil + // } + + // err = session.scanMapIntoStruct(bean, resultsSlice[0]) + // if err != nil { + // return true, err + // } + // if len(resultsSlice) == 1 { + // return true, nil + // } else { + // return true, errors.New("More than one record") + // } } // Count counts the records. bean's non-empty fields @@ -951,16 +958,16 @@ func (session *Session) Count(bean interface{}) (int64, error) { defer session.Close() } - var sql string + var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { - sql, args = session.Statement.genCountSql(bean) + sqlStr, args = session.Statement.genCountSql(bean) } else { - sql = session.Statement.RawSQL + sqlStr = session.Statement.RawSQL args = session.Statement.RawParams } - resultsSlice, err := session.query(sql, args...) + resultsSlice, err := session.query(sqlStr, args...) if err != nil { return 0, err } @@ -1021,7 +1028,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) session.Statement.BeanArgs = args } - var sql string + var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { var columnStr string = session.Statement.ColumnStr @@ -1031,51 +1038,103 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) session.Statement.attachInSql() - sql = session.Statement.genSelectSql(columnStr) + sqlStr = session.Statement.genSelectSql(columnStr) args = append(session.Statement.Params, session.Statement.BeanArgs...) } else { - sql = session.Statement.RawSQL + sqlStr = session.Statement.RawSQL args = session.Statement.RawParams } if table.Cacher != nil && session.Statement.UseCache && !session.Statement.IsDistinct { - err = session.cacheFind(sliceElementType, sql, rowsSlicePtr, args...) + err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) if err != ErrCacheFailed { return err } + err = nil // !nashtsai! reset err to nil for ErrCacheFailed session.Engine.LogWarn("Cache Find Failed") } - resultsSlice, err := session.query(sql, args...) - if err != nil { - return err - } + if sliceValue.Kind() != reflect.Map { + var rawRows *sql.Rows + var stmt *sql.Stmt - for i, results := range resultsSlice { - var newValue reflect.Value - if sliceElementType.Kind() == reflect.Ptr { - newValue = reflect.New(sliceElementType.Elem()) + session.queryPreprocess(&sqlStr, args...) + // err = session.queryRows(&stmt, &rawRows, sqlStr, args...) + // if err != nil { + // return err + // } + // if stmt != nil { + // defer stmt.Close() + // } + // defer rawRows.Close() + + if session.IsAutoCommit { + stmt, err = session.Db.Prepare(sqlStr) + if err != nil { + return err + } + defer stmt.Close() + rawRows, err = stmt.Query(args...) } else { - newValue = reflect.New(sliceElementType) + rawRows, err = session.Tx.Query(sqlStr, args...) } - err := session.scanMapIntoStruct(newValue.Interface(), results) if err != nil { return err } - if sliceValue.Kind() == reflect.Slice { + defer rawRows.Close() + + fields, err := rawRows.Columns() + if err != nil { + return err + } + + fieldsCount := len(fields) + + for rawRows.Next() { + var newValue reflect.Value if sliceElementType.Kind() == reflect.Ptr { - sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) + newValue = reflect.New(sliceElementType.Elem()) } else { - sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + newValue = reflect.New(sliceElementType) + } + err := session.row2Bean(rawRows, fields, fieldsCount, newValue.Interface()) + if err != nil { + return err + } + if sliceValue.Kind() == reflect.Slice { + if sliceElementType.Kind() == reflect.Ptr { + sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) + } else { + sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + } + } + } + } else { + resultsSlice, err := session.query(sqlStr, args...) + if err != nil { + return err + } + + for i, results := range resultsSlice { + var newValue reflect.Value + if sliceElementType.Kind() == reflect.Ptr { + newValue = reflect.New(sliceElementType.Elem()) + } else { + newValue = reflect.New(sliceElementType) + } + err := session.scanMapIntoStruct(newValue.Interface(), results) + if err != nil { + return err } - } else if sliceValue.Kind() == reflect.Map { var key int64 - if table.PrimaryKey != "" { - x, err := strconv.ParseInt(string(results[table.PrimaryKey]), 10, 64) + // if there is only one pk, we can put the id as map key. + // TODO: should know if the column is ints + if len(table.PrimaryKeys) == 1 { + x, err := strconv.ParseInt(string(results[table.PrimaryKeys[0]]), 10, 64) if err != nil { - return errors.New("pk " + table.PrimaryKey + " as int64: " + err.Error()) + return errors.New("pk " + table.PrimaryKeys[0] + " as int64: " + err.Error()) } key = x } else { @@ -1091,6 +1150,20 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return nil } +func (session *Session) queryRows(rawStmt **sql.Stmt, rawRows **sql.Rows, sqlStr string, args ...interface{}) error { + var err error + if session.IsAutoCommit { + *rawStmt, err = session.Db.Prepare(sqlStr) + if err != nil { + return err + } + *rawRows, err = (*rawStmt).Query(args...) + } else { + *rawRows, err = session.Tx.Query(sqlStr, args...) + } + return err +} + // Test if database is ok func (session *Session) Ping() error { err := session.newDb() @@ -1114,8 +1187,8 @@ func (session *Session) isColumnExist(tableName, colName string) (bool, error) { if session.IsAutoClose { defer session.Close() } - sql, args := session.Engine.dialect.ColumnCheckSql(tableName, colName) - results, err := session.query(sql, args...) + sqlStr, args := session.Engine.dialect.ColumnCheckSql(tableName, colName) + results, err := session.query(sqlStr, args...) return len(results) > 0, err } @@ -1128,8 +1201,8 @@ func (session *Session) isTableExist(tableName string) (bool, error) { if session.IsAutoClose { defer session.Close() } - sql, args := session.Engine.dialect.TableCheckSql(tableName) - results, err := session.query(sql, args...) + sqlStr, args := session.Engine.dialect.TableCheckSql(tableName) + results, err := session.query(sqlStr, args...) return len(results) > 0, err } @@ -1148,8 +1221,8 @@ func (session *Session) isIndexExist(tableName, idxName string, unique bool) (bo } else { idx = indexName(tableName, idxName) } - sql, args := session.Engine.dialect.IndexCheckSql(tableName, idx) - results, err := session.query(sql, args...) + sqlStr, args := session.Engine.dialect.IndexCheckSql(tableName, idx) + results, err := session.query(sqlStr, args...) return len(results) > 0, err } @@ -1182,7 +1255,8 @@ func (session *Session) addColumn(colName string) error { defer session.Close() } //fmt.Println(session.Statement.RefTable) - col := session.Statement.RefTable.Columns[colName] + + col := session.Statement.RefTable.Columns[strings.ToLower(colName)] sql, args := session.Statement.genAddColumnStr(col) _, err = session.exec(sql, args...) return err @@ -1199,8 +1273,8 @@ func (session *Session) addIndex(tableName, idxName string) error { } //fmt.Println(idxName) cols := session.Statement.RefTable.Indexes[idxName].Cols - sql, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) - _, err = session.exec(sql, args...) + sqlStr, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) + _, err = session.exec(sqlStr, args...) return err } @@ -1215,8 +1289,8 @@ func (session *Session) addUnique(tableName, uqeName string) error { } //fmt.Println(uqeName, session.Statement.RefTable.Uniques) cols := session.Statement.RefTable.Indexes[uqeName].Cols - sql, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) - _, err = session.exec(sql, args...) + sqlStr, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) + _, err = session.exec(sqlStr, args...) return err } @@ -1234,8 +1308,8 @@ func (session *Session) dropAll() error { for _, table := range session.Engine.Tables { session.Statement.Init() session.Statement.RefTable = table - sql := session.Statement.genDropSQL() - _, err := session.exec(sql) + sqlStr := session.Statement.genDropSQL() + _, err := session.exec(sqlStr) if err != nil { return err } @@ -1254,100 +1328,342 @@ func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err return nil, err } - // !nashtsai! TODO optimization for query performance, where current process has gone from - // sql driver converted type back to []bytes then to ORM's fields for ii, key := range fields { rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) - //if row is null then ignore if rawValue.Interface() == nil { //fmt.Println("ignore ...", key, rawValue) continue } - aa := reflect.TypeOf(rawValue.Interface()) - vv := reflect.ValueOf(rawValue.Interface()) - var str string - switch aa.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str = strconv.FormatInt(vv.Int(), 10) - result[key] = []byte(str) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str = strconv.FormatUint(vv.Uint(), 10) - result[key] = []byte(str) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) - result[key] = []byte(str) - case reflect.String: - str = vv.String() - result[key] = []byte(str) - case reflect.Array, reflect.Slice: - switch aa.Elem().Kind() { - case reflect.Uint8: - result[key] = rawValue.Interface().([]byte) - str = string(result[key]) - default: - return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) - } - //时间类型 - case reflect.Struct: - if aa == reflect.TypeOf(c_TIME_DEFAULT) { - str = rawValue.Interface().(time.Time).Format(time.RFC3339Nano) - result[key] = []byte(str) - } else { - return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) - } - case reflect.Bool: - str = strconv.FormatBool(vv.Bool()) - result[key] = []byte(str) - case reflect.Complex128, reflect.Complex64: - str = fmt.Sprintf("%v", vv.Complex()) - result[key] = []byte(str) - /* TODO: unsupported types below - case reflect.Map: - case reflect.Ptr: - case reflect.Uintptr: - case reflect.UnsafePointer: - case reflect.Chan, reflect.Func, reflect.Interface: - */ - default: - return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) + + if data, err := value2Bytes(&rawValue); err == nil { + result[key] = data + } else { + return nil, err // !nashtsai! REVIEW, should return err or just error log? } } return result, nil } -func rows2maps(rows *sql.Rows) (resultsSlice []map[string][]byte, err error) { - fields, err := rows.Columns() - if err != nil { - return nil, err +func (session *Session) getField(dataStruct *reflect.Value, key string, table *Table) *reflect.Value { + key = strings.ToLower(key) + if _, ok := table.Columns[key]; !ok { + session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.ColumnsSeq)) + return nil } - for rows.Next() { - result, err := row2map(rows, fields) - if err != nil { - return nil, err + col := table.Columns[key] + fieldName := col.FieldName + fieldPath := strings.Split(fieldName, ".") + var fieldValue reflect.Value + if len(fieldPath) > 2 { + session.Engine.LogError("Unsupported mutliderive", fieldName) + return nil + } else if len(fieldPath) == 2 { + parentField := dataStruct.FieldByName(fieldPath[0]) + if parentField.IsValid() { + fieldValue = parentField.FieldByName(fieldPath[1]) } - resultsSlice = append(resultsSlice, result) + } else { + fieldValue = dataStruct.FieldByName(fieldName) } - - return resultsSlice, nil + if !fieldValue.IsValid() || !fieldValue.CanSet() { + session.Engine.LogWarn("table %v's column %v is not valid or cannot set", + table.Name, key) + return nil + } + return &fieldValue } -func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { - for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) +func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount int, bean interface{}) error { + + dataStruct := reflect.Indirect(reflect.ValueOf(bean)) + if dataStruct.Kind() != reflect.Struct { + return errors.New("Expected a pointer to a struct") } - session.Engine.LogSQL(sql) + table := session.Engine.autoMapType(rType(bean)) + + var scanResultContainers []interface{} + for i := 0; i < fieldsCount; i++ { + var scanResultContainer interface{} + scanResultContainers = append(scanResultContainers, &scanResultContainer) + } + if err := rows.Scan(scanResultContainers...); err != nil { + return err + } + + for ii, key := range fields { + if fieldValue := session.getField(&dataStruct, key, table); fieldValue != nil { + + rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) + + //if row is null then ignore + if rawValue.Interface() == nil { + //fmt.Println("ignore ...", key, rawValue) + continue + } + + if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { + if data, err := value2Bytes(&rawValue); err == nil { + structConvert.FromDB(data) + } else { + session.Engine.LogError(err) + } + continue + } + + rawValueType := reflect.TypeOf(rawValue.Interface()) + vv := reflect.ValueOf(rawValue.Interface()) + + fieldType := fieldValue.Type() + + //fmt.Println("column name:", key, ", fieldType:", fieldType.String()) + + hasAssigned := false + + switch fieldType.Kind() { + + case reflect.Complex64, reflect.Complex128: + if rawValueType.Kind() == reflect.String { + hasAssigned = true + x := reflect.New(fieldType) + err := json.Unmarshal([]byte(vv.String()), x.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(x.Elem()) + } + case reflect.Slice, reflect.Array: + switch rawValueType.Kind() { + case reflect.Slice, reflect.Array: + switch rawValueType.Elem().Kind() { + case reflect.Uint8: + if fieldType.Elem().Kind() == reflect.Uint8 { + hasAssigned = true + fieldValue.Set(vv) + } + } + } + case reflect.String: + if rawValueType.Kind() == reflect.String { + hasAssigned = true + fieldValue.SetString(vv.String()) + } + case reflect.Bool: + if rawValueType.Kind() == reflect.Bool { + hasAssigned = true + fieldValue.SetBool(vv.Bool()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch rawValueType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + hasAssigned = true + fieldValue.SetInt(vv.Int()) + } + case reflect.Float32, reflect.Float64: + switch rawValueType.Kind() { + case reflect.Float32, reflect.Float64: + hasAssigned = true + fieldValue.SetFloat(vv.Float()) + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + switch rawValueType.Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + hasAssigned = true + fieldValue.SetUint(vv.Uint()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + hasAssigned = true + fieldValue.SetUint(uint64(vv.Int())) + } + case reflect.Struct: + if fieldType == reflect.TypeOf(c_TIME_DEFAULT) { + if rawValueType == reflect.TypeOf(c_TIME_DEFAULT) { + hasAssigned = true + fieldValue.Set(vv) + } + } else if session.Statement.UseCascade { + table := session.Engine.autoMapType(fieldValue.Type()) + if table != nil { + var x int64 + if rawValueType.Kind() == reflect.Int64 { + x = vv.Int() + } + if x != 0 { + // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // property to be fetched lazily + structInter := reflect.New(fieldValue.Type()) + newsession := session.Engine.NewSession() + defer newsession.Close() + has, err := newsession.Id(x).Get(structInter.Interface()) + if err != nil { + return err + } + if has { + v := structInter.Elem().Interface() + fieldValue.Set(reflect.ValueOf(v)) + } else { + return errors.New("cascade obj is not exist!") + } + } + } else { + session.Engine.LogError("unsupported struct type in Scan: ", fieldValue.Type().String()) + } + } + case reflect.Ptr: + // !nashtsai! TODO merge duplicated codes above + //typeStr := fieldType.String() + switch fieldType { + // following types case matching ptr's native type, therefore assign ptr directly + case reflect.TypeOf(&c_EMPTY_STRING): + if rawValueType.Kind() == reflect.String { + x := vv.String() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_BOOL_DEFAULT): + if rawValueType.Kind() == reflect.Bool { + x := vv.Bool() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_TIME_DEFAULT): + if rawValueType == reflect.TypeOf(c_TIME_DEFAULT) { + hasAssigned = true + var x time.Time = rawValue.Interface().(time.Time) + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_FLOAT64_DEFAULT): + if rawValueType.Kind() == reflect.Float64 { + x := vv.Float() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_UINT64_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x uint64 = uint64(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_INT64_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + x := vv.Int() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_FLOAT32_DEFAULT): + if rawValueType.Kind() == reflect.Float64 { + var x float32 = float32(vv.Float()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_INT_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x int = int(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_INT32_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x int32 = int32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_INT8_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x int8 = int8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_INT16_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x int16 = int16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_UINT_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x uint = uint(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_UINT32_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x uint32 = uint32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_UINT8_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x uint8 = uint8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_UINT16_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x uint16 = uint16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_COMPLEX64_DEFAULT): + var x complex64 + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + session.Engine.LogError(err) + } else { + fieldValue.Set(reflect.ValueOf(&x)) + } + hasAssigned = true + case reflect.TypeOf(&c_COMPLEX128_DEFAULT): + var x complex128 + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + session.Engine.LogError(err) + } else { + fieldValue.Set(reflect.ValueOf(&x)) + } + hasAssigned = true + } // switch fieldType + // default: + // session.Engine.LogError("unsupported type in Scan: ", reflect.TypeOf(v).String()) + } // switch fieldType.Kind() + + // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value + if !hasAssigned { + data, err := value2Bytes(&rawValue) + if err == nil { + session.bytes2Value(table.Columns[strings.ToLower(key)], fieldValue, data) + } else { + session.Engine.LogError(err.Error()) + } + } + } + } + return nil + +} + +func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { + for _, filter := range session.Engine.Filters { + *sqlStr = filter.Do(*sqlStr, session) + } + + session.Engine.LogSQL(*sqlStr) session.Engine.LogSQL(paramStr) +} + +func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { + session.queryPreprocess(&sqlStr, paramStr...) if session.IsAutoCommit { - return query(session.Db, sql, paramStr...) + return query(session.Db, sqlStr, paramStr...) } - return txQuery(session.Tx, sql, paramStr...) + return txQuery(session.Tx, sqlStr, paramStr...) } -func txQuery(tx *sql.Tx, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { - rows, err := tx.Query(sql, params...) +func txQuery(tx *sql.Tx, sqlStr string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { + rows, err := tx.Query(sqlStr, params...) if err != nil { return nil, err } @@ -1356,8 +1672,8 @@ func txQuery(tx *sql.Tx, sql string, params ...interface{}) (resultsSlice []map[ return rows2maps(rows) } -func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { - s, err := db.Prepare(sql) +func query(db *sql.DB, sqlStr string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { + s, err := db.Prepare(sqlStr) if err != nil { return nil, err } @@ -1368,12 +1684,11 @@ func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[st } defer rows.Close() //fmt.Println(rows) - return rows2maps(rows) } // Exec a raw sql and return records as []map[string][]byte -func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { +func (session *Session) Query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { err = session.newDb() if err != nil { return nil, err @@ -1383,7 +1698,7 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice defer session.Close() } - return session.query(sql, paramStr...) + return session.query(sqlStr, paramStr...) } // insert one or more beans @@ -1618,10 +1933,14 @@ func (session *Session) byte2Time(col *Column, data []byte) (outTime time.Time, ssd := strings.Split(sdata, " ") sdata = ssd[1] } - if len(sdata) > 8 { + + sdata = strings.TrimSpace(sdata) + //fmt.Println(sdata) + if session.Engine.dialect.DBType() == MYSQL && len(sdata) > 8 { sdata = sdata[len(sdata)-8:] } - fmt.Println(sdata) + //fmt.Println(sdata) + st := fmt.Sprintf("2006-01-02 %v", sdata) x, err = time.Parse("2006-01-02 15:04:05", st) } else { @@ -1653,7 +1972,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data err := json.Unmarshal(data, x.Interface()) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(x.Elem()) @@ -1665,7 +1984,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x := reflect.New(fieldType) err := json.Unmarshal(data, x.Interface()) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(x.Elem()) @@ -1676,7 +1995,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x := reflect.New(fieldType) err := json.Unmarshal(data, x.Interface()) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(x.Elem()) @@ -1690,7 +2009,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data d := string(data) v, err := strconv.ParseBool(d) if err != nil { - return errors.New("arg " + key + " as bool: " + err.Error()) + return fmt.Errorf("arg %v as bool: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(v)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -1718,19 +2037,19 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x, err = strconv.ParseInt(sdata, 10, 64) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.SetInt(x) case reflect.Float32, reflect.Float64: x, err := strconv.ParseFloat(string(data), 64) if err != nil { - return errors.New("arg " + key + " as float64: " + err.Error()) + return fmt.Errorf("arg %v as float64: %s", key, err.Error()) } fieldValue.SetFloat(x) case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: x, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.SetUint(x) //Currently only support Time type @@ -1747,9 +2066,12 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data if table != nil { x, err := strconv.ParseInt(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } if x != 0 { + // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // property to be fetched lazily structInter := reflect.New(fieldValue.Type()) newsession := session.Engine.NewSession() defer newsession.Close() @@ -1765,7 +2087,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data } } } else { - return errors.New("unsupported struct type in Scan: " + fieldValue.Type().String()) + return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) } } case reflect.Ptr: @@ -1781,7 +2103,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data d := string(data) v, err := strconv.ParseBool(d) if err != nil { - return errors.New("arg " + key + " as bool: " + err.Error()) + return fmt.Errorf("arg %v as bool: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&v)) // case "*complex64": @@ -1789,7 +2111,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x complex64 err := json.Unmarshal(data, &x) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(reflect.ValueOf(&x)) @@ -1798,7 +2120,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x complex128 err := json.Unmarshal(data, &x) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(reflect.ValueOf(&x)) @@ -1806,7 +2128,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data case reflect.TypeOf(&c_FLOAT64_DEFAULT): x, err := strconv.ParseFloat(string(data), 64) if err != nil { - return errors.New("arg " + key + " as float64: " + err.Error()) + return fmt.Errorf("arg %v as float64: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*float32": @@ -1814,7 +2136,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x float32 x1, err := strconv.ParseFloat(string(data), 32) if err != nil { - return errors.New("arg " + key + " as float32: " + err.Error()) + return fmt.Errorf("arg %v as float32: %s", key, err.Error()) } x = float32(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1831,7 +2153,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint64 x, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*uint": @@ -1839,7 +2161,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } x = uint(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1848,7 +2170,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint32 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } x = uint32(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1857,7 +2179,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint8 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } x = uint8(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1866,7 +2188,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint16 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } x = uint16(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1892,7 +2214,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x, err = strconv.ParseInt(sdata, 10, 64) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*int": @@ -1921,7 +2243,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x = int(x1) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*int32": @@ -1950,7 +2272,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x = int32(x1) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*int8": @@ -1979,7 +2301,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x = int8(x1) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*int16": @@ -2008,14 +2330,14 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x = int16(x1) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) default: - return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + return fmt.Errorf("unsupported type in Scan: %s", reflect.TypeOf(v).String()) } default: - return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + return fmt.Errorf("unsupported type in Scan: %s", reflect.TypeOf(v).String()) } return nil @@ -2082,14 +2404,14 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( return fieldValue.Interface(), nil } if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { - if fieldTable.PrimaryKey != "" { - pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumn().FieldName) + if len(fieldTable.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) return pkField.Interface(), nil } else { - return 0, errors.New("no primary key") + return 0, fmt.Errorf("no primary key for col %v", col.Name) } } else { - return 0, errors.New(fmt.Sprintf("Unsupported type %v", fieldValue.Type())) + return 0, fmt.Errorf("Unsupported type %v\n", fieldValue.Type()) } case reflect.Complex64, reflect.Complex128: bytes, err := json.Marshal(fieldValue.Interface()) @@ -2156,7 +2478,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { fmt.Println(colNames, args) colPlaces = colPlaces[0 : len(colPlaces)-2] - sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", + sqlStr := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", session.Engine.QuoteStr(), session.Statement.TableName(), session.Engine.QuoteStr(), @@ -2196,8 +2518,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. - if session.Engine.DriverName != POSTGRES || table.PrimaryKey == "" { - res, err := session.exec(sql, args...) + + if session.Engine.DriverName != POSTGRES || table.AutoIncrement == "" { + res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } else { @@ -2215,7 +2538,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } } - if table.PrimaryKey == "" || table.PKColumn().SQLType.IsText() { + if table.AutoIncrement == "" { return res.RowsAffected() } @@ -2225,24 +2548,32 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - pkValue := table.PKColumn().ValueOf(bean) - if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() { + aiValue := table.AutoIncrColumn().ValueOf(bean) + if !aiValue.IsValid() /*|| aiValue.Int() != 0*/ || !aiValue.CanSet() { return res.RowsAffected() } var v interface{} = id - switch pkValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int: + switch aiValue.Type().Kind() { + case reflect.Int32: + v = int32(id) + case reflect.Int: v = int(id) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + case reflect.Uint32: + v = uint32(id) + case reflect.Uint64: + v = uint64(id) + case reflect.Uint: v = uint(id) } - pkValue.Set(reflect.ValueOf(v)) + aiValue.Set(reflect.ValueOf(v)) return res.RowsAffected() } else { - sql = sql + " RETURNING (id)" - res, err := session.query(sql, args...) + //assert table.AutoIncrement != "" + sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement) + res, err := session.query(sqlStr, args...) + if err != nil { return 0, err } else { @@ -2264,25 +2595,31 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 0, errors.New("insert no error but not returned id") } - idByte := res[0][table.PrimaryKey] + idByte := res[0][table.AutoIncrement] id, err := strconv.ParseInt(string(idByte), 10, 64) if err != nil { return 1, err } - pkValue := table.PKColumn().ValueOf(bean) - if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() { + aiValue := table.AutoIncrColumn().ValueOf(bean) + if !aiValue.IsValid() /*|| aiValue. != 0*/ || !aiValue.CanSet() { return 1, nil } var v interface{} = id - switch pkValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int: + switch aiValue.Type().Kind() { + case reflect.Int32: + v = int32(id) + case reflect.Int: v = int(id) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + case reflect.Uint32: + v = uint32(id) + case reflect.Uint64: + v = uint64(id) + case reflect.Uint: v = uint(id) } - pkValue.Set(reflect.ValueOf(v)) + aiValue.Set(reflect.ValueOf(v)) return 1, nil } @@ -2304,15 +2641,15 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { return session.innerInsert(bean) } -func (statement *Statement) convertUpdateSql(sql string) (string, string) { - if statement.RefTable == nil || statement.RefTable.PrimaryKey == "" { +func (statement *Statement) convertUpdateSql(sqlStr string) (string, string) { + if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { return "", "" } - sqls := splitNNoCase(sql, "where", 2) + sqls := splitNNoCase(sqlStr, "where", 2) if len(sqls) != 2 { if len(sqls) == 1 { return sqls[0], fmt.Sprintf("SELECT %v FROM %v", - statement.Engine.Quote(statement.RefTable.PrimaryKey), + statement.Engine.Quote(statement.RefTable.PrimaryKeys[0]), statement.Engine.Quote(statement.RefTable.Name)) } return "", "" @@ -2321,22 +2658,31 @@ func (statement *Statement) convertUpdateSql(sql string) (string, string) { var whereStr = sqls[1] //TODO: for postgres only, if any other database? - if strings.Contains(sqls[1], "$") { - dollers := strings.Split(sqls[1], "$") - whereStr = dollers[0] - for i, c := range dollers[1:] { - ccs := strings.SplitN(c, " ", 2) - whereStr += fmt.Sprintf("$%v %v", i+1, ccs[1]) + var paraStr string + if statement.Engine.dialect.DBType() == POSTGRES { + paraStr = "$" + } else if statement.Engine.dialect.DBType() == MSSQL { + paraStr = ":" + } + + if paraStr != "" { + if strings.Contains(sqls[1], paraStr) { + dollers := strings.Split(sqls[1], paraStr) + whereStr = dollers[0] + for i, c := range dollers[1:] { + ccs := strings.SplitN(c, " ", 2) + whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1]) + } } } return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", - statement.Engine.Quote(statement.RefTable.PrimaryKey), statement.Engine.Quote(statement.TableName()), + statement.Engine.Quote(statement.RefTable.PrimaryKeys[0]), statement.Engine.Quote(statement.TableName()), whereStr) } func (session *Session) cacheInsert(tables ...string) error { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { return ErrCacheFailed } @@ -2351,12 +2697,12 @@ func (session *Session) cacheInsert(tables ...string) error { return nil } -func (session *Session) cacheUpdate(sql string, args ...interface{}) error { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { +func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { + if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { return ErrCacheFailed } - oldhead, newsql := session.Statement.convertUpdateSql(sql) + oldhead, newsql := session.Statement.convertUpdateSql(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -2367,7 +2713,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { var nStart int if len(args) > 0 { - if strings.Index(sql, "?") > -1 { + if strings.Index(sqlStr, "?") > -1 { nStart = strings.Count(oldhead, "?") } else { // only for pq, TODO: if any other databse? @@ -2390,7 +2736,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { if len(resultsSlice) > 0 { for _, data := range resultsSlice { var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + if v, ok := data[session.Statement.RefTable.PrimaryKeys[0]]; !ok { return errors.New("no id") } else { id, err = strconv.ParseInt(string(v), 10, 64) @@ -2408,7 +2754,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { for _, id := range ids { if bean := cacher.GetBean(tableName, id); bean != nil { - sqls := splitNNoCase(sql, "where", 2) + sqls := splitNNoCase(sqlStr, "where", 2) if len(sqls) == 0 || len(sqls) > 2 { return ErrCacheFailed } @@ -2431,7 +2777,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { return ErrCacheFailed } - if col, ok := table.Columns[colName]; ok { + if col, ok := table.Columns[strings.ToLower(colName)]; ok { fieldValue := col.ValueOf(bean) session.Engine.LogDebug("[xorm:cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) if col.IsVersion && session.Statement.checkVersion { @@ -2547,7 +2893,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - var sql, inSql string + var sqlStr, inSql string var inArgs []interface{} if table.Version != "" && session.Statement.checkVersion { if condition != "" { @@ -2565,7 +2911,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - sql = fmt.Sprintf("UPDATE %v SET %v, %v %v", + sqlStr = fmt.Sprintf("UPDATE %v SET %v, %v %v", session.Engine.Quote(session.Statement.TableName()), strings.Join(colNames, ", "), session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1", @@ -2585,7 +2931,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - sql = fmt.Sprintf("UPDATE %v SET %v %v", + sqlStr = fmt.Sprintf("UPDATE %v SET %v %v", session.Engine.Quote(session.Statement.TableName()), strings.Join(colNames, ", "), condition) @@ -2595,13 +2941,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 args = append(args, inArgs...) args = append(args, condiArgs...) - res, err := session.exec(sql, args...) + res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } if table.Cacher != nil && session.Statement.UseCache { - //session.cacheUpdate(sql, args...) + //session.cacheUpdate(sqlStr, args...) table.Cacher.ClearIds(session.Statement.TableName()) table.Cacher.ClearBeans(session.Statement.TableName()) } @@ -2638,16 +2984,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return res.RowsAffected() } -func (session *Session) cacheDelete(sql string, args ...interface{}) error { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { +func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { + if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { return ErrCacheFailed } for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - newsql := session.Statement.convertIdSql(sql) + newsql := session.Statement.convertIdSql(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -2664,7 +3010,7 @@ func (session *Session) cacheDelete(sql string, args ...interface{}) error { if len(resultsSlice) > 0 { for _, data := range resultsSlice { var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + if v, ok := data[session.Statement.RefTable.PrimaryKeys[0]]; !ok { return errors.New("no id") } else { id, err = strconv.ParseInt(string(v), 10, 64) @@ -2739,16 +3085,16 @@ func (session *Session) Delete(bean interface{}) (int64, error) { return 0, ErrNeedDeletedCond } - sql := fmt.Sprintf("DELETE FROM %v WHERE %v", + sqlStr := fmt.Sprintf("DELETE FROM %v WHERE %v", session.Engine.Quote(session.Statement.TableName()), condition) args = append(session.Statement.Params, args...) if table.Cacher != nil && session.Statement.UseCache { - session.cacheDelete(sql, args...) + session.cacheDelete(sqlStr, args...) } - res, err := session.exec(sql, args...) + res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } diff --git a/sqlite3.go b/sqlite3.go index 84a9d1b0..d52b2966 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -190,16 +190,19 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { indexes := make(map[string]*Index, 0) for _, record := range res { - var sql string index := new(Index) - for name, content := range record { - if name == "sql" { - sql = string(content) - } + sql := string(record["sql"]) + + if sql == "" { + continue } nNStart := strings.Index(sql, "INDEX") nNEnd := strings.Index(sql, "ON") + if nNStart == -1 || nNEnd == -1 { + continue + } + indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") //fmt.Println(indexName) if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { diff --git a/statement.go b/statement.go index 59c9421f..b36a2ebf 100644 --- a/statement.go +++ b/statement.go @@ -335,11 +335,15 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, } else { engine.autoMapType(fieldValue.Type()) if table, ok := engine.Tables[fieldValue.Type()]; ok { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) - if pkField.Int() != 0 { - val = pkField.Interface() + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + if pkField.Int() != 0 { + val = pkField.Interface() + } else { + continue + } } else { - continue + //TODO: how to handler? } } else { val = fieldValue.Interface() @@ -506,7 +510,7 @@ func (statement *Statement) Distinct(columns ...string) *Statement { func (statement *Statement) Cols(columns ...string) *Statement { newColumns := col2NewCols(columns...) for _, nc := range newColumns { - statement.columnMap[nc] = true + statement.columnMap[strings.ToLower(nc)] = true } statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) return statement @@ -517,7 +521,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement { if len(columns) > 0 { newColumns := col2NewCols(columns...) for _, nc := range newColumns { - statement.boolColumnMap[nc] = true + statement.boolColumnMap[strings.ToLower(nc)] = true } } else { statement.allUseBool = true @@ -529,7 +533,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement { func (statement *Statement) Omit(columns ...string) { newColumns := col2NewCols(columns...) for _, nc := range newColumns { - statement.columnMap[nc] = false + statement.columnMap[strings.ToLower(nc)] = false } statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) } @@ -582,7 +586,7 @@ func (statement *Statement) genColumnStr() string { colNames := make([]string, 0) for _, col := range table.Columns { if statement.OmitStr != "" { - if _, ok := statement.columnMap[col.Name]; ok { + if _, ok := statement.columnMap[strings.ToLower(col.Name)]; ok { continue } } @@ -606,7 +610,7 @@ func (statement *Statement) genCreateTableSQL() string { pkList := []string{} for _, colName := range statement.RefTable.ColumnsSeq { - col := statement.RefTable.Columns[colName] + col := statement.RefTable.Columns[strings.ToLower(colName)] if col.IsPrimaryKey { pkList = append(pkList, col.Name) } @@ -614,7 +618,7 @@ func (statement *Statement) genCreateTableSQL() string { statement.Engine.LogDebug("len:", len(pkList)) for _, colName := range statement.RefTable.ColumnsSeq { - col := statement.RefTable.Columns[colName] + col := statement.RefTable.Columns[strings.ToLower(colName)] if col.IsPrimaryKey && len(pkList) == 1 { sql += col.String(statement.Engine.dialect) } else { @@ -634,8 +638,12 @@ func (statement *Statement) genCreateTableSQL() string { if statement.Engine.dialect.SupportEngine() && statement.StoreEngine != "" { sql += " ENGINE=" + statement.StoreEngine } - if statement.Engine.dialect.SupportCharset() && statement.Charset != "" { - sql += " DEFAULT CHARSET " + statement.Charset + if statement.Engine.dialect.SupportCharset() { + if statement.Charset != "" { + sql += " DEFAULT CHARSET " + statement.Charset + } else if statement.Engine.dialect.URI().charset != "" { + sql += " DEFAULT CHARSET " + statement.Engine.dialect.URI().charset + } } sql += ";" return sql @@ -753,9 +761,10 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args - var id string = "*" - if table.PrimaryKey != "" { - id = statement.Engine.Quote(table.PrimaryKey) + // count(index fieldname) > count(0) > count(*) + var id string = "0" + if len(table.PrimaryKeys) == 1 { + id = statement.Engine.Quote(table.PrimaryKeys[0]) } return statement.genSelectSql(fmt.Sprintf("COUNT(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...) } @@ -818,7 +827,7 @@ func (statement *Statement) processIdParam() { for _, elem := range *(statement.IdParam) { for ; i < colCnt; i++ { colName := statement.RefTable.ColumnsSeq[i] - col := statement.RefTable.Columns[colName] + col := statement.RefTable.Columns[strings.ToLower(colName)] if col.IsPrimaryKey { statement.And(fmt.Sprintf("%v=?", col.Name), elem) i++ @@ -832,7 +841,7 @@ func (statement *Statement) processIdParam() { // false update/delete for ; i < colCnt; i++ { colName := statement.RefTable.ColumnsSeq[i] - col := statement.RefTable.Columns[colName] + col := statement.RefTable.Columns[strings.ToLower(colName)] if col.IsPrimaryKey { statement.And(fmt.Sprintf("%v=?", col.Name), "") } diff --git a/table.go b/table.go index aac87528..76b4c3ae 100644 --- a/table.go +++ b/table.go @@ -163,6 +163,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) { if t == reflect.TypeOf(c_TIME_DEFAULT) { st = SQLType{DateTime, 0, 0} } else { + // TODO need to handle association struct st = SQLType{Text, 0, 0} } case reflect.Ptr: @@ -297,6 +298,8 @@ func (col *Column) String(d dialect) string { if col.Default != "" { sql += "DEFAULT " + col.Default + " " + } else if col.IsVersion { + sql += "DEFAULT 1 " } return sql @@ -315,6 +318,8 @@ func (col *Column) stringNoPk(d dialect) string { if col.Default != "" { sql += "DEFAULT " + col.Default + " " + } else if col.IsVersion { + sql += "DEFAULT 1 " } return sql @@ -339,16 +344,17 @@ func (col *Column) ValueOf(bean interface{}) reflect.Value { // database table type Table struct { - Name string - Type reflect.Type - ColumnsSeq []string - Columns map[string]*Column - Indexes map[string]*Index - PrimaryKey string - Created map[string]bool - Updated string - Version string - Cacher Cacher + Name string + Type reflect.Type + ColumnsSeq []string + Columns map[string]*Column + Indexes map[string]*Index + PrimaryKeys []string + AutoIncrement string + Created map[string]bool + Updated string + Version string + Cacher Cacher } /* @@ -362,20 +368,31 @@ func NewTable(name string, t reflect.Type) *Table { }*/ // if has primary key, return column -func (table *Table) PKColumn() *Column { - return table.Columns[table.PrimaryKey] +func (table *Table) PKColumns() []*Column { + columns := make([]*Column, 0) + for _, name := range table.PrimaryKeys { + columns = append(columns, table.Columns[strings.ToLower(name)]) + } + return columns +} + +func (table *Table) AutoIncrColumn() *Column { + return table.Columns[strings.ToLower(table.AutoIncrement)] } func (table *Table) VersionColumn() *Column { - return table.Columns[table.Version] + return table.Columns[strings.ToLower(table.Version)] } // add a column to table func (table *Table) AddColumn(col *Column) { table.ColumnsSeq = append(table.ColumnsSeq, col.Name) - table.Columns[col.Name] = col + table.Columns[strings.ToLower(col.Name)] = col if col.IsPrimaryKey { - table.PrimaryKey = col.Name + table.PrimaryKeys = append(table.PrimaryKeys, col.Name) + } + if col.IsAutoIncrement { + table.AutoIncrement = col.Name } if col.IsCreated { table.Created[col.Name] = true @@ -398,8 +415,9 @@ func (table *Table) genCols(session *Session, bean interface{}, useCol bool, inc args := make([]interface{}, 0) for _, col := range table.Columns { + lColName := strings.ToLower(col.Name) if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { - if _, ok := session.Statement.columnMap[col.Name]; !ok { + if _, ok := session.Statement.columnMap[lColName]; !ok { continue } } @@ -408,17 +426,30 @@ func (table *Table) genCols(session *Session, bean interface{}, useCol bool, inc } fieldValue := col.ValueOf(bean) - if col.IsAutoIncrement && fieldValue.Int() == 0 { - continue + if col.IsAutoIncrement { + switch fieldValue.Type().Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: + if fieldValue.Int() == 0 { + continue + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: + if fieldValue.Uint() == 0 { + continue + } + case reflect.String: + if len(fieldValue.String()) == 0 { + continue + } + } } if session.Statement.ColumnStr != "" { - if _, ok := session.Statement.columnMap[col.Name]; !ok { + if _, ok := session.Statement.columnMap[lColName]; !ok { continue } } if session.Statement.OmitStr != "" { - if _, ok := session.Statement.columnMap[col.Name]; ok { + if _, ok := session.Statement.columnMap[lColName]; ok { continue } } diff --git a/xorm.go b/xorm.go index 06c30c41..bff3f1bc 100644 --- a/xorm.go +++ b/xorm.go @@ -10,7 +10,7 @@ import ( ) const ( - Version string = "0.2.3" + Version string = "0.3.1" ) func close(engine *Engine) { diff --git a/xorm/README.md b/xorm/README.md index aebfca77..b0d39b86 100644 --- a/xorm/README.md +++ b/xorm/README.md @@ -1,30 +1,30 @@ -# xorm tools - - -xorm tools is a set of tools for database operation. - -## Install - +# xorm tools + + +xorm tools is a set of tools for database operation. + +## Install + `go get github.com/lunny/xorm/xorm` and you should install the depends below: -* github.com/lunny/xorm - -* Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) - -* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) - -* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) - +* github.com/lunny/xorm + +* Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) + +* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) + +* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) + * Postgres: [github.com/bylevel/pq](https://github.com/bylevel/pq) - - -## Reverse + + +## Reverse After you installed the tool, you can type -`xorm help reverse` +`xorm help reverse` to get help @@ -50,13 +50,13 @@ Now, xorm tool supports go and c++ two languages and have go, goxorm, c++ three ```` lang=go -genJson=1 +genJson=1 ``` lang must be go or c++ now. genJson can be 1 or 0, if 1 then the struct will have json tag. - -## LICENSE - - BSD License - [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) + +## LICENSE + + BSD License + [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/)
name当前field对应的字段的名称,可选,如不写,则自动根据field名字和转换规则命名name当前field对应的字段的名称,可选,如不写,则自动根据field名字和转换规则命名,如与其它关键字冲突,请使用单引号括起来。
pk是否是Primary Key,如果在一个struct中有两个字段都使用了此标记,则这两个字段构成了复合主键pk是否是Primary Key,如果在一个struct中有多个字段都使用了此标记,则这多个字段构成了复合主键,单主键当前支持int32,int,int64,uint32,uint,uint64,string这7种Go的数据类型,复合主键支持这7种Go的数据类型的组合。
当前支持30多种字段类型,详情参见 [字段类型](https://github.com/lunny/xorm/blob/master/docs/COLUMNTYPE.md)字段类型autoincr是否是自增
[not ]null是否可以为空[not ]null 或 notnull是否可以为空
unique或unique(uniquename)是否是唯一,如不加括号则该字段不允许重复;如加上括号,则括号中为联合唯一索引的名字,此时如果有另外一个或多个字段和本unique的uniquename相同,则这些uniquename相同的字段组成联合唯一索引