fix mysql test and sqlite test

This commit is contained in:
Lunny Xiao 2013-05-06 16:01:17 +08:00
parent 95927253d4
commit 74e0e3b175
7 changed files with 580 additions and 392 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
*.o *.o
*.a *.a
*.so *.so
*.db
# Folders # Folders
_obj _obj

View File

@ -1,5 +1,5 @@
xorm # xorm
===== ===========
[中文](README_CN.md) [中文](README_CN.md)
@ -17,71 +17,72 @@ Mysql:[github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL)
SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
### Installing xorm ## Installing xorm
go get github.com/lunny/xorm go get github.com/lunny/xorm
### Quick Start ## Quick Start
1.Create an database engine (for example: mysql) 1.Create a database engine (for example: mysql)
```go
engine := xorm.Create("mysql://root:123@localhost/test") engine := xorm.Create("mysql://root:123@localhost/test")
```
2.Define your struct 2.Define your struct
```go
type User struct { type User struct {
Id int Id int
Name string Name string
Age int `xorm:"-"` Age int `xorm:"-"`
} }
```
for Simple Task, just use engine's functions: for Simple Task, just use engine's functions:
begin start, you should create a database and then we create the tables before beginning, you should create a database in mysql and then we will create the tables.
```go
err := engine.CreateTables(&User{}) err := engine.CreateTables(&User{})
```
then, insert an struct to table then, insert an struct to table
```go
id, err := engine.Insert(&User{Name:"lunny"}) id, err := engine.Insert(&User{Name:"lunny"})
```
or you want to update this struct or you want to update this struct
```go
user := User{Id:1, Name:"xlw"} user := User{Id:1, Name:"xlw"}
rows, err := engine.Update(&user) rows, err := engine.Update(&user)
```
3.Fetch a single object by user 3.Fetch a single object by user
```go
var user = User{Id:27} var user = User{Id:27}
engine.Get(&user) engine.Get(&user)
var user = User{Name:"xlw"} var user = User{Name:"xlw"}
engine.Get(&user) engine.Get(&user)
```
##Deep Use
for deep use, you should create a session, this func will create a connection to db for deep use, you should create a session, this func will create a connection to db
```go
session, err := engine.MakeSession() session, err := engine.MakeSession()
defer session.Close() defer session.Close()
if err != nil { if err != nil {
return return
} }
```
1.Fetch a single object by where 1.Fetch a single object by where
```go
var user Userinfo var user Userinfo
session.Where("id=?", 27).Get(&user) session.Where("id=?", 27).Get(&user)
@ -93,11 +94,11 @@ session.Where("name = ?", "john").Get(&user3) // more complex query
var user4 Userinfo var user4 Userinfo
session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex
```
2.Fetch multiple objects 2.Fetch multiple objects
```go
var allusers []Userinfo var allusers []Userinfo
err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20 err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20
@ -106,9 +107,9 @@ err := session.Where("id > ?", "3").Limit(10).Find(&tenusers) //Get id>3 limit 1
var everyone []Userinfo var everyone []Userinfo
err := session.Find(&everyone) err := session.Find(&everyone)
```
###***About Map Rules***
##***Mapping Rules***
1.Struct and struct's fields name should be Pascal style, and the table and column's name default is us 1.Struct and struct's fields name should be Pascal style, and the table and column's name default is us
for example: for example:
The structs Name 'UserInfo' will turn into the table name 'user_info', the same as the keyname. The structs Name 'UserInfo' will turn into the table name 'user_info', the same as the keyname.
@ -117,13 +118,22 @@ If the keyname is 'UserName' will turn into the select colum 'user_name'
2.You have two method to change the rule. One is implement your own Map interface according IMapper, you can find the interface in mapper.go and set it to engine.Mapper 2.You have two method to change the rule. One is implement your own Map interface according IMapper, you can find the interface in mapper.go and set it to engine.Mapper
another is use field tag, field tag support the below keywords: another is use field tag, field tag support the below keywords:
[name] column name * [name] column name
pk the field is a primary key * pk the field is a primary key
int(11)/varchar(50) column type * int(11)/varchar(50) column type
autoincr auto incrment * autoincr auto incrment
[not ]null if column can be null value * [not ]null if column can be null value
unique unique * unique unique
- this field is not map as a table column * \- this field is not map as a table column
##FAQ
1.How the xorm tag use both with json?
use space
type User struct {
User string `json:"user" orm:"user_id"`
}
## LICENSE ## LICENSE

207
engine.go
View File

@ -6,22 +6,6 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time"
)
type SQLType struct {
Name string
DefaultLength int
}
var (
Int = SQLType{"int", 11}
Char = SQLType{"char", 1}
Varchar = SQLType{"varchar", 50}
Date = SQLType{"date", 24}
Decimal = SQLType{"decimal", 26}
Float = SQLType{"float", 31}
Double = SQLType{"double", 31}
) )
const ( const (
@ -32,51 +16,6 @@ const (
MYMYSQL = "mymysql" MYMYSQL = "mymysql"
) )
type Column struct {
Name string
FieldName string
SQLType SQLType
Length int
Nullable bool
Default string
IsUnique bool
IsPrimaryKey bool
AutoIncrement bool
}
type Table struct {
Name string
Type reflect.Type
Columns map[string]Column
PrimaryKey string
}
func (table *Table) ColumnStr() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
colNames = append(colNames, col.Name)
}
return strings.Join(colNames, ", ")
}
func (table *Table) PlaceHolders() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
colNames = append(colNames, "?")
}
return strings.Join(colNames, ", ")
}
func (table *Table) PKColumn() Column {
return table.Columns[table.PrimaryKey]
}
type Engine struct { type Engine struct {
Mapper IMapper Mapper IMapper
Protocol string Protocol string
@ -91,21 +30,27 @@ type Engine struct {
AutoIncrement string AutoIncrement string
ShowSQL bool ShowSQL bool
QuoteIdentifier string QuoteIdentifier string
Statement Statement
}
func Type(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
return reflect.TypeOf(sliceValue.Interface())
} }
func (e *Engine) OpenDB() (db *sql.DB, err error) { func (e *Engine) OpenDB() (db *sql.DB, err error) {
db = nil db = nil
err = nil err = nil
if e.Protocol == "sqlite" { if e.Protocol == SQLITE {
// 'sqlite:///foo.db' // 'sqlite:///foo.db'
db, err = sql.Open("sqlite3", e.Others) db, err = sql.Open("sqlite3", e.Others)
// 'sqlite:///:memory:' // 'sqlite:///:memory:'
} else if e.Protocol == "mysql" { } else if e.Protocol == MYSQL {
// 'mysql://<username>:<passwd>@<host>/<dbname>?charset=<encoding>' // 'mysql://<username>:<passwd>@<host>/<dbname>?charset=<encoding>'
connstr := strings.Join([]string{e.UserName, ":", connstr := strings.Join([]string{e.UserName, ":",
e.Password, "@tcp(", e.Host, ":3306)/", e.DBName, "?charset=", e.Charset}, "") e.Password, "@tcp(", e.Host, ":3306)/", e.DBName, "?charset=", e.Charset}, "")
db, err = sql.Open(e.Protocol, connstr) db, err = sql.Open(e.Protocol, connstr)
} else if e.Protocol == "mymysql" { } else if e.Protocol == MYMYSQL {
// DBNAME/USER/PASSWD // DBNAME/USER/PASSWD
connstr := strings.Join([]string{e.DBName, e.UserName, e.Password}, "/") connstr := strings.Join([]string{e.DBName, e.UserName, e.Password}, "/")
db, err = sql.Open(e.Protocol, connstr) db, err = sql.Open(e.Protocol, connstr)
@ -123,34 +68,71 @@ func (engine *Engine) MakeSession() (session Session, err error) {
if err != nil { if err != nil {
return Session{}, err return Session{}, err
} }
if engine.Protocol == "pgsql" { if engine.Protocol == PQSQL {
engine.QuoteIdentifier = "\"" engine.QuoteIdentifier = "\""
session = Session{Engine: engine, Db: db, ParamIteration: 1} session = Session{Engine: engine, Db: db}
} else if engine.Protocol == "mssql" { } else if engine.Protocol == MSSQL {
engine.QuoteIdentifier = "" engine.QuoteIdentifier = ""
session = Session{Engine: engine, Db: db, ParamIteration: 1} session = Session{Engine: engine, Db: db}
} else { } else {
engine.QuoteIdentifier = "`" engine.QuoteIdentifier = "`"
session = Session{Engine: engine, Db: db, ParamIteration: 1} session = Session{Engine: engine, Db: db}
} }
session.Mapper = engine.Mapper session.Mapper = engine.Mapper
session.Init() session.Init()
return return
} }
func (sqlType SQLType) genSQL(length int) string { func (engine *Engine) Where(querystring string, args ...interface{}) *Engine {
if sqlType == Date { engine.Statement.Where(querystring, args...)
return " datetime " return engine
} }
return sqlType.Name + "(" + strconv.Itoa(length) + ")"
func (engine *Engine) Limit(limit int, start ...int) *Engine {
engine.Statement.Limit(limit, start...)
return engine
}
func (engine *Engine) OrderBy(order string) *Engine {
engine.Statement.OrderBy(order)
return engine
}
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (engine *Engine) Join(join_operator, tablename, condition string) *Engine {
engine.Statement.Join(join_operator, tablename, condition)
return engine
}
func (engine *Engine) GroupBy(keys string) *Engine {
engine.Statement.GroupBy(keys)
return engine
}
func (engine *Engine) Having(conditions string) *Engine {
engine.Statement.Having(conditions)
return engine
}
func (e *Engine) genColumnStr(col *Column) string {
sql := "`" + col.Name + "` "
if col.SQLType == Date {
sql += " datetime "
} else {
if e.Protocol == SQLITE && col.IsPrimaryKey {
sql += "integer"
} else {
sql += col.SQLType.Name
}
if e.Protocol != SQLITE {
if col.SQLType != Decimal {
sql += "(" + strconv.Itoa(col.Length) + ")"
} else {
sql += "(" + strconv.Itoa(col.Length) + "," + strconv.Itoa(col.Length2) + ")"
}
}
} }
func (e *Engine) genCreateSQL(table *Table) string {
sql := "CREATE TABLE IF NOT EXISTS `" + table.Name + "` ("
//fmt.Println(session.Mapper.Obj2Table(session.PrimaryKey))
for _, col := range table.Columns {
if col.Name != "" {
sql += "`" + col.Name + "` " + col.SQLType.genSQL(col.Length) + " "
if col.Nullable { if col.Nullable {
sql += " NULL " sql += " NULL "
} else { } else {
@ -160,14 +142,21 @@ func (e *Engine) genCreateSQL(table *Table) string {
if col.IsPrimaryKey { if col.IsPrimaryKey {
sql += "PRIMARY KEY " sql += "PRIMARY KEY "
} }
if col.AutoIncrement { if col.IsAutoIncrement {
sql += e.AutoIncrement + " " sql += e.AutoIncrement + " "
} }
if col.IsUnique { if col.IsUnique {
sql += "Unique " sql += "Unique "
} }
sql += "," return sql
} }
func (e *Engine) genCreateSQL(table *Table) string {
sql := "CREATE TABLE IF NOT EXISTS `" + table.Name + "` ("
//fmt.Println(session.Mapper.Obj2Table(session.PrimaryKey))
for _, col := range table.Columns {
sql += e.genColumnStr(&col)
sql += ","
} }
sql = sql[:len(sql)-2] + ");" sql = sql[:len(sql)-2] + ");"
if e.ShowSQL { if e.ShowSQL {
@ -184,27 +173,6 @@ func (e *Engine) genDropSQL(table *Table) string {
return sql return sql
} }
func Type(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
return reflect.TypeOf(sliceValue.Interface())
}
func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k {
case reflect.Int, reflect.Int32, reflect.Int64:
st = Int
case reflect.String:
st = Varchar
case reflect.Struct:
if t == reflect.TypeOf(time.Time{}) {
st = Date
}
default:
st = Varchar
}
return
}
/* /*
map an object into a table object map an object into a table object
*/ */
@ -242,7 +210,7 @@ func (engine *Engine) MapType(t reflect.Type) Table {
case "null": case "null":
col.Nullable = (tags[j-1] != "not") col.Nullable = (tags[j-1] != "not")
case "autoincr": case "autoincr":
col.AutoIncrement = true col.IsAutoIncrement = true
case "default": case "default":
col.Default = tags[j+1] col.Default = tags[j+1]
case "int": case "int":
@ -255,6 +223,7 @@ func (engine *Engine) MapType(t reflect.Type) Table {
if col.SQLType.Name == "" { if col.SQLType.Name == "" {
col.SQLType = Type2SQLType(fieldType) col.SQLType = Type2SQLType(fieldType)
col.Length = col.SQLType.DefaultLength col.Length = col.SQLType.DefaultLength
col.Length2 = col.SQLType.DefaultLength2
} }
if col.Name == "" { if col.Name == "" {
@ -266,7 +235,7 @@ func (engine *Engine) MapType(t reflect.Type) Table {
if col.Name == "" { if col.Name == "" {
sqlType := Type2SQLType(fieldType) sqlType := Type2SQLType(fieldType)
col = Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, col = Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType,
sqlType.DefaultLength, true, "", false, false, false} sqlType.DefaultLength, sqlType.DefaultLength2, true, "", false, false, false}
} }
table.Columns[col.Name] = col table.Columns[col.Name] = col
if strings.ToLower(t.Field(i).Name) == "id" { if strings.ToLower(t.Field(i).Name) == "id" {
@ -278,7 +247,7 @@ func (engine *Engine) MapType(t reflect.Type) Table {
if pkstr != "" { if pkstr != "" {
col := table.Columns[pkstr] col := table.Columns[pkstr]
col.IsPrimaryKey = true col.IsPrimaryKey = true
col.AutoIncrement = true col.IsAutoIncrement = true
col.Nullable = false col.Nullable = false
col.Length = Int.DefaultLength col.Length = Int.DefaultLength
table.PrimaryKey = col.Name table.PrimaryKey = col.Name
@ -292,7 +261,6 @@ func (engine *Engine) MapType(t reflect.Type) Table {
func (engine *Engine) Map(beans ...interface{}) (e error) { func (engine *Engine) Map(beans ...interface{}) (e error) {
for _, bean := range beans { for _, bean := range beans {
//t := getBeanType(bean)
tableName := engine.Mapper.Obj2Table(StructName(bean)) tableName := engine.Mapper.Obj2Table(StructName(bean))
if _, ok := engine.Tables[tableName]; !ok { if _, ok := engine.Tables[tableName]; !ok {
table := engine.MapOne(bean) table := engine.MapOne(bean)
@ -304,7 +272,6 @@ func (engine *Engine) Map(beans ...interface{}) (e error) {
func (engine *Engine) UnMap(beans ...interface{}) (e error) { func (engine *Engine) UnMap(beans ...interface{}) (e error) {
for _, bean := range beans { for _, bean := range beans {
//t := getBeanType(bean)
tableName := engine.Mapper.Obj2Table(StructName(bean)) tableName := engine.Mapper.Obj2Table(StructName(bean))
if _, ok := engine.Tables[tableName]; ok { if _, ok := engine.Tables[tableName]; ok {
delete(engine.Tables, tableName) delete(engine.Tables, tableName)
@ -380,7 +347,9 @@ func (engine *Engine) Insert(beans ...interface{}) (int64, error) {
if err != nil { if err != nil {
return -1, err return -1, err
} }
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Insert(beans...) return session.Insert(beans...)
} }
@ -390,7 +359,9 @@ func (engine *Engine) Update(bean interface{}) (int64, error) {
if err != nil { if err != nil {
return -1, err return -1, err
} }
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Update(bean) return session.Update(bean)
} }
@ -400,7 +371,9 @@ func (engine *Engine) Delete(bean interface{}) (int64, error) {
if err != nil { if err != nil {
return -1, err return -1, err
} }
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Delete(bean) return session.Delete(bean)
} }
@ -410,7 +383,9 @@ func (engine *Engine) Get(bean interface{}) error {
if err != nil { if err != nil {
return err return err
} }
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Get(bean) return session.Get(bean)
} }
@ -420,7 +395,9 @@ func (engine *Engine) Find(beans interface{}) error {
if err != nil { if err != nil {
return err return err
} }
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Find(beans) return session.Find(beans)
} }
@ -430,6 +407,8 @@ func (engine *Engine) Count(bean interface{}) (int64, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Count(bean) return session.Count(bean)
} }

View File

@ -37,30 +37,28 @@ func Type2StructName(v reflect.Type) string {
type Session struct { type Session struct {
Db *sql.DB Db *sql.DB
Engine *Engine Engine *Engine
Tx *sql.Tx
Statements []Statement Statements []Statement
Mapper IMapper Mapper IMapper
AutoCommit bool IsAutoCommit bool
ParamIteration int IsAutoRollback bool
CurStatementIdx int CurStatementIdx int
} }
func (session *Session) Init() { func (session *Session) Init() {
session.Statements = make([]Statement, 0) session.Statements = make([]Statement, 0)
/*session.Statement.TableName = ""
session.Statement.LimitStr = 0
session.Statement.OffsetStr = 0
session.Statement.WhereStr = ""
session.Statement.ParamStr = make([]interface{}, 0)
session.Statement.OrderStr = ""
session.Statement.JoinStr = ""
session.Statement.GroupByStr = ""
session.Statement.HavingStr = ""*/
session.CurStatementIdx = -1 session.CurStatementIdx = -1
session.IsAutoCommit = true
session.ParamIteration = 1 session.IsAutoRollback = false
} }
func (session *Session) Close() { func (session *Session) Close() {
rollbackfunc := func() {
if session.IsAutoRollback {
session.Rollback()
}
}
defer rollbackfunc()
defer session.Db.Close() defer session.Db.Close()
} }
@ -78,102 +76,84 @@ func (session *Session) AutoStatement() *Statement {
return session.CurrentStatement() return session.CurrentStatement()
} }
//Execute sql func (session *Session) Where(querystring string, args ...interface{}) *Session {
func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) {
rs, err := session.Db.Prepare(finalQueryString)
if err != nil {
return nil, err
}
defer rs.Close()
res, err := rs.Exec(args...)
if err != nil {
return nil, err
}
return res, nil
}
func (session *Session) Where(querystring interface{}, args ...interface{}) *Session {
statement := session.AutoStatement() statement := session.AutoStatement()
switch querystring := querystring.(type) { statement.Where(querystring, args...)
case string:
statement.WhereStr = querystring
case int:
if session.Engine.Protocol == "pgsql" {
statement.WhereStr = fmt.Sprintf("%v%v%v = $%v", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier, session.ParamIteration)
} else {
statement.WhereStr = fmt.Sprintf("%v%v%v = ?", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier)
session.ParamIteration++
}
args = append(args, querystring)
}
statement.ParamStr = args
return session return session
} }
func (session *Session) Limit(start int, size ...int) *Session { func (session *Session) Limit(limit int, start ...int) *Session {
session.AutoStatement().LimitStr = start statement := session.AutoStatement()
if len(size) > 0 { statement.Limit(limit, start...)
session.CurrentStatement().OffsetStr = size[0]
}
return session
}
func (session *Session) Offset(offset int) *Session {
session.AutoStatement().OffsetStr = offset
return session return session
} }
func (session *Session) OrderBy(order string) *Session { func (session *Session) OrderBy(order string) *Session {
session.AutoStatement().OrderStr = order statement := session.AutoStatement()
statement.OrderBy(order)
return session return session
} }
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (session *Session) Join(join_operator, tablename, condition string) *Session { func (session *Session) Join(join_operator, tablename, condition string) *Session {
if session.AutoStatement().JoinStr != "" { statement := session.AutoStatement()
session.CurrentStatement().JoinStr = session.CurrentStatement().JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) statement.Join(join_operator, tablename, condition)
} else {
session.CurrentStatement().JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
}
return session return session
} }
func (session *Session) GroupBy(keys string) *Session { func (session *Session) GroupBy(keys string) *Session {
session.AutoStatement().GroupByStr = fmt.Sprintf("GROUP BY %v", keys) statement := session.AutoStatement()
statement.GroupBy(keys)
return session return session
} }
func (session *Session) Having(conditions string) *Session { func (session *Session) Having(conditions string) *Session {
session.AutoStatement().HavingStr = fmt.Sprintf("HAVING %v", conditions) statement := session.AutoStatement()
statement.Having(conditions)
return session return session
} }
func (session *Session) Begin() { func (session *Session) Begin() error {
session.IsAutoCommit = false
session.IsAutoRollback = true
tx, err := session.Db.Begin()
session.Tx = tx
return err
} }
func (session *Session) Rollback() { func (session *Session) Rollback() error {
return session.Tx.Rollback()
} }
func (session *Session) Commit() { func (session *Session) Commit() error {
for _, statement := range session.Statements { return session.Tx.Commit()
sql := statement.generateSql()
session.Exec(sql)
}
} }
func (session *Session) TableName(bean interface{}) string { func (session *Session) TableName(bean interface{}) string {
return session.Mapper.Obj2Table(StructName(bean)) return session.Mapper.Obj2Table(StructName(bean))
} }
func (session *Session) SetStatement(statement *Statement) {
if session.CurStatementIdx == len(session.Statements)-1 {
session.Statements = append(session.Statements, *statement)
} else {
session.Statements[session.CurStatementIdx+1] = *statement
}
session.CurStatementIdx = session.CurStatementIdx + 1
}
func (session *Session) newStatement() { func (session *Session) newStatement() {
state := Statement{} if session.CurStatementIdx == len(session.Statements)-1 {
state.Session = session state := Statement{Session: session}
state.Init()
session.Statements = append(session.Statements, state) session.Statements = append(session.Statements, state)
session.CurStatementIdx = len(session.Statements) - 1 }
session.CurStatementIdx = session.CurStatementIdx + 1
}
func (session *Session) clearStatment() {
session.Statements[session.CurStatementIdx].Init()
session.CurStatementIdx = session.CurStatementIdx - 1
} }
func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error {
@ -250,13 +230,32 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
return nil return nil
} }
func (session *Session) Get(output interface{}) error { //Execute sql
func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) {
rs, err := session.Db.Prepare(finalQueryString)
if err != nil {
return nil, err
}
defer rs.Close()
res, err := rs.Exec(args...)
if err != nil {
return nil, err
}
return res, nil
}
func (session *Session) Get(bean interface{}) error {
statement := session.AutoStatement() statement := session.AutoStatement()
session.Limit(1) statement.Limit(1)
tableName := session.TableName(output) tableName := session.TableName(bean)
table := session.Engine.Tables[tableName] table := session.Engine.Tables[tableName]
statement.Table = &table statement.Table = &table
colNames, args := session.BuildConditions(&table, bean)
statement.ColumnStr = strings.Join(colNames, " and ")
statement.BeanArgs = args
resultsSlice, err := session.FindMap(statement) resultsSlice, err := session.FindMap(statement)
if err != nil { if err != nil {
return err return err
@ -265,7 +264,7 @@ func (session *Session) Get(output interface{}) error {
return nil return nil
} else if len(resultsSlice) == 1 { } else if len(resultsSlice) == 1 {
results := resultsSlice[0] results := resultsSlice[0]
err := session.scanMapIntoStruct(output, results) err := session.scanMapIntoStruct(bean, results)
if err != nil { if err != nil {
return err return err
} }
@ -277,12 +276,15 @@ func (session *Session) Get(output interface{}) error {
func (session *Session) Count(bean interface{}) (int64, error) { func (session *Session) Count(bean interface{}) (int64, error) {
statement := session.AutoStatement() statement := session.AutoStatement()
session.Limit(1)
tableName := session.TableName(bean) tableName := session.TableName(bean)
table := session.Engine.Tables[tableName] table := session.Engine.Tables[tableName]
statement.Table = &table statement.Table = &table
resultsSlice, err := session.SQL2Map(statement.genCountSql(), statement.ParamStr) colNames, args := session.BuildConditions(&table, bean)
statement.ColumnStr = strings.Join(colNames, " and ")
statement.BeanArgs = args
resultsSlice, err := session.SQL2Map(statement.genCountSql(), append(statement.Params, statement.BeanArgs...))
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -346,6 +348,7 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli
} }
for res.Next() { for res.Next() {
result := make(map[string][]byte) result := make(map[string][]byte)
//scanResultContainers := make([]interface{}, len(fields))
var scanResultContainers []interface{} var scanResultContainers []interface{}
for i := 0; i < len(fields); i++ { for i := 0; i < len(fields); i++ {
var scanResultContainer interface{} var scanResultContainer interface{}
@ -367,7 +370,6 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli
switch aa.Kind() { switch aa.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10) str = strconv.FormatInt(vv.Int(), 10)
result[key] = []byte(str) result[key] = []byte(str)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10) str = strconv.FormatUint(vv.Uint(), 10)
@ -397,7 +399,7 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli
func (session *Session) FindMap(statement *Statement) (resultsSlice []map[string][]byte, err error) { func (session *Session) FindMap(statement *Statement) (resultsSlice []map[string][]byte, err error) {
sqls := statement.generateSql() sqls := statement.generateSql()
return session.SQL2Map(sqls, statement.ParamStr) return session.SQL2Map(sqls, append(statement.Params, statement.BeanArgs...))
} }
func (session *Session) Insert(beans ...interface{}) (int64, error) { func (session *Session) Insert(beans ...interface{}) (int64, error) {
@ -421,7 +423,7 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
for _, col := range table.Columns { for _, col := range table.Columns {
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
val := fieldValue.Interface() val := fieldValue.Interface()
if col.AutoIncrement { if col.IsAutoIncrement {
if fieldValue.Int() == 0 { if fieldValue.Int() == 0 {
continue continue
} }
@ -442,7 +444,14 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
fmt.Println(statement) fmt.Println(statement)
} }
res, err := session.Exec(statement, args...) var res sql.Result
var err error
if session.IsAutoCommit {
res, err = session.Exec(statement, args...)
} else {
res, err = session.Tx.Exec(statement, args...)
}
if err != nil { if err != nil {
return -1, err return -1, err
} }
@ -455,16 +464,10 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
return id, nil return id, nil
} }
func (session *Session) Update(bean interface{}) (int64, error) { func (session *Session) BuildConditions(table *Table, bean interface{}) ([]string, []interface{}) {
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
colNames := make([]string, 0) colNames := make([]string, 0)
var args = make([]interface{}, 0) var args = make([]interface{}, 0)
for _, col := range table.Columns { for _, col := range table.Columns {
if col.Name == "" {
continue
}
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
fieldType := reflect.TypeOf(fieldValue.Interface()) fieldType := reflect.TypeOf(fieldValue.Interface())
val := fieldValue.Interface() val := fieldValue.Interface()
@ -491,9 +494,18 @@ func (session *Session) Update(bean interface{}) (int64, error) {
colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?") colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?")
} }
return colNames, args
}
func (session *Session) Update(bean interface{}) (int64, error) {
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
colNames, args := session.BuildConditions(&table, bean)
var condition = "" var condition = ""
st := session.CurrentStatement() st := session.AutoStatement()
if st != nil && st.WhereStr != "" { defer session.clearStatment()
if st.WhereStr != "" {
condition = fmt.Sprintf("WHERE %v", st.WhereStr) condition = fmt.Sprintf("WHERE %v", st.WhereStr)
} }
@ -516,7 +528,15 @@ func (session *Session) Update(bean interface{}) (int64, error) {
fmt.Println(statement) fmt.Println(statement)
} }
res, err := session.Exec(statement, args...) var res sql.Result
var err error
if session.IsAutoCommit {
fmt.Println("session.Exec")
res, err = session.Exec(statement, append(args, st.Params...)...)
} else {
fmt.Println("tx.Exec")
res, err = session.Tx.Exec(statement, append(args, st.Params...)...)
}
if err != nil { if err != nil {
return -1, err return -1, err
} }
@ -532,42 +552,12 @@ func (session *Session) Update(bean interface{}) (int64, error) {
func (session *Session) Delete(bean interface{}) (int64, error) { func (session *Session) Delete(bean interface{}) (int64, error) {
tableName := session.TableName(bean) tableName := session.TableName(bean)
table := session.Engine.Tables[tableName] table := session.Engine.Tables[tableName]
colNames, args := session.BuildConditions(&table, bean)
colNames := make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
fieldType := reflect.TypeOf(fieldValue.Interface())
val := fieldValue.Interface()
switch fieldType.Kind() {
case reflect.String:
if fieldValue.String() == "" {
continue
}
case reflect.Int, reflect.Int32, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Struct:
if fieldType == reflect.TypeOf(time.Now()) {
t := fieldValue.Interface().(time.Time)
if t.IsZero() {
continue
}
}
default:
continue
}
args = append(args, val)
colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?")
}
var condition = "" var condition = ""
st := session.CurrentStatement() st := session.AutoStatement()
if st != nil && st.WhereStr != "" { defer session.clearStatment()
if st.WhereStr != "" {
condition = fmt.Sprintf("WHERE %v", st.WhereStr) condition = fmt.Sprintf("WHERE %v", st.WhereStr)
if len(colNames) > 0 { if len(colNames) > 0 {
condition += " and " condition += " and "
@ -587,7 +577,13 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
fmt.Println(statement) fmt.Println(statement)
} }
res, err := session.Exec(statement, args...) var res sql.Result
var err error
if session.IsAutoCommit {
res, err = session.Exec(statement, append(st.Params, args...)...)
} else {
res, err = session.Tx.Exec(statement, append(st.Params, args...)...)
}
if err != nil { if err != nil {
return -1, err return -1, err
} }

View File

@ -5,40 +5,66 @@ import (
) )
type Statement struct { type Statement struct {
TableName string
Table *Table Table *Table
Session *Session Session *Session
LimitStr int Start int
OffsetStr int LimitN int
WhereStr string WhereStr string
ParamStr []interface{} Params []interface{}
OrderStr string OrderStr string
JoinStr string JoinStr string
GroupByStr string GroupByStr string
HavingStr string HavingStr string
ColumnStr string
BeanArgs []interface{}
} }
func (statement *Statement) Limit(start int, size ...int) *Statement { func (statement *Statement) Init() {
statement.LimitStr = start statement.Table = nil
if len(size) > 0 { statement.Session = nil
statement.OffsetStr = size[0] statement.Start = 0
} statement.LimitN = 0
return statement statement.WhereStr = ""
statement.Params = make([]interface{}, 0)
statement.OrderStr = ""
statement.JoinStr = ""
statement.GroupByStr = ""
statement.HavingStr = ""
statement.ColumnStr = ""
statement.BeanArgs = make([]interface{}, 0)
} }
func (statement *Statement) Offset(offset int) *Statement { func (statement *Statement) Where(querystring string, args ...interface{}) {
statement.OffsetStr = offset statement.WhereStr = querystring
return statement statement.Params = args
} }
func (statement *Statement) OrderBy(order string) *Statement { func (statement *Statement) Limit(limit int, start ...int) {
statement.LimitN = limit
if len(start) > 0 {
statement.Start = start[0]
}
}
func (statement *Statement) OrderBy(order string) {
statement.OrderStr = order statement.OrderStr = order
return statement
} }
func (statement *Statement) Select(colums string) *Statement { //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
//statement.ColumnStr = colums func (statement *Statement) Join(join_operator, tablename, condition string) {
return statement if statement.JoinStr != "" {
statement.JoinStr = statement.JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
} else {
statement.JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
}
}
func (statement *Statement) GroupBy(keys string) {
statement.GroupByStr = fmt.Sprintf("GROUP BY %v", keys)
}
func (statement *Statement) Having(conditions string) {
statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
} }
func (statement Statement) generateSql() string { func (statement Statement) generateSql() string {
@ -50,27 +76,41 @@ func (statement Statement) genCountSql() string {
return statement.genSelectSql("count(*) as total") return statement.genSelectSql("count(*) as total")
} }
func (statement Statement) genExecSql() string {
return ""
}
func (statement Statement) genSelectSql(columnStr string) (a string) { func (statement Statement) genSelectSql(columnStr string) (a string) {
session := statement.Session session := statement.Session
if session.Engine.Protocol == "mssql" { if session.Engine.Protocol == "mssql" {
if statement.OffsetStr > 0 { if statement.Start > 0 {
a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v", a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v",
statement.Table.PKColumn().Name, statement.Table.PKColumn().Name,
columnStr, columnStr,
statement.Table.Name) statement.Table.Name)
if statement.WhereStr != "" { if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
} }
a = fmt.Sprintf("select %v from (%v) "+ a = fmt.Sprintf("select %v from (%v) "+
"as a where rownum between %v and %v", "as a where rownum between %v and %v",
columnStr, columnStr,
a, a,
statement.OffsetStr, statement.Start,
statement.LimitStr) statement.LimitN)
} else if statement.LimitStr > 0 { } else if statement.LimitN > 0 {
a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitStr, columnStr, statement.Table.Name) a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitN, columnStr, statement.Table.Name)
if statement.WhereStr != "" { if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
} }
if statement.GroupByStr != "" { if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr) a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
@ -85,6 +125,11 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name) a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name)
if statement.WhereStr != "" { if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
} }
if statement.GroupByStr != "" { if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr) a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
@ -103,6 +148,11 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
} }
if statement.WhereStr != "" { if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
} }
if statement.GroupByStr != "" { if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr) a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
@ -113,10 +163,10 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
if statement.OrderStr != "" { if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
} }
if statement.OffsetStr > 0 { if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.OffsetStr, statement.LimitStr) a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.Start, statement.LimitN)
} else if statement.LimitStr > 0 { } else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitStr) a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
} }
} }
return return

90
table.go Normal file
View File

@ -0,0 +1,90 @@
package xorm
import (
"reflect"
"strconv"
"strings"
"time"
)
type SQLType struct {
Name string
DefaultLength int
DefaultLength2 int
}
var (
Int = SQLType{"int", 11, 0}
Char = SQLType{"char", 1, 0}
Bool = SQLType{"int", 1, 0}
Varchar = SQLType{"varchar", 50, 0}
Date = SQLType{"date", 24, 0}
Decimal = SQLType{"decimal", 26, 2}
Float = SQLType{"float", 31, 0}
Double = SQLType{"double", 31, 0}
)
func (sqlType SQLType) genSQL(length int) string {
if sqlType == Date {
return " datetime "
}
return sqlType.Name + "(" + strconv.Itoa(length) + ")"
}
func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k {
case reflect.Int, reflect.Int32, reflect.Int64:
st = Int
case reflect.Bool:
st = Bool
case reflect.String:
st = Varchar
case reflect.Struct:
if t == reflect.TypeOf(time.Time{}) {
st = Date
}
default:
st = Varchar
}
return
}
type Column struct {
Name string
FieldName string
SQLType SQLType
Length int
Length2 int
Nullable bool
Default string
IsUnique bool
IsPrimaryKey bool
IsAutoIncrement bool
}
type Table struct {
Name string
Type reflect.Type
Columns map[string]Column
PrimaryKey string
}
func (table *Table) ColumnStr() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
colNames = append(colNames, col.Name)
}
return strings.Join(colNames, ", ")
}
/*func (table *Table) PlaceHolders() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
colNames = append(colNames, "?")
}
return strings.Join(colNames, ", ")
}*/
func (table *Table) PKColumn() Column {
return table.Columns[table.PrimaryKey]
}

View File

@ -3,8 +3,7 @@ package xorm_test
import ( import (
"fmt" "fmt"
_ "github.com/Go-SQL-Driver/MySQL" _ "github.com/Go-SQL-Driver/MySQL"
//_ "github.com/ziutek/mymysql/godrv" _ "github.com/mattn/go-sqlite3"
//_ "github.com/mattn/go-sqlite3"
"testing" "testing"
"time" "time"
"xorm" "xorm"
@ -36,21 +35,14 @@ type Userinfo struct {
var engine xorm.Engine var engine xorm.Engine
func TestCreateEngine(t *testing.T) { func directCreateTable(t *testing.T) {
engine = xorm.Create("mysql://root:123@localhost/test")
//engine = orm.Create("mymysql://root:123@localhost/test")
//engine = orm.Create("sqlite:///test.db")
engine.ShowSQL = true
}
func TestDirectCreateTable(t *testing.T) {
err := engine.CreateTables(&Userinfo{}) err := engine.CreateTables(&Userinfo{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
} }
func TestMapper(t *testing.T) { func mapper(t *testing.T) {
err := engine.UnMap(&Userinfo{}) err := engine.UnMap(&Userinfo{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -72,7 +64,7 @@ func TestMapper(t *testing.T) {
} }
} }
func TestInsert(t *testing.T) { func insert(t *testing.T) {
user := Userinfo{1, "xiaolunwen", "dev", "lunny", time.Now()} user := Userinfo{1, "xiaolunwen", "dev", "lunny", time.Now()}
_, err := engine.Insert(&user) _, err := engine.Insert(&user)
if err != nil { if err != nil {
@ -80,7 +72,7 @@ func TestInsert(t *testing.T) {
} }
} }
func TestInsertAutoIncr(t *testing.T) { func insertAutoIncr(t *testing.T) {
// auto increment insert // auto increment insert
user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now()} user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err := engine.Insert(&user) _, err := engine.Insert(&user)
@ -89,7 +81,7 @@ func TestInsertAutoIncr(t *testing.T) {
} }
} }
func TestInsertMulti(t *testing.T) { func insertMulti(t *testing.T) {
user1 := Userinfo{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()} user1 := Userinfo{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}
user2 := Userinfo{Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()} user2 := Userinfo{Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}
_, err := engine.Insert(&user1, &user2) _, err := engine.Insert(&user1, &user2)
@ -98,7 +90,7 @@ func TestInsertMulti(t *testing.T) {
} }
} }
func TestUpdate(t *testing.T) { func update(t *testing.T) {
// update by id // update by id
user := Userinfo{Uid: 1, Username: "xxx"} user := Userinfo{Uid: 1, Username: "xxx"}
_, err := engine.Update(&user) _, err := engine.Update(&user)
@ -107,7 +99,7 @@ func TestUpdate(t *testing.T) {
} }
} }
func TestDelete(t *testing.T) { func delete(t *testing.T) {
user := Userinfo{Uid: 1} user := Userinfo{Uid: 1}
_, err := engine.Delete(&user) _, err := engine.Delete(&user)
if err != nil { if err != nil {
@ -115,7 +107,7 @@ func TestDelete(t *testing.T) {
} }
} }
func TestGet(t *testing.T) { func get(t *testing.T) {
user := Userinfo{Uid: 2} user := Userinfo{Uid: 2}
err := engine.Get(&user) err := engine.Get(&user)
@ -125,7 +117,7 @@ func TestGet(t *testing.T) {
fmt.Println(user) fmt.Println(user)
} }
func TestFind(t *testing.T) { func find(t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
err := engine.Find(&users) err := engine.Find(&users)
@ -135,8 +127,8 @@ func TestFind(t *testing.T) {
fmt.Println(users) fmt.Println(users)
} }
func TestCount(t *testing.T) { func count(t *testing.T) {
user := Userinfo{} user := Userinfo{Departname: "dev"}
total, err := engine.Count(&user) total, err := engine.Count(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -144,47 +136,117 @@ func TestCount(t *testing.T) {
fmt.Printf("Total %d records!!!", total) fmt.Printf("Total %d records!!!", total)
} }
func TestWhere(t *testing.T) { func where(t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
session, err := engine.MakeSession() err := engine.Where("id > ?", 2).Find(&users)
defer session.Close()
if err != nil {
t.Error(err)
}
err = session.Where("id > ?", 2).Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
fmt.Println(users) fmt.Println(users)
} }
func TestLimit(t *testing.T) { func limit(t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
session, err := engine.MakeSession() err := engine.Limit(2, 1).Find(&users)
defer session.Close()
if err != nil {
t.Error(err)
}
err = session.Limit(2, 1).Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
fmt.Println(users) fmt.Println(users)
} }
func TestOrder(t *testing.T) { func order(t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
session, err := engine.MakeSession() err := engine.OrderBy("id desc").Find(&users)
defer session.Close()
if err != nil {
t.Error(err)
}
err = session.OrderBy("id desc").Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
fmt.Println(users) fmt.Println(users)
} }
func TestTransaction(*testing.T) { func transaction(t *testing.T) {
counter := func() {
total, err := engine.Count(&Userinfo{})
if err != nil {
t.Error(err)
}
fmt.Printf("----now total %v records\n", total)
}
counter()
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
t.Error(err)
return
}
defer counter()
session.Begin()
session.IsAutoRollback = true
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1)
if err != nil {
t.Error(err)
return
}
user2 := Userinfo{Username: "yyy"}
_, err = session.Where("id = ?", 2).Update(&user2)
if err != nil {
t.Error(err)
return
}
_, err = session.Delete(&user2)
if err != nil {
t.Error(err)
return
}
err = session.Commit()
if err != nil {
t.Error(err)
return
}
}
func TestMysql(t *testing.T) {
engine = xorm.Create("mysql://root:123@localhost/test")
engine.ShowSQL = true
directCreateTable(t)
mapper(t)
insert(t)
insertAutoIncr(t)
insertMulti(t)
update(t)
delete(t)
get(t)
find(t)
count(t)
where(t)
limit(t)
order(t)
transaction(t)
}
func TestSqlite(t *testing.T) {
engine = xorm.Create("sqlite:///test.db")
engine.ShowSQL = true
directCreateTable(t)
mapper(t)
insert(t)
insertAutoIncr(t)
insertMulti(t)
update(t)
delete(t)
get(t)
find(t)
count(t)
where(t)
limit(t)
order(t)
transaction(t)
} }