diff --git a/COLUMNTYPE.md b/COLUMNTYPE.md
new file mode 100644
index 00000000..8cded9aa
--- /dev/null
+++ b/COLUMNTYPE.md
@@ -0,0 +1,425 @@
+
+
+ xorm
+ |
+ mysql
+ |
+ sqlite3
+ |
+ postgres
+ |
+ remark |
+
+
+
+ BIT
+ |
+ BIT
+ |
+ INTEGER
+ |
+ BIT
+ |
+ |
+
+
+
+ TINYINT
+ |
+ TINYINT
+ |
+ INTEGER
+ |
+ SMALLINT
+ |
+ |
+
+
+
+
+ SMALLINT
+ |
+ SMALLINT
+ |
+ INTEGER
+ |
+ SMALLINT
+ |
+ |
+
+
+
+
+ MEDIUMINT
+ |
+ MEDIUMINT
+ |
+ INTEGER
+ |
+ INTEGER
+ |
+ |
+
+
+
+
+ INT
+ |
+ INT
+ |
+ INTEGER
+ |
+ INTEGER
+ |
+ |
+
+
+
+ INTEGER
+ |
+ INTEGER
+ |
+ INTEGER
+ |
+ INTEGER
+ |
+ |
+
+
+
+
+ BIGINT
+ |
+ BIGINT
+ |
+ INTEGER
+ |
+ BIGINT
+ |
+ |
+
+
+
+
+ CHAR
+ |
+ CHAR
+ |
+ TEXT
+ |
+ CHAR
+ |
+ |
+
+
+
+
+ VARCHAR
+ |
+ VARCHAR
+ |
+ TEXT
+ |
+ VARCHAR
+ |
+ |
+
+
+
+
+ TINYTEXT
+ |
+ TINYTEXT
+ |
+ TEXT
+ |
+ TEXT
+ |
+ |
+
+
+
+ TEXT
+ |
+ TEXT
+ |
+ TEXT
+ |
+ TEXT
+ |
+ |
+
+
+
+ MEDIUMTEXT
+ |
+ MEDIUMTEXT
+ |
+ TEXT
+ |
+ TEXT
+ |
+ |
+
+
+
+
+ LONGTEXT
+ |
+ LONGTEXT
+ |
+ TEXT
+ |
+ TEXT
+ |
+ |
+
+
+
+
+ BINARY
+ |
+ BINARY
+ |
+ BLOB
+ |
+ BYTEA
+ |
+ |
+
+
+
+
+ VARBINARY
+ |
+ VARBINARY
+ |
+ BLOB
+ |
+ BYTEA
+ |
+ |
+
+
+
+
+ DATE
+ |
+ DATE
+ |
+ NUMERIC
+ |
+ DATE
+ |
+ |
+
+
+
+
+ DATETIME
+ |
+ DATETIME
+ |
+ NUMERIC
+ |
+ TIMESTAMP
+ |
+ |
+
+
+
+
+ TIME
+ |
+ TIME
+ |
+ NUMERIC
+ |
+ TIME
+ |
+ |
+
+
+
+
+ TIMESTAMP
+ |
+ TIMESTAMP
+ |
+ NUMERIC
+ |
+ TIMESTAMP
+ |
+ |
+
+
+
+
+ REAL
+ |
+ REAL
+ |
+ REAL
+ |
+ REAL
+ |
+ |
+
+
+
+
+ FLOAT
+ |
+ FLOAT
+ |
+ REAL
+ |
+ REAL
+ |
+ |
+
+
+
+
+ DOUBLE
+ |
+ DOUBLE
+ |
+ REAL
+ |
+ DOUBLE PRECISION
+ |
+ |
+
+
+
+
+ DECIMAL
+ |
+ DECIMAL
+ |
+ NUMERIC
+ |
+ DECIMAL
+ |
+ |
+
+
+
+
+ NUMERIC
+ |
+ NUMERIC
+ |
+ NUMERIC
+ |
+ NUMERIC
+ |
+ |
+
+
+
+
+ TINYBLOB
+ |
+ TINYBLOB
+ |
+ BLOB
+ |
+ BYTEA
+ |
+ |
+
+
+
+
+ BLOB
+ |
+ BLOB
+ |
+ BLOB
+ |
+ BYTEA
+ |
+ |
+
+
+
+
+ MEDIUMBLOB
+ |
+ MEDIUMBLOB
+ |
+ BLOB
+ |
+ BYTEA
+ |
+ |
+
+
+
+
+ LONGBLOB
+ |
+ LONGBLOB
+ |
+ BLOB
+ |
+ BYTEA
+ |
+ |
+
+
+
+ BYTEA
+ |
+ BLOB
+ |
+ BLOB
+ |
+ BYTEA
+ |
+ |
+
+
+
+
+
+ BOOL
+ |
+ TINYINT
+ |
+ INTEGER
+ |
+ BOOLEAN
+ |
+ |
+
+
+
+
+ SERIAL
+ |
+ INT
+ |
+ INTEGER
+ |
+ SERIAL
+ |
+ auto increment |
+
+
+
+ BIGSERIAL
+ |
+ BIGINT
+ |
+ INTEGER
+ |
+ BIGSERIAL
+ |
+ auto increment |
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
index d85e445b..c0ec4d5a 100644
--- a/README.md
+++ b/README.md
@@ -13,10 +13,16 @@ Drivers for Go's sql package which currently support database/sql includes:
* Mysql: [github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL)
+* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
+
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
+* Postgres: [github.com/bylevel/pg](https://github.com/bylevel/pg)
+
+
## Changelog
+* **v0.1.9** : Added postgres and mymysql supported; Added ` and ? supported on Raw SQL even if postgres; Added Cols, StoreEngine, Charset function, Added many column data type supported, please see [Mapping Rules](#mapping).
* **v0.1.8** : Added union index and union unique supported, please see [Mapping Rules](#mapping).
* **v0.1.7** : Added IConnectPool interface and NoneConnectPool, SysConnectPool, SimpleConnectPool the three implements. You can choose one of them and the default is SysConnectPool. You can customrize your own connection pool. struct Engine added Close method, It should be invoked before system exit.
* **v0.1.6** : Added conversion interface support; added struct derive support; added single mapping support
@@ -157,6 +163,13 @@ var tenusers []Userinfo
err := engine.In("id", 1, 3, 5).Find(&tenusers) //Get All id in (1, 3, 5)
```
+6.4 The default will query all columns of a table. Use Cols function if you want to select some columns
+
+```Go
+var tenusers []Userinfo
+err := engine.Cols("id", "name").Find(&tenusers) //Find only id and name
+```
+
7.Delete
```Go
@@ -311,7 +324,7 @@ Another is use field tag, field tag support the below keywords which split with
pk | the field is a primary key |
- int(11)/varchar(50)/text/date/datetime/blob/decimal(26,2) | column type |
+ more than 30 column type supported, please see [Column Type](https://github.com/lunny/xorm/blob/master/COLUMNTYPE.md) | column type |
autoincr | auto incrment |
diff --git a/README_CN.md b/README_CN.md
index 7251f024..d420d332 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -12,11 +12,16 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* Mysql: [github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL)
+* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
+
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
+* Postgres: [github.com/bylevel/pg](https://github.com/bylevel/pg)
+
## 更新日志
-* **v0.1.8** : 新增联合index,联合unique支持,请查看[映射规则](#mapping)。
+* **v0.1.9** : 新增 postgres 和 mymysql 驱动支持; 在Postgres中支持原始SQL语句中使用 ` 和 ? 符号; 新增Cols, StoreEngine, Charset 函数;SQL语句打印支持io.Writer接口,默认打印到控制台;新增更多的字段类型支持,详见 [映射规则](#mapping);删除废弃的MakeSession和Create函数。
+* **v0.1.8** : 新增联合index,联合unique支持,请查看 [映射规则](#mapping)。
* **v0.1.7** : 新增IConnectPool接口以及NoneConnectPool, SysConnectPool, SimpleConnectPool三种实现,可以选择不使用连接池,使用系统连接池和使用自带连接池三种实现,默认为SysConnectPool,即系统自带的连接池。同时支持自定义连接池。Engine新增Close方法,在系统退出时应调用此方法。
* **v0.1.6** : 新增Conversion,支持自定义类型到数据库类型的转换;新增查询结构体自动检测匿名成员支持;新增单向映射支持;
* **v0.1.5** : 新增对多线程的支持;新增Sql()函数;支持任意sql语句的struct查询;Get函数返回值变动;MakeSession和Create函数被NewSession和NewEngine函数替代;
@@ -36,7 +41,7 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* 使用连写来简化调用
-* 支持使用Id, In, Where, Limit, Join, Having, Table, Sql等函数和结构体等方式作为条件
+* 支持使用Id, In, Where, Limit, Join, Having, Table, Sql, Cols等函数和结构体等方式作为条件
* 支持数据库连接池
@@ -82,6 +87,11 @@ Sqlite:
engine.ShowSQL = true
```
+默认打印会打印到控制台,如果需要打印到文件,只需要设置一个符合io.writer接口的struct即可:
+```Go
+engine.Logger = [io.Writer]
+```
+
1.2.如果要更换连接池实现,可使用SetPool方法
```Go
err = engine.SetPool(NewSimpleConnectPool())
@@ -142,7 +152,7 @@ users := make(map[int64]Userinfo)
err := engine.Find(&users)
```
-6.1 你也可以使用Where和Limit方法设定条件和查询数量
+6.1 你也可以使用Where和Limit方法设定条件和查询数量,Limit参数可为1个到2个,第一个参数为查询条数,第二个参数为开始条数。
```Go
var allusers []Userinfo
@@ -163,6 +173,13 @@ var tenusers []Userinfo
err := engine.In("id", 1, 3, 5).Find(&tenusers) //Get All id in (1, 3, 5)
```
+6.4 默认将查询出所有字段,如果要指定字段,则可以调用Cols函数
+
+```Go
+var tenusers []Userinfo
+err := engine.Cols("id", "name").Find(&tenusers) //Find only id and name
+```
+
7.Delete方法
```Go
@@ -317,7 +334,7 @@ UserInfo中的成员UserName将会自动对应名为user_name的字段。
pk | 是否是Primary Key,当前仅支持int64类型 |
- int(11)/varchar(50)/text/date/datetime/blob/decimal(26,2) | 字段类型 |
+ 当前支持30多种字段类型,详情参见 [字段类型](https://github.com/lunny/xorm/blob/master/COLUMNTYPE.md) | 字段类型 |
autoincr | 是否是自增 |
diff --git a/VERSION b/VERSION
index 56fecb09..a992b005 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-xorm v0.1.7
+xorm v0.1.9
diff --git a/testbase.go b/base_test.go
similarity index 74%
rename from testbase.go
rename to base_test.go
index 6ae6edef..75575292 100644
--- a/testbase.go
+++ b/base_test.go
@@ -8,14 +8,14 @@ import (
/*
CREATE TABLE `userinfo` (
- `uid` INT(10) NULL AUTO_INCREMENT,
+ `id` INT(10) NULL AUTO_INCREMENT,
`username` VARCHAR(64) NULL,
`departname` VARCHAR(64) NULL,
`created` DATE NULL,
PRIMARY KEY (`uid`)
);
CREATE TABLE `userdeatail` (
- `uid` INT(10) NULL,
+ `id` INT(10) NULL,
`intro` TEXT NULL,
`profile` TEXT NULL,
PRIMARY KEY (`uid`)
@@ -41,15 +41,16 @@ type Userdetail struct {
}
func directCreateTable(engine *Engine, t *testing.T) {
- err := engine.DropTables(&Userinfo{})
+ err := engine.DropTables(&Userinfo{}, &Userdetail{})
if err != nil {
t.Error(err)
- return
+ panic(err)
}
err = engine.CreateTables(&Userinfo{})
if err != nil {
t.Error(err)
+ panic(err)
}
}
@@ -57,30 +58,36 @@ func mapper(engine *Engine, t *testing.T) {
err := engine.UnMap(&Userinfo{})
if err != nil {
t.Error(err)
+ panic(err)
}
err = engine.Map(&Userinfo{}, &Userdetail{})
if err != nil {
t.Error(err)
+ panic(err)
}
err = engine.DropAll()
if err != nil {
t.Error(err)
+ panic(err)
}
err = engine.CreateAll()
if err != nil {
t.Error(err)
+ panic(err)
}
}
func insert(engine *Engine, t *testing.T) {
- user := Userinfo{1, "xiaolunwen", "dev", "lunny", time.Now(),
+ user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(),
Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true}
_, err := engine.Insert(&user)
+ fmt.Println(user.Uid)
if err != nil {
t.Error(err)
+ panic(err)
}
}
@@ -89,6 +96,7 @@ func query(engine *Engine, t *testing.T) {
results, err := engine.Query(sql)
if err != nil {
t.Error(err)
+ panic(err)
}
fmt.Println(results)
}
@@ -98,6 +106,7 @@ func exec(engine *Engine, t *testing.T) {
res, err := engine.Exec(sql, "xiaolun", 1)
if err != nil {
t.Error(err)
+ panic(err)
}
fmt.Println(res)
}
@@ -107,8 +116,10 @@ func insertAutoIncr(engine *Engine, t *testing.T) {
user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now(),
Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true}
_, err := engine.Insert(&user)
+ fmt.Println(user.Uid)
if err != nil {
t.Error(err)
+ panic(err)
}
}
@@ -123,22 +134,8 @@ func insertMulti(engine *Engine, t *testing.T) {
_, err := engine.Insert(&users)
if err != nil {
t.Error(err)
+ panic(err)
}
-
- /*engine.InsertMany = false
-
- users = []Userinfo{
- {Username: "xlw9", Departname: "dev", Alias: "lunny9", Created: time.Now()},
- {Username: "xlw10", Departname: "dev", Alias: "lunny10", Created: time.Now()},
- {Username: "xlw99", Departname: "dev", Alias: "lunny2", Created: time.Now()},
- {Username: "xlw1010", Departname: "dev", Alias: "lunny3", Created: time.Now()},
- }
- _, err = engine.Insert(&users)
- if err != nil {
- t.Error(err)
- }
-
- engine.InsertMany = true*/
}
func insertTwoTable(engine *Engine, t *testing.T) {
@@ -148,6 +145,7 @@ func insertTwoTable(engine *Engine, t *testing.T) {
_, err := engine.Insert(&userinfo, &userdetail)
if err != nil {
t.Error(err)
+ panic(err)
}
}
@@ -157,12 +155,13 @@ func update(engine *Engine, t *testing.T) {
_, err := engine.Id(1).Update(&user)
if err != nil {
t.Error(err)
- return
+ panic(err)
}
_, err = engine.Update(&Userinfo{Username: "yyy"}, &user)
if err != nil {
t.Error(err)
+ panic(err)
}
}
@@ -171,6 +170,7 @@ func testdelete(engine *Engine, t *testing.T) {
_, err := engine.Delete(&user)
if err != nil {
t.Error(err)
+ panic(err)
}
}
@@ -180,6 +180,7 @@ func get(engine *Engine, t *testing.T) {
has, err := engine.Get(&user)
if err != nil {
t.Error(err)
+ panic(err)
}
if has {
fmt.Println(user)
@@ -194,6 +195,7 @@ func cascadeGet(engine *Engine, t *testing.T) {
has, err := engine.Get(&user)
if err != nil {
t.Error(err)
+ panic(err)
}
if has {
fmt.Println(user)
@@ -208,6 +210,7 @@ func find(engine *Engine, t *testing.T) {
err := engine.Find(&users)
if err != nil {
t.Error(err)
+ panic(err)
}
fmt.Println(users)
}
@@ -218,6 +221,7 @@ func findMap(engine *Engine, t *testing.T) {
err := engine.Find(&users)
if err != nil {
t.Error(err)
+ panic(err)
}
fmt.Println(users)
}
@@ -227,6 +231,7 @@ func count(engine *Engine, t *testing.T) {
total, err := engine.Count(&user)
if err != nil {
t.Error(err)
+ panic(err)
}
fmt.Printf("Total %d records!!!", total)
}
@@ -236,6 +241,7 @@ func where(engine *Engine, t *testing.T) {
err := engine.Where("id > ?", 2).Find(&users)
if err != nil {
t.Error(err)
+ panic(err)
}
fmt.Println(users)
}
@@ -245,7 +251,7 @@ func in(engine *Engine, t *testing.T) {
err := engine.In("id", 1, 2, 3).Find(&users)
if err != nil {
t.Error(err)
- return
+ panic(err)
}
fmt.Println(users)
@@ -253,7 +259,7 @@ func in(engine *Engine, t *testing.T) {
err = engine.Where("id > ?", 2).In("id", ids...).Find(&users)
if err != nil {
t.Error(err)
- return
+ panic(err)
}
fmt.Println(users)
}
@@ -263,6 +269,7 @@ func limit(engine *Engine, t *testing.T) {
err := engine.Limit(2, 1).Find(&users)
if err != nil {
t.Error(err)
+ panic(err)
}
fmt.Println(users)
}
@@ -272,6 +279,7 @@ func order(engine *Engine, t *testing.T) {
err := engine.OrderBy("id desc").Find(&users)
if err != nil {
t.Error(err)
+ panic(err)
}
fmt.Println(users)
}
@@ -281,6 +289,7 @@ func join(engine *Engine, t *testing.T) {
err := engine.Join("LEFT", "userdetail", "userinfo.id=userdetail.id").Find(&users)
if err != nil {
t.Error(err)
+ panic(err)
}
}
@@ -289,6 +298,7 @@ func having(engine *Engine, t *testing.T) {
err := engine.GroupBy("username").Having("username='xlw'").Find(&users)
if err != nil {
t.Error(err)
+ panic(err)
}
fmt.Println(users)
}
@@ -310,7 +320,7 @@ func transaction(engine *Engine, t *testing.T) {
err := session.Begin()
if err != nil {
t.Error(err)
- return
+ panic(err)
}
//session.IsAutoRollback = false
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
@@ -318,7 +328,7 @@ func transaction(engine *Engine, t *testing.T) {
if err != nil {
session.Rollback()
t.Error(err)
- return
+ panic(err)
}
user2 := Userinfo{Username: "yyy"}
_, err = session.Where("uid = ?", 0).Update(&user2)
@@ -333,14 +343,15 @@ func transaction(engine *Engine, t *testing.T) {
if err != nil {
session.Rollback()
t.Error(err)
- return
+ panic(err)
}
err = session.Commit()
if err != nil {
t.Error(err)
- return
+ panic(err)
}
+ panic(err)
}
func combineTransaction(engine *Engine, t *testing.T) {
@@ -360,7 +371,7 @@ func combineTransaction(engine *Engine, t *testing.T) {
err := session.Begin()
if err != nil {
t.Error(err)
- return
+ panic(err)
}
//session.IsAutoRollback = false
user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
@@ -368,32 +379,42 @@ func combineTransaction(engine *Engine, t *testing.T) {
if err != nil {
session.Rollback()
t.Error(err)
- return
+ panic(err)
}
user2 := Userinfo{Username: "zzz"}
_, err = session.Where("id = ?", 0).Update(&user2)
if err != nil {
session.Rollback()
t.Error(err)
- return
+ panic(err)
}
_, err = session.Exec("delete from userinfo where username = ?", user2.Username)
if err != nil {
session.Rollback()
t.Error(err)
- return
+ panic(err)
}
err = session.Commit()
if err != nil {
t.Error(err)
- return
+ panic(err)
}
}
func table(engine *Engine, t *testing.T) {
- engine.Table("user_user").CreateTable(&Userinfo{})
+ err := engine.DropTables("user_user")
+ if err != nil {
+ t.Error(err)
+ panic(err)
+ }
+
+ err = engine.Table("user_user").CreateTable(&Userinfo{})
+ if err != nil {
+ t.Error(err)
+ panic(err)
+ }
}
func createMultiTables(engine *Engine, t *testing.T) {
@@ -404,19 +425,29 @@ func createMultiTables(engine *Engine, t *testing.T) {
err := session.Begin()
if err != nil {
t.Error(err)
- return
+ panic(err)
}
for i := 0; i < 10; i++ {
- err = session.Table(fmt.Sprintf("user_%v", i)).CreateTable(user)
+ tableName := fmt.Sprintf("user_%v", i)
+
+ err = engine.DropTables(tableName)
if err != nil {
session.Rollback()
t.Error(err)
- return
+ panic(err)
+ }
+
+ err = session.Table(tableName).CreateTable(user)
+ if err != nil {
+ session.Rollback()
+ t.Error(err)
+ panic(err)
}
}
err = session.Commit()
if err != nil {
t.Error(err)
+ panic(err)
}
}
@@ -426,26 +457,118 @@ func tableOp(engine *Engine, t *testing.T) {
id, err := engine.Table(tableName).Insert(&user)
if err != nil {
t.Error(err)
+ panic(err)
}
_, err = engine.Table(tableName).Get(&Userinfo{Username: "tablexiao"})
if err != nil {
t.Error(err)
+ panic(err)
}
users := make([]Userinfo, 0)
err = engine.Table(tableName).Find(&users)
if err != nil {
t.Error(err)
+ panic(err)
}
_, err = engine.Table(tableName).Id(id).Update(&Userinfo{Username: "tableda"})
if err != nil {
t.Error(err)
+ panic(err)
}
_, err = engine.Table(tableName).Id(id).Delete(&Userinfo{})
if err != nil {
t.Error(err)
+ panic(err)
}
}
+
+func testCharst(engine *Engine, t *testing.T) {
+ err := engine.DropTables("user_charset")
+ if err != nil {
+ t.Error(err)
+ panic(err)
+ }
+
+ err = engine.Charset("utf8").Table("user_charset").CreateTable(&Userinfo{})
+ if err != nil {
+ t.Error(err)
+ panic(err)
+ }
+}
+
+func testStoreEngine(engine *Engine, t *testing.T) {
+ err := engine.DropTables("user_store_engine")
+ if err != nil {
+ t.Error(err)
+ panic(err)
+ }
+
+ err = engine.StoreEngine("InnoDB").Table("user_store_engine").CreateTable(&Userinfo{})
+ if err != nil {
+ t.Error(err)
+ panic(err)
+ }
+}
+
+type tempUser struct {
+ Id int64
+ Username string
+}
+
+func testCols(engine *Engine, t *testing.T) {
+ users := []Userinfo{}
+ err := engine.Cols("id, username").Find(&users)
+ if err != nil {
+ t.Error(err)
+ panic(err)
+ }
+
+ fmt.Println(users)
+
+ tmpUsers := []tempUser{}
+ err = engine.Table("userinfo").Cols("id, username").Find(&tmpUsers)
+ if err != nil {
+ t.Error(err)
+ panic(err)
+ }
+ fmt.Println(tmpUsers)
+}
+
+func testTrans(engine *Engine, t *testing.T) {
+}
+
+func testAll(engine *Engine, t *testing.T) {
+ directCreateTable(engine, t)
+ mapper(engine, t)
+ insert(engine, t)
+ query(engine, t)
+ exec(engine, t)
+ insertAutoIncr(engine, t)
+ insertMulti(engine, t)
+ insertTwoTable(engine, t)
+ update(engine, t)
+ testdelete(engine, t)
+ get(engine, t)
+ cascadeGet(engine, t)
+ find(engine, t)
+ findMap(engine, t)
+ count(engine, t)
+ where(engine, t)
+ in(engine, t)
+ limit(engine, t)
+ order(engine, t)
+ join(engine, t)
+ having(engine, t)
+ transaction(engine, t)
+ combineTransaction(engine, t)
+ table(engine, t)
+ createMultiTables(engine, t)
+ tableOp(engine, t)
+ testCols(engine, t)
+ testCharst(engine, t)
+ testStoreEngine(engine, t)
+}
diff --git a/deprecated.go b/deprecated.go
deleted file mode 100644
index b9a8003c..00000000
--- a/deprecated.go
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright 2013 The XORM Authors. All rights reserved.
-// Use of this source code is governed by a BSD
-// license that can be found in the LICENSE file.
-
-// Package xorm provides is a simple and powerful ORM for Go. It makes your
-// database operation simple.
-
-// Warning: All contents in this file will be removed from xorm some times after
-package xorm
-
-// @deprecated : please use NewSession instead
-func (engine *Engine) MakeSession() (Session, error) {
- s := engine.NewSession()
- return *s, nil
-}
-
-// @deprecated : please use NewEngine instead
-func Create(driverName string, dataSourceName string) Engine {
- engine, _ := NewEngine(driverName, dataSourceName)
- return *engine
-}
diff --git a/engine.go b/engine.go
index b6207706..bcd30357 100644
--- a/engine.go
+++ b/engine.go
@@ -10,6 +10,7 @@ package xorm
import (
"database/sql"
"fmt"
+ "io"
"reflect"
"strconv"
"strings"
@@ -17,18 +18,19 @@ import (
)
const (
- PQSQL = "pqsql"
- MSSQL = "mssql"
- SQLITE = "sqlite3"
- MYSQL = "mysql"
- MYMYSQL = "mymysql"
+ POSTGRES = "postgres"
+ SQLITE = "sqlite3"
+ MYSQL = "mysql"
+ MYMYSQL = "mymysql"
)
type dialect interface {
SqlType(t *Column) string
SupportInsertMany() bool
- QuoteIdentifier() string
- AutoIncrIdentifier() string
+ QuoteStr() string
+ AutoIncrStr() string
+ SupportEngine() bool
+ SupportCharset() bool
}
type Engine struct {
@@ -42,18 +44,28 @@ type Engine struct {
ShowSQL bool
pool IConnectPool
CacheMapping bool
+ Filters []Filter
+ Logger io.Writer
}
func (engine *Engine) SupportInsertMany() bool {
return engine.Dialect.SupportInsertMany()
}
-func (engine *Engine) QuoteIdentifier() string {
- return engine.Dialect.QuoteIdentifier()
+func (engine *Engine) QuoteStr() string {
+ return engine.Dialect.QuoteStr()
}
-func (engine *Engine) AutoIncrIdentifier() string {
- return engine.Dialect.AutoIncrIdentifier()
+func (engine *Engine) Quote(sql string) string {
+ return engine.Dialect.QuoteStr() + sql + engine.Dialect.QuoteStr()
+}
+
+func (engine *Engine) SqlType(c *Column) string {
+ return engine.Dialect.SqlType(c)
+}
+
+func (engine *Engine) AutoIncrStr() string {
+ return engine.Dialect.AutoIncrStr()
}
func (engine *Engine) SetPool(pool IConnectPool) error {
@@ -90,12 +102,20 @@ func (engine *Engine) Close() error {
func (engine *Engine) Test() error {
session := engine.NewSession()
defer session.Close()
- if engine.ShowSQL {
- fmt.Printf("PING DATABASE %v\n", engine.DriverName)
- }
+ engine.LogSQL("PING DATABASE", engine.DriverName)
return session.Ping()
}
+func (engine *Engine) LogSQL(contents ...interface{}) {
+ if engine.ShowSQL {
+ io.WriteString(engine.Logger, fmt.Sprintln(contents...))
+ }
+}
+
+func (engine *Engine) LogError(contents ...interface{}) {
+ io.WriteString(engine.Logger, fmt.Sprintln(contents...))
+}
+
func (engine *Engine) Sql(querystring string, args ...interface{}) *Session {
session := engine.NewSession()
return session.Sql(querystring, args...)
@@ -126,6 +146,16 @@ func (engine *Engine) StoreEngine(storeEngine string) *Session {
return session.StoreEngine(storeEngine)
}
+func (engine *Engine) Cols(columns ...string) *Session {
+ session := engine.NewSession()
+ return session.Cols(columns...)
+}
+
+func (engine *Engine) Trans(t string) *Session {
+ session := engine.NewSession()
+ return session.Trans(t)
+}
+
func (engine *Engine) In(column string, args ...interface{}) *Session {
session := engine.NewSession()
return session.In(column, args...)
@@ -273,6 +303,10 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
}
case k == "date":
col.SQLType = Date
+ case k == "float":
+ col.SQLType = Float
+ case k == "double":
+ col.SQLType = Double
case k == "datetime":
col.SQLType = DateTime
case k == "timestamp":
@@ -375,7 +409,8 @@ func (e *Engine) DropAll() error {
}
err = session.DropAll()
if err != nil {
- return session.Rollback()
+ session.Rollback()
+ return err
}
return session.Commit()
}
@@ -418,17 +453,8 @@ func (e *Engine) DropTables(beans ...interface{}) error {
func (e *Engine) CreateAll() error {
session := e.NewSession()
- err := session.Begin()
defer session.Close()
- if err != nil {
- return err
- }
-
- err = session.CreateAll()
- if err != nil {
- return session.Rollback()
- }
- return session.Commit()
+ return session.CreateAll()
}
func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) {
@@ -449,6 +475,12 @@ func (engine *Engine) Insert(beans ...interface{}) (int64, error) {
return session.Insert(beans...)
}
+func (engine *Engine) InsertOne(bean interface{}) (int64, error) {
+ session := engine.NewSession()
+ defer session.Close()
+ return session.InsertOne(bean)
+}
+
func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) {
session := engine.NewSession()
defer session.Close()
diff --git a/examples/goroutine.go b/examples/goroutine.go
index 23f6a8ce..70f8c5ab 100644
--- a/examples/goroutine.go
+++ b/examples/goroutine.go
@@ -22,7 +22,7 @@ func sqliteEngine() (*xorm.Engine, error) {
}
func mysqlEngine() (*xorm.Engine, error) {
- return xorm.NewEngine("mysql", "root:123@/test?charset=utf8")
+ return xorm.NewEngine("mysql", "root:@/test?charset=utf8")
}
var u *User = &User{}
diff --git a/filter.go b/filter.go
new file mode 100644
index 00000000..8c4d4ea0
--- /dev/null
+++ b/filter.go
@@ -0,0 +1,43 @@
+package xorm
+
+import (
+ "fmt"
+ "strings"
+)
+
+type Filter interface {
+ Do(sql string, session *Session) string
+}
+
+type PgSeqFilter struct {
+}
+
+func (s *PgSeqFilter) Do(sql string, session *Session) string {
+ segs := strings.Split(sql, "?")
+ size := len(segs)
+ res := ""
+ for i, c := range segs {
+ if i < size-1 {
+ res += c + fmt.Sprintf("$%v", i+1)
+ }
+ }
+ res += segs[size-1]
+ return res
+}
+
+type PgQuoteFilter struct {
+}
+
+func (s *PgQuoteFilter) Do(sql string, session *Session) string {
+ return strings.Replace(sql, "`", session.Engine.QuoteStr(), -1)
+}
+
+type IdFilter struct {
+}
+
+func (i *IdFilter) Do(sql string, session *Session) string {
+ if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" {
+ return strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1)
+ }
+ return sql
+}
diff --git a/mymysql_test.go b/mymysql_test.go
new file mode 100644
index 00000000..8164e29e
--- /dev/null
+++ b/mymysql_test.go
@@ -0,0 +1,23 @@
+package xorm
+
+import (
+ _ "github.com/ziutek/mymysql/godrv"
+ "testing"
+)
+
+/*
+CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET
+utf8 COLLATE utf8_general_ci;
+*/
+
+func TestMyMysql(t *testing.T) {
+ engine, err := NewEngine("mymysql", "xorm_test/root/")
+ defer engine.Close()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ engine.ShowSQL = true
+
+ testAll(engine, t)
+}
diff --git a/mysql.go b/mysql.go
index 0a7029b0..6fba8dc0 100644
--- a/mysql.go
+++ b/mysql.go
@@ -13,26 +13,48 @@ type mysql struct {
}
func (db *mysql) SqlType(c *Column) string {
+ var res string
switch t := c.SQLType; t {
- case Date, DateTime, TimeStamp:
- return "DATETIME"
- case Varchar:
- return t.Name + "(" + strconv.Itoa(c.Length) + ")"
- case Decimal:
- return t.Name + "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
+ case Bool:
+ res = TinyInt.Name
+ case Serial:
+ c.IsAutoIncrement = true
+ res = Int.Name
+ case BigSerial:
+ c.IsAutoIncrement = true
+ res = Integer.Name
+ case Bytea:
+ res = Blob.Name
default:
- return t.Name
+ res = t.Name
}
+
+ var hasLen1 bool = (c.Length > 0)
+ var hasLen2 bool = (c.Length2 > 0)
+ if hasLen1 {
+ res += "(" + strconv.Itoa(c.Length) + ")"
+ } else if hasLen2 {
+ res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
+ }
+ return res
}
func (db *mysql) SupportInsertMany() bool {
return true
}
-func (db *mysql) QuoteIdentifier() string {
+func (db *mysql) QuoteStr() string {
return "`"
}
-func (db *mysql) AutoIncrIdentifier() string {
+func (db *mysql) AutoIncrStr() string {
return "AUTO_INCREMENT"
}
+
+func (db *mysql) SupportEngine() bool {
+ return true
+}
+
+func (db *mysql) SupportCharset() bool {
+ return true
+}
diff --git a/mysql_test.go b/mysql_test.go
index 8e66b625..4b446c94 100644
--- a/mysql_test.go
+++ b/mysql_test.go
@@ -5,47 +5,19 @@ import (
"testing"
)
-var me Engine
-
/*
CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET
utf8 COLLATE utf8_general_ci;
*/
func TestMysql(t *testing.T) {
- // You should drop all tables before executing this testing
engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8")
+ defer engine.Close()
if err != nil {
t.Error(err)
return
}
- me = *engine
- me.ShowSQL = true
+ engine.ShowSQL = true
- directCreateTable(&me, t)
- mapper(&me, t)
- insert(&me, t)
- query(&me, t)
- exec(&me, t)
- insertAutoIncr(&me, t)
- insertMulti(&me, t)
- insertTwoTable(&me, t)
- update(&me, t)
- testdelete(&me, t)
- get(&me, t)
- cascadeGet(&me, t)
- find(&me, t)
- findMap(&me, t)
- count(&me, t)
- where(&me, t)
- in(&me, t)
- limit(&me, t)
- order(&me, t)
- join(&me, t)
- having(&me, t)
- transaction(&me, t)
- combineTransaction(&me, t)
- table(&me, t)
- createMultiTables(&me, t)
- tableOp(&me, t)
+ testAll(engine, t)
}
diff --git a/postgres.go b/postgres.go
new file mode 100644
index 00000000..2b7e3787
--- /dev/null
+++ b/postgres.go
@@ -0,0 +1,72 @@
+// Copyright 2013 The XORM Authors. All rights reserved.
+// Use of this source code is governed by a BSD
+// license that can be found in the LICENSE file.
+
+// Package xorm provides is a simple and powerful ORM for Go. It makes your
+// database operation simple.
+
+package xorm
+
+import "strconv"
+
+type postgres struct {
+}
+
+func (db *postgres) SqlType(c *Column) string {
+ var res string
+ switch t := c.SQLType; t {
+ case TinyInt:
+ res = SmallInt.Name
+ case MediumInt, Int, Integer:
+ return Integer.Name
+ case Serial, BigSerial:
+ c.IsAutoIncrement = true
+ res = t.Name
+ case Binary, VarBinary:
+ res = Bytea.Name
+ case DateTime:
+ res = TimeStamp.Name
+ case Float:
+ res = Real.Name
+ case TinyText, MediumText, LongText:
+ res = Text.Name
+ case Blob, TinyBlob, MediumBlob, LongBlob:
+ res = Bytea.Name
+ case Double:
+ return "DOUBLE PRECISION"
+ default:
+ if c.IsAutoIncrement {
+ return Serial.Name
+ }
+ res = t.Name
+ }
+
+ var hasLen1 bool = (c.Length > 0)
+ var hasLen2 bool = (c.Length2 > 0)
+ if hasLen1 {
+ res += "(" + strconv.Itoa(c.Length) + ")"
+ } else if hasLen2 {
+ res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
+ }
+ return res
+}
+
+func (db *postgres) SupportInsertMany() bool {
+ return true
+}
+
+func (db *postgres) QuoteStr() string {
+ return "\""
+}
+
+func (db *postgres) AutoIncrStr() string {
+ return ""
+}
+
+func (db *postgres) SupportEngine() bool {
+ return false
+}
+
+func (db *postgres) SupportCharset() bool {
+ return false
+}
diff --git a/postgres_test.go b/postgres_test.go
new file mode 100644
index 00000000..c994bb1b
--- /dev/null
+++ b/postgres_test.go
@@ -0,0 +1,18 @@
+package xorm
+
+import (
+ _ "github.com/bylevel/pq"
+ "testing"
+)
+
+func TestPostgres(t *testing.T) {
+ engine, err := NewEngine("postgres", "dbname=xorm_test sslmode=disable")
+ defer engine.Close()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ engine.ShowSQL = true
+
+ testAll(engine, t)
+}
diff --git a/session.go b/session.go
index 41f41ab2..7370a12a 100644
--- a/session.go
+++ b/session.go
@@ -24,6 +24,7 @@ type Session struct {
Statement Statement
IsAutoCommit bool
IsCommitedOrRollbacked bool
+ TransType string
}
func (session *Session) Init() {
@@ -69,6 +70,16 @@ func (session *Session) In(column string, args ...interface{}) *Session {
return session
}
+func (session *Session) Cols(columns ...string) *Session {
+ session.Statement.Cols(columns...)
+ return session
+}
+
+func (session *Session) Trans(t string) *Session {
+ session.TransType = t
+ return session
+}
+
func (session *Session) Limit(limit int, start ...int) *Session {
session.Statement.Limit(limit, start...)
return session
@@ -136,18 +147,15 @@ func (session *Session) Begin() error {
session.IsAutoCommit = false
session.IsCommitedOrRollbacked = false
session.Tx = tx
- if session.Engine.ShowSQL {
- fmt.Println("BEGIN TRANSACTION")
- }
+
+ session.Engine.LogSQL("BEGIN TRANSACTION")
}
return nil
}
func (session *Session) Rollback() error {
if !session.IsAutoCommit && !session.IsCommitedOrRollbacked {
- if session.Engine.ShowSQL {
- fmt.Println("ROLL BACK")
- }
+ session.Engine.LogSQL("ROLL BACK")
session.IsCommitedOrRollbacked = true
return session.Tx.Rollback()
}
@@ -156,9 +164,7 @@ func (session *Session) Rollback() error {
func (session *Session) Commit() error {
if !session.IsAutoCommit && !session.IsCommitedOrRollbacked {
- if session.Engine.ShowSQL {
- fmt.Println("COMMIT")
- }
+ session.Engine.LogSQL("COMMIT")
session.IsCommitedOrRollbacked = true
return session.Tx.Commit()
}
@@ -168,7 +174,7 @@ func (session *Session) Commit() error {
func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error {
dataStruct := reflect.Indirect(reflect.ValueOf(obj))
if dataStruct.Kind() != reflect.Struct {
- return errors.New("expected a pointer to a struct")
+ return errors.New("Expected a pointer to a struct")
}
table := session.Engine.Tables[Type(obj)]
@@ -181,7 +187,7 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
fieldPath := strings.Split(fieldName, ".")
var structField reflect.Value
if len(fieldPath) > 2 {
- fmt.Printf("xorm: Warning! Unsupported mutliderive %v\n", fieldName)
+ session.Engine.LogError("Unsupported mutliderive", fieldName)
continue
} else if len(fieldPath) == 2 {
parentField := dataStruct.FieldByName(fieldPath[0])
@@ -207,7 +213,7 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
case reflect.String:
v = string(data)
case reflect.Bool:
- v = string(data) == "1"
+ v = (string(data) == "1")
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
x, err := strconv.Atoi(string(data))
if err != nil {
@@ -269,14 +275,14 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
if has {
v = structInter.Elem().Interface()
} else {
- fmt.Println("cascade obj is not exist!")
+ session.Engine.LogError("cascade obj is not exist!")
continue
}
} else {
continue
}
} else {
- fmt.Println("unsupported struct type in Scan: " + structField.Type().String())
+ session.Engine.LogError("unsupported struct type in Scan: " + structField.Type().String())
continue
}
} else {
@@ -309,20 +315,17 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result,
func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) {
err := session.newDb()
- if session.IsAutoCommit {
- defer session.Close()
- }
if err != nil {
return nil, err
}
- if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" {
- sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1)
- }
- if session.Engine.ShowSQL {
- fmt.Println(sql)
- fmt.Println(args)
+ for _, filter := range session.Engine.Filters {
+ sql = filter.Do(sql, session)
}
+
+ session.Engine.LogSQL(sql)
+ session.Engine.LogSQL(args)
+
if session.IsAutoCommit {
return session.innerExec(sql, args...)
}
@@ -331,13 +334,21 @@ func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error
// this function create a table according a bean
func (session *Session) CreateTable(bean interface{}) error {
- statement := session.Statement
- defer statement.Init()
- statement.RefTable = session.Engine.AutoMap(bean)
- sql := statement.genCreateSQL()
+ session.Statement.RefTable = session.Engine.AutoMap(bean)
+
+ err := session.newDb()
+ if err != nil {
+ return err
+ }
+
+ return session.createOneTable()
+}
+
+func (session *Session) createOneTable() error {
+ sql := session.Statement.genCreateSQL()
_, err := session.Exec(sql)
if err == nil {
- sqls := statement.genIndexSQL()
+ sqls := session.Statement.genIndexSQL()
for _, sql := range sqls {
_, err = session.Exec(sql)
if err != nil {
@@ -346,7 +357,7 @@ func (session *Session) CreateTable(bean interface{}) error {
}
}
if err == nil {
- sqls := statement.genUniqueSQL()
+ sqls := session.Statement.genUniqueSQL()
for _, sql := range sqls {
_, err = session.Exec(sql)
if err != nil {
@@ -357,26 +368,59 @@ func (session *Session) CreateTable(bean interface{}) error {
return err
}
+func (session *Session) CreateAll() error {
+ err := session.newDb()
+ if err != nil {
+ return err
+ }
+
+ for _, table := range session.Engine.Tables {
+ session.Statement.RefTable = table
+ err := session.createOneTable()
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
func (session *Session) DropTable(bean interface{}) error {
- statement := session.Statement
- defer statement.Init()
- statement.RefTable = session.Engine.AutoMap(bean)
- sql := statement.genDropSQL()
- _, err := session.Exec(sql)
+ err := session.newDb()
+ if err != nil {
+ return err
+ }
+
+ t := reflect.Indirect(reflect.ValueOf(bean)).Type()
+ defer session.Statement.Init()
+ if t.Kind() == reflect.String {
+ session.Statement.AltTableName = bean.(string)
+ } else if t.Kind() == reflect.Struct {
+ session.Statement.RefTable = session.Engine.AutoMap(bean)
+ } else {
+ return errors.New("Unsupported type")
+ }
+
+ sql := session.Statement.genDropSQL()
+ _, err = session.Exec(sql)
return err
}
func (session *Session) Get(bean interface{}) (bool, error) {
- statement := session.Statement
- defer statement.Init()
- statement.Limit(1)
+ err := session.newDb()
+ if err != nil {
+ return false, err
+ }
+
+ defer session.Statement.Init()
+ session.Statement.Limit(1)
var sql string
var args []interface{}
- if statement.RawSQL == "" {
- sql, args = statement.genGetSql(bean)
+ if session.Statement.RawSQL == "" {
+ sql, args = session.Statement.genGetSql(bean)
} else {
- sql = statement.RawSQL
- args = statement.RawParams
+ sql = session.Statement.RawSQL
+ args = session.Statement.RawParams
+ session.Engine.AutoMap(bean)
}
resultsSlice, err := session.Query(sql, args...)
if err != nil {
@@ -387,7 +431,6 @@ func (session *Session) Get(bean interface{}) (bool, error) {
}
results := resultsSlice[0]
- session.Engine.AutoMap(bean)
err = session.scanMapIntoStruct(bean, results)
if err != nil {
return false, err
@@ -400,15 +443,19 @@ func (session *Session) Get(bean interface{}) (bool, error) {
}
func (session *Session) Count(bean interface{}) (int64, error) {
- statement := session.Statement
+ err := session.newDb()
+ if err != nil {
+ return 0, err
+ }
+
defer session.Statement.Init()
var sql string
var args []interface{}
- if statement.RawSQL == "" {
- sql, args = statement.genCountSql(bean)
+ if session.Statement.RawSQL == "" {
+ sql, args = session.Statement.genCountSql(bean)
} else {
- sql = statement.RawSQL
- args = statement.RawParams
+ sql = session.Statement.RawSQL
+ args = session.Statement.RawParams
}
resultsSlice, err := session.Query(sql, args...)
@@ -429,7 +476,11 @@ func (session *Session) Count(bean interface{}) (int64, error) {
}
func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
- statement := session.Statement
+ err := session.newDb()
+ if err != nil {
+ return err
+ }
+
defer session.Statement.Init()
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
@@ -438,22 +489,26 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
sliceElementType := sliceValue.Type().Elem()
table := session.Engine.AutoMapType(sliceElementType)
- statement.RefTable = table
+ session.Statement.RefTable = table
if len(condiBean) > 0 {
colNames, args := BuildConditions(session.Engine, table, condiBean[0])
- statement.ColumnStr = strings.Join(colNames, " and ")
- statement.BeanArgs = args
+ session.Statement.ConditionStr = strings.Join(colNames, " and ")
+ session.Statement.BeanArgs = args
}
var sql string
var args []interface{}
- if statement.RawSQL == "" {
- sql = statement.generateSql()
- args = append(statement.Params, statement.BeanArgs...)
+ if session.Statement.RawSQL == "" {
+ var columnStr string = session.Statement.ColumnStr
+ if columnStr == "" {
+ columnStr = session.Statement.genColumnStr()
+ }
+ sql = session.Statement.genSelectSql(columnStr)
+ args = append(session.Statement.Params, session.Statement.BeanArgs...)
} else {
- sql = statement.RawSQL
- args = statement.RawParams
+ sql = session.Statement.RawSQL
+ args = session.Statement.RawParams
}
resultsSlice, err := session.Query(sql, args...)
@@ -496,20 +551,14 @@ func (session *Session) Ping() error {
return session.Db.Ping()
}
-func (session *Session) CreateAll() error {
- for _, table := range session.Engine.Tables {
- session.Statement.RefTable = table
- sql := session.Statement.genCreateSQL()
- _, err := session.Exec(sql)
- if err != nil {
- return err
- }
- }
- return nil
-}
-
func (session *Session) DropAll() error {
+ err := session.newDb()
+ if err != nil {
+ return err
+ }
+
for _, table := range session.Engine.Tables {
+ session.Statement.Init()
session.Statement.RefTable = table
sql := session.Statement.genDropSQL()
_, err := session.Exec(sql)
@@ -526,18 +575,13 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
return nil, err
}
- if session.IsAutoCommit {
- defer session.Close()
+ for _, filter := range session.Engine.Filters {
+ sql = filter.Do(sql, session)
}
- // TODO: this statement should be invoke before Query
- if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" {
- sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1)
- }
- if session.Engine.ShowSQL {
- fmt.Println(sql)
- fmt.Println(paramStr)
- }
+ session.Engine.LogSQL(sql)
+ session.Engine.LogSQL(paramStr)
+
s, err := session.Db.Prepare(sql)
if err != nil {
return nil, err
@@ -597,7 +641,7 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700")
result[key] = []byte(str)
} else {
- fmt.Print("Unsupported struct type")
+ session.Engine.LogError("Unsupported struct type")
}
}
//default:
@@ -625,7 +669,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice {
if session.Engine.SupportInsertMany() {
- lastId, err = session.InsertMulti(bean)
+ lastId, err = session.innerInsertMulti(bean)
if err != nil {
if !isInTransaction {
err1 := session.Rollback()
@@ -639,7 +683,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
} else {
size := sliceValue.Len()
for i := 0; i < size; i++ {
- lastId, err = session.InsertOne(sliceValue.Index(i).Interface())
+ lastId, err = session.innerInsert(sliceValue.Index(i).Interface())
if err != nil {
if !isInTransaction {
err1 := session.Rollback()
@@ -653,7 +697,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
}
}
} else {
- lastId, err = session.InsertOne(bean)
+ lastId, err = session.innerInsert(bean)
if err != nil {
if !isInTransaction {
err1 := session.Rollback()
@@ -672,7 +716,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
return lastId, err
}
-func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
+func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice {
return -1, errors.New("needs a pointer to a slice")
@@ -698,20 +742,18 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
if i == 0 {
for _, col := range table.Columns {
fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName)
- val := fieldValue.Interface()
if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue
}
if col.MapType == ONLYFROMDB {
continue
}
- if table, ok := session.Engine.Tables[fieldValue.Type()]; ok {
- pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName)
- //fmt.Println(pkField.Interface())
- args = append(args, pkField.Interface())
- } else {
- args = append(args, val)
+ arg, err := session.value2Interface(fieldValue)
+ if err != nil {
+ return 0, err
}
+
+ args = append(args, arg)
colNames = append(colNames, col.Name)
cols = append(cols, col)
colPlaces = append(colPlaces, "?")
@@ -719,30 +761,36 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
} else {
for _, col := range cols {
fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName)
- val := fieldValue.Interface()
if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue
}
if col.MapType == ONLYFROMDB {
continue
}
- if table, ok := session.Engine.Tables[fieldValue.Type()]; ok {
- pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName)
- args = append(args, pkField.Interface())
- } else {
- args = append(args, val)
+ if session.Statement.ColumnStr != "" {
+ if _, ok := session.Statement.columnMap[col.Name]; !ok {
+ continue
+ }
}
+ arg, err := session.value2Interface(fieldValue)
+ if err != nil {
+ return 0, err
+ }
+
+ args = append(args, arg)
colPlaces = append(colPlaces, "?")
}
}
colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
}
- statement := fmt.Sprintf("INSERT INTO %v%v%v (%v) VALUES (%v)",
- session.Engine.QuoteIdentifier(),
+ statement := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v);",
+ session.Engine.QuoteStr(),
session.Statement.TableName(),
- session.Engine.QuoteIdentifier(),
- strings.Join(colNames, ", "),
+ session.Engine.QuoteStr(),
+ session.Engine.QuoteStr(),
+ strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()),
+ session.Engine.QuoteStr(),
strings.Join(colMultiPlaces, "),("))
res, err := session.Exec(statement, args...)
@@ -759,94 +807,149 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
return id, nil
}
-func (session *Session) InsertOne(bean interface{}) (int64, error) {
+func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
+ err := session.newDb()
+ if session.IsAutoCommit {
+ defer session.Close()
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ return session.innerInsertMulti(rowsSlicePtr)
+}
+
+func (session *Session) value2Interface(fieldValue reflect.Value) (interface{}, error) {
+ if fieldValue.Type().Kind() == reflect.Bool {
+ if fieldValue.Bool() {
+ return 1, nil
+ } else {
+ return 0, nil
+ }
+ } else if fieldValue.Type().String() == "time.Time" {
+ return fieldValue.Interface(), nil
+ } else if fieldValue.Type().Kind() == reflect.Struct {
+ if fieldValue.CanAddr() {
+ if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
+ data, err := fieldConvert.ToDB()
+ if err != nil {
+ return 0, err
+ } else {
+ return string(data), nil
+ }
+ }
+ }
+
+ if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok {
+ if fieldTable.PrimaryKey != "" {
+ pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumn().FieldName)
+ return pkField.Interface(), nil
+ } else {
+ return 0, errors.New("no primary key")
+ }
+ } else {
+ return 0, errors.New(fmt.Sprintf("Unsupported type %v", fieldValue.Type()))
+ }
+ } else {
+ return fieldValue.Interface(), nil
+ }
+}
+
+func (session *Session) innerInsert(bean interface{}) (int64, error) {
table := session.Engine.AutoMap(bean)
- //fmt.Printf("table: %v\n", table)
+
session.Statement.RefTable = table
colNames := make([]string, 0)
colPlaces := make([]string, 0)
var args = make([]interface{}, 0)
+
for _, col := range table.Columns {
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
- val := fieldValue.Interface()
if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue
}
if col.MapType == ONLYFROMDB {
continue
}
- if fieldValue.Type().String() == "time.Time" {
- args = append(args, val)
- } else if fieldValue.Type().Kind() == reflect.Struct {
- if fieldValue.CanAddr() {
- if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
- data, err := fieldConvert.ToDB()
- if err != nil {
- return 0, err
- } else {
- args = append(args, string(data))
- }
- } else {
- if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok {
- if fieldTable.PrimaryKey != "" {
- pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumn().FieldName)
- args = append(args, pkField.Interface())
- } else {
- continue
- }
- } else {
- //args = append(args, val)
- continue
- }
- }
- } else {
+ if session.Statement.ColumnStr != "" {
+ if _, ok := session.Statement.columnMap[col.Name]; !ok {
continue
}
- } else {
- args = append(args, val)
}
+
+ arg, err := session.value2Interface(fieldValue)
+ if err != nil {
+ return 0, err
+ }
+
+ args = append(args, arg)
colNames = append(colNames, col.Name)
colPlaces = append(colPlaces, "?")
}
- sql := fmt.Sprintf("INSERT INTO %v%v%v (%v) VALUES (%v)",
- session.Engine.QuoteIdentifier(),
+ sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v);",
+ session.Engine.QuoteStr(),
session.Statement.TableName(),
- session.Engine.QuoteIdentifier(),
- strings.Join(colNames, ", "),
+ session.Engine.QuoteStr(),
+ session.Engine.QuoteStr(),
+ strings.Join(colNames, session.Engine.Quote(", ")),
+ session.Engine.QuoteStr(),
strings.Join(colPlaces, ", "))
res, err := session.Exec(sql, args...)
if err != nil {
- return -1, err
+ return 0, err
}
- id, err := res.LastInsertId()
- if err != nil {
- return -1, err
+ if table.PrimaryKey == "" {
+ return 0, nil
}
- if id > 0 && table.PrimaryKey != "" {
- pkValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(table.PKColumn().FieldName)
- if pkValue.CanSet() {
- var v interface{} = id
- switch pkValue.Type().Kind() {
- case reflect.Int8, reflect.Int16, reflect.Int32:
- v = int(id)
- pkValue.Set(reflect.ValueOf(v))
- case reflect.Int64:
- pkValue.Set(reflect.ValueOf(v))
- case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- v = uint(id)
- pkValue.Set(reflect.ValueOf(v))
- }
- }
+ var id int64 = 0
+ pkValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(table.PKColumn().FieldName)
+ if pkValue.Int() != 0 || !pkValue.CanSet() {
+ return 0, nil
}
+ id, err = res.LastInsertId()
+ if err != nil || id <= 0 {
+ return 0, err
+ }
+
+ var v interface{} = id
+ switch pkValue.Type().Kind() {
+ case reflect.Int8, reflect.Int16, reflect.Int32:
+ v = int(id)
+ case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ v = uint(id)
+
+ }
+ pkValue.Set(reflect.ValueOf(v))
+
return id, nil
}
+func (session *Session) InsertOne(bean interface{}) (int64, error) {
+ err := session.newDb()
+ if session.IsAutoCommit {
+ defer session.Close()
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ return session.innerInsert(bean)
+}
+
func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) {
+ err := session.newDb()
+ if session.IsAutoCommit {
+ defer session.Close()
+ }
+ if err != nil {
+ return 0, err
+ }
+
table := session.Engine.AutoMap(bean)
session.Statement.RefTable = table
colNames, args := BuildConditions(session.Engine, table, bean)
@@ -874,10 +977,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}
- sql := fmt.Sprintf("UPDATE %v%v%v SET %v %v",
- session.Engine.QuoteIdentifier(),
- session.Statement.TableName(),
- session.Engine.QuoteIdentifier(),
+ sql := fmt.Sprintf("UPDATE %v SET %v %v",
+ session.Engine.Quote(session.Statement.TableName()),
strings.Join(colNames, ", "),
condition)
@@ -887,15 +988,23 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return -1, err
}
- id, err := res.RowsAffected()
+ rows, err := res.RowsAffected()
if err != nil {
return -1, err
}
- return id, nil
+ return rows, nil
}
func (session *Session) Delete(bean interface{}) (int64, error) {
+ err := session.newDb()
+ if session.IsAutoCommit {
+ defer session.Close()
+ }
+ if err != nil {
+ return 0, err
+ }
+
table := session.Engine.AutoMap(bean)
session.Statement.RefTable = table
colNames, args := BuildConditions(session.Engine, table, bean)
@@ -914,9 +1023,9 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
}
statement := fmt.Sprintf("DELETE FROM %v%v%v %v",
- session.Engine.QuoteIdentifier(),
+ session.Engine.QuoteStr(),
session.Statement.TableName(),
- session.Engine.QuoteIdentifier(),
+ session.Engine.QuoteStr(),
condition)
res, err := session.Exec(statement, append(st.Params, args...)...)
diff --git a/sqlite3.go b/sqlite3.go
index 13813612..48e5a636 100644
--- a/sqlite3.go
+++ b/sqlite3.go
@@ -12,18 +12,21 @@ type sqlite3 struct {
func (db *sqlite3) SqlType(c *Column) string {
switch t := c.SQLType; t {
- case Date, DateTime, TimeStamp:
- return "NUMERIC"
- case Char, Varchar, Text:
- return "TEXT"
- case TinyInt, SmallInt, MediumInt, Int, BigInt:
- return "INTEGER"
- case Float, Double:
- return "REAL"
- case Decimal:
- return "NUMERIC"
- case Blob:
- return "BLOB"
+ case Date, DateTime, TimeStamp, Time:
+ return Numeric.Name
+ case Char, Varchar, TinyText, Text, MediumText, LongText:
+ return Text.Name
+ case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool:
+ return Integer.Name
+ case Float, Double, Real:
+ return Real.Name
+ case Decimal, Numeric:
+ return Numeric.Name
+ case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary:
+ return Blob.Name
+ case Serial, BigSerial:
+ c.IsAutoIncrement = true
+ return Integer.Name
default:
return t.Name
}
@@ -33,10 +36,18 @@ func (db *sqlite3) SupportInsertMany() bool {
return true
}
-func (db *sqlite3) QuoteIdentifier() string {
+func (db *sqlite3) QuoteStr() string {
return "`"
}
-func (db *sqlite3) AutoIncrIdentifier() string {
+func (db *sqlite3) AutoIncrStr() string {
return "AUTOINCREMENT"
}
+
+func (db *sqlite3) SupportEngine() bool {
+ return false
+}
+
+func (db *sqlite3) SupportCharset() bool {
+ return false
+}
diff --git a/sqlite3_test.go b/sqlite3_test.go
index fe9af9f9..41bdb707 100644
--- a/sqlite3_test.go
+++ b/sqlite3_test.go
@@ -6,147 +6,15 @@ import (
"testing"
)
-var se *Engine
-
-func autoConn() {
- if se == nil {
- os.Remove("./test.db")
- se, _ = NewEngine("sqlite3", "./test.db")
- se.ShowSQL = true
+func TestSqlite3(t *testing.T) {
+ os.Remove("./test.db")
+ engine, err := NewEngine("sqlite3", "./test.db")
+ defer engine.Close()
+ if err != nil {
+ t.Error(err)
+ return
}
-}
+ engine.ShowSQL = true
-func TestSqliteCreateTable(t *testing.T) {
- autoConn()
- directCreateTable(se, t)
-}
-
-func TestSqliteMapper(t *testing.T) {
- autoConn()
- mapper(se, t)
-}
-
-func TestSqliteInsert(t *testing.T) {
- autoConn()
- insert(se, t)
-}
-
-func TestSqliteQuery(t *testing.T) {
- autoConn()
- query(se, t)
-}
-
-func TestSqliteExec(t *testing.T) {
- autoConn()
- exec(se, t)
-}
-
-func TestSqliteInsertAutoIncr(t *testing.T) {
- autoConn()
- insertAutoIncr(se, t)
-}
-
-func TestInsertMulti(t *testing.T) {
- autoConn()
- insertMulti(se, t)
-}
-
-func TestSqliteInsertMulti(t *testing.T) {
- autoConn()
- insertMulti(se, t)
-}
-
-func TestSqliteInsertTwoTable(t *testing.T) {
- autoConn()
- insertTwoTable(se, t)
-}
-
-func TestSqliteUpdate(t *testing.T) {
- autoConn()
- update(se, t)
-}
-
-func TestSqliteDelete(t *testing.T) {
- autoConn()
- testdelete(se, t)
-}
-
-func TestSqliteGet(t *testing.T) {
- autoConn()
- get(se, t)
-}
-
-func TestSqliteCascadeGet(t *testing.T) {
- autoConn()
- cascadeGet(se, t)
-}
-
-func TestSqliteFind(t *testing.T) {
- autoConn()
- find(se, t)
-}
-
-func TestSqliteFindMap(t *testing.T) {
- autoConn()
- findMap(se, t)
-}
-
-func TestSqliteCount(t *testing.T) {
- autoConn()
- count(se, t)
-}
-
-func TestSqliteWhere(t *testing.T) {
- autoConn()
- where(se, t)
-}
-
-func TestSqliteIn(t *testing.T) {
- autoConn()
- in(se, t)
-}
-
-func TestSqliteLimit(t *testing.T) {
- autoConn()
- limit(se, t)
-}
-
-func TestSqliteOrder(t *testing.T) {
- autoConn()
- order(se, t)
-}
-
-func TestSqliteJoin(t *testing.T) {
- autoConn()
- join(se, t)
-}
-
-func TestSqliteHaving(t *testing.T) {
- autoConn()
- having(se, t)
-}
-
-func TestSqliteTransaction(t *testing.T) {
- autoConn()
- transaction(se, t)
-}
-
-func TestSqliteCombineTransaction(t *testing.T) {
- autoConn()
- combineTransaction(se, t)
-}
-
-func TestSqliteTable(t *testing.T) {
- autoConn()
- table(se, t)
-}
-
-func TestSqliteCreateMultiTables(t *testing.T) {
- autoConn()
- createMultiTables(se, t)
-}
-
-func TestSqliteTableOp(t *testing.T) {
- autoConn()
- tableOp(se, t)
+ testAll(engine, t)
}
diff --git a/statement.go b/statement.go
index aa1cd510..cd64960a 100644
--- a/statement.go
+++ b/statement.go
@@ -27,6 +27,8 @@ type Statement struct {
GroupByStr string
HavingStr string
ColumnStr string
+ columnMap map[string]bool
+ ConditionStr string
AltTableName string
RawSQL string
RawParams []interface{}
@@ -57,6 +59,8 @@ func (statement *Statement) Init() {
statement.GroupByStr = ""
statement.HavingStr = ""
statement.ColumnStr = ""
+ statement.columnMap = make(map[string]bool)
+ statement.ConditionStr = ""
statement.AltTableName = ""
statement.RawSQL = ""
statement.RawParams = make([]interface{}, 0)
@@ -116,8 +120,7 @@ func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string,
} else {
args = append(args, val)
}
- colNames = append(colNames, fmt.Sprintf("%v%v%v = ?", engine.QuoteIdentifier(),
- col.Name, engine.QuoteIdentifier()))
+ colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
}
return colNames, args
@@ -155,6 +158,13 @@ func (statement *Statement) In(column string, args ...interface{}) {
}
}
+func (statement *Statement) Cols(columns ...string) {
+ statement.ColumnStr = strings.Join(columns, statement.Engine.Quote(", "))
+ for _, column := range columns {
+ statement.columnMap[column] = true
+ }
+}
+
func (statement *Statement) Limit(limit int, start ...int) {
statement.LimitN = limit
if len(start) > 0 {
@@ -176,65 +186,36 @@ func (statement *Statement) Join(join_operator, tablename, condition string) {
}
func (statement *Statement) GroupBy(keys string) {
- statement.GroupByStr = fmt.Sprintf("GROUP BY %v", keys)
+ statement.GroupByStr = keys
}
func (statement *Statement) Having(conditions string) {
statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
}
-func (statement *Statement) genColumnStr(col *Column) string {
- sql := "`" + col.Name + "` "
-
- sql += statement.Engine.Dialect.SqlType(col) + " "
-
- if col.IsPrimaryKey {
- sql += "PRIMARY KEY "
- }
-
- if col.IsAutoIncrement {
- sql += statement.Engine.AutoIncrIdentifier() + " "
- }
-
- if col.Nullable {
- sql += "NULL "
- } else {
- sql += "NOT NULL "
- }
-
- /*if col.UniqueType == SINGLEUNIQUE {
- sql += "UNIQUE "
- }*/
-
- if col.Default != "" {
- sql += "DEFAULT " + col.Default + " "
- }
- return sql
-}
-
-func (statement *Statement) selectColumnStr() string {
+func (statement *Statement) genColumnStr() string {
table := statement.RefTable
colNames := make([]string, 0)
for _, col := range table.Columns {
if col.MapType != ONLYTODB {
- colNames = append(colNames, statement.TableName()+"."+col.Name)
+ colNames = append(colNames, statement.Engine.Quote(statement.TableName())+"."+statement.Engine.Quote(col.Name))
}
}
return strings.Join(colNames, ", ")
}
func (statement *Statement) genCreateSQL() string {
- sql := "CREATE TABLE IF NOT EXISTS `" + statement.TableName() + "` ("
+ sql := "CREATE TABLE IF NOT EXISTS " + statement.Engine.Quote(statement.TableName()) + " ("
for _, col := range statement.RefTable.Columns {
- sql += statement.genColumnStr(&col)
+ sql += col.String(statement.Engine)
sql = strings.TrimSpace(sql)
sql += ", "
}
sql = sql[:len(sql)-2] + ")"
- if statement.StoreEngine != "" {
+ if statement.Engine.Dialect.SupportEngine() && statement.StoreEngine != "" {
sql += " ENGINE=" + statement.StoreEngine
}
- if statement.Charset != "" {
+ if statement.Engine.Dialect.SupportCharset() && statement.Charset != "" {
sql += " DEFAULT CHARSET " + statement.Charset
}
sql += ";"
@@ -262,24 +243,24 @@ func (statement *Statement) genUniqueSQL() []string {
}
func (statement *Statement) genDropSQL() string {
- sql := "DROP TABLE IF EXISTS `" + statement.TableName() + "`;"
+ sql := "DROP TABLE IF EXISTS " + statement.Engine.Quote(statement.TableName()) + ";"
return sql
}
-func (statement Statement) generateSql() string {
- columnStr := statement.selectColumnStr()
- return statement.genSelectSql(columnStr)
-}
-
func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) {
table := statement.Engine.AutoMap(bean)
statement.RefTable = table
colNames, args := BuildConditions(statement.Engine, table, bean)
- statement.ColumnStr = strings.Join(colNames, " and ")
+ statement.ConditionStr = strings.Join(colNames, " and ")
statement.BeanArgs = args
- return statement.generateSql(), append(statement.Params, statement.BeanArgs...)
+ var columnStr string = statement.ColumnStr
+ if columnStr == "" {
+ columnStr = statement.genColumnStr()
+ }
+
+ return statement.genSelectSql(columnStr), append(statement.Params, statement.BeanArgs...)
}
func (statement Statement) genCountSql(bean interface{}) (string, []interface{}) {
@@ -287,98 +268,42 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{})
statement.RefTable = table
colNames, args := BuildConditions(statement.Engine, table, bean)
- statement.ColumnStr = strings.Join(colNames, " and ")
+ statement.ConditionStr = strings.Join(colNames, " and ")
statement.BeanArgs = args
- return statement.genSelectSql("count(*) as total"), append(statement.Params, statement.BeanArgs...)
+ return statement.genSelectSql(fmt.Sprintf("count(*) as %v", statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...)
}
func (statement Statement) genSelectSql(columnStr string) (a string) {
- if statement.Engine.DriverName == MSSQL {
- if statement.Start > 0 {
- a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v",
- statement.RefTable.PKColumn().Name,
- columnStr,
- statement.TableName())
- if statement.WhereStr != "" {
- a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
- if statement.ColumnStr != "" {
- a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
- }
- } else if statement.ColumnStr != "" {
- a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
- }
- a = fmt.Sprintf("select %v from (%v) "+
- "as a where rownum between %v and %v",
- columnStr,
- a,
- statement.Start,
- statement.LimitN)
- } else if statement.LimitN > 0 {
- a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitN, columnStr, statement.TableName())
- if statement.WhereStr != "" {
- a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
- if statement.ColumnStr != "" {
- a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
- }
- } else if statement.ColumnStr != "" {
- a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
- }
- if statement.GroupByStr != "" {
- a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
- }
- if statement.HavingStr != "" {
- a = fmt.Sprintf("%v %v", a, statement.HavingStr)
- }
- if statement.OrderStr != "" {
- a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
- }
- } else {
- a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.TableName())
- if statement.WhereStr != "" {
- a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
- if statement.ColumnStr != "" {
- a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
- }
- } else if statement.ColumnStr != "" {
- a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
- }
- if statement.GroupByStr != "" {
- a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
- }
- if statement.HavingStr != "" {
- a = fmt.Sprintf("%v %v", a, statement.HavingStr)
- }
- if statement.OrderStr != "" {
- a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
- }
- }
- } else {
- a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.TableName())
- if statement.JoinStr != "" {
- a = fmt.Sprintf("%v %v", a, statement.JoinStr)
- }
- if statement.WhereStr != "" {
- a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
- if statement.ColumnStr != "" {
- a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
- }
- } else if statement.ColumnStr != "" {
- a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
- }
- if statement.GroupByStr != "" {
- a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
- }
- if statement.HavingStr != "" {
- a = fmt.Sprintf("%v %v", a, statement.HavingStr)
- }
- if statement.OrderStr != "" {
- a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
- }
- if statement.Start > 0 {
- a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.Start, statement.LimitN)
- } else if statement.LimitN > 0 {
- a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
+ if statement.GroupByStr != "" {
+ columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
+ statement.GroupByStr = columnStr
+ }
+ a = fmt.Sprintf("SELECT %v FROM %v", columnStr,
+ statement.Engine.Quote(statement.TableName()))
+ if statement.JoinStr != "" {
+ a = fmt.Sprintf("%v %v", a, statement.JoinStr)
+ }
+ if statement.WhereStr != "" {
+ a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
+ if statement.ConditionStr != "" {
+ a = fmt.Sprintf("%v and %v", a, statement.ConditionStr)
}
+ } else if statement.ConditionStr != "" {
+ a = fmt.Sprintf("%v WHERE %v", a, statement.ConditionStr)
+ }
+ if statement.GroupByStr != "" {
+ a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
+ }
+ if statement.HavingStr != "" {
+ a = fmt.Sprintf("%v %v", a, statement.HavingStr)
+ }
+ if statement.OrderStr != "" {
+ a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
+ }
+ if statement.Start > 0 {
+ a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
+ } else if statement.LimitN > 0 {
+ a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
}
return
}
diff --git a/table.go b/table.go
index f363b004..e313ee54 100644
--- a/table.go
+++ b/table.go
@@ -21,21 +21,46 @@ type SQLType struct {
}
var (
+ Bit = SQLType{"BIT", 0, 0}
TinyInt = SQLType{"TINYINT", 0, 0}
SmallInt = SQLType{"SMALLINT", 0, 0}
MediumInt = SQLType{"MEDIUMINT", 0, 0}
- Int = SQLType{"INT", 11, 0}
+ Int = SQLType{"INT", 0, 0}
+ Integer = SQLType{"INTEGER", 0, 0}
BigInt = SQLType{"BIGINT", 0, 0}
- Char = SQLType{"CHAR", 1, 0}
- Varchar = SQLType{"VARCHAR", 64, 0}
- Text = SQLType{"TEXT", 16, 0}
- Date = SQLType{"DATE", 24, 0}
+
+ Char = SQLType{"CHAR", 0, 0}
+ Varchar = SQLType{"VARCHAR", 64, 0}
+ TinyText = SQLType{"TINYTEXT", 0, 0}
+ Text = SQLType{"TEXT", 0, 0}
+ MediumText = SQLType{"MEDIUMTEXT", 0, 0}
+ LongText = SQLType{"LONGTEXT", 0, 0}
+ Binary = SQLType{"BINARY", 0, 0}
+ VarBinary = SQLType{"VARBINARY", 0, 0}
+
+ Date = SQLType{"DATE", 0, 0}
DateTime = SQLType{"DATETIME", 0, 0}
- Decimal = SQLType{"DECIMAL", 26, 2}
- Float = SQLType{"FLOAT", 31, 0}
- Double = SQLType{"DOUBLE", 31, 0}
- Blob = SQLType{"BLOB", 0, 0}
+ Time = SQLType{"TIME", 0, 0}
TimeStamp = SQLType{"TIMESTAMP", 0, 0}
+
+ Decimal = SQLType{"DECIMAL", 26, 2}
+ Numeric = SQLType{"NUMERIC", 0, 0}
+
+ Real = SQLType{"REAL", 0, 0}
+ Float = SQLType{"FLOAT", 0, 0}
+ Double = SQLType{"DOUBLE", 0, 0}
+ //Money = SQLType{"MONEY", 0, 0}
+
+ TinyBlob = SQLType{"TINYBLOB", 0, 0}
+ Blob = SQLType{"BLOB", 0, 0}
+ MediumBlob = SQLType{"MEDIUMBLOB", 0, 0}
+ LongBlob = SQLType{"LONGBLOB", 0, 0}
+ Bytea = SQLType{"BYTEA", 0, 0}
+
+ Bool = SQLType{"BOOL", 0, 0}
+
+ Serial = SQLType{"SERIAL", 0, 0}
+ BigSerial = SQLType{"BIGSERIAL", 0, 0}
)
var b byte
@@ -106,6 +131,31 @@ type Column struct {
MapType int
}
+func (col *Column) String(engine *Engine) string {
+ sql := engine.Quote(col.Name) + " "
+
+ sql += engine.SqlType(col) + " "
+
+ if col.IsPrimaryKey {
+ sql += "PRIMARY KEY "
+ }
+
+ if col.IsAutoIncrement {
+ sql += engine.AutoIncrStr() + " "
+ }
+
+ if col.Nullable {
+ sql += "NULL "
+ } else {
+ sql += "NOT NULL "
+ }
+
+ if col.Default != "" {
+ sql += "DEFAULT " + col.Default + " "
+ }
+ return sql
+}
+
type Table struct {
Name string
Type reflect.Type
diff --git a/xorm.go b/xorm.go
index 94816f84..0e5a3958 100644
--- a/xorm.go
+++ b/xorm.go
@@ -8,18 +8,19 @@
package xorm
import (
- //"database/sql"
"errors"
"fmt"
+ "os"
"reflect"
"sync"
- //"time"
)
const (
- version string = "0.1.8"
+ version string = "0.1.9"
)
+// new a db manager according to the parameter. Currently support three
+// driver
func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine := &Engine{ShowSQL: false, DriverName: driverName, Mapper: SnakeMapper{},
DataSourceName: dataSourceName}
@@ -27,13 +28,22 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine.Tables = make(map[reflect.Type]*Table)
engine.mutex = &sync.Mutex{}
engine.TagIdentifier = "xorm"
+ engine.Filters = make([]Filter, 0)
if driverName == SQLITE {
engine.Dialect = &sqlite3{}
} else if driverName == MYSQL {
engine.Dialect = &mysql{}
+ } else if driverName == POSTGRES {
+ engine.Dialect = &postgres{}
+ engine.Filters = append(engine.Filters, &PgSeqFilter{})
+ engine.Filters = append(engine.Filters, &PgQuoteFilter{})
+ } else if driverName == MYMYSQL {
+ engine.Dialect = &mysql{}
} else {
return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName))
}
+ engine.Filters = append(engine.Filters, &IdFilter{})
+ engine.Logger = os.Stdout
//engine.Pool = NewSimpleConnectPool()
//engine.Pool = NewNoneConnectPool()