v0.1.9 added postgres,mymysql supported;add Cols, StoreEngine, Charset;add many column types

This commit is contained in:
Lunny Xiao 2013-08-08 13:24:38 +08:00
parent a7f6aa92d3
commit b0b3c16372
20 changed files with 1302 additions and 590 deletions

425
COLUMNTYPE.md Normal file
View File

@ -0,0 +1,425 @@
<table>
<tr>
<td>xorm
</td>
<td>mysql
</td>
<td>sqlite3
</td>
<td>postgres
</td>
<td>remark</td>
</tr>
<tr>
<td>BIT
</td>
<td>BIT
</td>
<td>INTEGER
</td>
<td>BIT
</td>
<td></td>
</tr>
<tr>
<td>TINYINT
</td>
<td>TINYINT
</td>
<td>INTEGER
</td>
<td>SMALLINT
</td>
<td></td>
</tr>
<tr>
<td>SMALLINT
</td>
<td>SMALLINT
</td>
<td>INTEGER
</td>
<td>SMALLINT
</td>
<td></td>
</tr>
<tr>
<td>MEDIUMINT
</td>
<td>MEDIUMINT
</td>
<td>INTEGER
</td>
<td>INTEGER
</td>
<td></td>
</tr>
<tr>
<td>INT
</td>
<td>INT
</td>
<td>INTEGER
</td>
<td>INTEGER
</td>
<td></td>
</tr>
<tr>
<td>INTEGER
</td>
<td>INTEGER
</td>
<td>INTEGER
</td>
<td>INTEGER
</td>
<td></td>
</tr>
<tr>
<td>BIGINT
</td>
<td>BIGINT
</td>
<td>INTEGER
</td>
<td>BIGINT
</td>
<td></td>
</tr>
<tr>
<td>CHAR
</td>
<td>CHAR
</td>
<td>TEXT
</td>
<td>CHAR
</td>
<td></td>
</tr>
<tr>
<td>VARCHAR
</td>
<td>VARCHAR
</td>
<td>TEXT
</td>
<td>VARCHAR
</td>
<td></td>
</tr>
<tr>
<td>TINYTEXT
</td>
<td>TINYTEXT
</td>
<td>TEXT
</td>
<td>TEXT
</td>
<td></td>
</tr>
<tr>
<td>TEXT
</td>
<td>TEXT
</td>
<td>TEXT
</td>
<td>TEXT
</td>
<td></td>
</tr>
<tr>
<td>MEDIUMTEXT
</td>
<td>MEDIUMTEXT
</td>
<td>TEXT
</td>
<td>TEXT
</td>
<td></td>
</tr>
<tr>
<td>LONGTEXT
</td>
<td>LONGTEXT
</td>
<td>TEXT
</td>
<td>TEXT
</td>
<td></td>
</tr>
<tr>
<td>BINARY
</td>
<td>BINARY
</td>
<td>BLOB
</td>
<td>BYTEA
</td>
<td></td>
</tr>
<tr>
<td>VARBINARY
</td>
<td>VARBINARY
</td>
<td>BLOB
</td>
<td>BYTEA
</td>
<td></td>
</tr>
<tr>
<td>DATE
</td>
<td>DATE
</td>
<td>NUMERIC
</td>
<td>DATE
</td>
<td></td>
</tr>
<tr>
<td>DATETIME
</td>
<td>DATETIME
</td>
<td>NUMERIC
</td>
<td>TIMESTAMP
</td>
<td></td>
</tr>
<tr>
<td>TIME
</td>
<td>TIME
</td>
<td>NUMERIC
</td>
<td>TIME
</td>
<td></td>
</tr>
<tr>
<td>TIMESTAMP
</td>
<td>TIMESTAMP
</td>
<td>NUMERIC
</td>
<td>TIMESTAMP
</td>
<td></td>
</tr>
<tr>
<td>REAL
</td>
<td>REAL
</td>
<td>REAL
</td>
<td>REAL
</td>
<td></td>
</tr>
<tr>
<td>FLOAT
</td>
<td>FLOAT
</td>
<td>REAL
</td>
<td>REAL
</td>
<td></td>
</tr>
<tr>
<td>DOUBLE
</td>
<td>DOUBLE
</td>
<td>REAL
</td>
<td>DOUBLE PRECISION
</td>
<td></td>
</tr>
<tr>
<td>DECIMAL
</td>
<td>DECIMAL
</td>
<td>NUMERIC
</td>
<td>DECIMAL
</td>
<td></td>
</tr>
<tr>
<td>NUMERIC
</td>
<td>NUMERIC
</td>
<td>NUMERIC
</td>
<td>NUMERIC
</td>
<td></td>
</tr>
<tr>
<td>TINYBLOB
</td>
<td>TINYBLOB
</td>
<td>BLOB
</td>
<td>BYTEA
</td>
<td></td>
</tr>
<tr>
<td>BLOB
</td>
<td>BLOB
</td>
<td>BLOB
</td>
<td>BYTEA
</td>
<td></td>
</tr>
<tr>
<td>MEDIUMBLOB
</td>
<td>MEDIUMBLOB
</td>
<td>BLOB
</td>
<td>BYTEA
</td>
<td></td>
</tr>
<tr>
<td>LONGBLOB
</td>
<td>LONGBLOB
</td>
<td>BLOB
</td>
<td>BYTEA
</td>
<td></td>
</tr>
<tr>
<td>BYTEA
</td>
<td>BLOB
</td>
<td>BLOB
</td>
<td>BYTEA
</td>
<td></td>
</tr>
<tr>
<td>BOOL
</td>
<td>TINYINT
</td>
<td>INTEGER
</td>
<td>BOOLEAN
</td>
<td></td>
</tr>
<tr>
<td>SERIAL
</td>
<td>INT
</td>
<td>INTEGER
</td>
<td>SERIAL
</td>
<td>auto increment</td>
</tr>
<tr>
<td>BIGSERIAL
</td>
<td>BIGINT
</td>
<td>INTEGER
</td>
<td>BIGSERIAL
</td>
<td>auto increment</td>
</tr>
</table>

View File

@ -13,10 +13,16 @@ Drivers for Go's sql package which currently support database/sql includes:
* Mysql: [github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL) * Mysql: [github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL)
* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) * SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
* Postgres: [github.com/bylevel/pg](https://github.com/bylevel/pg)
## Changelog ## Changelog
* **v0.1.9** : Added postgres and mymysql supported; Added ` and ? supported on Raw SQL even if postgres; Added Cols, StoreEngine, Charset function, Added many column data type supported, please see [Mapping Rules](#mapping).
* **v0.1.8** : Added union index and union unique supported, please see [Mapping Rules](#mapping). * **v0.1.8** : Added union index and union unique supported, please see [Mapping Rules](#mapping).
* **v0.1.7** : Added IConnectPool interface and NoneConnectPool, SysConnectPool, SimpleConnectPool the three implements. You can choose one of them and the default is SysConnectPool. You can customrize your own connection pool. struct Engine added Close method, It should be invoked before system exit. * **v0.1.7** : Added IConnectPool interface and NoneConnectPool, SysConnectPool, SimpleConnectPool the three implements. You can choose one of them and the default is SysConnectPool. You can customrize your own connection pool. struct Engine added Close method, It should be invoked before system exit.
* **v0.1.6** : Added conversion interface support; added struct derive support; added single mapping support * **v0.1.6** : Added conversion interface support; added struct derive support; added single mapping support
@ -157,6 +163,13 @@ var tenusers []Userinfo
err := engine.In("id", 1, 3, 5).Find(&tenusers) //Get All id in (1, 3, 5) err := engine.In("id", 1, 3, 5).Find(&tenusers) //Get All id in (1, 3, 5)
``` ```
6.4 The default will query all columns of a table. Use Cols function if you want to select some columns
```Go
var tenusers []Userinfo
err := engine.Cols("id", "name").Find(&tenusers) //Find only id and name
```
7.Delete 7.Delete
```Go ```Go
@ -311,7 +324,7 @@ Another is use field tag, field tag support the below keywords which split with
<td>pk</td><td>the field is a primary key</td> <td>pk</td><td>the field is a primary key</td>
</tr> </tr>
<tr> <tr>
<td>int(11)/varchar(50)/text/date/datetime/blob/decimal(26,2)</td><td>column type</td> <td>more than 30 column type supported, please see [Column Type](https://github.com/lunny/xorm/blob/master/COLUMNTYPE.md)</td><td>column type</td>
</tr> </tr>
<tr> <tr>
<td>autoincr</td><td>auto incrment</td> <td>autoincr</td><td>auto incrment</td>

View File

@ -12,10 +12,15 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* Mysql: [github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL) * Mysql: [github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL)
* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) * SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
* Postgres: [github.com/bylevel/pg](https://github.com/bylevel/pg)
## 更新日志 ## 更新日志
* **v0.1.9** : 新增 postgres 和 mymysql 驱动支持; 在Postgres中支持原始SQL语句中使用 ` 和 ? 符号; 新增Cols, StoreEngine, Charset 函数SQL语句打印支持io.Writer接口默认打印到控制台新增更多的字段类型支持详见 [映射规则](#mapping)删除废弃的MakeSession和Create函数。
* **v0.1.8** : 新增联合index联合unique支持请查看 [映射规则](#mapping)。 * **v0.1.8** : 新增联合index联合unique支持请查看 [映射规则](#mapping)。
* **v0.1.7** : 新增IConnectPool接口以及NoneConnectPool, SysConnectPool, SimpleConnectPool三种实现可以选择不使用连接池使用系统连接池和使用自带连接池三种实现默认为SysConnectPool即系统自带的连接池。同时支持自定义连接池。Engine新增Close方法在系统退出时应调用此方法。 * **v0.1.7** : 新增IConnectPool接口以及NoneConnectPool, SysConnectPool, SimpleConnectPool三种实现可以选择不使用连接池使用系统连接池和使用自带连接池三种实现默认为SysConnectPool即系统自带的连接池。同时支持自定义连接池。Engine新增Close方法在系统退出时应调用此方法。
* **v0.1.6** : 新增Conversion支持自定义类型到数据库类型的转换新增查询结构体自动检测匿名成员支持新增单向映射支持 * **v0.1.6** : 新增Conversion支持自定义类型到数据库类型的转换新增查询结构体自动检测匿名成员支持新增单向映射支持
@ -36,7 +41,7 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* 使用连写来简化调用 * 使用连写来简化调用
* 支持使用Id, In, Where, Limit, Join, Having, Table, Sql等函数和结构体等方式作为条件 * 支持使用Id, In, Where, Limit, Join, Having, Table, Sql, Cols等函数和结构体等方式作为条件
* 支持数据库连接池 * 支持数据库连接池
@ -82,6 +87,11 @@ Sqlite
engine.ShowSQL = true engine.ShowSQL = true
``` ```
默认打印会打印到控制台如果需要打印到文件只需要设置一个符合io.writer接口的struct即可
```Go
engine.Logger = [io.Writer]
```
1.2.如果要更换连接池实现可使用SetPool方法 1.2.如果要更换连接池实现可使用SetPool方法
```Go ```Go
err = engine.SetPool(NewSimpleConnectPool()) err = engine.SetPool(NewSimpleConnectPool())
@ -142,7 +152,7 @@ users := make(map[int64]Userinfo)
err := engine.Find(&users) err := engine.Find(&users)
``` ```
6.1 你也可以使用Where和Limit方法设定条件和查询数量 6.1 你也可以使用Where和Limit方法设定条件和查询数量Limit参数可为1个到2个第一个参数为查询条数第二个参数为开始条数。
```Go ```Go
var allusers []Userinfo var allusers []Userinfo
@ -163,6 +173,13 @@ var tenusers []Userinfo
err := engine.In("id", 1, 3, 5).Find(&tenusers) //Get All id in (1, 3, 5) err := engine.In("id", 1, 3, 5).Find(&tenusers) //Get All id in (1, 3, 5)
``` ```
6.4 默认将查询出所有字段如果要指定字段则可以调用Cols函数
```Go
var tenusers []Userinfo
err := engine.Cols("id", "name").Find(&tenusers) //Find only id and name
```
7.Delete方法 7.Delete方法
```Go ```Go
@ -317,7 +334,7 @@ UserInfo中的成员UserName将会自动对应名为user_name的字段。
<td>pk</td><td>是否是Primary Key当前仅支持int64类型</td> <td>pk</td><td>是否是Primary Key当前仅支持int64类型</td>
</tr> </tr>
<tr> <tr>
<td>int(11)/varchar(50)/text/date/datetime/blob/decimal(26,2)</td><td>字段类型</td> <td>当前支持30多种字段类型详情参见 [字段类型](https://github.com/lunny/xorm/blob/master/COLUMNTYPE.md)</td><td>字段类型</td>
</tr> </tr>
<tr> <tr>
<td>autoincr</td><td>是否是自增</td> <td>autoincr</td><td>是否是自增</td>

View File

@ -1 +1 @@
xorm v0.1.7 xorm v0.1.9

View File

@ -8,14 +8,14 @@ import (
/* /*
CREATE TABLE `userinfo` ( CREATE TABLE `userinfo` (
`uid` INT(10) NULL AUTO_INCREMENT, `id` INT(10) NULL AUTO_INCREMENT,
`username` VARCHAR(64) NULL, `username` VARCHAR(64) NULL,
`departname` VARCHAR(64) NULL, `departname` VARCHAR(64) NULL,
`created` DATE NULL, `created` DATE NULL,
PRIMARY KEY (`uid`) PRIMARY KEY (`uid`)
); );
CREATE TABLE `userdeatail` ( CREATE TABLE `userdeatail` (
`uid` INT(10) NULL, `id` INT(10) NULL,
`intro` TEXT NULL, `intro` TEXT NULL,
`profile` TEXT NULL, `profile` TEXT NULL,
PRIMARY KEY (`uid`) PRIMARY KEY (`uid`)
@ -41,15 +41,16 @@ type Userdetail struct {
} }
func directCreateTable(engine *Engine, t *testing.T) { func directCreateTable(engine *Engine, t *testing.T) {
err := engine.DropTables(&Userinfo{}) err := engine.DropTables(&Userinfo{}, &Userdetail{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return panic(err)
} }
err = engine.CreateTables(&Userinfo{}) err = engine.CreateTables(&Userinfo{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
} }
@ -57,30 +58,36 @@ func mapper(engine *Engine, t *testing.T) {
err := engine.UnMap(&Userinfo{}) err := engine.UnMap(&Userinfo{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
err = engine.Map(&Userinfo{}, &Userdetail{}) err = engine.Map(&Userinfo{}, &Userdetail{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
err = engine.DropAll() err = engine.DropAll()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
err = engine.CreateAll() err = engine.CreateAll()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
} }
func insert(engine *Engine, t *testing.T) { func insert(engine *Engine, t *testing.T) {
user := Userinfo{1, "xiaolunwen", "dev", "lunny", time.Now(), user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(),
Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true} Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true}
_, err := engine.Insert(&user) _, err := engine.Insert(&user)
fmt.Println(user.Uid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
} }
@ -89,6 +96,7 @@ func query(engine *Engine, t *testing.T) {
results, err := engine.Query(sql) results, err := engine.Query(sql)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
fmt.Println(results) fmt.Println(results)
} }
@ -98,6 +106,7 @@ func exec(engine *Engine, t *testing.T) {
res, err := engine.Exec(sql, "xiaolun", 1) res, err := engine.Exec(sql, "xiaolun", 1)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
fmt.Println(res) fmt.Println(res)
} }
@ -107,8 +116,10 @@ func insertAutoIncr(engine *Engine, t *testing.T) {
user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now(), user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now(),
Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true} Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true}
_, err := engine.Insert(&user) _, err := engine.Insert(&user)
fmt.Println(user.Uid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
} }
@ -123,22 +134,8 @@ func insertMulti(engine *Engine, t *testing.T) {
_, err := engine.Insert(&users) _, err := engine.Insert(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
/*engine.InsertMany = false
users = []Userinfo{
{Username: "xlw9", Departname: "dev", Alias: "lunny9", Created: time.Now()},
{Username: "xlw10", Departname: "dev", Alias: "lunny10", Created: time.Now()},
{Username: "xlw99", Departname: "dev", Alias: "lunny2", Created: time.Now()},
{Username: "xlw1010", Departname: "dev", Alias: "lunny3", Created: time.Now()},
}
_, err = engine.Insert(&users)
if err != nil {
t.Error(err)
}
engine.InsertMany = true*/
} }
func insertTwoTable(engine *Engine, t *testing.T) { func insertTwoTable(engine *Engine, t *testing.T) {
@ -148,6 +145,7 @@ func insertTwoTable(engine *Engine, t *testing.T) {
_, err := engine.Insert(&userinfo, &userdetail) _, err := engine.Insert(&userinfo, &userdetail)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
} }
@ -157,12 +155,13 @@ func update(engine *Engine, t *testing.T) {
_, err := engine.Id(1).Update(&user) _, err := engine.Id(1).Update(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return panic(err)
} }
_, err = engine.Update(&Userinfo{Username: "yyy"}, &user) _, err = engine.Update(&Userinfo{Username: "yyy"}, &user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
} }
@ -171,6 +170,7 @@ func testdelete(engine *Engine, t *testing.T) {
_, err := engine.Delete(&user) _, err := engine.Delete(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
} }
@ -180,6 +180,7 @@ func get(engine *Engine, t *testing.T) {
has, err := engine.Get(&user) has, err := engine.Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
if has { if has {
fmt.Println(user) fmt.Println(user)
@ -194,6 +195,7 @@ func cascadeGet(engine *Engine, t *testing.T) {
has, err := engine.Get(&user) has, err := engine.Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
if has { if has {
fmt.Println(user) fmt.Println(user)
@ -208,6 +210,7 @@ func find(engine *Engine, t *testing.T) {
err := engine.Find(&users) err := engine.Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
fmt.Println(users) fmt.Println(users)
} }
@ -218,6 +221,7 @@ func findMap(engine *Engine, t *testing.T) {
err := engine.Find(&users) err := engine.Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
fmt.Println(users) fmt.Println(users)
} }
@ -227,6 +231,7 @@ func count(engine *Engine, t *testing.T) {
total, err := engine.Count(&user) total, err := engine.Count(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
fmt.Printf("Total %d records!!!", total) fmt.Printf("Total %d records!!!", total)
} }
@ -236,6 +241,7 @@ func where(engine *Engine, t *testing.T) {
err := engine.Where("id > ?", 2).Find(&users) err := engine.Where("id > ?", 2).Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
fmt.Println(users) fmt.Println(users)
} }
@ -245,7 +251,7 @@ func in(engine *Engine, t *testing.T) {
err := engine.In("id", 1, 2, 3).Find(&users) err := engine.In("id", 1, 2, 3).Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return panic(err)
} }
fmt.Println(users) fmt.Println(users)
@ -253,7 +259,7 @@ func in(engine *Engine, t *testing.T) {
err = engine.Where("id > ?", 2).In("id", ids...).Find(&users) err = engine.Where("id > ?", 2).In("id", ids...).Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return panic(err)
} }
fmt.Println(users) fmt.Println(users)
} }
@ -263,6 +269,7 @@ func limit(engine *Engine, t *testing.T) {
err := engine.Limit(2, 1).Find(&users) err := engine.Limit(2, 1).Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
fmt.Println(users) fmt.Println(users)
} }
@ -272,6 +279,7 @@ func order(engine *Engine, t *testing.T) {
err := engine.OrderBy("id desc").Find(&users) err := engine.OrderBy("id desc").Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
fmt.Println(users) fmt.Println(users)
} }
@ -281,6 +289,7 @@ func join(engine *Engine, t *testing.T) {
err := engine.Join("LEFT", "userdetail", "userinfo.id=userdetail.id").Find(&users) err := engine.Join("LEFT", "userdetail", "userinfo.id=userdetail.id").Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
} }
@ -289,6 +298,7 @@ func having(engine *Engine, t *testing.T) {
err := engine.GroupBy("username").Having("username='xlw'").Find(&users) err := engine.GroupBy("username").Having("username='xlw'").Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
fmt.Println(users) fmt.Println(users)
} }
@ -310,7 +320,7 @@ func transaction(engine *Engine, t *testing.T) {
err := session.Begin() err := session.Begin()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return panic(err)
} }
//session.IsAutoRollback = false //session.IsAutoRollback = false
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
@ -318,7 +328,7 @@ func transaction(engine *Engine, t *testing.T) {
if err != nil { if err != nil {
session.Rollback() session.Rollback()
t.Error(err) t.Error(err)
return panic(err)
} }
user2 := Userinfo{Username: "yyy"} user2 := Userinfo{Username: "yyy"}
_, err = session.Where("uid = ?", 0).Update(&user2) _, err = session.Where("uid = ?", 0).Update(&user2)
@ -333,14 +343,15 @@ func transaction(engine *Engine, t *testing.T) {
if err != nil { if err != nil {
session.Rollback() session.Rollback()
t.Error(err) t.Error(err)
return panic(err)
} }
err = session.Commit() err = session.Commit()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return panic(err)
} }
panic(err)
} }
func combineTransaction(engine *Engine, t *testing.T) { func combineTransaction(engine *Engine, t *testing.T) {
@ -360,7 +371,7 @@ func combineTransaction(engine *Engine, t *testing.T) {
err := session.Begin() err := session.Begin()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return panic(err)
} }
//session.IsAutoRollback = false //session.IsAutoRollback = false
user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
@ -368,32 +379,42 @@ func combineTransaction(engine *Engine, t *testing.T) {
if err != nil { if err != nil {
session.Rollback() session.Rollback()
t.Error(err) t.Error(err)
return panic(err)
} }
user2 := Userinfo{Username: "zzz"} user2 := Userinfo{Username: "zzz"}
_, err = session.Where("id = ?", 0).Update(&user2) _, err = session.Where("id = ?", 0).Update(&user2)
if err != nil { if err != nil {
session.Rollback() session.Rollback()
t.Error(err) t.Error(err)
return panic(err)
} }
_, err = session.Exec("delete from userinfo where username = ?", user2.Username) _, err = session.Exec("delete from userinfo where username = ?", user2.Username)
if err != nil { if err != nil {
session.Rollback() session.Rollback()
t.Error(err) t.Error(err)
return panic(err)
} }
err = session.Commit() err = session.Commit()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return panic(err)
} }
} }
func table(engine *Engine, t *testing.T) { func table(engine *Engine, t *testing.T) {
engine.Table("user_user").CreateTable(&Userinfo{}) err := engine.DropTables("user_user")
if err != nil {
t.Error(err)
panic(err)
}
err = engine.Table("user_user").CreateTable(&Userinfo{})
if err != nil {
t.Error(err)
panic(err)
}
} }
func createMultiTables(engine *Engine, t *testing.T) { func createMultiTables(engine *Engine, t *testing.T) {
@ -404,19 +425,29 @@ func createMultiTables(engine *Engine, t *testing.T) {
err := session.Begin() err := session.Begin()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return panic(err)
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
err = session.Table(fmt.Sprintf("user_%v", i)).CreateTable(user) tableName := fmt.Sprintf("user_%v", i)
err = engine.DropTables(tableName)
if err != nil { if err != nil {
session.Rollback() session.Rollback()
t.Error(err) t.Error(err)
return panic(err)
}
err = session.Table(tableName).CreateTable(user)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
} }
} }
err = session.Commit() err = session.Commit()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
} }
@ -426,26 +457,118 @@ func tableOp(engine *Engine, t *testing.T) {
id, err := engine.Table(tableName).Insert(&user) id, err := engine.Table(tableName).Insert(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
_, err = engine.Table(tableName).Get(&Userinfo{Username: "tablexiao"}) _, err = engine.Table(tableName).Get(&Userinfo{Username: "tablexiao"})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
err = engine.Table(tableName).Find(&users) err = engine.Table(tableName).Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
_, err = engine.Table(tableName).Id(id).Update(&Userinfo{Username: "tableda"}) _, err = engine.Table(tableName).Id(id).Update(&Userinfo{Username: "tableda"})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
_, err = engine.Table(tableName).Id(id).Delete(&Userinfo{}) _, err = engine.Table(tableName).Id(id).Delete(&Userinfo{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err)
} }
} }
func testCharst(engine *Engine, t *testing.T) {
err := engine.DropTables("user_charset")
if err != nil {
t.Error(err)
panic(err)
}
err = engine.Charset("utf8").Table("user_charset").CreateTable(&Userinfo{})
if err != nil {
t.Error(err)
panic(err)
}
}
func testStoreEngine(engine *Engine, t *testing.T) {
err := engine.DropTables("user_store_engine")
if err != nil {
t.Error(err)
panic(err)
}
err = engine.StoreEngine("InnoDB").Table("user_store_engine").CreateTable(&Userinfo{})
if err != nil {
t.Error(err)
panic(err)
}
}
type tempUser struct {
Id int64
Username string
}
func testCols(engine *Engine, t *testing.T) {
users := []Userinfo{}
err := engine.Cols("id, username").Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(users)
tmpUsers := []tempUser{}
err = engine.Table("userinfo").Cols("id, username").Find(&tmpUsers)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(tmpUsers)
}
func testTrans(engine *Engine, t *testing.T) {
}
func testAll(engine *Engine, t *testing.T) {
directCreateTable(engine, t)
mapper(engine, t)
insert(engine, t)
query(engine, t)
exec(engine, t)
insertAutoIncr(engine, t)
insertMulti(engine, t)
insertTwoTable(engine, t)
update(engine, t)
testdelete(engine, t)
get(engine, t)
cascadeGet(engine, t)
find(engine, t)
findMap(engine, t)
count(engine, t)
where(engine, t)
in(engine, t)
limit(engine, t)
order(engine, t)
join(engine, t)
having(engine, t)
transaction(engine, t)
combineTransaction(engine, t)
table(engine, t)
createMultiTables(engine, t)
tableOp(engine, t)
testCols(engine, t)
testCharst(engine, t)
testStoreEngine(engine, t)
}

View File

@ -1,21 +0,0 @@
// Copyright 2013 The XORM Authors. All rights reserved.
// Use of this source code is governed by a BSD
// license that can be found in the LICENSE file.
// Package xorm provides is a simple and powerful ORM for Go. It makes your
// database operation simple.
// Warning: All contents in this file will be removed from xorm some times after
package xorm
// @deprecated : please use NewSession instead
func (engine *Engine) MakeSession() (Session, error) {
s := engine.NewSession()
return *s, nil
}
// @deprecated : please use NewEngine instead
func Create(driverName string, dataSourceName string) Engine {
engine, _ := NewEngine(driverName, dataSourceName)
return *engine
}

View File

@ -10,6 +10,7 @@ package xorm
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"io"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
@ -17,8 +18,7 @@ import (
) )
const ( const (
PQSQL = "pqsql" POSTGRES = "postgres"
MSSQL = "mssql"
SQLITE = "sqlite3" SQLITE = "sqlite3"
MYSQL = "mysql" MYSQL = "mysql"
MYMYSQL = "mymysql" MYMYSQL = "mymysql"
@ -27,8 +27,10 @@ const (
type dialect interface { type dialect interface {
SqlType(t *Column) string SqlType(t *Column) string
SupportInsertMany() bool SupportInsertMany() bool
QuoteIdentifier() string QuoteStr() string
AutoIncrIdentifier() string AutoIncrStr() string
SupportEngine() bool
SupportCharset() bool
} }
type Engine struct { type Engine struct {
@ -42,18 +44,28 @@ type Engine struct {
ShowSQL bool ShowSQL bool
pool IConnectPool pool IConnectPool
CacheMapping bool CacheMapping bool
Filters []Filter
Logger io.Writer
} }
func (engine *Engine) SupportInsertMany() bool { func (engine *Engine) SupportInsertMany() bool {
return engine.Dialect.SupportInsertMany() return engine.Dialect.SupportInsertMany()
} }
func (engine *Engine) QuoteIdentifier() string { func (engine *Engine) QuoteStr() string {
return engine.Dialect.QuoteIdentifier() return engine.Dialect.QuoteStr()
} }
func (engine *Engine) AutoIncrIdentifier() string { func (engine *Engine) Quote(sql string) string {
return engine.Dialect.AutoIncrIdentifier() return engine.Dialect.QuoteStr() + sql + engine.Dialect.QuoteStr()
}
func (engine *Engine) SqlType(c *Column) string {
return engine.Dialect.SqlType(c)
}
func (engine *Engine) AutoIncrStr() string {
return engine.Dialect.AutoIncrStr()
} }
func (engine *Engine) SetPool(pool IConnectPool) error { func (engine *Engine) SetPool(pool IConnectPool) error {
@ -90,12 +102,20 @@ func (engine *Engine) Close() error {
func (engine *Engine) Test() error { func (engine *Engine) Test() error {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
if engine.ShowSQL { engine.LogSQL("PING DATABASE", engine.DriverName)
fmt.Printf("PING DATABASE %v\n", engine.DriverName)
}
return session.Ping() return session.Ping()
} }
func (engine *Engine) LogSQL(contents ...interface{}) {
if engine.ShowSQL {
io.WriteString(engine.Logger, fmt.Sprintln(contents...))
}
}
func (engine *Engine) LogError(contents ...interface{}) {
io.WriteString(engine.Logger, fmt.Sprintln(contents...))
}
func (engine *Engine) Sql(querystring string, args ...interface{}) *Session { func (engine *Engine) Sql(querystring string, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
return session.Sql(querystring, args...) return session.Sql(querystring, args...)
@ -126,6 +146,16 @@ func (engine *Engine) StoreEngine(storeEngine string) *Session {
return session.StoreEngine(storeEngine) return session.StoreEngine(storeEngine)
} }
func (engine *Engine) Cols(columns ...string) *Session {
session := engine.NewSession()
return session.Cols(columns...)
}
func (engine *Engine) Trans(t string) *Session {
session := engine.NewSession()
return session.Trans(t)
}
func (engine *Engine) In(column string, args ...interface{}) *Session { func (engine *Engine) In(column string, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
return session.In(column, args...) return session.In(column, args...)
@ -273,6 +303,10 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
} }
case k == "date": case k == "date":
col.SQLType = Date col.SQLType = Date
case k == "float":
col.SQLType = Float
case k == "double":
col.SQLType = Double
case k == "datetime": case k == "datetime":
col.SQLType = DateTime col.SQLType = DateTime
case k == "timestamp": case k == "timestamp":
@ -375,7 +409,8 @@ func (e *Engine) DropAll() error {
} }
err = session.DropAll() err = session.DropAll()
if err != nil { if err != nil {
return session.Rollback() session.Rollback()
return err
} }
return session.Commit() return session.Commit()
} }
@ -418,17 +453,8 @@ func (e *Engine) DropTables(beans ...interface{}) error {
func (e *Engine) CreateAll() error { func (e *Engine) CreateAll() error {
session := e.NewSession() session := e.NewSession()
err := session.Begin()
defer session.Close() defer session.Close()
if err != nil { return session.CreateAll()
return err
}
err = session.CreateAll()
if err != nil {
return session.Rollback()
}
return session.Commit()
} }
func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) { func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) {
@ -449,6 +475,12 @@ func (engine *Engine) Insert(beans ...interface{}) (int64, error) {
return session.Insert(beans...) return session.Insert(beans...)
} }
func (engine *Engine) InsertOne(bean interface{}) (int64, error) {
session := engine.NewSession()
defer session.Close()
return session.InsertOne(bean)
}
func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) { func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()

View File

@ -22,7 +22,7 @@ func sqliteEngine() (*xorm.Engine, error) {
} }
func mysqlEngine() (*xorm.Engine, error) { func mysqlEngine() (*xorm.Engine, error) {
return xorm.NewEngine("mysql", "root:123@/test?charset=utf8") return xorm.NewEngine("mysql", "root:@/test?charset=utf8")
} }
var u *User = &User{} var u *User = &User{}

43
filter.go Normal file
View File

@ -0,0 +1,43 @@
package xorm
import (
"fmt"
"strings"
)
type Filter interface {
Do(sql string, session *Session) string
}
type PgSeqFilter struct {
}
func (s *PgSeqFilter) Do(sql string, session *Session) string {
segs := strings.Split(sql, "?")
size := len(segs)
res := ""
for i, c := range segs {
if i < size-1 {
res += c + fmt.Sprintf("$%v", i+1)
}
}
res += segs[size-1]
return res
}
type PgQuoteFilter struct {
}
func (s *PgQuoteFilter) Do(sql string, session *Session) string {
return strings.Replace(sql, "`", session.Engine.QuoteStr(), -1)
}
type IdFilter struct {
}
func (i *IdFilter) Do(sql string, session *Session) string {
if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" {
return strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1)
}
return sql
}

23
mymysql_test.go Normal file
View File

@ -0,0 +1,23 @@
package xorm
import (
_ "github.com/ziutek/mymysql/godrv"
"testing"
)
/*
CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET
utf8 COLLATE utf8_general_ci;
*/
func TestMyMysql(t *testing.T) {
engine, err := NewEngine("mymysql", "xorm_test/root/")
defer engine.Close()
if err != nil {
t.Error(err)
return
}
engine.ShowSQL = true
testAll(engine, t)
}

View File

@ -13,26 +13,48 @@ type mysql struct {
} }
func (db *mysql) SqlType(c *Column) string { func (db *mysql) SqlType(c *Column) string {
var res string
switch t := c.SQLType; t { switch t := c.SQLType; t {
case Date, DateTime, TimeStamp: case Bool:
return "DATETIME" res = TinyInt.Name
case Varchar: case Serial:
return t.Name + "(" + strconv.Itoa(c.Length) + ")" c.IsAutoIncrement = true
case Decimal: res = Int.Name
return t.Name + "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" case BigSerial:
c.IsAutoIncrement = true
res = Integer.Name
case Bytea:
res = Blob.Name
default: default:
return t.Name res = t.Name
} }
var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0)
if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")"
} else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
}
return res
} }
func (db *mysql) SupportInsertMany() bool { func (db *mysql) SupportInsertMany() bool {
return true return true
} }
func (db *mysql) QuoteIdentifier() string { func (db *mysql) QuoteStr() string {
return "`" return "`"
} }
func (db *mysql) AutoIncrIdentifier() string { func (db *mysql) AutoIncrStr() string {
return "AUTO_INCREMENT" return "AUTO_INCREMENT"
} }
func (db *mysql) SupportEngine() bool {
return true
}
func (db *mysql) SupportCharset() bool {
return true
}

View File

@ -5,47 +5,19 @@ import (
"testing" "testing"
) )
var me Engine
/* /*
CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET
utf8 COLLATE utf8_general_ci; utf8 COLLATE utf8_general_ci;
*/ */
func TestMysql(t *testing.T) { func TestMysql(t *testing.T) {
// You should drop all tables before executing this testing
engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8") engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8")
defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
me = *engine engine.ShowSQL = true
me.ShowSQL = true
directCreateTable(&me, t) testAll(engine, t)
mapper(&me, t)
insert(&me, t)
query(&me, t)
exec(&me, t)
insertAutoIncr(&me, t)
insertMulti(&me, t)
insertTwoTable(&me, t)
update(&me, t)
testdelete(&me, t)
get(&me, t)
cascadeGet(&me, t)
find(&me, t)
findMap(&me, t)
count(&me, t)
where(&me, t)
in(&me, t)
limit(&me, t)
order(&me, t)
join(&me, t)
having(&me, t)
transaction(&me, t)
combineTransaction(&me, t)
table(&me, t)
createMultiTables(&me, t)
tableOp(&me, t)
} }

72
postgres.go Normal file
View File

@ -0,0 +1,72 @@
// Copyright 2013 The XORM Authors. All rights reserved.
// Use of this source code is governed by a BSD
// license that can be found in the LICENSE file.
// Package xorm provides is a simple and powerful ORM for Go. It makes your
// database operation simple.
package xorm
import "strconv"
type postgres struct {
}
func (db *postgres) SqlType(c *Column) string {
var res string
switch t := c.SQLType; t {
case TinyInt:
res = SmallInt.Name
case MediumInt, Int, Integer:
return Integer.Name
case Serial, BigSerial:
c.IsAutoIncrement = true
res = t.Name
case Binary, VarBinary:
res = Bytea.Name
case DateTime:
res = TimeStamp.Name
case Float:
res = Real.Name
case TinyText, MediumText, LongText:
res = Text.Name
case Blob, TinyBlob, MediumBlob, LongBlob:
res = Bytea.Name
case Double:
return "DOUBLE PRECISION"
default:
if c.IsAutoIncrement {
return Serial.Name
}
res = t.Name
}
var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0)
if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")"
} else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
}
return res
}
func (db *postgres) SupportInsertMany() bool {
return true
}
func (db *postgres) QuoteStr() string {
return "\""
}
func (db *postgres) AutoIncrStr() string {
return ""
}
func (db *postgres) SupportEngine() bool {
return false
}
func (db *postgres) SupportCharset() bool {
return false
}

18
postgres_test.go Normal file
View File

@ -0,0 +1,18 @@
package xorm
import (
_ "github.com/bylevel/pq"
"testing"
)
func TestPostgres(t *testing.T) {
engine, err := NewEngine("postgres", "dbname=xorm_test sslmode=disable")
defer engine.Close()
if err != nil {
t.Error(err)
return
}
engine.ShowSQL = true
testAll(engine, t)
}

View File

@ -24,6 +24,7 @@ type Session struct {
Statement Statement Statement Statement
IsAutoCommit bool IsAutoCommit bool
IsCommitedOrRollbacked bool IsCommitedOrRollbacked bool
TransType string
} }
func (session *Session) Init() { func (session *Session) Init() {
@ -69,6 +70,16 @@ func (session *Session) In(column string, args ...interface{}) *Session {
return session return session
} }
func (session *Session) Cols(columns ...string) *Session {
session.Statement.Cols(columns...)
return session
}
func (session *Session) Trans(t string) *Session {
session.TransType = t
return session
}
func (session *Session) Limit(limit int, start ...int) *Session { func (session *Session) Limit(limit int, start ...int) *Session {
session.Statement.Limit(limit, start...) session.Statement.Limit(limit, start...)
return session return session
@ -136,18 +147,15 @@ func (session *Session) Begin() error {
session.IsAutoCommit = false session.IsAutoCommit = false
session.IsCommitedOrRollbacked = false session.IsCommitedOrRollbacked = false
session.Tx = tx session.Tx = tx
if session.Engine.ShowSQL {
fmt.Println("BEGIN TRANSACTION") session.Engine.LogSQL("BEGIN TRANSACTION")
}
} }
return nil return nil
} }
func (session *Session) Rollback() error { func (session *Session) Rollback() error {
if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { if !session.IsAutoCommit && !session.IsCommitedOrRollbacked {
if session.Engine.ShowSQL { session.Engine.LogSQL("ROLL BACK")
fmt.Println("ROLL BACK")
}
session.IsCommitedOrRollbacked = true session.IsCommitedOrRollbacked = true
return session.Tx.Rollback() return session.Tx.Rollback()
} }
@ -156,9 +164,7 @@ func (session *Session) Rollback() error {
func (session *Session) Commit() error { func (session *Session) Commit() error {
if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { if !session.IsAutoCommit && !session.IsCommitedOrRollbacked {
if session.Engine.ShowSQL { session.Engine.LogSQL("COMMIT")
fmt.Println("COMMIT")
}
session.IsCommitedOrRollbacked = true session.IsCommitedOrRollbacked = true
return session.Tx.Commit() return session.Tx.Commit()
} }
@ -168,7 +174,7 @@ func (session *Session) Commit() error {
func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error {
dataStruct := reflect.Indirect(reflect.ValueOf(obj)) dataStruct := reflect.Indirect(reflect.ValueOf(obj))
if dataStruct.Kind() != reflect.Struct { if dataStruct.Kind() != reflect.Struct {
return errors.New("expected a pointer to a struct") return errors.New("Expected a pointer to a struct")
} }
table := session.Engine.Tables[Type(obj)] table := session.Engine.Tables[Type(obj)]
@ -181,7 +187,7 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
fieldPath := strings.Split(fieldName, ".") fieldPath := strings.Split(fieldName, ".")
var structField reflect.Value var structField reflect.Value
if len(fieldPath) > 2 { if len(fieldPath) > 2 {
fmt.Printf("xorm: Warning! Unsupported mutliderive %v\n", fieldName) session.Engine.LogError("Unsupported mutliderive", fieldName)
continue continue
} else if len(fieldPath) == 2 { } else if len(fieldPath) == 2 {
parentField := dataStruct.FieldByName(fieldPath[0]) parentField := dataStruct.FieldByName(fieldPath[0])
@ -207,7 +213,7 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
case reflect.String: case reflect.String:
v = string(data) v = string(data)
case reflect.Bool: case reflect.Bool:
v = string(data) == "1" v = (string(data) == "1")
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
x, err := strconv.Atoi(string(data)) x, err := strconv.Atoi(string(data))
if err != nil { if err != nil {
@ -269,14 +275,14 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
if has { if has {
v = structInter.Elem().Interface() v = structInter.Elem().Interface()
} else { } else {
fmt.Println("cascade obj is not exist!") session.Engine.LogError("cascade obj is not exist!")
continue continue
} }
} else { } else {
continue continue
} }
} else { } else {
fmt.Println("unsupported struct type in Scan: " + structField.Type().String()) session.Engine.LogError("unsupported struct type in Scan: " + structField.Type().String())
continue continue
} }
} else { } else {
@ -309,20 +315,17 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result,
func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) {
err := session.newDb() err := session.newDb()
if session.IsAutoCommit {
defer session.Close()
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { for _, filter := range session.Engine.Filters {
sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1) sql = filter.Do(sql, session)
}
if session.Engine.ShowSQL {
fmt.Println(sql)
fmt.Println(args)
} }
session.Engine.LogSQL(sql)
session.Engine.LogSQL(args)
if session.IsAutoCommit { if session.IsAutoCommit {
return session.innerExec(sql, args...) return session.innerExec(sql, args...)
} }
@ -331,13 +334,21 @@ func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error
// this function create a table according a bean // this function create a table according a bean
func (session *Session) CreateTable(bean interface{}) error { func (session *Session) CreateTable(bean interface{}) error {
statement := session.Statement session.Statement.RefTable = session.Engine.AutoMap(bean)
defer statement.Init()
statement.RefTable = session.Engine.AutoMap(bean) err := session.newDb()
sql := statement.genCreateSQL() if err != nil {
return err
}
return session.createOneTable()
}
func (session *Session) createOneTable() error {
sql := session.Statement.genCreateSQL()
_, err := session.Exec(sql) _, err := session.Exec(sql)
if err == nil { if err == nil {
sqls := statement.genIndexSQL() sqls := session.Statement.genIndexSQL()
for _, sql := range sqls { for _, sql := range sqls {
_, err = session.Exec(sql) _, err = session.Exec(sql)
if err != nil { if err != nil {
@ -346,7 +357,7 @@ func (session *Session) CreateTable(bean interface{}) error {
} }
} }
if err == nil { if err == nil {
sqls := statement.genUniqueSQL() sqls := session.Statement.genUniqueSQL()
for _, sql := range sqls { for _, sql := range sqls {
_, err = session.Exec(sql) _, err = session.Exec(sql)
if err != nil { if err != nil {
@ -357,26 +368,59 @@ func (session *Session) CreateTable(bean interface{}) error {
return err return err
} }
func (session *Session) CreateAll() error {
err := session.newDb()
if err != nil {
return err
}
for _, table := range session.Engine.Tables {
session.Statement.RefTable = table
err := session.createOneTable()
if err != nil {
return err
}
}
return nil
}
func (session *Session) DropTable(bean interface{}) error { func (session *Session) DropTable(bean interface{}) error {
statement := session.Statement err := session.newDb()
defer statement.Init() if err != nil {
statement.RefTable = session.Engine.AutoMap(bean) return err
sql := statement.genDropSQL() }
_, err := session.Exec(sql)
t := reflect.Indirect(reflect.ValueOf(bean)).Type()
defer session.Statement.Init()
if t.Kind() == reflect.String {
session.Statement.AltTableName = bean.(string)
} else if t.Kind() == reflect.Struct {
session.Statement.RefTable = session.Engine.AutoMap(bean)
} else {
return errors.New("Unsupported type")
}
sql := session.Statement.genDropSQL()
_, err = session.Exec(sql)
return err return err
} }
func (session *Session) Get(bean interface{}) (bool, error) { func (session *Session) Get(bean interface{}) (bool, error) {
statement := session.Statement err := session.newDb()
defer statement.Init() if err != nil {
statement.Limit(1) return false, err
}
defer session.Statement.Init()
session.Statement.Limit(1)
var sql string var sql string
var args []interface{} var args []interface{}
if statement.RawSQL == "" { if session.Statement.RawSQL == "" {
sql, args = statement.genGetSql(bean) sql, args = session.Statement.genGetSql(bean)
} else { } else {
sql = statement.RawSQL sql = session.Statement.RawSQL
args = statement.RawParams args = session.Statement.RawParams
session.Engine.AutoMap(bean)
} }
resultsSlice, err := session.Query(sql, args...) resultsSlice, err := session.Query(sql, args...)
if err != nil { if err != nil {
@ -387,7 +431,6 @@ func (session *Session) Get(bean interface{}) (bool, error) {
} }
results := resultsSlice[0] results := resultsSlice[0]
session.Engine.AutoMap(bean)
err = session.scanMapIntoStruct(bean, results) err = session.scanMapIntoStruct(bean, results)
if err != nil { if err != nil {
return false, err return false, err
@ -400,15 +443,19 @@ func (session *Session) Get(bean interface{}) (bool, error) {
} }
func (session *Session) Count(bean interface{}) (int64, error) { func (session *Session) Count(bean interface{}) (int64, error) {
statement := session.Statement err := session.newDb()
if err != nil {
return 0, err
}
defer session.Statement.Init() defer session.Statement.Init()
var sql string var sql string
var args []interface{} var args []interface{}
if statement.RawSQL == "" { if session.Statement.RawSQL == "" {
sql, args = statement.genCountSql(bean) sql, args = session.Statement.genCountSql(bean)
} else { } else {
sql = statement.RawSQL sql = session.Statement.RawSQL
args = statement.RawParams args = session.Statement.RawParams
} }
resultsSlice, err := session.Query(sql, args...) resultsSlice, err := session.Query(sql, args...)
@ -429,7 +476,11 @@ func (session *Session) Count(bean interface{}) (int64, error) {
} }
func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) error { func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
statement := session.Statement err := session.newDb()
if err != nil {
return err
}
defer session.Statement.Init() defer session.Statement.Init()
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
@ -438,22 +489,26 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
sliceElementType := sliceValue.Type().Elem() sliceElementType := sliceValue.Type().Elem()
table := session.Engine.AutoMapType(sliceElementType) table := session.Engine.AutoMapType(sliceElementType)
statement.RefTable = table session.Statement.RefTable = table
if len(condiBean) > 0 { if len(condiBean) > 0 {
colNames, args := BuildConditions(session.Engine, table, condiBean[0]) colNames, args := BuildConditions(session.Engine, table, condiBean[0])
statement.ColumnStr = strings.Join(colNames, " and ") session.Statement.ConditionStr = strings.Join(colNames, " and ")
statement.BeanArgs = args session.Statement.BeanArgs = args
} }
var sql string var sql string
var args []interface{} var args []interface{}
if statement.RawSQL == "" { if session.Statement.RawSQL == "" {
sql = statement.generateSql() var columnStr string = session.Statement.ColumnStr
args = append(statement.Params, statement.BeanArgs...) if columnStr == "" {
columnStr = session.Statement.genColumnStr()
}
sql = session.Statement.genSelectSql(columnStr)
args = append(session.Statement.Params, session.Statement.BeanArgs...)
} else { } else {
sql = statement.RawSQL sql = session.Statement.RawSQL
args = statement.RawParams args = session.Statement.RawParams
} }
resultsSlice, err := session.Query(sql, args...) resultsSlice, err := session.Query(sql, args...)
@ -496,20 +551,14 @@ func (session *Session) Ping() error {
return session.Db.Ping() return session.Db.Ping()
} }
func (session *Session) CreateAll() error { func (session *Session) DropAll() error {
for _, table := range session.Engine.Tables { err := session.newDb()
session.Statement.RefTable = table
sql := session.Statement.genCreateSQL()
_, err := session.Exec(sql)
if err != nil { if err != nil {
return err return err
} }
}
return nil
}
func (session *Session) DropAll() error {
for _, table := range session.Engine.Tables { for _, table := range session.Engine.Tables {
session.Statement.Init()
session.Statement.RefTable = table session.Statement.RefTable = table
sql := session.Statement.genDropSQL() sql := session.Statement.genDropSQL()
_, err := session.Exec(sql) _, err := session.Exec(sql)
@ -526,18 +575,13 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
return nil, err return nil, err
} }
if session.IsAutoCommit { for _, filter := range session.Engine.Filters {
defer session.Close() sql = filter.Do(sql, session)
} }
// TODO: this statement should be invoke before Query session.Engine.LogSQL(sql)
if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { session.Engine.LogSQL(paramStr)
sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1)
}
if session.Engine.ShowSQL {
fmt.Println(sql)
fmt.Println(paramStr)
}
s, err := session.Db.Prepare(sql) s, err := session.Db.Prepare(sql)
if err != nil { if err != nil {
return nil, err return nil, err
@ -597,7 +641,7 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700") str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700")
result[key] = []byte(str) result[key] = []byte(str)
} else { } else {
fmt.Print("Unsupported struct type") session.Engine.LogError("Unsupported struct type")
} }
} }
//default: //default:
@ -625,7 +669,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
sliceValue := reflect.Indirect(reflect.ValueOf(bean)) sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice { if sliceValue.Kind() == reflect.Slice {
if session.Engine.SupportInsertMany() { if session.Engine.SupportInsertMany() {
lastId, err = session.InsertMulti(bean) lastId, err = session.innerInsertMulti(bean)
if err != nil { if err != nil {
if !isInTransaction { if !isInTransaction {
err1 := session.Rollback() err1 := session.Rollback()
@ -639,7 +683,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
} else { } else {
size := sliceValue.Len() size := sliceValue.Len()
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
lastId, err = session.InsertOne(sliceValue.Index(i).Interface()) lastId, err = session.innerInsert(sliceValue.Index(i).Interface())
if err != nil { if err != nil {
if !isInTransaction { if !isInTransaction {
err1 := session.Rollback() err1 := session.Rollback()
@ -653,7 +697,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
} }
} }
} else { } else {
lastId, err = session.InsertOne(bean) lastId, err = session.innerInsert(bean)
if err != nil { if err != nil {
if !isInTransaction { if !isInTransaction {
err1 := session.Rollback() err1 := session.Rollback()
@ -672,7 +716,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
return lastId, err return lastId, err
} }
func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice { if sliceValue.Kind() != reflect.Slice {
return -1, errors.New("needs a pointer to a slice") return -1, errors.New("needs a pointer to a slice")
@ -698,20 +742,18 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
if i == 0 { if i == 0 {
for _, col := range table.Columns { for _, col := range table.Columns {
fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName) fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName)
val := fieldValue.Interface()
if col.IsAutoIncrement && fieldValue.Int() == 0 { if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue continue
} }
if col.MapType == ONLYFROMDB { if col.MapType == ONLYFROMDB {
continue continue
} }
if table, ok := session.Engine.Tables[fieldValue.Type()]; ok { arg, err := session.value2Interface(fieldValue)
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) if err != nil {
//fmt.Println(pkField.Interface()) return 0, err
args = append(args, pkField.Interface())
} else {
args = append(args, val)
} }
args = append(args, arg)
colNames = append(colNames, col.Name) colNames = append(colNames, col.Name)
cols = append(cols, col) cols = append(cols, col)
colPlaces = append(colPlaces, "?") colPlaces = append(colPlaces, "?")
@ -719,30 +761,36 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
} else { } else {
for _, col := range cols { for _, col := range cols {
fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName) fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName)
val := fieldValue.Interface()
if col.IsAutoIncrement && fieldValue.Int() == 0 { if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue continue
} }
if col.MapType == ONLYFROMDB { if col.MapType == ONLYFROMDB {
continue continue
} }
if table, ok := session.Engine.Tables[fieldValue.Type()]; ok { if session.Statement.ColumnStr != "" {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) if _, ok := session.Statement.columnMap[col.Name]; !ok {
args = append(args, pkField.Interface()) continue
} else {
args = append(args, val)
} }
}
arg, err := session.value2Interface(fieldValue)
if err != nil {
return 0, err
}
args = append(args, arg)
colPlaces = append(colPlaces, "?") colPlaces = append(colPlaces, "?")
} }
} }
colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", ")) colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
} }
statement := fmt.Sprintf("INSERT INTO %v%v%v (%v) VALUES (%v)", statement := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v);",
session.Engine.QuoteIdentifier(), session.Engine.QuoteStr(),
session.Statement.TableName(), session.Statement.TableName(),
session.Engine.QuoteIdentifier(), session.Engine.QuoteStr(),
strings.Join(colNames, ", "), session.Engine.QuoteStr(),
strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()),
session.Engine.QuoteStr(),
strings.Join(colMultiPlaces, "),(")) strings.Join(colMultiPlaces, "),("))
res, err := session.Exec(statement, args...) res, err := session.Exec(statement, args...)
@ -759,24 +807,27 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
return id, nil return id, nil
} }
func (session *Session) InsertOne(bean interface{}) (int64, error) { func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
table := session.Engine.AutoMap(bean) err := session.newDb()
//fmt.Printf("table: %v\n", table) if session.IsAutoCommit {
session.Statement.RefTable = table defer session.Close()
colNames := make([]string, 0)
colPlaces := make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns {
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
val := fieldValue.Interface()
if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue
} }
if col.MapType == ONLYFROMDB { if err != nil {
continue return 0, err
} }
if fieldValue.Type().String() == "time.Time" {
args = append(args, val) return session.innerInsertMulti(rowsSlicePtr)
}
func (session *Session) value2Interface(fieldValue reflect.Value) (interface{}, error) {
if fieldValue.Type().Kind() == reflect.Bool {
if fieldValue.Bool() {
return 1, nil
} else {
return 0, nil
}
} else if fieldValue.Type().String() == "time.Time" {
return fieldValue.Interface(), nil
} else if fieldValue.Type().Kind() == reflect.Struct { } else if fieldValue.Type().Kind() == reflect.Struct {
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
@ -784,69 +835,121 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} else { } else {
args = append(args, string(data)) return string(data), nil
} }
} else { }
}
if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok {
if fieldTable.PrimaryKey != "" { if fieldTable.PrimaryKey != "" {
pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumn().FieldName) pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumn().FieldName)
args = append(args, pkField.Interface()) return pkField.Interface(), nil
} else { } else {
continue return 0, errors.New("no primary key")
} }
} else { } else {
//args = append(args, val) return 0, errors.New(fmt.Sprintf("Unsupported type %v", fieldValue.Type()))
}
} else {
return fieldValue.Interface(), nil
}
}
func (session *Session) innerInsert(bean interface{}) (int64, error) {
table := session.Engine.AutoMap(bean)
session.Statement.RefTable = table
colNames := make([]string, 0)
colPlaces := make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns {
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue
}
if col.MapType == ONLYFROMDB {
continue
}
if session.Statement.ColumnStr != "" {
if _, ok := session.Statement.columnMap[col.Name]; !ok {
continue continue
} }
} }
} else {
continue arg, err := session.value2Interface(fieldValue)
} if err != nil {
} else { return 0, err
args = append(args, val)
} }
args = append(args, arg)
colNames = append(colNames, col.Name) colNames = append(colNames, col.Name)
colPlaces = append(colPlaces, "?") colPlaces = append(colPlaces, "?")
} }
sql := fmt.Sprintf("INSERT INTO %v%v%v (%v) VALUES (%v)", sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v);",
session.Engine.QuoteIdentifier(), session.Engine.QuoteStr(),
session.Statement.TableName(), session.Statement.TableName(),
session.Engine.QuoteIdentifier(), session.Engine.QuoteStr(),
strings.Join(colNames, ", "), session.Engine.QuoteStr(),
strings.Join(colNames, session.Engine.Quote(", ")),
session.Engine.QuoteStr(),
strings.Join(colPlaces, ", ")) strings.Join(colPlaces, ", "))
res, err := session.Exec(sql, args...) res, err := session.Exec(sql, args...)
if err != nil { if err != nil {
return -1, err return 0, err
} }
id, err := res.LastInsertId() if table.PrimaryKey == "" {
if err != nil { return 0, nil
return -1, err
} }
if id > 0 && table.PrimaryKey != "" {
var id int64 = 0
pkValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(table.PKColumn().FieldName) pkValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(table.PKColumn().FieldName)
if pkValue.CanSet() { if pkValue.Int() != 0 || !pkValue.CanSet() {
return 0, nil
}
id, err = res.LastInsertId()
if err != nil || id <= 0 {
return 0, err
}
var v interface{} = id var v interface{} = id
switch pkValue.Type().Kind() { switch pkValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32: case reflect.Int8, reflect.Int16, reflect.Int32:
v = int(id) v = int(id)
pkValue.Set(reflect.ValueOf(v))
case reflect.Int64:
pkValue.Set(reflect.ValueOf(v))
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v = uint(id) v = uint(id)
pkValue.Set(reflect.ValueOf(v))
}
} }
} pkValue.Set(reflect.ValueOf(v))
return id, nil return id, nil
} }
func (session *Session) InsertOne(bean interface{}) (int64, error) {
err := session.newDb()
if session.IsAutoCommit {
defer session.Close()
}
if err != nil {
return 0, err
}
return session.innerInsert(bean)
}
func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) {
err := session.newDb()
if session.IsAutoCommit {
defer session.Close()
}
if err != nil {
return 0, err
}
table := session.Engine.AutoMap(bean) table := session.Engine.AutoMap(bean)
session.Statement.RefTable = table session.Statement.RefTable = table
colNames, args := BuildConditions(session.Engine, table, bean) colNames, args := BuildConditions(session.Engine, table, bean)
@ -874,10 +977,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
sql := fmt.Sprintf("UPDATE %v%v%v SET %v %v", sql := fmt.Sprintf("UPDATE %v SET %v %v",
session.Engine.QuoteIdentifier(), session.Engine.Quote(session.Statement.TableName()),
session.Statement.TableName(),
session.Engine.QuoteIdentifier(),
strings.Join(colNames, ", "), strings.Join(colNames, ", "),
condition) condition)
@ -887,15 +988,23 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return -1, err return -1, err
} }
id, err := res.RowsAffected() rows, err := res.RowsAffected()
if err != nil { if err != nil {
return -1, err return -1, err
} }
return id, nil return rows, nil
} }
func (session *Session) Delete(bean interface{}) (int64, error) { func (session *Session) Delete(bean interface{}) (int64, error) {
err := session.newDb()
if session.IsAutoCommit {
defer session.Close()
}
if err != nil {
return 0, err
}
table := session.Engine.AutoMap(bean) table := session.Engine.AutoMap(bean)
session.Statement.RefTable = table session.Statement.RefTable = table
colNames, args := BuildConditions(session.Engine, table, bean) colNames, args := BuildConditions(session.Engine, table, bean)
@ -914,9 +1023,9 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
} }
statement := fmt.Sprintf("DELETE FROM %v%v%v %v", statement := fmt.Sprintf("DELETE FROM %v%v%v %v",
session.Engine.QuoteIdentifier(), session.Engine.QuoteStr(),
session.Statement.TableName(), session.Statement.TableName(),
session.Engine.QuoteIdentifier(), session.Engine.QuoteStr(),
condition) condition)
res, err := session.Exec(statement, append(st.Params, args...)...) res, err := session.Exec(statement, append(st.Params, args...)...)

View File

@ -12,18 +12,21 @@ type sqlite3 struct {
func (db *sqlite3) SqlType(c *Column) string { func (db *sqlite3) SqlType(c *Column) string {
switch t := c.SQLType; t { switch t := c.SQLType; t {
case Date, DateTime, TimeStamp: case Date, DateTime, TimeStamp, Time:
return "NUMERIC" return Numeric.Name
case Char, Varchar, Text: case Char, Varchar, TinyText, Text, MediumText, LongText:
return "TEXT" return Text.Name
case TinyInt, SmallInt, MediumInt, Int, BigInt: case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool:
return "INTEGER" return Integer.Name
case Float, Double: case Float, Double, Real:
return "REAL" return Real.Name
case Decimal: case Decimal, Numeric:
return "NUMERIC" return Numeric.Name
case Blob: case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary:
return "BLOB" return Blob.Name
case Serial, BigSerial:
c.IsAutoIncrement = true
return Integer.Name
default: default:
return t.Name return t.Name
} }
@ -33,10 +36,18 @@ func (db *sqlite3) SupportInsertMany() bool {
return true return true
} }
func (db *sqlite3) QuoteIdentifier() string { func (db *sqlite3) QuoteStr() string {
return "`" return "`"
} }
func (db *sqlite3) AutoIncrIdentifier() string { func (db *sqlite3) AutoIncrStr() string {
return "AUTOINCREMENT" return "AUTOINCREMENT"
} }
func (db *sqlite3) SupportEngine() bool {
return false
}
func (db *sqlite3) SupportCharset() bool {
return false
}

View File

@ -6,147 +6,15 @@ import (
"testing" "testing"
) )
var se *Engine func TestSqlite3(t *testing.T) {
func autoConn() {
if se == nil {
os.Remove("./test.db") os.Remove("./test.db")
se, _ = NewEngine("sqlite3", "./test.db") engine, err := NewEngine("sqlite3", "./test.db")
se.ShowSQL = true defer engine.Close()
} if err != nil {
t.Error(err)
return
} }
engine.ShowSQL = true
func TestSqliteCreateTable(t *testing.T) { testAll(engine, t)
autoConn()
directCreateTable(se, t)
}
func TestSqliteMapper(t *testing.T) {
autoConn()
mapper(se, t)
}
func TestSqliteInsert(t *testing.T) {
autoConn()
insert(se, t)
}
func TestSqliteQuery(t *testing.T) {
autoConn()
query(se, t)
}
func TestSqliteExec(t *testing.T) {
autoConn()
exec(se, t)
}
func TestSqliteInsertAutoIncr(t *testing.T) {
autoConn()
insertAutoIncr(se, t)
}
func TestInsertMulti(t *testing.T) {
autoConn()
insertMulti(se, t)
}
func TestSqliteInsertMulti(t *testing.T) {
autoConn()
insertMulti(se, t)
}
func TestSqliteInsertTwoTable(t *testing.T) {
autoConn()
insertTwoTable(se, t)
}
func TestSqliteUpdate(t *testing.T) {
autoConn()
update(se, t)
}
func TestSqliteDelete(t *testing.T) {
autoConn()
testdelete(se, t)
}
func TestSqliteGet(t *testing.T) {
autoConn()
get(se, t)
}
func TestSqliteCascadeGet(t *testing.T) {
autoConn()
cascadeGet(se, t)
}
func TestSqliteFind(t *testing.T) {
autoConn()
find(se, t)
}
func TestSqliteFindMap(t *testing.T) {
autoConn()
findMap(se, t)
}
func TestSqliteCount(t *testing.T) {
autoConn()
count(se, t)
}
func TestSqliteWhere(t *testing.T) {
autoConn()
where(se, t)
}
func TestSqliteIn(t *testing.T) {
autoConn()
in(se, t)
}
func TestSqliteLimit(t *testing.T) {
autoConn()
limit(se, t)
}
func TestSqliteOrder(t *testing.T) {
autoConn()
order(se, t)
}
func TestSqliteJoin(t *testing.T) {
autoConn()
join(se, t)
}
func TestSqliteHaving(t *testing.T) {
autoConn()
having(se, t)
}
func TestSqliteTransaction(t *testing.T) {
autoConn()
transaction(se, t)
}
func TestSqliteCombineTransaction(t *testing.T) {
autoConn()
combineTransaction(se, t)
}
func TestSqliteTable(t *testing.T) {
autoConn()
table(se, t)
}
func TestSqliteCreateMultiTables(t *testing.T) {
autoConn()
createMultiTables(se, t)
}
func TestSqliteTableOp(t *testing.T) {
autoConn()
tableOp(se, t)
} }

View File

@ -27,6 +27,8 @@ type Statement struct {
GroupByStr string GroupByStr string
HavingStr string HavingStr string
ColumnStr string ColumnStr string
columnMap map[string]bool
ConditionStr string
AltTableName string AltTableName string
RawSQL string RawSQL string
RawParams []interface{} RawParams []interface{}
@ -57,6 +59,8 @@ func (statement *Statement) Init() {
statement.GroupByStr = "" statement.GroupByStr = ""
statement.HavingStr = "" statement.HavingStr = ""
statement.ColumnStr = "" statement.ColumnStr = ""
statement.columnMap = make(map[string]bool)
statement.ConditionStr = ""
statement.AltTableName = "" statement.AltTableName = ""
statement.RawSQL = "" statement.RawSQL = ""
statement.RawParams = make([]interface{}, 0) statement.RawParams = make([]interface{}, 0)
@ -116,8 +120,7 @@ func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string,
} else { } else {
args = append(args, val) args = append(args, val)
} }
colNames = append(colNames, fmt.Sprintf("%v%v%v = ?", engine.QuoteIdentifier(), colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
col.Name, engine.QuoteIdentifier()))
} }
return colNames, args return colNames, args
@ -155,6 +158,13 @@ func (statement *Statement) In(column string, args ...interface{}) {
} }
} }
func (statement *Statement) Cols(columns ...string) {
statement.ColumnStr = strings.Join(columns, statement.Engine.Quote(", "))
for _, column := range columns {
statement.columnMap[column] = true
}
}
func (statement *Statement) Limit(limit int, start ...int) { func (statement *Statement) Limit(limit int, start ...int) {
statement.LimitN = limit statement.LimitN = limit
if len(start) > 0 { if len(start) > 0 {
@ -176,65 +186,36 @@ func (statement *Statement) Join(join_operator, tablename, condition string) {
} }
func (statement *Statement) GroupBy(keys string) { func (statement *Statement) GroupBy(keys string) {
statement.GroupByStr = fmt.Sprintf("GROUP BY %v", keys) statement.GroupByStr = keys
} }
func (statement *Statement) Having(conditions string) { func (statement *Statement) Having(conditions string) {
statement.HavingStr = fmt.Sprintf("HAVING %v", conditions) statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
} }
func (statement *Statement) genColumnStr(col *Column) string { func (statement *Statement) genColumnStr() string {
sql := "`" + col.Name + "` "
sql += statement.Engine.Dialect.SqlType(col) + " "
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
}
if col.IsAutoIncrement {
sql += statement.Engine.AutoIncrIdentifier() + " "
}
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
/*if col.UniqueType == SINGLEUNIQUE {
sql += "UNIQUE "
}*/
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
return sql
}
func (statement *Statement) selectColumnStr() string {
table := statement.RefTable table := statement.RefTable
colNames := make([]string, 0) colNames := make([]string, 0)
for _, col := range table.Columns { for _, col := range table.Columns {
if col.MapType != ONLYTODB { if col.MapType != ONLYTODB {
colNames = append(colNames, statement.TableName()+"."+col.Name) colNames = append(colNames, statement.Engine.Quote(statement.TableName())+"."+statement.Engine.Quote(col.Name))
} }
} }
return strings.Join(colNames, ", ") return strings.Join(colNames, ", ")
} }
func (statement *Statement) genCreateSQL() string { func (statement *Statement) genCreateSQL() string {
sql := "CREATE TABLE IF NOT EXISTS `" + statement.TableName() + "` (" sql := "CREATE TABLE IF NOT EXISTS " + statement.Engine.Quote(statement.TableName()) + " ("
for _, col := range statement.RefTable.Columns { for _, col := range statement.RefTable.Columns {
sql += statement.genColumnStr(&col) sql += col.String(statement.Engine)
sql = strings.TrimSpace(sql) sql = strings.TrimSpace(sql)
sql += ", " sql += ", "
} }
sql = sql[:len(sql)-2] + ")" sql = sql[:len(sql)-2] + ")"
if statement.StoreEngine != "" { if statement.Engine.Dialect.SupportEngine() && statement.StoreEngine != "" {
sql += " ENGINE=" + statement.StoreEngine sql += " ENGINE=" + statement.StoreEngine
} }
if statement.Charset != "" { if statement.Engine.Dialect.SupportCharset() && statement.Charset != "" {
sql += " DEFAULT CHARSET " + statement.Charset sql += " DEFAULT CHARSET " + statement.Charset
} }
sql += ";" sql += ";"
@ -262,24 +243,24 @@ func (statement *Statement) genUniqueSQL() []string {
} }
func (statement *Statement) genDropSQL() string { func (statement *Statement) genDropSQL() string {
sql := "DROP TABLE IF EXISTS `" + statement.TableName() + "`;" sql := "DROP TABLE IF EXISTS " + statement.Engine.Quote(statement.TableName()) + ";"
return sql return sql
} }
func (statement Statement) generateSql() string {
columnStr := statement.selectColumnStr()
return statement.genSelectSql(columnStr)
}
func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) { func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) {
table := statement.Engine.AutoMap(bean) table := statement.Engine.AutoMap(bean)
statement.RefTable = table statement.RefTable = table
colNames, args := BuildConditions(statement.Engine, table, bean) colNames, args := BuildConditions(statement.Engine, table, bean)
statement.ColumnStr = strings.Join(colNames, " and ") statement.ConditionStr = strings.Join(colNames, " and ")
statement.BeanArgs = args statement.BeanArgs = args
return statement.generateSql(), append(statement.Params, statement.BeanArgs...) var columnStr string = statement.ColumnStr
if columnStr == "" {
columnStr = statement.genColumnStr()
}
return statement.genSelectSql(columnStr), append(statement.Params, statement.BeanArgs...)
} }
func (statement Statement) genCountSql(bean interface{}) (string, []interface{}) { func (statement Statement) genCountSql(bean interface{}) (string, []interface{}) {
@ -287,86 +268,31 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{})
statement.RefTable = table statement.RefTable = table
colNames, args := BuildConditions(statement.Engine, table, bean) colNames, args := BuildConditions(statement.Engine, table, bean)
statement.ColumnStr = strings.Join(colNames, " and ") statement.ConditionStr = strings.Join(colNames, " and ")
statement.BeanArgs = args statement.BeanArgs = args
return statement.genSelectSql("count(*) as total"), append(statement.Params, statement.BeanArgs...) return statement.genSelectSql(fmt.Sprintf("count(*) as %v", statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...)
} }
func (statement Statement) genSelectSql(columnStr string) (a string) { func (statement Statement) genSelectSql(columnStr string) (a string) {
if statement.Engine.DriverName == MSSQL {
if statement.Start > 0 {
a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v",
statement.RefTable.PKColumn().Name,
columnStr,
statement.TableName())
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
}
a = fmt.Sprintf("select %v from (%v) "+
"as a where rownum between %v and %v",
columnStr,
a,
statement.Start,
statement.LimitN)
} else if statement.LimitN > 0 {
a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitN, columnStr, statement.TableName())
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
}
if statement.GroupByStr != "" { if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr) columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
statement.GroupByStr = columnStr
} }
if statement.HavingStr != "" { a = fmt.Sprintf("SELECT %v FROM %v", columnStr,
a = fmt.Sprintf("%v %v", a, statement.HavingStr) statement.Engine.Quote(statement.TableName()))
}
if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
}
} else {
a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.TableName())
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
}
if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
}
if statement.HavingStr != "" {
a = fmt.Sprintf("%v %v", a, statement.HavingStr)
}
if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
}
}
} else {
a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.TableName())
if statement.JoinStr != "" { if statement.JoinStr != "" {
a = fmt.Sprintf("%v %v", a, statement.JoinStr) a = fmt.Sprintf("%v %v", a, statement.JoinStr)
} }
if statement.WhereStr != "" { if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" { if statement.ConditionStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr) a = fmt.Sprintf("%v and %v", a, statement.ConditionStr)
} }
} else if statement.ColumnStr != "" { } else if statement.ConditionStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr) a = fmt.Sprintf("%v WHERE %v", a, statement.ConditionStr)
} }
if statement.GroupByStr != "" { if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr) a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
} }
if statement.HavingStr != "" { if statement.HavingStr != "" {
a = fmt.Sprintf("%v %v", a, statement.HavingStr) a = fmt.Sprintf("%v %v", a, statement.HavingStr)
@ -375,10 +301,9 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
} }
if statement.Start > 0 { if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.Start, statement.LimitN) a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
} else if statement.LimitN > 0 { } else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
} }
}
return return
} }

View File

@ -21,21 +21,46 @@ type SQLType struct {
} }
var ( var (
Bit = SQLType{"BIT", 0, 0}
TinyInt = SQLType{"TINYINT", 0, 0} TinyInt = SQLType{"TINYINT", 0, 0}
SmallInt = SQLType{"SMALLINT", 0, 0} SmallInt = SQLType{"SMALLINT", 0, 0}
MediumInt = SQLType{"MEDIUMINT", 0, 0} MediumInt = SQLType{"MEDIUMINT", 0, 0}
Int = SQLType{"INT", 11, 0} Int = SQLType{"INT", 0, 0}
Integer = SQLType{"INTEGER", 0, 0}
BigInt = SQLType{"BIGINT", 0, 0} BigInt = SQLType{"BIGINT", 0, 0}
Char = SQLType{"CHAR", 1, 0}
Char = SQLType{"CHAR", 0, 0}
Varchar = SQLType{"VARCHAR", 64, 0} Varchar = SQLType{"VARCHAR", 64, 0}
Text = SQLType{"TEXT", 16, 0} TinyText = SQLType{"TINYTEXT", 0, 0}
Date = SQLType{"DATE", 24, 0} Text = SQLType{"TEXT", 0, 0}
MediumText = SQLType{"MEDIUMTEXT", 0, 0}
LongText = SQLType{"LONGTEXT", 0, 0}
Binary = SQLType{"BINARY", 0, 0}
VarBinary = SQLType{"VARBINARY", 0, 0}
Date = SQLType{"DATE", 0, 0}
DateTime = SQLType{"DATETIME", 0, 0} DateTime = SQLType{"DATETIME", 0, 0}
Decimal = SQLType{"DECIMAL", 26, 2} Time = SQLType{"TIME", 0, 0}
Float = SQLType{"FLOAT", 31, 0}
Double = SQLType{"DOUBLE", 31, 0}
Blob = SQLType{"BLOB", 0, 0}
TimeStamp = SQLType{"TIMESTAMP", 0, 0} TimeStamp = SQLType{"TIMESTAMP", 0, 0}
Decimal = SQLType{"DECIMAL", 26, 2}
Numeric = SQLType{"NUMERIC", 0, 0}
Real = SQLType{"REAL", 0, 0}
Float = SQLType{"FLOAT", 0, 0}
Double = SQLType{"DOUBLE", 0, 0}
//Money = SQLType{"MONEY", 0, 0}
TinyBlob = SQLType{"TINYBLOB", 0, 0}
Blob = SQLType{"BLOB", 0, 0}
MediumBlob = SQLType{"MEDIUMBLOB", 0, 0}
LongBlob = SQLType{"LONGBLOB", 0, 0}
Bytea = SQLType{"BYTEA", 0, 0}
Bool = SQLType{"BOOL", 0, 0}
Serial = SQLType{"SERIAL", 0, 0}
BigSerial = SQLType{"BIGSERIAL", 0, 0}
) )
var b byte var b byte
@ -106,6 +131,31 @@ type Column struct {
MapType int MapType int
} }
func (col *Column) String(engine *Engine) string {
sql := engine.Quote(col.Name) + " "
sql += engine.SqlType(col) + " "
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
}
if col.IsAutoIncrement {
sql += engine.AutoIncrStr() + " "
}
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
return sql
}
type Table struct { type Table struct {
Name string Name string
Type reflect.Type Type reflect.Type

16
xorm.go
View File

@ -8,18 +8,19 @@
package xorm package xorm
import ( import (
//"database/sql"
"errors" "errors"
"fmt" "fmt"
"os"
"reflect" "reflect"
"sync" "sync"
//"time"
) )
const ( const (
version string = "0.1.8" version string = "0.1.9"
) )
// new a db manager according to the parameter. Currently support three
// driver
func NewEngine(driverName string, dataSourceName string) (*Engine, error) { func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine := &Engine{ShowSQL: false, DriverName: driverName, Mapper: SnakeMapper{}, engine := &Engine{ShowSQL: false, DriverName: driverName, Mapper: SnakeMapper{},
DataSourceName: dataSourceName} DataSourceName: dataSourceName}
@ -27,13 +28,22 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine.Tables = make(map[reflect.Type]*Table) engine.Tables = make(map[reflect.Type]*Table)
engine.mutex = &sync.Mutex{} engine.mutex = &sync.Mutex{}
engine.TagIdentifier = "xorm" engine.TagIdentifier = "xorm"
engine.Filters = make([]Filter, 0)
if driverName == SQLITE { if driverName == SQLITE {
engine.Dialect = &sqlite3{} engine.Dialect = &sqlite3{}
} else if driverName == MYSQL { } else if driverName == MYSQL {
engine.Dialect = &mysql{} engine.Dialect = &mysql{}
} else if driverName == POSTGRES {
engine.Dialect = &postgres{}
engine.Filters = append(engine.Filters, &PgSeqFilter{})
engine.Filters = append(engine.Filters, &PgQuoteFilter{})
} else if driverName == MYMYSQL {
engine.Dialect = &mysql{}
} else { } else {
return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName)) return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName))
} }
engine.Filters = append(engine.Filters, &IdFilter{})
engine.Logger = os.Stdout
//engine.Pool = NewSimpleConnectPool() //engine.Pool = NewSimpleConnectPool()
//engine.Pool = NewNoneConnectPool() //engine.Pool = NewNoneConnectPool()