go routine support;raw sql support

This commit is contained in:
Lunny Xiao 2013-06-16 11:05:16 +08:00
parent e68c855ee3
commit fb6f84bc90
10 changed files with 555 additions and 218 deletions

View File

@ -4,16 +4,12 @@ package xorm
// @deprecation : please use NewSession instead
func (engine *Engine) MakeSession() (Session, error) {
s, err := engine.NewSession()
if err == nil {
return *s, err
} else {
return Session{}, err
}
s := engine.NewSession()
return *s, nil
}
// @deprecation : please use NewEngine instead
func Create(driverName string, dataSourceName string) Engine {
engine := NewEngine(driverName, dataSourceName)
engine, _ := NewEngine(driverName, dataSourceName)
return *engine
}

207
engine.go
View File

@ -6,6 +6,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"
)
const (
@ -26,12 +27,13 @@ type Engine struct {
DriverName string
DataSourceName string
Dialect dialect
Tables map[reflect.Type]Table
Tables map[reflect.Type]*Table
mutex *sync.Mutex
AutoIncrement string
ShowSQL bool
InsertMany bool
QuoteIdentifier string
Statement Statement
Pool IConnectionPool
}
func Type(bean interface{}) reflect.Type {
@ -50,78 +52,89 @@ func (e *Engine) OpenDB() (*sql.DB, error) {
return sql.Open(e.DriverName, e.DataSourceName)
}
func (engine *Engine) NewSession() (session *Session, err error) {
db, err := engine.OpenDB()
if err != nil {
return nil, err
}
session = &Session{Engine: engine, Db: db}
func (engine *Engine) NewSession() *Session {
session := &Session{Engine: engine}
session.Init()
return
return session
}
func (engine *Engine) Test() error {
session, err := engine.NewSession()
if err != nil {
return err
}
return session.Db.Ping()
session := engine.NewSession()
defer session.Close()
return session.Ping()
}
func (engine *Engine) Where(querystring string, args ...interface{}) *Engine {
engine.Statement.Where(querystring, args...)
return engine
func (engine *Engine) Sql(querystring string, args ...interface{}) *Session {
session := engine.NewSession()
session.Sql(querystring, args...)
return session
}
func (engine *Engine) Id(id int64) *Engine {
engine.Statement.Id(id)
return engine
func (engine *Engine) Where(querystring string, args ...interface{}) *Session {
session := engine.NewSession()
session.Where(querystring, args...)
return session
}
func (engine *Engine) In(column string, args ...interface{}) *Engine {
engine.Statement.In(column, args...)
return engine
func (engine *Engine) Id(id int64) *Session {
session := engine.NewSession()
session.Id(id)
return session
}
func (engine *Engine) Table(tableName string) *Engine {
engine.Statement.Table(tableName)
return engine
func (engine *Engine) In(column string, args ...interface{}) *Session {
session := engine.NewSession()
session.In(column, args...)
return session
}
func (engine *Engine) Limit(limit int, start ...int) *Engine {
engine.Statement.Limit(limit, start...)
return engine
func (engine *Engine) Table(tableName string) *Session {
session := engine.NewSession()
session.Table(tableName)
return session
}
func (engine *Engine) OrderBy(order string) *Engine {
engine.Statement.OrderBy(order)
return engine
func (engine *Engine) Limit(limit int, start ...int) *Session {
session := engine.NewSession()
session.Limit(limit, start...)
return session
}
func (engine *Engine) OrderBy(order string) *Session {
session := engine.NewSession()
session.OrderBy(order)
return session
}
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (engine *Engine) Join(join_operator, tablename, condition string) *Engine {
engine.Statement.Join(join_operator, tablename, condition)
return engine
func (engine *Engine) Join(join_operator, tablename, condition string) *Session {
session := engine.NewSession()
session.Join(join_operator, tablename, condition)
return session
}
func (engine *Engine) GroupBy(keys string) *Engine {
engine.Statement.GroupBy(keys)
return engine
func (engine *Engine) GroupBy(keys string) *Session {
session := engine.NewSession()
session.GroupBy(keys)
return session
}
func (engine *Engine) Having(conditions string) *Engine {
engine.Statement.Having(conditions)
return engine
func (engine *Engine) Having(conditions string) *Session {
session := engine.NewSession()
session.Having(conditions)
return session
}
// some lock needed
func (engine *Engine) AutoMapType(t reflect.Type) *Table {
engine.mutex.Lock()
defer engine.mutex.Unlock()
table, ok := engine.Tables[t]
if !ok {
table = engine.MapType(t)
engine.Tables[t] = table
//engine.Tables[t] = table
}
return &table
return table
}
func (engine *Engine) AutoMap(bean interface{}) *Table {
@ -129,8 +142,8 @@ func (engine *Engine) AutoMap(bean interface{}) *Table {
return engine.AutoMapType(t)
}
func (engine *Engine) MapType(t reflect.Type) Table {
table := Table{Name: engine.Mapper.Obj2Table(t.Name()), Type: t}
func (engine *Engine) MapType(t reflect.Type) *Table {
table := &Table{Name: engine.Mapper.Obj2Table(t.Name()), Type: t}
table.Columns = make(map[string]Column)
for i := 0; i < t.NumField(); i++ {
@ -226,7 +239,10 @@ func (engine *Engine) MapType(t reflect.Type) Table {
return table
}
// Map should use after all operation because it's not thread safe
func (engine *Engine) Map(beans ...interface{}) (e error) {
engine.mutex.Lock()
defer engine.mutex.Unlock()
for _, bean := range beans {
t := Type(bean)
if _, ok := engine.Tables[t]; !ok {
@ -237,6 +253,8 @@ func (engine *Engine) Map(beans ...interface{}) (e error) {
}
func (engine *Engine) UnMap(beans ...interface{}) (e error) {
engine.mutex.Lock()
defer engine.mutex.Unlock()
for _, bean := range beans {
t := Type(bean)
if _, ok := engine.Tables[t]; ok {
@ -247,37 +265,24 @@ func (engine *Engine) UnMap(beans ...interface{}) (e error) {
}
func (e *Engine) DropAll() error {
session, err := e.MakeSession()
session.Begin()
session := e.NewSession()
defer session.Close()
if err != nil {
return err
}
for _, table := range e.Tables {
e.Statement.RefTable = &table
sql := e.Statement.genDropSQL()
_, err = session.Exec(sql)
err := session.Begin()
if err != nil {
session.Rollback()
return err
}
err = session.DropAll()
if err != nil {
return session.Rollback()
}
return session.Commit()
}
func (e *Engine) CreateTables(beans ...interface{}) error {
session, err := e.MakeSession()
if err != nil {
return err
}
session := e.NewSession()
defer session.Close()
err = session.Begin()
if err != nil {
return err
}
session.Statement = e.Statement
defer e.Statement.Init()
err := session.Begin()
if err != nil {
return err
}
@ -292,106 +297,64 @@ func (e *Engine) CreateTables(beans ...interface{}) error {
}
func (e *Engine) CreateAll() error {
session, err := e.MakeSession()
session.Begin()
session := e.NewSession()
err := session.Begin()
defer session.Close()
if err != nil {
return err
}
for _, table := range e.Tables {
e.Statement.RefTable = &table
sql := e.Statement.genCreateSQL()
_, err = session.Exec(sql)
err = session.CreateAll()
if err != nil {
session.Rollback()
break
return session.Rollback()
}
}
session.Commit()
return err
return session.Commit()
}
func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) {
session, err := engine.MakeSession()
session := engine.NewSession()
defer session.Close()
if err != nil {
return nil, err
}
return session.Exec(sql, args...)
}
func (engine *Engine) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
session, err := engine.MakeSession()
session := engine.NewSession()
defer session.Close()
if err != nil {
return nil, err
}
return session.Query(sql, paramStr...)
}
func (engine *Engine) Insert(beans ...interface{}) (int64, error) {
session, err := engine.MakeSession()
session := engine.NewSession()
defer session.Close()
if err != nil {
return -1, err
}
defer engine.Statement.Init()
session.Statement = engine.Statement
return session.Insert(beans...)
}
func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) {
session, err := engine.MakeSession()
session := engine.NewSession()
defer session.Close()
if err != nil {
return -1, err
}
defer engine.Statement.Init()
session.Statement = engine.Statement
return session.Update(bean, condiBeans...)
}
func (engine *Engine) Delete(bean interface{}) (int64, error) {
session, err := engine.MakeSession()
session := engine.NewSession()
defer session.Close()
if err != nil {
return -1, err
}
defer engine.Statement.Init()
session.Statement = engine.Statement
return session.Delete(bean)
}
func (engine *Engine) Get(bean interface{}) error {
session, err := engine.MakeSession()
func (engine *Engine) Get(bean interface{}) (bool, error) {
session := engine.NewSession()
defer session.Close()
if err != nil {
return err
}
defer engine.Statement.Init()
session.Statement = engine.Statement
return session.Get(bean)
}
func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error {
session, err := engine.MakeSession()
session := engine.NewSession()
defer session.Close()
if err != nil {
return err
}
defer engine.Statement.Init()
session.Statement = engine.Statement
return session.Find(beans, condiBeans...)
}
func (engine *Engine) Count(bean interface{}) (int64, error) {
session, err := engine.MakeSession()
session := engine.NewSession()
defer session.Close()
if err != nil {
return 0, err
}
defer engine.Statement.Init()
session.Statement = engine.Statement
return session.Count(bean)
}

88
examples/goroutine.go Normal file
View File

@ -0,0 +1,88 @@
package main
import (
//xorm "github.com/lunny/xorm"
"fmt"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"os"
//"time"
xorm "xorm"
)
type User struct {
Id int
Name string
}
func sqliteEngine() (*xorm.Engine, error) {
os.Remove("./test.db")
return xorm.NewEngine("sqlite3", "./goroutine.db")
}
func mysqlEngine() (*xorm.Engine, error) {
return xorm.NewEngine("mysql", "root:123@/test?charset=utf8")
}
func main() {
engine, err := sqliteEngine()
// engine, err := mysqlEngine()
if err != nil {
fmt.Println(err)
return
}
u := &User{}
err = engine.CreateTables(u)
if err != nil {
fmt.Println(err)
return
}
size := 10
queue := make(chan int, size)
for i := 0; i < size; i++ {
go func(x int) {
//x := i
err := engine.Test()
if err != nil {
fmt.Println(err)
} else {
err = engine.Map(u)
if err != nil {
fmt.Println("Map user failed")
} else {
for j := 0; j < 10; j++ {
if x+j < 2 {
_, err = engine.Get(u)
} else if x+j < 4 {
users := make([]User, 0)
err = engine.Find(&users)
} else if x+j < 8 {
_, err = engine.Count(u)
} else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 {
_, err = engine.Id(1).Delete(u)
}
if err != nil {
fmt.Println(err)
queue <- x
return
}
}
fmt.Printf("%v success!\n", x)
}
}
queue <- x
}(i)
}
for i := 0; i < size; i++ {
<-queue
}
fmt.Println("end")
}

View File

@ -9,7 +9,7 @@ var me Engine
func TestMysql(t *testing.T) {
// You should drop all tables before executing this testing
me = Create("mysql", "root:@/xorm_test?charset=utf8")
me = Create("mysql", "root:123@/test?charset=utf8")
me.ShowSQL = true
directCreateTable(&me, t)

78
pool.go Normal file
View File

@ -0,0 +1,78 @@
package xorm
import (
"database/sql"
//"fmt"
//"sync"
//"time"
)
type IConnectionPool interface {
RetrieveDB(engine *Engine) (*sql.DB, error)
ReleaseDB(engine *Engine, db *sql.DB)
}
type NoneConnectPool struct {
}
func (p NoneConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) {
db, err = engine.OpenDB()
return
}
func (p NoneConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
db.Close()
}
/*
var (
total int = 0
)
type SimpleConnectPool struct {
releasedSessions []*sql.DB
cur int
usingSessions map[*sql.DB]time.Time
maxWaitTimeOut int
mutex *sync.Mutex
}
func (p SimpleConnectPool) RetrieveDB(engine *Engine) (*sql.DB, error) {
p.mutex.Lock()
defer p.mutex.Unlock()
var db *sql.DB = nil
var err error = nil
fmt.Printf("%x, rbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions))
if p.cur < 0 {
total = total + 1
fmt.Printf("new %v\n", total)
db, err = engine.OpenDB()
if err != nil {
return nil, err
}
p.usingSessions[db] = time.Now()
} else {
db = p.releasedSessions[p.cur]
p.usingSessions[db] = time.Now()
p.releasedSessions[p.cur] = nil
p.cur = p.cur - 1
fmt.Println("release one")
}
fmt.Printf("%x, rend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions))
return db, nil
}
func (p SimpleConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
p.mutex.Lock()
defer p.mutex.Unlock()
fmt.Printf("%x, lbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions))
if p.cur >= 29 {
db.Close()
} else {
p.cur = p.cur + 1
p.releasedSessions[p.cur] = db
}
delete(p.usingSessions, db)
fmt.Printf("%x, lend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions))
}*/

View File

@ -21,6 +21,7 @@ type Session struct {
func (session *Session) Init() {
session.Statement = Statement{Engine: session.Engine}
session.Statement.Init()
session.IsAutoCommit = true
session.IsCommitedOrRollbacked = false
}
@ -28,11 +29,19 @@ func (session *Session) Init() {
func (session *Session) Close() {
defer func() {
if session.Db != nil {
session.Db.Close()
session.Engine.Pool.ReleaseDB(session.Engine, session.Db)
session.Db = nil
session.Tx = nil
session.Init()
}
}()
}
func (session *Session) Sql(querystring string, args ...interface{}) *Session {
session.Statement.Sql(querystring, args...)
return session
}
func (session *Session) Where(querystring string, args ...interface{}) *Session {
session.Statement.Where(querystring, args...)
return session
@ -86,7 +95,22 @@ func (session *Session) Having(conditions string) *Session {
return session
}
func (session *Session) newDb() error {
if session.Db == nil {
db, err := session.Engine.Pool.RetrieveDB(session.Engine)
if err != nil {
return err
}
session.Db = db
}
return nil
}
func (session *Session) Begin() error {
err := session.newDb()
if err != nil {
return err
}
if session.IsAutoCommit {
tx, err := session.Db.Begin()
if err != nil {
@ -189,31 +213,38 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
v = x
} else if session.Statement.UseCascade {
session.Engine.AutoMapType(structField.Type())
if _, ok := session.Engine.Tables[structField.Type()]; ok {
table := session.Engine.AutoMapType(structField.Type())
if table != nil {
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
if x != 0 {
structInter := reflect.New(structField.Type())
st := session.Statement
session.Statement.Init()
err = session.Id(x).Get(structInter.Interface())
has, err := session.Id(x).Get(structInter.Interface())
if err != nil {
session.Statement = st
return err
}
if has {
v = structInter.Elem().Interface()
session.Statement = st
} else {
fmt.Println("cascade obj is not exist!")
session.Statement = st
continue
}
} else {
//fmt.Println("zero value of struct type " + structField.Type().String())
continue
}
} else {
fmt.Println("unsupported struct type in Scan: " + structField.Type().String())
continue
}
} else {
continue
}
default:
return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String())
@ -241,6 +272,11 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result,
}
func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) {
err := session.newDb()
if err != nil {
return nil, err
}
if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" {
sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1)
}
@ -263,37 +299,48 @@ func (session *Session) CreateTable(bean interface{}) error {
return err
}
func (session *Session) Get(bean interface{}) error {
func (session *Session) Get(bean interface{}) (bool, error) {
statement := session.Statement
defer statement.Init()
statement.Limit(1)
fmt.Println(bean)
sql, args := statement.genGetSql(bean)
var sql string
var args []interface{}
if statement.RawSQL == "" {
sql, args = statement.genGetSql(bean)
} else {
sql = statement.RawSQL
args = statement.RawParams
}
resultsSlice, err := session.Query(sql, args...)
if err != nil {
return err
return false, err
}
if len(resultsSlice) == 0 {
return nil
return false, nil
} else if len(resultsSlice) == 1 {
results := resultsSlice[0]
err := session.scanMapIntoStruct(bean, results)
if err != nil {
return err
return false, err
}
} else {
return errors.New("More than one record")
return false, errors.New("More than one record")
}
return nil
return true, nil
}
func (session *Session) Count(bean interface{}) (int64, error) {
statement := session.Statement
defer session.Statement.Init()
sql, args := statement.genCountSql(bean)
var sql string
var args []interface{}
if statement.RawSQL == "" {
sql, args = statement.genCountSql(bean)
} else {
sql = statement.RawSQL
args = statement.RawParams
}
resultsSlice, err := session.Query(sql, args...)
if err != nil {
@ -301,10 +348,13 @@ func (session *Session) Count(bean interface{}) (int64, error) {
}
var total int64 = 0
for _, results := range resultsSlice {
total, err = strconv.ParseInt(string(results["total"]), 10, 64)
if len(resultsSlice) > 0 {
results := resultsSlice[0]
for _, value := range results {
total, err = strconv.ParseInt(string(value), 10, 64)
break
}
}
return int64(total), err
}
@ -327,8 +377,17 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
statement.BeanArgs = args
}
sql := statement.generateSql()
resultsSlice, err := session.Query(sql, append(statement.Params, statement.BeanArgs...)...)
var sql string
var args []interface{}
if statement.RawSQL == "" {
sql = statement.generateSql()
args = append(statement.Params, statement.BeanArgs...)
} else {
sql = statement.RawSQL
args = statement.RawParams
}
resultsSlice, err := session.Query(sql, args...)
if err != nil {
return err
@ -359,7 +418,45 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
return nil
}
func (session *Session) Ping() error {
err := session.newDb()
if err != nil {
return err
}
return session.Db.Ping()
}
func (session *Session) CreateAll() error {
for _, table := range session.Engine.Tables {
session.Statement.RefTable = table
sql := session.Statement.genCreateSQL()
_, err := session.Exec(sql)
if err != nil {
return err
}
}
return nil
}
func (session *Session) DropAll() error {
for _, table := range session.Engine.Tables {
session.Statement.RefTable = table
sql := session.Statement.genDropSQL()
_, err := session.Exec(sql)
if err != nil {
return err
}
}
return nil
}
func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
err = session.newDb()
if err != nil {
return nil, err
}
if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" {
sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1)
}
@ -635,7 +732,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}
statement := fmt.Sprintf("UPDATE %v%v%v SET %v %v",
sql := fmt.Sprintf("UPDATE %v%v%v SET %v %v",
session.Engine.QuoteIdentifier,
session.Statement.TableName(),
session.Engine.QuoteIdentifier,
@ -643,7 +740,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condition)
eargs := append(append(args, st.Params...), condiArgs...)
res, err := session.Exec(statement, eargs...)
res, err := session.Exec(sql, eargs...)
if err != nil {
return -1, err
}

View File

@ -6,65 +6,147 @@ import (
"testing"
)
var se Engine
var se *Engine
func TestSqlite(t *testing.T) {
func autoConn() {
if se == nil {
os.Remove("./test.db")
se = Create("sqlite3", "./test.db")
se, _ = NewEngine("sqlite3", "./test.db")
se.ShowSQL = true
}
}
func TestSqliteCreateTable(t *testing.T) {
directCreateTable(&se, t)
autoConn()
directCreateTable(se, t)
}
func TestSqliteMapper(t *testing.T) {
mapper(&se, t)
autoConn()
mapper(se, t)
}
func TestSqliteInsert(t *testing.T) {
insert(&se, t)
autoConn()
insert(se, t)
}
func TestSqliteQuery(t *testing.T) {
query(&se, t)
autoConn()
query(se, t)
}
func TestSqliteExec(t *testing.T) {
exec(&se, t)
autoConn()
exec(se, t)
}
func TestSqliteInsertAutoIncr(t *testing.T) {
insertAutoIncr(&se, t)
autoConn()
insertAutoIncr(se, t)
}
type sss struct {
}
func (s sss) TestInsertMulti(t *testing.T) {
insertMulti(&se, t)
func TestInsertMulti(t *testing.T) {
autoConn()
insertMulti(se, t)
}
func TestSqliteInsertMulti(t *testing.T) {
insertMulti(&se, t)
insertTwoTable(&se, t)
update(&se, t)
testdelete(&se, t)
get(&se, t)
cascadeGet(&se, t)
find(&se, t)
findMap(&se, t)
count(&se, t)
where(&se, t)
in(&se, t)
limit(&se, t)
order(&se, t)
join(&se, t)
having(&se, t)
transaction(&se, t)
combineTransaction(&se, t)
table(&se, t)
createMultiTables(&se, t)
tableOp(&se, t)
autoConn()
insertMulti(se, t)
}
func TestSqliteInsertTwoTable(t *testing.T) {
autoConn()
insertTwoTable(se, t)
}
func TestSqliteUpdate(t *testing.T) {
autoConn()
update(se, t)
}
func TestSqliteDelete(t *testing.T) {
autoConn()
testdelete(se, t)
}
func TestSqliteGet(t *testing.T) {
autoConn()
get(se, t)
}
func TestSqliteCascadeGet(t *testing.T) {
autoConn()
cascadeGet(se, t)
}
func TestSqliteFind(t *testing.T) {
autoConn()
find(se, t)
}
func TestSqliteFindMap(t *testing.T) {
autoConn()
findMap(se, t)
}
func TestSqliteCount(t *testing.T) {
autoConn()
count(se, t)
}
func TestSqliteWhere(t *testing.T) {
autoConn()
where(se, t)
}
func TestSqliteIn(t *testing.T) {
autoConn()
in(se, t)
}
func TestSqliteLimit(t *testing.T) {
autoConn()
limit(se, t)
}
func TestSqliteOrder(t *testing.T) {
autoConn()
order(se, t)
}
func TestSqliteJoin(t *testing.T) {
autoConn()
join(se, t)
}
func TestSqliteHaving(t *testing.T) {
autoConn()
having(se, t)
}
func TestSqliteTransaction(t *testing.T) {
autoConn()
transaction(se, t)
}
func TestSqliteCombineTransaction(t *testing.T) {
autoConn()
combineTransaction(se, t)
}
func TestSqliteTable(t *testing.T) {
autoConn()
table(se, t)
}
func TestSqliteCreateMultiTables(t *testing.T) {
autoConn()
createMultiTables(se, t)
}
func TestSqliteTableOp(t *testing.T) {
autoConn()
tableOp(se, t)
}

View File

@ -21,6 +21,8 @@ type Statement struct {
HavingStr string
ColumnStr string
AltTableName string
RawSQL string
RawParams []interface{}
UseCascade bool
BeanArgs []interface{}
}
@ -46,9 +48,16 @@ func (statement *Statement) Init() {
statement.HavingStr = ""
statement.ColumnStr = ""
statement.AltTableName = ""
statement.RawSQL = ""
statement.RawParams = make([]interface{}, 0)
statement.BeanArgs = make([]interface{}, 0)
}
func (statement *Statement) Sql(querystring string, args ...interface{}) {
statement.RawSQL = querystring
statement.RawParams = args
}
func (statement *Statement) Where(querystring string, args ...interface{}) {
statement.WhereStr = querystring
statement.Params = args

View File

@ -171,21 +171,29 @@ func testdelete(engine *Engine, t *testing.T) {
func get(engine *Engine, t *testing.T) {
user := Userinfo{Uid: 2}
err := engine.Get(&user)
has, err := engine.Get(&user)
if err != nil {
t.Error(err)
}
if has {
fmt.Println(user)
} else {
fmt.Println("no record id is 2")
}
}
func cascadeGet(engine *Engine, t *testing.T) {
user := Userinfo{Uid: 11}
err := engine.Get(&user)
has, err := engine.Get(&user)
if err != nil {
t.Error(err)
}
if has {
fmt.Println(user)
} else {
fmt.Println("no record id is 2")
}
}
func find(engine *Engine, t *testing.T) {
@ -290,14 +298,14 @@ func transaction(engine *Engine, t *testing.T) {
counter()
defer counter()
session, err := engine.MakeSession()
session := engine.NewSession()
defer session.Close()
err := session.Begin()
if err != nil {
t.Error(err)
return
}
session.Begin()
//session.IsAutoRollback = false
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1)
@ -340,14 +348,14 @@ func combineTransaction(engine *Engine, t *testing.T) {
counter()
defer counter()
session, err := engine.MakeSession()
session := engine.NewSession()
defer session.Close()
err := session.Begin()
if err != nil {
t.Error(err)
return
}
session.Begin()
//session.IsAutoRollback = false
user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1)
@ -379,19 +387,19 @@ func combineTransaction(engine *Engine, t *testing.T) {
}
func table(engine *Engine, t *testing.T) {
engine.Table("user_user").CreateTables(&Userinfo{})
engine.Table("user_user").CreateTable(&Userinfo{})
}
func createMultiTables(engine *Engine, t *testing.T) {
session, err := engine.MakeSession()
session := engine.NewSession()
defer session.Close()
user := &Userinfo{}
err := session.Begin()
if err != nil {
t.Error(err)
return
}
user := &Userinfo{}
session.Begin()
for i := 0; i < 10; i++ {
err = session.Table(fmt.Sprintf("user_%v", i)).CreateTable(user)
if err != nil {
@ -414,7 +422,7 @@ func tableOp(engine *Engine, t *testing.T) {
t.Error(err)
}
err = engine.Table(tableName).Get(&Userinfo{Username: "tablexiao"})
_, err = engine.Table(tableName).Get(&Userinfo{Username: "tablexiao"})
if err != nil {
t.Error(err)
}

28
xorm.go
View File

@ -1,26 +1,42 @@
package xorm
import (
//"database/sql"
"errors"
"fmt"
"reflect"
"sync"
//"time"
)
func NewEngine(driverName string, dataSourceName string) *Engine {
func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine := &Engine{ShowSQL: false, DriverName: driverName, Mapper: SnakeMapper{},
DataSourceName: dataSourceName}
engine.Tables = make(map[reflect.Type]Table)
engine.Statement.Engine = engine
engine.Tables = make(map[reflect.Type]*Table)
engine.mutex = &sync.Mutex{}
engine.InsertMany = true
engine.TagIdentifier = "xorm"
engine.QuoteIdentifier = "`"
if driverName == SQLITE {
engine.Dialect = sqlite3{}
engine.AutoIncrement = "AUTOINCREMENT"
} else {
//engine.Pool = NoneConnectPool{}
} else if driverName == MYSQL {
engine.Dialect = mysql{}
engine.AutoIncrement = "AUTO_INCREMENT"
} else {
return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName))
}
engine.QuoteIdentifier = "`"
/*engine.Pool = SimpleConnectPool{
releasedSessions: make([]*sql.DB, 30),
usingSessions: map[*sql.DB]time.Time{},
cur: -1,
maxWaitTimeOut: 14400,
mutex: &sync.Mutex{},
}*/
engine.Pool = NoneConnectPool{}
return engine
return engine, nil
}