go routine support;raw sql support

This commit is contained in:
Lunny Xiao 2013-06-16 11:05:16 +08:00
parent e68c855ee3
commit fb6f84bc90
10 changed files with 555 additions and 218 deletions

View File

@ -4,16 +4,12 @@ package xorm
// @deprecation : please use NewSession instead // @deprecation : please use NewSession instead
func (engine *Engine) MakeSession() (Session, error) { func (engine *Engine) MakeSession() (Session, error) {
s, err := engine.NewSession() s := engine.NewSession()
if err == nil { return *s, nil
return *s, err
} else {
return Session{}, err
}
} }
// @deprecation : please use NewEngine instead // @deprecation : please use NewEngine instead
func Create(driverName string, dataSourceName string) Engine { func Create(driverName string, dataSourceName string) Engine {
engine := NewEngine(driverName, dataSourceName) engine, _ := NewEngine(driverName, dataSourceName)
return *engine return *engine
} }

211
engine.go
View File

@ -6,6 +6,7 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
) )
const ( const (
@ -26,12 +27,13 @@ type Engine struct {
DriverName string DriverName string
DataSourceName string DataSourceName string
Dialect dialect Dialect dialect
Tables map[reflect.Type]Table Tables map[reflect.Type]*Table
mutex *sync.Mutex
AutoIncrement string AutoIncrement string
ShowSQL bool ShowSQL bool
InsertMany bool InsertMany bool
QuoteIdentifier string QuoteIdentifier string
Statement Statement Pool IConnectionPool
} }
func Type(bean interface{}) reflect.Type { func Type(bean interface{}) reflect.Type {
@ -50,78 +52,89 @@ func (e *Engine) OpenDB() (*sql.DB, error) {
return sql.Open(e.DriverName, e.DataSourceName) return sql.Open(e.DriverName, e.DataSourceName)
} }
func (engine *Engine) NewSession() (session *Session, err error) { func (engine *Engine) NewSession() *Session {
db, err := engine.OpenDB() session := &Session{Engine: engine}
if err != nil {
return nil, err
}
session = &Session{Engine: engine, Db: db}
session.Init() session.Init()
return return session
} }
func (engine *Engine) Test() error { func (engine *Engine) Test() error {
session, err := engine.NewSession() session := engine.NewSession()
if err != nil { defer session.Close()
return err return session.Ping()
}
return session.Db.Ping()
} }
func (engine *Engine) Where(querystring string, args ...interface{}) *Engine { func (engine *Engine) Sql(querystring string, args ...interface{}) *Session {
engine.Statement.Where(querystring, args...) session := engine.NewSession()
return engine session.Sql(querystring, args...)
return session
} }
func (engine *Engine) Id(id int64) *Engine { func (engine *Engine) Where(querystring string, args ...interface{}) *Session {
engine.Statement.Id(id) session := engine.NewSession()
return engine session.Where(querystring, args...)
return session
} }
func (engine *Engine) In(column string, args ...interface{}) *Engine { func (engine *Engine) Id(id int64) *Session {
engine.Statement.In(column, args...) session := engine.NewSession()
return engine session.Id(id)
return session
} }
func (engine *Engine) Table(tableName string) *Engine { func (engine *Engine) In(column string, args ...interface{}) *Session {
engine.Statement.Table(tableName) session := engine.NewSession()
return engine session.In(column, args...)
return session
} }
func (engine *Engine) Limit(limit int, start ...int) *Engine { func (engine *Engine) Table(tableName string) *Session {
engine.Statement.Limit(limit, start...) session := engine.NewSession()
return engine session.Table(tableName)
return session
} }
func (engine *Engine) OrderBy(order string) *Engine { func (engine *Engine) Limit(limit int, start ...int) *Session {
engine.Statement.OrderBy(order) session := engine.NewSession()
return engine session.Limit(limit, start...)
return session
}
func (engine *Engine) OrderBy(order string) *Session {
session := engine.NewSession()
session.OrderBy(order)
return session
} }
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (engine *Engine) Join(join_operator, tablename, condition string) *Engine { func (engine *Engine) Join(join_operator, tablename, condition string) *Session {
engine.Statement.Join(join_operator, tablename, condition) session := engine.NewSession()
return engine session.Join(join_operator, tablename, condition)
return session
} }
func (engine *Engine) GroupBy(keys string) *Engine { func (engine *Engine) GroupBy(keys string) *Session {
engine.Statement.GroupBy(keys) session := engine.NewSession()
return engine session.GroupBy(keys)
return session
} }
func (engine *Engine) Having(conditions string) *Engine { func (engine *Engine) Having(conditions string) *Session {
engine.Statement.Having(conditions) session := engine.NewSession()
return engine session.Having(conditions)
return session
} }
// some lock needed
func (engine *Engine) AutoMapType(t reflect.Type) *Table { func (engine *Engine) AutoMapType(t reflect.Type) *Table {
engine.mutex.Lock()
defer engine.mutex.Unlock()
table, ok := engine.Tables[t] table, ok := engine.Tables[t]
if !ok { if !ok {
table = engine.MapType(t) table = engine.MapType(t)
engine.Tables[t] = table //engine.Tables[t] = table
} }
return &table return table
} }
func (engine *Engine) AutoMap(bean interface{}) *Table { func (engine *Engine) AutoMap(bean interface{}) *Table {
@ -129,8 +142,8 @@ func (engine *Engine) AutoMap(bean interface{}) *Table {
return engine.AutoMapType(t) return engine.AutoMapType(t)
} }
func (engine *Engine) MapType(t reflect.Type) Table { func (engine *Engine) MapType(t reflect.Type) *Table {
table := Table{Name: engine.Mapper.Obj2Table(t.Name()), Type: t} table := &Table{Name: engine.Mapper.Obj2Table(t.Name()), Type: t}
table.Columns = make(map[string]Column) table.Columns = make(map[string]Column)
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
@ -226,7 +239,10 @@ func (engine *Engine) MapType(t reflect.Type) Table {
return table return table
} }
// Map should use after all operation because it's not thread safe
func (engine *Engine) Map(beans ...interface{}) (e error) { func (engine *Engine) Map(beans ...interface{}) (e error) {
engine.mutex.Lock()
defer engine.mutex.Unlock()
for _, bean := range beans { for _, bean := range beans {
t := Type(bean) t := Type(bean)
if _, ok := engine.Tables[t]; !ok { if _, ok := engine.Tables[t]; !ok {
@ -237,6 +253,8 @@ func (engine *Engine) Map(beans ...interface{}) (e error) {
} }
func (engine *Engine) UnMap(beans ...interface{}) (e error) { func (engine *Engine) UnMap(beans ...interface{}) (e error) {
engine.mutex.Lock()
defer engine.mutex.Unlock()
for _, bean := range beans { for _, bean := range beans {
t := Type(bean) t := Type(bean)
if _, ok := engine.Tables[t]; ok { if _, ok := engine.Tables[t]; ok {
@ -247,37 +265,24 @@ func (engine *Engine) UnMap(beans ...interface{}) (e error) {
} }
func (e *Engine) DropAll() error { func (e *Engine) DropAll() error {
session, err := e.MakeSession() session := e.NewSession()
session.Begin()
defer session.Close() defer session.Close()
err := session.Begin()
if err != nil { if err != nil {
return err return err
} }
err = session.DropAll()
for _, table := range e.Tables { if err != nil {
e.Statement.RefTable = &table return session.Rollback()
sql := e.Statement.genDropSQL()
_, err = session.Exec(sql)
if err != nil {
session.Rollback()
return err
}
} }
return session.Commit() return session.Commit()
} }
func (e *Engine) CreateTables(beans ...interface{}) error { func (e *Engine) CreateTables(beans ...interface{}) error {
session, err := e.MakeSession() session := e.NewSession()
if err != nil {
return err
}
defer session.Close() defer session.Close()
err = session.Begin() err := session.Begin()
if err != nil {
return err
}
session.Statement = e.Statement
defer e.Statement.Init()
if err != nil { if err != nil {
return err return err
} }
@ -292,106 +297,64 @@ func (e *Engine) CreateTables(beans ...interface{}) error {
} }
func (e *Engine) CreateAll() error { func (e *Engine) CreateAll() error {
session, err := e.MakeSession() session := e.NewSession()
session.Begin() err := session.Begin()
defer session.Close() defer session.Close()
if err != nil { if err != nil {
return err return err
} }
for _, table := range e.Tables { err = session.CreateAll()
e.Statement.RefTable = &table if err != nil {
sql := e.Statement.genCreateSQL() return session.Rollback()
_, err = session.Exec(sql)
if err != nil {
session.Rollback()
break
}
} }
session.Commit() return session.Commit()
return err
} }
func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) { func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) {
session, err := engine.MakeSession() session := engine.NewSession()
defer session.Close() defer session.Close()
if err != nil {
return nil, err
}
return session.Exec(sql, args...) return session.Exec(sql, args...)
} }
func (engine *Engine) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { func (engine *Engine) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
session, err := engine.MakeSession() session := engine.NewSession()
defer session.Close() defer session.Close()
if err != nil {
return nil, err
}
return session.Query(sql, paramStr...) return session.Query(sql, paramStr...)
} }
func (engine *Engine) Insert(beans ...interface{}) (int64, error) { func (engine *Engine) Insert(beans ...interface{}) (int64, error) {
session, err := engine.MakeSession() session := engine.NewSession()
defer session.Close() defer session.Close()
if err != nil {
return -1, err
}
defer engine.Statement.Init()
session.Statement = engine.Statement
return session.Insert(beans...) return session.Insert(beans...)
} }
func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) { func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) {
session, err := engine.MakeSession() session := engine.NewSession()
defer session.Close() defer session.Close()
if err != nil {
return -1, err
}
defer engine.Statement.Init()
session.Statement = engine.Statement
return session.Update(bean, condiBeans...) return session.Update(bean, condiBeans...)
} }
func (engine *Engine) Delete(bean interface{}) (int64, error) { func (engine *Engine) Delete(bean interface{}) (int64, error) {
session, err := engine.MakeSession() session := engine.NewSession()
defer session.Close() defer session.Close()
if err != nil {
return -1, err
}
defer engine.Statement.Init()
session.Statement = engine.Statement
return session.Delete(bean) return session.Delete(bean)
} }
func (engine *Engine) Get(bean interface{}) error { func (engine *Engine) Get(bean interface{}) (bool, error) {
session, err := engine.MakeSession() session := engine.NewSession()
defer session.Close() defer session.Close()
if err != nil {
return err
}
defer engine.Statement.Init()
session.Statement = engine.Statement
return session.Get(bean) return session.Get(bean)
} }
func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error { func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error {
session, err := engine.MakeSession() session := engine.NewSession()
defer session.Close() defer session.Close()
if err != nil {
return err
}
defer engine.Statement.Init()
session.Statement = engine.Statement
return session.Find(beans, condiBeans...) return session.Find(beans, condiBeans...)
} }
func (engine *Engine) Count(bean interface{}) (int64, error) { func (engine *Engine) Count(bean interface{}) (int64, error) {
session, err := engine.MakeSession() session := engine.NewSession()
defer session.Close() defer session.Close()
if err != nil {
return 0, err
}
defer engine.Statement.Init()
session.Statement = engine.Statement
return session.Count(bean) return session.Count(bean)
} }

88
examples/goroutine.go Normal file
View File

@ -0,0 +1,88 @@
package main
import (
//xorm "github.com/lunny/xorm"
"fmt"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"os"
//"time"
xorm "xorm"
)
type User struct {
Id int
Name string
}
func sqliteEngine() (*xorm.Engine, error) {
os.Remove("./test.db")
return xorm.NewEngine("sqlite3", "./goroutine.db")
}
func mysqlEngine() (*xorm.Engine, error) {
return xorm.NewEngine("mysql", "root:123@/test?charset=utf8")
}
func main() {
engine, err := sqliteEngine()
// engine, err := mysqlEngine()
if err != nil {
fmt.Println(err)
return
}
u := &User{}
err = engine.CreateTables(u)
if err != nil {
fmt.Println(err)
return
}
size := 10
queue := make(chan int, size)
for i := 0; i < size; i++ {
go func(x int) {
//x := i
err := engine.Test()
if err != nil {
fmt.Println(err)
} else {
err = engine.Map(u)
if err != nil {
fmt.Println("Map user failed")
} else {
for j := 0; j < 10; j++ {
if x+j < 2 {
_, err = engine.Get(u)
} else if x+j < 4 {
users := make([]User, 0)
err = engine.Find(&users)
} else if x+j < 8 {
_, err = engine.Count(u)
} else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 {
_, err = engine.Id(1).Delete(u)
}
if err != nil {
fmt.Println(err)
queue <- x
return
}
}
fmt.Printf("%v success!\n", x)
}
}
queue <- x
}(i)
}
for i := 0; i < size; i++ {
<-queue
}
fmt.Println("end")
}

View File

@ -9,7 +9,7 @@ var me Engine
func TestMysql(t *testing.T) { func TestMysql(t *testing.T) {
// You should drop all tables before executing this testing // You should drop all tables before executing this testing
me = Create("mysql", "root:@/xorm_test?charset=utf8") me = Create("mysql", "root:123@/test?charset=utf8")
me.ShowSQL = true me.ShowSQL = true
directCreateTable(&me, t) directCreateTable(&me, t)

78
pool.go Normal file
View File

@ -0,0 +1,78 @@
package xorm
import (
"database/sql"
//"fmt"
//"sync"
//"time"
)
type IConnectionPool interface {
RetrieveDB(engine *Engine) (*sql.DB, error)
ReleaseDB(engine *Engine, db *sql.DB)
}
type NoneConnectPool struct {
}
func (p NoneConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) {
db, err = engine.OpenDB()
return
}
func (p NoneConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
db.Close()
}
/*
var (
total int = 0
)
type SimpleConnectPool struct {
releasedSessions []*sql.DB
cur int
usingSessions map[*sql.DB]time.Time
maxWaitTimeOut int
mutex *sync.Mutex
}
func (p SimpleConnectPool) RetrieveDB(engine *Engine) (*sql.DB, error) {
p.mutex.Lock()
defer p.mutex.Unlock()
var db *sql.DB = nil
var err error = nil
fmt.Printf("%x, rbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions))
if p.cur < 0 {
total = total + 1
fmt.Printf("new %v\n", total)
db, err = engine.OpenDB()
if err != nil {
return nil, err
}
p.usingSessions[db] = time.Now()
} else {
db = p.releasedSessions[p.cur]
p.usingSessions[db] = time.Now()
p.releasedSessions[p.cur] = nil
p.cur = p.cur - 1
fmt.Println("release one")
}
fmt.Printf("%x, rend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions))
return db, nil
}
func (p SimpleConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
p.mutex.Lock()
defer p.mutex.Unlock()
fmt.Printf("%x, lbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions))
if p.cur >= 29 {
db.Close()
} else {
p.cur = p.cur + 1
p.releasedSessions[p.cur] = db
}
delete(p.usingSessions, db)
fmt.Printf("%x, lend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions))
}*/

View File

@ -21,6 +21,7 @@ type Session struct {
func (session *Session) Init() { func (session *Session) Init() {
session.Statement = Statement{Engine: session.Engine} session.Statement = Statement{Engine: session.Engine}
session.Statement.Init()
session.IsAutoCommit = true session.IsAutoCommit = true
session.IsCommitedOrRollbacked = false session.IsCommitedOrRollbacked = false
} }
@ -28,11 +29,19 @@ func (session *Session) Init() {
func (session *Session) Close() { func (session *Session) Close() {
defer func() { defer func() {
if session.Db != nil { if session.Db != nil {
session.Db.Close() session.Engine.Pool.ReleaseDB(session.Engine, session.Db)
session.Db = nil
session.Tx = nil
session.Init()
} }
}() }()
} }
func (session *Session) Sql(querystring string, args ...interface{}) *Session {
session.Statement.Sql(querystring, args...)
return session
}
func (session *Session) Where(querystring string, args ...interface{}) *Session { func (session *Session) Where(querystring string, args ...interface{}) *Session {
session.Statement.Where(querystring, args...) session.Statement.Where(querystring, args...)
return session return session
@ -86,7 +95,22 @@ func (session *Session) Having(conditions string) *Session {
return session return session
} }
func (session *Session) newDb() error {
if session.Db == nil {
db, err := session.Engine.Pool.RetrieveDB(session.Engine)
if err != nil {
return err
}
session.Db = db
}
return nil
}
func (session *Session) Begin() error { func (session *Session) Begin() error {
err := session.newDb()
if err != nil {
return err
}
if session.IsAutoCommit { if session.IsAutoCommit {
tx, err := session.Db.Begin() tx, err := session.Db.Begin()
if err != nil { if err != nil {
@ -189,31 +213,38 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
v = x v = x
} else if session.Statement.UseCascade { } else if session.Statement.UseCascade {
session.Engine.AutoMapType(structField.Type()) table := session.Engine.AutoMapType(structField.Type())
if _, ok := session.Engine.Tables[structField.Type()]; ok { if table != nil {
x, err := strconv.ParseInt(string(data), 10, 64) x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil { if err != nil {
return errors.New("arg " + key + " as int: " + err.Error()) return errors.New("arg " + key + " as int: " + err.Error())
} }
if x != 0 { if x != 0 {
structInter := reflect.New(structField.Type()) structInter := reflect.New(structField.Type())
st := session.Statement
session.Statement.Init() session.Statement.Init()
err = session.Id(x).Get(structInter.Interface()) has, err := session.Id(x).Get(structInter.Interface())
if err != nil { if err != nil {
session.Statement = st
return err return err
} }
if has {
v = structInter.Elem().Interface() v = structInter.Elem().Interface()
session.Statement = st
} else {
fmt.Println("cascade obj is not exist!")
session.Statement = st
continue
}
} else { } else {
//fmt.Println("zero value of struct type " + structField.Type().String())
continue continue
} }
} else { } else {
fmt.Println("unsupported struct type in Scan: " + structField.Type().String()) fmt.Println("unsupported struct type in Scan: " + structField.Type().String())
continue continue
} }
} else {
continue
} }
default: default:
return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String())
@ -241,6 +272,11 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result,
} }
func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) {
err := session.newDb()
if err != nil {
return nil, err
}
if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" {
sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1) sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1)
} }
@ -263,37 +299,48 @@ func (session *Session) CreateTable(bean interface{}) error {
return err return err
} }
func (session *Session) Get(bean interface{}) error { func (session *Session) Get(bean interface{}) (bool, error) {
statement := session.Statement statement := session.Statement
defer statement.Init() defer statement.Init()
statement.Limit(1) statement.Limit(1)
var sql string
fmt.Println(bean) var args []interface{}
if statement.RawSQL == "" {
sql, args := statement.genGetSql(bean) sql, args = statement.genGetSql(bean)
} else {
sql = statement.RawSQL
args = statement.RawParams
}
resultsSlice, err := session.Query(sql, args...) resultsSlice, err := session.Query(sql, args...)
if err != nil { if err != nil {
return err return false, err
} }
if len(resultsSlice) == 0 { if len(resultsSlice) == 0 {
return nil return false, nil
} else if len(resultsSlice) == 1 { } else if len(resultsSlice) == 1 {
results := resultsSlice[0] results := resultsSlice[0]
err := session.scanMapIntoStruct(bean, results) err := session.scanMapIntoStruct(bean, results)
if err != nil { if err != nil {
return err return false, err
} }
} else { } else {
return errors.New("More than one record") return false, errors.New("More than one record")
} }
return nil return true, nil
} }
func (session *Session) Count(bean interface{}) (int64, error) { func (session *Session) Count(bean interface{}) (int64, error) {
statement := session.Statement statement := session.Statement
defer session.Statement.Init() defer session.Statement.Init()
sql, args := statement.genCountSql(bean) var sql string
var args []interface{}
if statement.RawSQL == "" {
sql, args = statement.genCountSql(bean)
} else {
sql = statement.RawSQL
args = statement.RawParams
}
resultsSlice, err := session.Query(sql, args...) resultsSlice, err := session.Query(sql, args...)
if err != nil { if err != nil {
@ -301,9 +348,12 @@ func (session *Session) Count(bean interface{}) (int64, error) {
} }
var total int64 = 0 var total int64 = 0
for _, results := range resultsSlice { if len(resultsSlice) > 0 {
total, err = strconv.ParseInt(string(results["total"]), 10, 64) results := resultsSlice[0]
break for _, value := range results {
total, err = strconv.ParseInt(string(value), 10, 64)
break
}
} }
return int64(total), err return int64(total), err
@ -327,8 +377,17 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
statement.BeanArgs = args statement.BeanArgs = args
} }
sql := statement.generateSql() var sql string
resultsSlice, err := session.Query(sql, append(statement.Params, statement.BeanArgs...)...) var args []interface{}
if statement.RawSQL == "" {
sql = statement.generateSql()
args = append(statement.Params, statement.BeanArgs...)
} else {
sql = statement.RawSQL
args = statement.RawParams
}
resultsSlice, err := session.Query(sql, args...)
if err != nil { if err != nil {
return err return err
@ -359,7 +418,45 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
return nil return nil
} }
func (session *Session) Ping() error {
err := session.newDb()
if err != nil {
return err
}
return session.Db.Ping()
}
func (session *Session) CreateAll() error {
for _, table := range session.Engine.Tables {
session.Statement.RefTable = table
sql := session.Statement.genCreateSQL()
_, err := session.Exec(sql)
if err != nil {
return err
}
}
return nil
}
func (session *Session) DropAll() error {
for _, table := range session.Engine.Tables {
session.Statement.RefTable = table
sql := session.Statement.genDropSQL()
_, err := session.Exec(sql)
if err != nil {
return err
}
}
return nil
}
func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
err = session.newDb()
if err != nil {
return nil, err
}
if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" {
sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1) sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1)
} }
@ -635,7 +732,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
statement := fmt.Sprintf("UPDATE %v%v%v SET %v %v", sql := fmt.Sprintf("UPDATE %v%v%v SET %v %v",
session.Engine.QuoteIdentifier, session.Engine.QuoteIdentifier,
session.Statement.TableName(), session.Statement.TableName(),
session.Engine.QuoteIdentifier, session.Engine.QuoteIdentifier,
@ -643,7 +740,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condition) condition)
eargs := append(append(args, st.Params...), condiArgs...) eargs := append(append(args, st.Params...), condiArgs...)
res, err := session.Exec(statement, eargs...) res, err := session.Exec(sql, eargs...)
if err != nil { if err != nil {
return -1, err return -1, err
} }

View File

@ -6,65 +6,147 @@ import (
"testing" "testing"
) )
var se Engine var se *Engine
func TestSqlite(t *testing.T) { func autoConn() {
os.Remove("./test.db") if se == nil {
se = Create("sqlite3", "./test.db") os.Remove("./test.db")
se.ShowSQL = true se, _ = NewEngine("sqlite3", "./test.db")
se.ShowSQL = true
}
} }
func TestSqliteCreateTable(t *testing.T) { func TestSqliteCreateTable(t *testing.T) {
directCreateTable(&se, t) autoConn()
directCreateTable(se, t)
} }
func TestSqliteMapper(t *testing.T) { func TestSqliteMapper(t *testing.T) {
mapper(&se, t) autoConn()
mapper(se, t)
} }
func TestSqliteInsert(t *testing.T) { func TestSqliteInsert(t *testing.T) {
insert(&se, t) autoConn()
insert(se, t)
} }
func TestSqliteQuery(t *testing.T) { func TestSqliteQuery(t *testing.T) {
query(&se, t) autoConn()
query(se, t)
} }
func TestSqliteExec(t *testing.T) { func TestSqliteExec(t *testing.T) {
exec(&se, t) autoConn()
exec(se, t)
} }
func TestSqliteInsertAutoIncr(t *testing.T) { func TestSqliteInsertAutoIncr(t *testing.T) {
insertAutoIncr(&se, t) autoConn()
insertAutoIncr(se, t)
} }
type sss struct { func TestInsertMulti(t *testing.T) {
} autoConn()
insertMulti(se, t)
func (s sss) TestInsertMulti(t *testing.T) {
insertMulti(&se, t)
} }
func TestSqliteInsertMulti(t *testing.T) { func TestSqliteInsertMulti(t *testing.T) {
insertMulti(&se, t) autoConn()
insertMulti(se, t)
insertTwoTable(&se, t) }
update(&se, t)
testdelete(&se, t) func TestSqliteInsertTwoTable(t *testing.T) {
get(&se, t) autoConn()
cascadeGet(&se, t) insertTwoTable(se, t)
find(&se, t) }
findMap(&se, t)
count(&se, t) func TestSqliteUpdate(t *testing.T) {
where(&se, t) autoConn()
in(&se, t) update(se, t)
limit(&se, t) }
order(&se, t)
join(&se, t) func TestSqliteDelete(t *testing.T) {
having(&se, t) autoConn()
transaction(&se, t) testdelete(se, t)
combineTransaction(&se, t) }
table(&se, t)
createMultiTables(&se, t) func TestSqliteGet(t *testing.T) {
tableOp(&se, t) autoConn()
get(se, t)
}
func TestSqliteCascadeGet(t *testing.T) {
autoConn()
cascadeGet(se, t)
}
func TestSqliteFind(t *testing.T) {
autoConn()
find(se, t)
}
func TestSqliteFindMap(t *testing.T) {
autoConn()
findMap(se, t)
}
func TestSqliteCount(t *testing.T) {
autoConn()
count(se, t)
}
func TestSqliteWhere(t *testing.T) {
autoConn()
where(se, t)
}
func TestSqliteIn(t *testing.T) {
autoConn()
in(se, t)
}
func TestSqliteLimit(t *testing.T) {
autoConn()
limit(se, t)
}
func TestSqliteOrder(t *testing.T) {
autoConn()
order(se, t)
}
func TestSqliteJoin(t *testing.T) {
autoConn()
join(se, t)
}
func TestSqliteHaving(t *testing.T) {
autoConn()
having(se, t)
}
func TestSqliteTransaction(t *testing.T) {
autoConn()
transaction(se, t)
}
func TestSqliteCombineTransaction(t *testing.T) {
autoConn()
combineTransaction(se, t)
}
func TestSqliteTable(t *testing.T) {
autoConn()
table(se, t)
}
func TestSqliteCreateMultiTables(t *testing.T) {
autoConn()
createMultiTables(se, t)
}
func TestSqliteTableOp(t *testing.T) {
autoConn()
tableOp(se, t)
} }

View File

@ -21,6 +21,8 @@ type Statement struct {
HavingStr string HavingStr string
ColumnStr string ColumnStr string
AltTableName string AltTableName string
RawSQL string
RawParams []interface{}
UseCascade bool UseCascade bool
BeanArgs []interface{} BeanArgs []interface{}
} }
@ -46,9 +48,16 @@ func (statement *Statement) Init() {
statement.HavingStr = "" statement.HavingStr = ""
statement.ColumnStr = "" statement.ColumnStr = ""
statement.AltTableName = "" statement.AltTableName = ""
statement.RawSQL = ""
statement.RawParams = make([]interface{}, 0)
statement.BeanArgs = make([]interface{}, 0) statement.BeanArgs = make([]interface{}, 0)
} }
func (statement *Statement) Sql(querystring string, args ...interface{}) {
statement.RawSQL = querystring
statement.RawParams = args
}
func (statement *Statement) Where(querystring string, args ...interface{}) { func (statement *Statement) Where(querystring string, args ...interface{}) {
statement.WhereStr = querystring statement.WhereStr = querystring
statement.Params = args statement.Params = args

View File

@ -171,21 +171,29 @@ func testdelete(engine *Engine, t *testing.T) {
func get(engine *Engine, t *testing.T) { func get(engine *Engine, t *testing.T) {
user := Userinfo{Uid: 2} user := Userinfo{Uid: 2}
err := engine.Get(&user) has, err := engine.Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
fmt.Println(user) if has {
fmt.Println(user)
} else {
fmt.Println("no record id is 2")
}
} }
func cascadeGet(engine *Engine, t *testing.T) { func cascadeGet(engine *Engine, t *testing.T) {
user := Userinfo{Uid: 11} user := Userinfo{Uid: 11}
err := engine.Get(&user) has, err := engine.Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
fmt.Println(user) if has {
fmt.Println(user)
} else {
fmt.Println("no record id is 2")
}
} }
func find(engine *Engine, t *testing.T) { func find(engine *Engine, t *testing.T) {
@ -290,14 +298,14 @@ func transaction(engine *Engine, t *testing.T) {
counter() counter()
defer counter() defer counter()
session, err := engine.MakeSession() session := engine.NewSession()
defer session.Close() defer session.Close()
err := session.Begin()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
session.Begin()
//session.IsAutoRollback = false //session.IsAutoRollback = false
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1) _, err = session.Insert(&user1)
@ -340,14 +348,14 @@ func combineTransaction(engine *Engine, t *testing.T) {
counter() counter()
defer counter() defer counter()
session, err := engine.MakeSession() session := engine.NewSession()
defer session.Close() defer session.Close()
err := session.Begin()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
session.Begin()
//session.IsAutoRollback = false //session.IsAutoRollback = false
user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1) _, err = session.Insert(&user1)
@ -379,19 +387,19 @@ func combineTransaction(engine *Engine, t *testing.T) {
} }
func table(engine *Engine, t *testing.T) { func table(engine *Engine, t *testing.T) {
engine.Table("user_user").CreateTables(&Userinfo{}) engine.Table("user_user").CreateTable(&Userinfo{})
} }
func createMultiTables(engine *Engine, t *testing.T) { func createMultiTables(engine *Engine, t *testing.T) {
session, err := engine.MakeSession() session := engine.NewSession()
defer session.Close() defer session.Close()
user := &Userinfo{}
err := session.Begin()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
user := &Userinfo{}
session.Begin()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
err = session.Table(fmt.Sprintf("user_%v", i)).CreateTable(user) err = session.Table(fmt.Sprintf("user_%v", i)).CreateTable(user)
if err != nil { if err != nil {
@ -414,7 +422,7 @@ func tableOp(engine *Engine, t *testing.T) {
t.Error(err) t.Error(err)
} }
err = engine.Table(tableName).Get(&Userinfo{Username: "tablexiao"}) _, err = engine.Table(tableName).Get(&Userinfo{Username: "tablexiao"})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }

28
xorm.go
View File

@ -1,26 +1,42 @@
package xorm package xorm
import ( import (
//"database/sql"
"errors"
"fmt"
"reflect" "reflect"
"sync"
//"time"
) )
func NewEngine(driverName string, dataSourceName string) *Engine { func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine := &Engine{ShowSQL: false, DriverName: driverName, Mapper: SnakeMapper{}, engine := &Engine{ShowSQL: false, DriverName: driverName, Mapper: SnakeMapper{},
DataSourceName: dataSourceName} DataSourceName: dataSourceName}
engine.Tables = make(map[reflect.Type]Table) engine.Tables = make(map[reflect.Type]*Table)
engine.Statement.Engine = engine engine.mutex = &sync.Mutex{}
engine.InsertMany = true engine.InsertMany = true
engine.TagIdentifier = "xorm" engine.TagIdentifier = "xorm"
engine.QuoteIdentifier = "`"
if driverName == SQLITE { if driverName == SQLITE {
engine.Dialect = sqlite3{} engine.Dialect = sqlite3{}
engine.AutoIncrement = "AUTOINCREMENT" engine.AutoIncrement = "AUTOINCREMENT"
} else { //engine.Pool = NoneConnectPool{}
} else if driverName == MYSQL {
engine.Dialect = mysql{} engine.Dialect = mysql{}
engine.AutoIncrement = "AUTO_INCREMENT" engine.AutoIncrement = "AUTO_INCREMENT"
} else {
return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName))
} }
engine.QuoteIdentifier = "`" /*engine.Pool = SimpleConnectPool{
releasedSessions: make([]*sql.DB, 30),
usingSessions: map[*sql.DB]time.Time{},
cur: -1,
maxWaitTimeOut: 14400,
mutex: &sync.Mutex{},
}*/
engine.Pool = NoneConnectPool{}
return engine return engine, nil
} }