From 5870dbaab0a85f2dc4bb4a6b6ea38c4857a500ab Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 8 May 2013 21:42:22 +0800 Subject: [PATCH] add sql execution support --- README.md | 147 ++++++++++++++++++++++----- README_CN.md | 281 +++++++++++++++++++++++++++++++++++---------------- engine.go | 44 +++++--- session.go | 252 ++++++++++++++++++--------------------------- statement.go | 31 +----- table.go | 2 +- xorm.go | 1 + xorm_test.go | 136 +++++++++++++++++++++++-- 8 files changed, 574 insertions(+), 320 deletions(-) diff --git a/README.md b/README.md index 5c245e5b..7c4f932f 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,10 @@ xorm is an ORM for Go. It lets you map Go structs to tables in a database. Right now, it interfaces with Mysql/SQLite. The goal however is to add support for PostgreSQL/DB2/MS ADODB/ODBC/Oracle in the future. -All in all, it's not entirely ready for advanced use yet, but it's getting there. +All in all, it's not entirely ready for product use yet, but it's getting there. Drivers for Go's sql package which support database/sql includes: -Mysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) - Mysql: [github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL) SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) @@ -55,19 +53,45 @@ then, insert an struct to table or you want to update this struct - user := User{Id:1, Name:"xlw"} - rows, err := engine.Update(&user) + user := User{Name:"xlw"} + rows, err := engine.Update(&user, &User{Id:1}) + // rows, err := engine.Where("id = ?", 1).Update(&user) 3.Fetch a single object by user var user = User{Id:27} - engine.Get(&user) + err := engine.Get(&user) var user = User{Name:"xlw"} - engine.Get(&user) + err := engine.Get(&user) + +4.Fetch multipe objects, use Find: + + var allusers []Userinfo + err := engine.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20 + var tenusers []Userinfo + err := engine.Limit(10).Find(&tenusers, &Userinfo{Name:"xlw"}) //Get All Name="xlw" limit 10 if omit offset the default is 0 + + var everyone []Userinfo + err := engine.Find(&everyone) + +5.Delete and Count: + + err := engine.Delete(&User{Id:1}) + + total, err := engine.Count(&User{Name:"xlw"}) + +##Origin Use +Of course, the basic usage is also provided. + + sql := "select * from userinfo" + results, err := engine.Query(sql) + + sql = "update userinfo set username=? where id=?" + res, err := engine.Exec(sql, "xiaolun", 1) ##Deep Use for deep use, you should create a session, this func will create a connection to db @@ -82,49 +106,120 @@ for deep use, you should create a session, this func will create a connection to 1.Fetch a single object by where - var user Userinfo session.Where("id=?", 27).Get(&user) var user2 Userinfo - session.Where(3).Get(&user2) // this is shorthand for the version above - - var user3 Userinfo session.Where("name = ?", "john").Get(&user3) // more complex query - var user4 Userinfo + var user3 Userinfo session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex 2.Fetch multiple objects - var allusers []Userinfo err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20 var tenusers []Userinfo - err := session.Where("id > ?", "3").Limit(10).Find(&tenusers) //Get id>3 limit 10 if omit offset the default is 0 + err := session.Limit(10).Find(&tenusers, &Userinfo{Name:"xlw"}) //Get All Name="xlw" limit 10 if omit offset the default is 0 var everyone []Userinfo - err := session.Find(&everyone) + err := session.Find(&everyone) + +3.Transaction + + // add Begin() before any action + session.Begin() + user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + session.Rollback() + return + } + user2 := Userinfo{Username: "yyy"} + _, err = session.Where("id = ?", 2).Update(&user2) + if err != nil { + session.Rollback() + return + } + + _, err = session.Delete(&user2) + if err != nil { + session.Rollback() + return + } + + // add Commit() after all actions + err = session.Commit() + if err != nil { + return + } + +4.Mixed Transaction + + // add Begin() before any action + session.Begin() + user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + session.Rollback() + return + } + user2 := Userinfo{Username: "yyy"} + _, err = session.Where("id = ?", 2).Update(&user2) + if err != nil { + session.Rollback() + return + } + + _, err = session.Exec("delete from userinfo where username = ?", user2.Username) + if err != nil { + session.Rollback() + return + } + + // add Commit() after all actions + err = session.Commit() + if err != nil { + return + } - -##***Mapping Rules*** +##Mapping Rules 1.Struct and struct's fields name should be Pascal style, and the table and column's name default is us -for example: + +For example: The structs Name 'UserInfo' will turn into the table name 'user_info', the same as the keyname. If the keyname is 'UserName' will turn into the select colum 'user_name' 2.You have two method to change the rule. One is implement your own Map interface according IMapper, you can find the interface in mapper.go and set it to engine.Mapper -another is use field tag, field tag support the below keywords: -* [name] column name -* pk the field is a primary key -* int(11)/varchar(50) column type -* autoincr auto incrment -* [not ]null if column can be null value -* unique unique -* \- this field is not map as a table column +another is use field tag, field tag support the below keywords: + + + + + + + + + + + + + + + + + + + + + + + +
namecolumn name
pkthe field is a primary key
int(11)/varchar(50)column type
autoincrauto incrment
[not ]nullif column can be null value
uniqueunique
-this field is not map as a table column
+ ##FAQ 1.How the xorm tag use both with json? diff --git a/README_CN.md b/README_CN.md index 06c60210..1b9ab420 100644 --- a/README_CN.md +++ b/README_CN.md @@ -1,127 +1,234 @@ -xorm -===== +# xorm +=========== [English](README.md) -xorm 是一个Go语言的ORM(对象关系模型). It lets you map Go structs to tables in a database. +xorm是一个Go语言的ORM库. 通过它可以简化对数据库的操作。 + +目前仅支持Mysql和SQLite,当然我们的目标是支持PostgreSQL/DB2/MS ADODB/ODBC/Oracle等等。 -Right now, it interfaces with Mysql/SQLite. The goal however is to add support for PostgreSQL/DB2/MS ADODB/ODBC/Oracle in the future. +但是,目前的版本还不可用于正式版本。 -All in all, it's not entirely ready for advanced use yet, but it's getting there. +目前支持的Go数据库驱动如下: -Drivers for Go's sql package which support database/sql includes: +Mysql: [github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL) -Mysql:[github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) +SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) -Mysql:[github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL) +## 安装 + + go get github.com/lunny/xorm -SQLite:[github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) +## 快速开始 -### Installing xorm - go get github.com/lunny/xorm +1.创建数据库引擎 (例如: mysql) -### Quick Start + engine := xorm.Create("mysql://root:123@localhost/test") -1. Create an database engine (for example: mysql) -```go -engine := xorm.Create("mysql://root:123@localhost/test") -``` +2.定义你的Struct -2. Define your struct -```go -type User struct { - Id int - Name string - Age int `xorm:"-"` -} -``` + type User struct { + Id int + Name string + Age int `xorm:"-"` + } -for Simple Task, just use engine's functions: -begin start, you should create a database and then we create the tables +对于简单的任务,可以只用engine一个对象就可以完成操作。 +首先,需要创建一个数据库,然后使用以下语句创建一个Struct对应的表。 + + + err := engine.CreateTables(&User{}) -```go -err := engine.CreateTables(&User{}) -``` -then, insert an struct to table +然后,可以将一个结构体作为一条记录插入到表中。 -```go -id, err := engine.Insert(&User{Name:"lunny"}) -``` -or you want to update this struct + id, err := engine.Insert(&User{Name:"lunny"}) -```go -user := User{Id:1, Name:"xlw"} -rows, err := engine.Update(&user) -``` -3. Fetch a single object by user -```go -var user = User{Id:27} -engine.Get(&user) +或者执行更新操作: -var user = User{Name:"xlw"} -engine.Get(&user) -``` -for deep use, you should create a session, this func will create a connection to db + user := User{Name:"xlw"} + rows, err := engine.Update(&user, &User{Id:1}) + // rows, err := engine.Where("id = ?", 1).Update(&user) -```go -session, err := engine.MakeSession() -defer session.Close() -if err != nil { - return -} -``` -1. Fetch a single object by where -```go -var user Userinfo -session.Where("id=?", 27).Get(&user) +3.获取单个对象,可以用Get方法: -var user2 Userinfo -session.Where(3).Get(&user2) // this is shorthand for the version above -var user3 Userinfo -session.Where("name = ?", "john").Get(&user3) // more complex query + var user = User{Id:27} + err := engine.Get(&user) -var user4 Userinfo -session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex -``` + var user = User{Name:"xlw"} + err := engine.Get(&user) + +4.获取多个对象,可以用Find方法: + + var allusers []Userinfo + err := engine.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20 -2. Fetch multiple objects + var tenusers []Userinfo + err := engine.Limit(10).Find(&tenusers, &Userinfo{Name:"xlw"}) //Get All Name="xlw" limit 10 if omit offset the default is 0 -```go -var allusers []Userinfo -err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20 + var everyone []Userinfo + err := engine.Find(&everyone) + +5.另外还有Delete和Count方法: + + err := engine.Delete(&User{Id:1}) + + total, err := engine.Count(&User{Name:"xlw"}) + +##Origin Use +当然,如果你想直接使用SQL语句进行操作,也是允许的。 + + sql := "select * from userinfo" + results, err := engine.Query(sql) + + sql = "update userinfo set username=? where id=?" + res, err := engine.Exec(sql, "xiaolun", 1) -var tenusers []Userinfo -err := session.Where("id > ?", "3").Limit(10).Find(&tenusers) //Get id>3 limit 10 if omit offset the default is 0 +##Deep Use +更高级的用法,我们必须要使用session对象,session对象在创建时会创建一个数据库连接。 -var everyone []Userinfo -err := session.Find(&everyone) -``` -###***About Map Rules*** -1. Struct and struct's fields name should be Pascal style, and the table and column's name default is us -for example: -The structs Name 'UserInfo' will turn into the table name 'user_info', the same as the keyname. -If the keyname is 'UserName' will turn into the select colum 'user_name' + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return + } -2. You have two method to change the rule. One is implement your own Map interface according IMapper, you can find the interface in mapper.go and set it to engine.Mapper -another is use field tag, field tag support the below keywords: -[name] column name -pk the field is a primary key -int(11)/varchar(50) column type -autoincr auto incrment -[not ]null if column can be null value -unique unique -- this field is not map as a table column +1.session对象同样也可以查询 + + var user Userinfo + session.Where("id=?", 27).Get(&user) + + var user2 Userinfo + session.Where("name = ?", "john").Get(&user3) // more complex query + + var user3 Userinfo + session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex + + +2.获取多个对象 + + var allusers []Userinfo + err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20 + + var tenusers []Userinfo + err := session.Limit(10).Find(&tenusers, &Userinfo{Name:"xlw"}) //Get All Name="xlw" limit 10 if omit offset the default is 0 + + var everyone []Userinfo + err := session.Find(&everyone) + +3.事务处理 + + // add Begin() before any action + session.Begin() + user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + session.Rollback() + return + } + user2 := Userinfo{Username: "yyy"} + _, err = session.Where("id = ?", 2).Update(&user2) + if err != nil { + session.Rollback() + return + } + + _, err = session.Delete(&user2) + if err != nil { + session.Rollback() + return + } + + // add Commit() after all actions + err = session.Commit() + if err != nil { + return + } + +4.混合型事务,这个事务中,既有直接的SQL语句,又有其它方法: + + // add Begin() before any action + session.Begin() + user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + session.Rollback() + return + } + user2 := Userinfo{Username: "yyy"} + _, err = session.Where("id = ?", 2).Update(&user2) + if err != nil { + session.Rollback() + return + } + + _, err = session.Exec("delete from userinfo where username = ?", user2.Username) + if err != nil { + session.Rollback() + return + } + + // add Commit() after all actions + err = session.Commit() + if err != nil { + return + } + +##Mapping Rules +1.Struct 和 Struct 的field名字应该为Pascal式命名,默认的映射规则将转换成用下划线连接的命名规则,这个映射是自动进行的,当然,你可以通过修改Engine或者Session的成员IMapper来改变它。 + +例如: + +结构体的名字UserInfo将会自动对应数据库中的名为user_info的表。 +UserInfo中的成员UserName将会自动对应名为user_name的字段。 + +2.当然你也可以改变这个规则,这有两种方法。一是实现你自己的IMapper,你可以在mapper.go中查看到这个借口。然后设置到 engine.Mapper,这将影响所有的Session,或者你可以设置到某一个session,那么只会影响到这个session对应的操作。 + +另外一种方法就通过Field Tag来进行改变,关于Field Tag请参考Go的语言文档,如下列出了Tag中可用的关键字及其对应的意义: + + + + + + + + + + + + + + + + + + + + + + + +
name当前field对应的字段的名称,可选
pk是否是Primary Key
int(11)/varchar(50)字段类型
autoincr是否是自增
[not ]null是否可以为空
unique是否是唯一
-这个Field将不进行字段映射
+ + +##FAQ +1.xorm的tag和json的tag如何同时起作用? + + 使用空格分开 + + type User struct { + User string `json:"user" orm:"user_id"` + } ## LICENSE diff --git a/engine.go b/engine.go index c2665fb3..e22dc9df 100644 --- a/engine.go +++ b/engine.go @@ -341,6 +341,24 @@ func (e *Engine) CreateAll() error { return err } +func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return nil, err + } + return session.Exec(sql, args...) +} + +func (engine *Engine) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return nil, err + } + return session.Query(sql, paramStr...) +} + func (engine *Engine) Insert(beans ...interface{}) (int64, error) { session, err := engine.MakeSession() defer session.Close() @@ -348,21 +366,19 @@ func (engine *Engine) Insert(beans ...interface{}) (int64, error) { return -1, err } defer engine.Statement.Init() - engine.Statement.Session = &session - session.SetStatement(&engine.Statement) + session.Statement = engine.Statement return session.Insert(beans...) } -func (engine *Engine) Update(bean interface{}) (int64, error) { +func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) { session, err := engine.MakeSession() defer session.Close() if err != nil { return -1, err } defer engine.Statement.Init() - engine.Statement.Session = &session - session.SetStatement(&engine.Statement) - return session.Update(bean) + session.Statement = engine.Statement + return session.Update(bean, condiBeans...) } func (engine *Engine) Delete(bean interface{}) (int64, error) { @@ -372,8 +388,7 @@ func (engine *Engine) Delete(bean interface{}) (int64, error) { return -1, err } defer engine.Statement.Init() - engine.Statement.Session = &session - session.SetStatement(&engine.Statement) + session.Statement = engine.Statement return session.Delete(bean) } @@ -384,21 +399,19 @@ func (engine *Engine) Get(bean interface{}) error { return err } defer engine.Statement.Init() - engine.Statement.Session = &session - session.SetStatement(&engine.Statement) + session.Statement = engine.Statement return session.Get(bean) } -func (engine *Engine) Find(beans interface{}) error { +func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error { session, err := engine.MakeSession() defer session.Close() if err != nil { return err } defer engine.Statement.Init() - engine.Statement.Session = &session - session.SetStatement(&engine.Statement) - return session.Find(beans) + session.Statement = engine.Statement + return session.Find(beans, condiBeans...) } func (engine *Engine) Count(bean interface{}) (int64, error) { @@ -408,7 +421,6 @@ func (engine *Engine) Count(bean interface{}) (int64, error) { return 0, err } defer engine.Statement.Init() - engine.Statement.Session = &session - session.SetStatement(&engine.Statement) + session.Statement = engine.Statement return session.Count(bean) } diff --git a/session.go b/session.go index 6aaa7f89..c2c9a22a 100644 --- a/session.go +++ b/session.go @@ -35,97 +35,75 @@ func Type2StructName(v reflect.Type) string { } type Session struct { - Db *sql.DB - Engine *Engine - Tx *sql.Tx - Statements []Statement - Mapper IMapper - IsAutoCommit bool - IsAutoRollback bool - CurStatementIdx int + Db *sql.DB + Engine *Engine + Tx *sql.Tx + Statement Statement + Mapper IMapper + IsAutoCommit bool } func (session *Session) Init() { - session.Statements = make([]Statement, 0) - session.CurStatementIdx = -1 + session.Statement = Statement{} session.IsAutoCommit = true - session.IsAutoRollback = false } func (session *Session) Close() { - rollbackfunc := func() { - if session.IsAutoRollback { - session.Rollback() - } - } - defer rollbackfunc() defer session.Db.Close() } -func (session *Session) CurrentStatement() *Statement { - if session.CurStatementIdx > -1 { - return &session.Statements[session.CurStatementIdx] - } - return nil -} - -func (session *Session) AutoStatement() *Statement { - if session.CurStatementIdx == -1 { - session.newStatement() - } - return session.CurrentStatement() -} - func (session *Session) Where(querystring string, args ...interface{}) *Session { - statement := session.AutoStatement() - statement.Where(querystring, args...) + session.Statement.Where(querystring, args...) return session } func (session *Session) Limit(limit int, start ...int) *Session { - statement := session.AutoStatement() - statement.Limit(limit, start...) + session.Statement.Limit(limit, start...) return session } func (session *Session) OrderBy(order string) *Session { - statement := session.AutoStatement() - statement.OrderBy(order) + session.Statement.OrderBy(order) return session } //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (session *Session) Join(join_operator, tablename, condition string) *Session { - statement := session.AutoStatement() - statement.Join(join_operator, tablename, condition) + session.Statement.Join(join_operator, tablename, condition) return session } func (session *Session) GroupBy(keys string) *Session { - statement := session.AutoStatement() - statement.GroupBy(keys) + session.Statement.GroupBy(keys) return session } func (session *Session) Having(conditions string) *Session { - statement := session.AutoStatement() - statement.Having(conditions) + session.Statement.Having(conditions) return session } func (session *Session) Begin() error { session.IsAutoCommit = false - session.IsAutoRollback = true tx, err := session.Db.Begin() session.Tx = tx + if session.Engine.ShowSQL { + fmt.Println("BEGIN TRANSACTION") + } return err } func (session *Session) Rollback() error { + if session.Engine.ShowSQL { + fmt.Println("ROLL BACK") + } return session.Tx.Rollback() } func (session *Session) Commit() error { + if session.Engine.ShowSQL { + fmt.Println("COMMIT") + } return session.Tx.Commit() } @@ -133,27 +111,10 @@ func (session *Session) TableName(bean interface{}) string { return session.Mapper.Obj2Table(StructName(bean)) } -func (session *Session) SetStatement(statement *Statement) { - if session.CurStatementIdx == len(session.Statements)-1 { - session.Statements = append(session.Statements, *statement) - } else { - session.Statements[session.CurStatementIdx+1] = *statement - } - session.CurStatementIdx = session.CurStatementIdx + 1 -} - -func (session *Session) newStatement() { - if session.CurStatementIdx == len(session.Statements)-1 { - state := Statement{Session: session} - state.Init() - session.Statements = append(session.Statements, state) - } - session.CurStatementIdx = session.CurStatementIdx + 1 -} - -func (session *Session) clearStatment() { - session.Statements[session.CurStatementIdx].Init() - session.CurStatementIdx = session.CurStatementIdx - 1 +func (session *Session) Bean2Table(bean interface{}) *Table { + tablName := session.TableName(bean) + table := session.Engine.Tables[tablName] + return &table } func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { @@ -162,8 +123,7 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b return errors.New("expected a pointer to a struct") } - tablName := session.TableName(obj) - table := session.Engine.Tables[tablName] + table := session.Bean2Table(obj) for key, data := range objMap { structField := dataStruct.FieldByName(table.Columns[key].FieldName) @@ -231,8 +191,8 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b } //Execute sql -func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) { - rs, err := session.Db.Prepare(finalQueryString) +func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, error) { + rs, err := session.Db.Prepare(sql) if err != nil { return nil, err } @@ -245,18 +205,30 @@ func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql. return res, nil } -func (session *Session) Get(bean interface{}) error { - statement := session.AutoStatement() - statement.Limit(1) - tableName := session.TableName(bean) - table := session.Engine.Tables[tableName] - statement.Table = &table +func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { + if session.Engine.ShowSQL { + fmt.Println(sql) + } + if session.IsAutoCommit { + return session.innerExec(sql, args...) + } + return session.Tx.Exec(sql, args...) +} - colNames, args := session.BuildConditions(&table, bean) +func (session *Session) Get(bean interface{}) error { + statement := session.Statement + defer session.Statement.Init() + statement.Limit(1) + table := session.Bean2Table(bean) + statement.Table = table + + colNames, args := session.BuildConditions(table, bean) statement.ColumnStr = strings.Join(colNames, " and ") statement.BeanArgs = args - resultsSlice, err := session.FindMap(statement) + sql := statement.generateSql() + resultsSlice, err := session.Query(sql, append(statement.Params, statement.BeanArgs...)...) + if err != nil { return err } @@ -275,16 +247,16 @@ func (session *Session) Get(bean interface{}) error { } func (session *Session) Count(bean interface{}) (int64, error) { - statement := session.AutoStatement() - tableName := session.TableName(bean) - table := session.Engine.Tables[tableName] - statement.Table = &table + statement := session.Statement + defer session.Statement.Init() + table := session.Bean2Table(bean) + statement.Table = table - colNames, args := session.BuildConditions(&table, bean) + colNames, args := session.BuildConditions(table, bean) statement.ColumnStr = strings.Join(colNames, " and ") statement.BeanArgs = args - resultsSlice, err := session.SQL2Map(statement.genCountSql(), append(statement.Params, statement.BeanArgs...)) + resultsSlice, err := session.Query(statement.genCountSql(), append(statement.Params, statement.BeanArgs...)...) if err != nil { return 0, err } @@ -298,9 +270,9 @@ func (session *Session) Count(bean interface{}) (int64, error) { return int64(total), err } -func (session *Session) Find(rowsSlicePtr interface{}) error { - statement := session.AutoStatement() - +func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) error { + statement := session.Statement + defer session.Statement.Init() sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice { return errors.New("needs a pointer to a slice") @@ -312,7 +284,15 @@ func (session *Session) Find(rowsSlicePtr interface{}) error { table := session.Engine.Tables[tableName] statement.Table = &table - resultsSlice, err := session.FindMap(statement) + if len(condiBean) > 0 { + colNames, args := session.BuildConditions(&table, condiBean[0]) + statement.ColumnStr = strings.Join(colNames, " and ") + statement.BeanArgs = args + } + + sql := statement.generateSql() + resultsSlice, err := session.Query(sql, append(statement.Params, statement.BeanArgs...)...) + if err != nil { return err } @@ -328,7 +308,7 @@ func (session *Session) Find(rowsSlicePtr interface{}) error { return nil } -func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSlice []map[string][]byte, err error) { +func (session *Session) Query(sqls string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { if session.Engine.ShowSQL { fmt.Println(sqls) } @@ -397,11 +377,6 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli return resultsSlice, nil } -func (session *Session) FindMap(statement *Statement) (resultsSlice []map[string][]byte, err error) { - sqls := statement.generateSql() - return session.SQL2Map(sqls, append(statement.Params, statement.BeanArgs...)) -} - func (session *Session) Insert(beans ...interface{}) (int64, error) { var lastId int64 = -1 for _, bean := range beans { @@ -414,8 +389,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { } func (session *Session) InsertOne(bean interface{}) (int64, error) { - tableName := session.TableName(bean) - table := session.Engine.Tables[tableName] + table := session.Bean2Table(bean) colNames := make([]string, 0) colPlaces := make([]string, 0) @@ -423,10 +397,8 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { for _, col := range table.Columns { fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) val := fieldValue.Interface() - if col.IsAutoIncrement { - if fieldValue.Int() == 0 { - continue - } + if col.IsAutoIncrement && fieldValue.Int() == 0 { + continue } args = append(args, val) colNames = append(colNames, col.Name) @@ -435,23 +407,12 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { statement := fmt.Sprintf("INSERT INTO %v%v%v (%v) VALUES (%v)", session.Engine.QuoteIdentifier, - tableName, + table.Name, session.Engine.QuoteIdentifier, strings.Join(colNames, ", "), strings.Join(colPlaces, ", ")) - if session.Engine.ShowSQL { - fmt.Println(statement) - } - - var res sql.Result - var err error - - if session.IsAutoCommit { - res, err = session.Exec(statement, args...) - } else { - res, err = session.Tx.Exec(statement, args...) - } + res, err := session.Exec(statement, args...) if err != nil { return -1, err } @@ -461,6 +422,7 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { if err != nil { return -1, err } + return id, nil } @@ -497,46 +459,42 @@ func (session *Session) BuildConditions(table *Table, bean interface{}) ([]strin return colNames, args } -func (session *Session) Update(bean interface{}) (int64, error) { - tableName := session.TableName(bean) - table := session.Engine.Tables[tableName] - colNames, args := session.BuildConditions(&table, bean) +func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { + table := session.Bean2Table(bean) + colNames, args := session.BuildConditions(table, bean) + var condiColNames []string + var condiArgs []interface{} + + if len(condiBean) > 0 { + condiColNames, condiArgs = session.BuildConditions(table, condiBean[0]) + } var condition = "" - st := session.AutoStatement() - defer session.clearStatment() + st := session.Statement + defer session.Statement.Init() if st.WhereStr != "" { condition = fmt.Sprintf("WHERE %v", st.WhereStr) } if condition == "" { - fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(table.PKColumn().FieldName) - if fieldValue.Int() != 0 { - condition = fmt.Sprintf("WHERE %v = ?", table.PKColumn().Name) - args = append(args, fieldValue.Interface()) + if len(condiColNames) > 0 { + condition = fmt.Sprintf("WHERE %v ", strings.Join(condiColNames, " and ")) + } + } else { + if len(condiColNames) > 0 { + condition = fmt.Sprintf("%v and %v", condition, strings.Join(condiColNames, " and ")) } } statement := fmt.Sprintf("UPDATE %v%v%v SET %v %v", session.Engine.QuoteIdentifier, - tableName, + table.Name, session.Engine.QuoteIdentifier, strings.Join(colNames, ", "), condition) - if session.Engine.ShowSQL { - fmt.Println(statement) - } - - var res sql.Result - var err error - if session.IsAutoCommit { - fmt.Println("session.Exec") - res, err = session.Exec(statement, append(args, st.Params...)...) - } else { - fmt.Println("tx.Exec") - res, err = session.Tx.Exec(statement, append(args, st.Params...)...) - } + eargs := append(append(args, st.Params...), condiArgs...) + res, err := session.Exec(statement, eargs...) if err != nil { return -1, err } @@ -550,13 +508,12 @@ func (session *Session) Update(bean interface{}) (int64, error) { } func (session *Session) Delete(bean interface{}) (int64, error) { - tableName := session.TableName(bean) - table := session.Engine.Tables[tableName] - colNames, args := session.BuildConditions(&table, bean) + table := session.Bean2Table(bean) + colNames, args := session.BuildConditions(table, bean) var condition = "" - st := session.AutoStatement() - defer session.clearStatment() + st := session.Statement + defer session.Statement.Init() if st.WhereStr != "" { condition = fmt.Sprintf("WHERE %v", st.WhereStr) if len(colNames) > 0 { @@ -569,21 +526,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) { statement := fmt.Sprintf("DELETE FROM %v%v%v %v", session.Engine.QuoteIdentifier, - tableName, + table.Name, session.Engine.QuoteIdentifier, condition) - if session.Engine.ShowSQL { - fmt.Println(statement) - } + res, err := session.Exec(statement, append(st.Params, args...)...) - var res sql.Result - var err error - if session.IsAutoCommit { - res, err = session.Exec(statement, append(st.Params, args...)...) - } else { - res, err = session.Tx.Exec(statement, append(st.Params, args...)...) - } if err != nil { return -1, err } diff --git a/statement.go b/statement.go index f683eba1..a771cc25 100644 --- a/statement.go +++ b/statement.go @@ -6,7 +6,7 @@ import ( type Statement struct { Table *Table - Session *Session + Engine *Engine Start int LimitN int WhereStr string @@ -21,7 +21,6 @@ type Statement struct { func (statement *Statement) Init() { statement.Table = nil - statement.Session = nil statement.Start = 0 statement.LimitN = 0 statement.WhereStr = "" @@ -76,13 +75,8 @@ func (statement Statement) genCountSql() string { return statement.genSelectSql("count(*) as total") } -func (statement Statement) genExecSql() string { - return "" -} - func (statement Statement) genSelectSql(columnStr string) (a string) { - session := statement.Session - if session.Engine.Protocol == "mssql" { + if statement.Engine.Protocol == "mssql" { if statement.Start > 0 { a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v", statement.Table.PKColumn().Name, @@ -171,24 +165,3 @@ func (statement Statement) genSelectSql(columnStr string) (a string) { } return } - -/*func (statement *Statement) genInsertSQL() string { - table = statement.Table - colNames := make([]string, len(table.Columns)) - for idx, col := range table.Columns { - if col.Name == "" { - continue - } - colNames[idx] = col.Name - } - return strings.Join(colNames, ", ") - - colNames := make([]string, len(table.Columns)) - for idx, col := range table.Columns { - if col.Name == "" { - continue - } - colNames[idx] = "?" - } - strings.Join(colNames, ", ") -}*/ diff --git a/table.go b/table.go index 5503ed99..ed18d177 100644 --- a/table.go +++ b/table.go @@ -72,7 +72,7 @@ type Table struct { func (table *Table) ColumnStr() string { colNames := make([]string, 0) for _, col := range table.Columns { - colNames = append(colNames, col.Name) + colNames = append(colNames, table.Name+"."+col.Name) } return strings.Join(colNames, ", ") } diff --git a/xorm.go b/xorm.go index 2acebd95..d79a0d42 100644 --- a/xorm.go +++ b/xorm.go @@ -12,6 +12,7 @@ func Create(schema string) Engine { engine := Engine{} engine.Mapper = SnakeMapper{} engine.Tables = make(map[string]Table) + engine.Statement.Engine = &engine l := strings.Split(schema, "://") if len(l) == 2 { engine.Protocol = l[0] diff --git a/xorm_test.go b/xorm_test.go index a0f181a5..0a3b9131 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -33,6 +33,12 @@ type Userinfo struct { Created time.Time } +type Userdetail struct { + Uid int `xorm:"id pk not null"` + Intro string + Profile string +} + var engine xorm.Engine func directCreateTable(t *testing.T) { @@ -48,7 +54,7 @@ func mapper(t *testing.T) { t.Error(err) } - err = engine.Map(&Userinfo{}) + err = engine.Map(&Userinfo{}, &Userdetail{}) if err != nil { t.Error(err) } @@ -72,6 +78,24 @@ func insert(t *testing.T) { } } +func query(t *testing.T) { + sql := "select * from userinfo" + results, err := engine.Query(sql) + if err != nil { + t.Error(err) + } + fmt.Println(results) +} + +func exec(t *testing.T) { + sql := "update userinfo set username=? where id=?" + res, err := engine.Exec(sql, "xiaolun", 1) + if err != nil { + t.Error(err) + } + fmt.Println(res) +} + func insertAutoIncr(t *testing.T) { // auto increment insert user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now()} @@ -90,10 +114,26 @@ func insertMulti(t *testing.T) { } } +func insertTwoTable(t *testing.T) { + userinfo := Userinfo{Username: "xlw3", Departname: "dev", Alias: "lunny4", Created: time.Now()} + uid, err := engine.Insert(&userinfo) + if err != nil { + t.Error(err) + return + } + + userdetail := Userdetail{Uid: int(uid), Intro: "I'm a very beautiful women.", Profile: "sfsaf"} + _, err = engine.Insert(&userdetail) + if err != nil { + t.Error(err) + } +} + func update(t *testing.T) { // update by id - user := Userinfo{Uid: 1, Username: "xxx"} - _, err := engine.Update(&user) + user := Userinfo{Username: "xxx"} + condiUser := Userinfo{Uid: 1} + _, err := engine.Update(&user, &condiUser) if err != nil { t.Error(err) } @@ -163,6 +203,23 @@ func order(t *testing.T) { fmt.Println(users) } +func join(t *testing.T) { + users := make([]Userinfo, 0) + err := engine.Join("LEFT", "userdetail", "userinfo.id=userdetail.id").Find(&users) + if err != nil { + t.Error(err) + } +} + +func having(t *testing.T) { + users := make([]Userinfo, 0) + err := engine.GroupBy("username").Having("username='xlw'").Find(&users) + if err != nil { + t.Error(err) + } + fmt.Println(users) +} + func transaction(t *testing.T) { counter := func() { total, err := engine.Count(&Userinfo{}) @@ -173,7 +230,7 @@ func transaction(t *testing.T) { } counter() - + defer counter() session, err := engine.MakeSession() defer session.Close() if err != nil { @@ -181,25 +238,76 @@ func transaction(t *testing.T) { return } - defer counter() - session.Begin() - session.IsAutoRollback = true + //session.IsAutoRollback = false user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} _, err = session.Insert(&user1) if err != nil { + session.Rollback() t.Error(err) return } user2 := Userinfo{Username: "yyy"} - _, err = session.Where("id = ?", 2).Update(&user2) + _, err = session.Where("uid = ?", 0).Update(&user2) if err != nil { - t.Error(err) + session.Rollback() + fmt.Println(err) + //t.Error(err) return } _, err = session.Delete(&user2) if err != nil { + session.Rollback() + t.Error(err) + return + } + + err = session.Commit() + if err != nil { + t.Error(err) + return + } +} + +func combineTransaction(t *testing.T) { + counter := func() { + total, err := engine.Count(&Userinfo{}) + if err != nil { + t.Error(err) + } + fmt.Printf("----now total %v records\n", total) + } + + counter() + defer counter() + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + t.Error(err) + return + } + + session.Begin() + //session.IsAutoRollback = false + user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + session.Rollback() + t.Error(err) + return + } + user2 := Userinfo{Username: "zzz"} + _, err = session.Where("id = ?", 0).Update(&user2) + if err != nil { + session.Rollback() + t.Error(err) + return + } + + _, err = session.Exec("delete from userinfo where username = ?", user2.Username) + if err != nil { + session.Rollback() t.Error(err) return } @@ -218,6 +326,8 @@ func TestMysql(t *testing.T) { directCreateTable(t) mapper(t) insert(t) + query(t) + exec(t) insertAutoIncr(t) insertMulti(t) update(t) @@ -228,7 +338,10 @@ func TestMysql(t *testing.T) { where(t) limit(t) order(t) + join(t) + having(t) transaction(t) + combineTransaction(t) } func TestSqlite(t *testing.T) { @@ -238,6 +351,8 @@ func TestSqlite(t *testing.T) { directCreateTable(t) mapper(t) insert(t) + query(t) + exec(t) insertAutoIncr(t) insertMulti(t) update(t) @@ -248,5 +363,8 @@ func TestSqlite(t *testing.T) { where(t) limit(t) order(t) + join(t) + having(t) transaction(t) + combineTransaction(t) }