diff --git a/deprecated.go b/deprecated.go index c960bcf5..aede52b1 100644 --- a/deprecated.go +++ b/deprecated.go @@ -4,16 +4,12 @@ package xorm // @deprecation : please use NewSession instead func (engine *Engine) MakeSession() (Session, error) { - s, err := engine.NewSession() - if err == nil { - return *s, err - } else { - return Session{}, err - } + s := engine.NewSession() + return *s, nil } // @deprecation : please use NewEngine instead func Create(driverName string, dataSourceName string) Engine { - engine := NewEngine(driverName, dataSourceName) + engine, _ := NewEngine(driverName, dataSourceName) return *engine } diff --git a/engine.go b/engine.go index 3d7e6d56..7f994c6f 100644 --- a/engine.go +++ b/engine.go @@ -6,6 +6,7 @@ import ( "reflect" "strconv" "strings" + "sync" ) const ( @@ -26,12 +27,13 @@ type Engine struct { DriverName string DataSourceName string Dialect dialect - Tables map[reflect.Type]Table + Tables map[reflect.Type]*Table + mutex *sync.Mutex AutoIncrement string ShowSQL bool InsertMany bool QuoteIdentifier string - Statement Statement + Pool IConnectionPool } func Type(bean interface{}) reflect.Type { @@ -50,78 +52,89 @@ func (e *Engine) OpenDB() (*sql.DB, error) { return sql.Open(e.DriverName, e.DataSourceName) } -func (engine *Engine) NewSession() (session *Session, err error) { - db, err := engine.OpenDB() - if err != nil { - return nil, err - } - - session = &Session{Engine: engine, Db: db} +func (engine *Engine) NewSession() *Session { + session := &Session{Engine: engine} session.Init() - return + return session } func (engine *Engine) Test() error { - session, err := engine.NewSession() - if err != nil { - return err - } - return session.Db.Ping() + session := engine.NewSession() + defer session.Close() + return session.Ping() } -func (engine *Engine) Where(querystring string, args ...interface{}) *Engine { - engine.Statement.Where(querystring, args...) - return engine +func (engine *Engine) Sql(querystring string, args ...interface{}) *Session { + session := engine.NewSession() + session.Sql(querystring, args...) + return session } -func (engine *Engine) Id(id int64) *Engine { - engine.Statement.Id(id) - return engine +func (engine *Engine) Where(querystring string, args ...interface{}) *Session { + session := engine.NewSession() + session.Where(querystring, args...) + return session } -func (engine *Engine) In(column string, args ...interface{}) *Engine { - engine.Statement.In(column, args...) - return engine +func (engine *Engine) Id(id int64) *Session { + session := engine.NewSession() + session.Id(id) + return session } -func (engine *Engine) Table(tableName string) *Engine { - engine.Statement.Table(tableName) - return engine +func (engine *Engine) In(column string, args ...interface{}) *Session { + session := engine.NewSession() + session.In(column, args...) + return session } -func (engine *Engine) Limit(limit int, start ...int) *Engine { - engine.Statement.Limit(limit, start...) - return engine +func (engine *Engine) Table(tableName string) *Session { + session := engine.NewSession() + session.Table(tableName) + return session } -func (engine *Engine) OrderBy(order string) *Engine { - engine.Statement.OrderBy(order) - return engine +func (engine *Engine) Limit(limit int, start ...int) *Session { + session := engine.NewSession() + session.Limit(limit, start...) + return session +} + +func (engine *Engine) OrderBy(order string) *Session { + session := engine.NewSession() + session.OrderBy(order) + return session } //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (engine *Engine) Join(join_operator, tablename, condition string) *Engine { - engine.Statement.Join(join_operator, tablename, condition) - return engine +func (engine *Engine) Join(join_operator, tablename, condition string) *Session { + session := engine.NewSession() + session.Join(join_operator, tablename, condition) + return session } -func (engine *Engine) GroupBy(keys string) *Engine { - engine.Statement.GroupBy(keys) - return engine +func (engine *Engine) GroupBy(keys string) *Session { + session := engine.NewSession() + session.GroupBy(keys) + return session } -func (engine *Engine) Having(conditions string) *Engine { - engine.Statement.Having(conditions) - return engine +func (engine *Engine) Having(conditions string) *Session { + session := engine.NewSession() + session.Having(conditions) + return session } +// some lock needed func (engine *Engine) AutoMapType(t reflect.Type) *Table { + engine.mutex.Lock() + defer engine.mutex.Unlock() table, ok := engine.Tables[t] if !ok { table = engine.MapType(t) - engine.Tables[t] = table + //engine.Tables[t] = table } - return &table + return table } func (engine *Engine) AutoMap(bean interface{}) *Table { @@ -129,8 +142,8 @@ func (engine *Engine) AutoMap(bean interface{}) *Table { return engine.AutoMapType(t) } -func (engine *Engine) MapType(t reflect.Type) Table { - table := Table{Name: engine.Mapper.Obj2Table(t.Name()), Type: t} +func (engine *Engine) MapType(t reflect.Type) *Table { + table := &Table{Name: engine.Mapper.Obj2Table(t.Name()), Type: t} table.Columns = make(map[string]Column) for i := 0; i < t.NumField(); i++ { @@ -226,7 +239,10 @@ func (engine *Engine) MapType(t reflect.Type) Table { return table } +// Map should use after all operation because it's not thread safe func (engine *Engine) Map(beans ...interface{}) (e error) { + engine.mutex.Lock() + defer engine.mutex.Unlock() for _, bean := range beans { t := Type(bean) if _, ok := engine.Tables[t]; !ok { @@ -237,6 +253,8 @@ func (engine *Engine) Map(beans ...interface{}) (e error) { } func (engine *Engine) UnMap(beans ...interface{}) (e error) { + engine.mutex.Lock() + defer engine.mutex.Unlock() for _, bean := range beans { t := Type(bean) if _, ok := engine.Tables[t]; ok { @@ -247,37 +265,24 @@ func (engine *Engine) UnMap(beans ...interface{}) (e error) { } func (e *Engine) DropAll() error { - session, err := e.MakeSession() - session.Begin() + session := e.NewSession() defer session.Close() + + err := session.Begin() if err != nil { return err } - - for _, table := range e.Tables { - e.Statement.RefTable = &table - sql := e.Statement.genDropSQL() - _, err = session.Exec(sql) - if err != nil { - session.Rollback() - return err - } + err = session.DropAll() + if err != nil { + return session.Rollback() } return session.Commit() } func (e *Engine) CreateTables(beans ...interface{}) error { - session, err := e.MakeSession() - if err != nil { - return err - } + session := e.NewSession() defer session.Close() - err = session.Begin() - if err != nil { - return err - } - session.Statement = e.Statement - defer e.Statement.Init() + err := session.Begin() if err != nil { return err } @@ -292,106 +297,64 @@ func (e *Engine) CreateTables(beans ...interface{}) error { } func (e *Engine) CreateAll() error { - session, err := e.MakeSession() - session.Begin() + session := e.NewSession() + err := session.Begin() defer session.Close() if err != nil { return err } - for _, table := range e.Tables { - e.Statement.RefTable = &table - sql := e.Statement.genCreateSQL() - _, err = session.Exec(sql) - if err != nil { - session.Rollback() - break - } + err = session.CreateAll() + if err != nil { + return session.Rollback() } - session.Commit() - return err + return session.Commit() } func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) { - session, err := engine.MakeSession() + session := engine.NewSession() defer session.Close() - if err != nil { - return nil, err - } return session.Exec(sql, args...) } func (engine *Engine) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { - session, err := engine.MakeSession() + session := engine.NewSession() defer session.Close() - if err != nil { - return nil, err - } return session.Query(sql, paramStr...) } func (engine *Engine) Insert(beans ...interface{}) (int64, error) { - session, err := engine.MakeSession() + session := engine.NewSession() defer session.Close() - if err != nil { - return -1, err - } - defer engine.Statement.Init() - session.Statement = engine.Statement return session.Insert(beans...) } func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) { - session, err := engine.MakeSession() + session := engine.NewSession() defer session.Close() - if err != nil { - return -1, err - } - defer engine.Statement.Init() - session.Statement = engine.Statement return session.Update(bean, condiBeans...) } func (engine *Engine) Delete(bean interface{}) (int64, error) { - session, err := engine.MakeSession() + session := engine.NewSession() defer session.Close() - if err != nil { - return -1, err - } - defer engine.Statement.Init() - session.Statement = engine.Statement return session.Delete(bean) } -func (engine *Engine) Get(bean interface{}) error { - session, err := engine.MakeSession() +func (engine *Engine) Get(bean interface{}) (bool, error) { + session := engine.NewSession() defer session.Close() - if err != nil { - return err - } - defer engine.Statement.Init() - session.Statement = engine.Statement return session.Get(bean) } func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error { - session, err := engine.MakeSession() + session := engine.NewSession() defer session.Close() - if err != nil { - return err - } - defer engine.Statement.Init() - session.Statement = engine.Statement return session.Find(beans, condiBeans...) } func (engine *Engine) Count(bean interface{}) (int64, error) { - session, err := engine.MakeSession() + session := engine.NewSession() defer session.Close() - if err != nil { - return 0, err - } - defer engine.Statement.Init() - session.Statement = engine.Statement return session.Count(bean) } diff --git a/examples/goroutine.go b/examples/goroutine.go new file mode 100644 index 00000000..4db301a2 --- /dev/null +++ b/examples/goroutine.go @@ -0,0 +1,88 @@ +package main + +import ( + //xorm "github.com/lunny/xorm" + "fmt" + _ "github.com/go-sql-driver/mysql" + _ "github.com/mattn/go-sqlite3" + "os" + //"time" + xorm "xorm" +) + +type User struct { + Id int + Name string +} + +func sqliteEngine() (*xorm.Engine, error) { + os.Remove("./test.db") + return xorm.NewEngine("sqlite3", "./goroutine.db") +} + +func mysqlEngine() (*xorm.Engine, error) { + return xorm.NewEngine("mysql", "root:123@/test?charset=utf8") +} + +func main() { + engine, err := sqliteEngine() + // engine, err := mysqlEngine() + + if err != nil { + fmt.Println(err) + return + } + + u := &User{} + + err = engine.CreateTables(u) + if err != nil { + fmt.Println(err) + return + } + + size := 10 + queue := make(chan int, size) + + for i := 0; i < size; i++ { + go func(x int) { + //x := i + err := engine.Test() + if err != nil { + fmt.Println(err) + } else { + err = engine.Map(u) + if err != nil { + fmt.Println("Map user failed") + } else { + for j := 0; j < 10; j++ { + if x+j < 2 { + _, err = engine.Get(u) + } else if x+j < 4 { + users := make([]User, 0) + err = engine.Find(&users) + } else if x+j < 8 { + _, err = engine.Count(u) + } else if x+j < 16 { + _, err = engine.Insert(&User{Name: "xlw"}) + } else if x+j < 32 { + _, err = engine.Id(1).Delete(u) + } + if err != nil { + fmt.Println(err) + queue <- x + return + } + } + fmt.Printf("%v success!\n", x) + } + } + queue <- x + }(i) + } + + for i := 0; i < size; i++ { + <-queue + } + fmt.Println("end") +} diff --git a/mysql_test.go b/mysql_test.go index 78ff3df7..8f54c8ce 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -9,7 +9,7 @@ var me Engine func TestMysql(t *testing.T) { // You should drop all tables before executing this testing - me = Create("mysql", "root:@/xorm_test?charset=utf8") + me = Create("mysql", "root:123@/test?charset=utf8") me.ShowSQL = true directCreateTable(&me, t) diff --git a/pool.go b/pool.go new file mode 100644 index 00000000..34e51b65 --- /dev/null +++ b/pool.go @@ -0,0 +1,78 @@ +package xorm + +import ( + "database/sql" + //"fmt" + //"sync" + //"time" +) + +type IConnectionPool interface { + RetrieveDB(engine *Engine) (*sql.DB, error) + ReleaseDB(engine *Engine, db *sql.DB) +} + +type NoneConnectPool struct { +} + +func (p NoneConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { + db, err = engine.OpenDB() + return +} + +func (p NoneConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { + db.Close() +} + +/* +var ( + total int = 0 +) + +type SimpleConnectPool struct { + releasedSessions []*sql.DB + cur int + usingSessions map[*sql.DB]time.Time + maxWaitTimeOut int + mutex *sync.Mutex +} + +func (p SimpleConnectPool) RetrieveDB(engine *Engine) (*sql.DB, error) { + p.mutex.Lock() + defer p.mutex.Unlock() + var db *sql.DB = nil + var err error = nil + fmt.Printf("%x, rbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions)) + if p.cur < 0 { + total = total + 1 + fmt.Printf("new %v\n", total) + db, err = engine.OpenDB() + if err != nil { + return nil, err + } + p.usingSessions[db] = time.Now() + } else { + db = p.releasedSessions[p.cur] + p.usingSessions[db] = time.Now() + p.releasedSessions[p.cur] = nil + p.cur = p.cur - 1 + fmt.Println("release one") + } + + fmt.Printf("%x, rend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions)) + return db, nil +} + +func (p SimpleConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { + p.mutex.Lock() + defer p.mutex.Unlock() + fmt.Printf("%x, lbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions)) + if p.cur >= 29 { + db.Close() + } else { + p.cur = p.cur + 1 + p.releasedSessions[p.cur] = db + } + delete(p.usingSessions, db) + fmt.Printf("%x, lend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions)) +}*/ diff --git a/session.go b/session.go index 3f21cd95..7a518395 100644 --- a/session.go +++ b/session.go @@ -21,6 +21,7 @@ type Session struct { func (session *Session) Init() { session.Statement = Statement{Engine: session.Engine} + session.Statement.Init() session.IsAutoCommit = true session.IsCommitedOrRollbacked = false } @@ -28,11 +29,19 @@ func (session *Session) Init() { func (session *Session) Close() { defer func() { if session.Db != nil { - session.Db.Close() + session.Engine.Pool.ReleaseDB(session.Engine, session.Db) + session.Db = nil + session.Tx = nil + session.Init() } }() } +func (session *Session) Sql(querystring string, args ...interface{}) *Session { + session.Statement.Sql(querystring, args...) + return session +} + func (session *Session) Where(querystring string, args ...interface{}) *Session { session.Statement.Where(querystring, args...) return session @@ -86,7 +95,22 @@ func (session *Session) Having(conditions string) *Session { return session } +func (session *Session) newDb() error { + if session.Db == nil { + db, err := session.Engine.Pool.RetrieveDB(session.Engine) + if err != nil { + return err + } + session.Db = db + } + return nil +} + func (session *Session) Begin() error { + err := session.newDb() + if err != nil { + return err + } if session.IsAutoCommit { tx, err := session.Db.Begin() if err != nil { @@ -189,31 +213,38 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b v = x } else if session.Statement.UseCascade { - session.Engine.AutoMapType(structField.Type()) - if _, ok := session.Engine.Tables[structField.Type()]; ok { + table := session.Engine.AutoMapType(structField.Type()) + if table != nil { x, err := strconv.ParseInt(string(data), 10, 64) if err != nil { return errors.New("arg " + key + " as int: " + err.Error()) } - if x != 0 { structInter := reflect.New(structField.Type()) + st := session.Statement session.Statement.Init() - err = session.Id(x).Get(structInter.Interface()) + has, err := session.Id(x).Get(structInter.Interface()) if err != nil { + session.Statement = st return err } - - v = structInter.Elem().Interface() + if has { + v = structInter.Elem().Interface() + session.Statement = st + } else { + fmt.Println("cascade obj is not exist!") + session.Statement = st + continue + } } else { - //fmt.Println("zero value of struct type " + structField.Type().String()) continue } - } else { fmt.Println("unsupported struct type in Scan: " + structField.Type().String()) continue } + } else { + continue } default: return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) @@ -241,6 +272,11 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, } func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { + err := session.newDb() + if err != nil { + return nil, err + } + if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1) } @@ -263,37 +299,48 @@ func (session *Session) CreateTable(bean interface{}) error { return err } -func (session *Session) Get(bean interface{}) error { +func (session *Session) Get(bean interface{}) (bool, error) { statement := session.Statement defer statement.Init() statement.Limit(1) - - fmt.Println(bean) - - sql, args := statement.genGetSql(bean) + var sql string + var args []interface{} + if statement.RawSQL == "" { + sql, args = statement.genGetSql(bean) + } else { + sql = statement.RawSQL + args = statement.RawParams + } resultsSlice, err := session.Query(sql, args...) if err != nil { - return err + return false, err } if len(resultsSlice) == 0 { - return nil + return false, nil } else if len(resultsSlice) == 1 { results := resultsSlice[0] err := session.scanMapIntoStruct(bean, results) if err != nil { - return err + return false, err } } else { - return errors.New("More than one record") + return false, errors.New("More than one record") } - return nil + return true, nil } func (session *Session) Count(bean interface{}) (int64, error) { statement := session.Statement defer session.Statement.Init() - sql, args := statement.genCountSql(bean) + var sql string + var args []interface{} + if statement.RawSQL == "" { + sql, args = statement.genCountSql(bean) + } else { + sql = statement.RawSQL + args = statement.RawParams + } resultsSlice, err := session.Query(sql, args...) if err != nil { @@ -301,9 +348,12 @@ func (session *Session) Count(bean interface{}) (int64, error) { } var total int64 = 0 - for _, results := range resultsSlice { - total, err = strconv.ParseInt(string(results["total"]), 10, 64) - break + if len(resultsSlice) > 0 { + results := resultsSlice[0] + for _, value := range results { + total, err = strconv.ParseInt(string(value), 10, 64) + break + } } return int64(total), err @@ -327,8 +377,17 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) statement.BeanArgs = args } - sql := statement.generateSql() - resultsSlice, err := session.Query(sql, append(statement.Params, statement.BeanArgs...)...) + var sql string + var args []interface{} + if statement.RawSQL == "" { + sql = statement.generateSql() + args = append(statement.Params, statement.BeanArgs...) + } else { + sql = statement.RawSQL + args = statement.RawParams + } + + resultsSlice, err := session.Query(sql, args...) if err != nil { return err @@ -359,7 +418,45 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return nil } +func (session *Session) Ping() error { + err := session.newDb() + if err != nil { + return err + } + + return session.Db.Ping() +} + +func (session *Session) CreateAll() error { + for _, table := range session.Engine.Tables { + session.Statement.RefTable = table + sql := session.Statement.genCreateSQL() + _, err := session.Exec(sql) + if err != nil { + return err + } + } + return nil +} + +func (session *Session) DropAll() error { + for _, table := range session.Engine.Tables { + session.Statement.RefTable = table + sql := session.Statement.genDropSQL() + _, err := session.Exec(sql) + if err != nil { + return err + } + } + return nil +} + func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { + err = session.newDb() + if err != nil { + return nil, err + } + if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1) } @@ -635,7 +732,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - statement := fmt.Sprintf("UPDATE %v%v%v SET %v %v", + sql := fmt.Sprintf("UPDATE %v%v%v SET %v %v", session.Engine.QuoteIdentifier, session.Statement.TableName(), session.Engine.QuoteIdentifier, @@ -643,7 +740,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 condition) eargs := append(append(args, st.Params...), condiArgs...) - res, err := session.Exec(statement, eargs...) + res, err := session.Exec(sql, eargs...) if err != nil { return -1, err } diff --git a/sqlite3_test.go b/sqlite3_test.go index 694e481e..fe9af9f9 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -6,65 +6,147 @@ import ( "testing" ) -var se Engine +var se *Engine -func TestSqlite(t *testing.T) { - os.Remove("./test.db") - se = Create("sqlite3", "./test.db") - se.ShowSQL = true +func autoConn() { + if se == nil { + os.Remove("./test.db") + se, _ = NewEngine("sqlite3", "./test.db") + se.ShowSQL = true + } } func TestSqliteCreateTable(t *testing.T) { - directCreateTable(&se, t) + autoConn() + directCreateTable(se, t) } func TestSqliteMapper(t *testing.T) { - mapper(&se, t) + autoConn() + mapper(se, t) } func TestSqliteInsert(t *testing.T) { - insert(&se, t) + autoConn() + insert(se, t) } func TestSqliteQuery(t *testing.T) { - query(&se, t) + autoConn() + query(se, t) } func TestSqliteExec(t *testing.T) { - exec(&se, t) + autoConn() + exec(se, t) } func TestSqliteInsertAutoIncr(t *testing.T) { - insertAutoIncr(&se, t) + autoConn() + insertAutoIncr(se, t) } -type sss struct { -} - -func (s sss) TestInsertMulti(t *testing.T) { - insertMulti(&se, t) +func TestInsertMulti(t *testing.T) { + autoConn() + insertMulti(se, t) } func TestSqliteInsertMulti(t *testing.T) { - insertMulti(&se, t) - - insertTwoTable(&se, t) - update(&se, t) - testdelete(&se, t) - get(&se, t) - cascadeGet(&se, t) - find(&se, t) - findMap(&se, t) - count(&se, t) - where(&se, t) - in(&se, t) - limit(&se, t) - order(&se, t) - join(&se, t) - having(&se, t) - transaction(&se, t) - combineTransaction(&se, t) - table(&se, t) - createMultiTables(&se, t) - tableOp(&se, t) + autoConn() + insertMulti(se, t) +} + +func TestSqliteInsertTwoTable(t *testing.T) { + autoConn() + insertTwoTable(se, t) +} + +func TestSqliteUpdate(t *testing.T) { + autoConn() + update(se, t) +} + +func TestSqliteDelete(t *testing.T) { + autoConn() + testdelete(se, t) +} + +func TestSqliteGet(t *testing.T) { + autoConn() + get(se, t) +} + +func TestSqliteCascadeGet(t *testing.T) { + autoConn() + cascadeGet(se, t) +} + +func TestSqliteFind(t *testing.T) { + autoConn() + find(se, t) +} + +func TestSqliteFindMap(t *testing.T) { + autoConn() + findMap(se, t) +} + +func TestSqliteCount(t *testing.T) { + autoConn() + count(se, t) +} + +func TestSqliteWhere(t *testing.T) { + autoConn() + where(se, t) +} + +func TestSqliteIn(t *testing.T) { + autoConn() + in(se, t) +} + +func TestSqliteLimit(t *testing.T) { + autoConn() + limit(se, t) +} + +func TestSqliteOrder(t *testing.T) { + autoConn() + order(se, t) +} + +func TestSqliteJoin(t *testing.T) { + autoConn() + join(se, t) +} + +func TestSqliteHaving(t *testing.T) { + autoConn() + having(se, t) +} + +func TestSqliteTransaction(t *testing.T) { + autoConn() + transaction(se, t) +} + +func TestSqliteCombineTransaction(t *testing.T) { + autoConn() + combineTransaction(se, t) +} + +func TestSqliteTable(t *testing.T) { + autoConn() + table(se, t) +} + +func TestSqliteCreateMultiTables(t *testing.T) { + autoConn() + createMultiTables(se, t) +} + +func TestSqliteTableOp(t *testing.T) { + autoConn() + tableOp(se, t) } diff --git a/statement.go b/statement.go index dc18aa90..2e4b1f7c 100644 --- a/statement.go +++ b/statement.go @@ -21,6 +21,8 @@ type Statement struct { HavingStr string ColumnStr string AltTableName string + RawSQL string + RawParams []interface{} UseCascade bool BeanArgs []interface{} } @@ -46,9 +48,16 @@ func (statement *Statement) Init() { statement.HavingStr = "" statement.ColumnStr = "" statement.AltTableName = "" + statement.RawSQL = "" + statement.RawParams = make([]interface{}, 0) statement.BeanArgs = make([]interface{}, 0) } +func (statement *Statement) Sql(querystring string, args ...interface{}) { + statement.RawSQL = querystring + statement.RawParams = args +} + func (statement *Statement) Where(querystring string, args ...interface{}) { statement.WhereStr = querystring statement.Params = args diff --git a/testbase.go b/testbase.go index 17ed6b0b..60924b65 100644 --- a/testbase.go +++ b/testbase.go @@ -171,21 +171,29 @@ func testdelete(engine *Engine, t *testing.T) { func get(engine *Engine, t *testing.T) { user := Userinfo{Uid: 2} - err := engine.Get(&user) + has, err := engine.Get(&user) if err != nil { t.Error(err) } - fmt.Println(user) + if has { + fmt.Println(user) + } else { + fmt.Println("no record id is 2") + } } func cascadeGet(engine *Engine, t *testing.T) { user := Userinfo{Uid: 11} - err := engine.Get(&user) + has, err := engine.Get(&user) if err != nil { t.Error(err) } - fmt.Println(user) + if has { + fmt.Println(user) + } else { + fmt.Println("no record id is 2") + } } func find(engine *Engine, t *testing.T) { @@ -290,14 +298,14 @@ func transaction(engine *Engine, t *testing.T) { counter() defer counter() - session, err := engine.MakeSession() + session := engine.NewSession() defer session.Close() + + err := session.Begin() if err != nil { t.Error(err) return } - - session.Begin() //session.IsAutoRollback = false user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} _, err = session.Insert(&user1) @@ -340,14 +348,14 @@ func combineTransaction(engine *Engine, t *testing.T) { counter() defer counter() - session, err := engine.MakeSession() + session := engine.NewSession() defer session.Close() + + err := session.Begin() if err != nil { t.Error(err) return } - - session.Begin() //session.IsAutoRollback = false user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} _, err = session.Insert(&user1) @@ -379,19 +387,19 @@ func combineTransaction(engine *Engine, t *testing.T) { } func table(engine *Engine, t *testing.T) { - engine.Table("user_user").CreateTables(&Userinfo{}) + engine.Table("user_user").CreateTable(&Userinfo{}) } func createMultiTables(engine *Engine, t *testing.T) { - session, err := engine.MakeSession() + session := engine.NewSession() defer session.Close() + + user := &Userinfo{} + err := session.Begin() if err != nil { t.Error(err) return } - - user := &Userinfo{} - session.Begin() for i := 0; i < 10; i++ { err = session.Table(fmt.Sprintf("user_%v", i)).CreateTable(user) if err != nil { @@ -414,7 +422,7 @@ func tableOp(engine *Engine, t *testing.T) { t.Error(err) } - err = engine.Table(tableName).Get(&Userinfo{Username: "tablexiao"}) + _, err = engine.Table(tableName).Get(&Userinfo{Username: "tablexiao"}) if err != nil { t.Error(err) } diff --git a/xorm.go b/xorm.go index 0203d459..75d7fa6f 100644 --- a/xorm.go +++ b/xorm.go @@ -1,26 +1,42 @@ package xorm import ( + //"database/sql" + "errors" + "fmt" "reflect" + "sync" + //"time" ) -func NewEngine(driverName string, dataSourceName string) *Engine { +func NewEngine(driverName string, dataSourceName string) (*Engine, error) { engine := &Engine{ShowSQL: false, DriverName: driverName, Mapper: SnakeMapper{}, DataSourceName: dataSourceName} - engine.Tables = make(map[reflect.Type]Table) - engine.Statement.Engine = engine + engine.Tables = make(map[reflect.Type]*Table) + engine.mutex = &sync.Mutex{} engine.InsertMany = true engine.TagIdentifier = "xorm" + engine.QuoteIdentifier = "`" if driverName == SQLITE { engine.Dialect = sqlite3{} engine.AutoIncrement = "AUTOINCREMENT" - } else { + //engine.Pool = NoneConnectPool{} + } else if driverName == MYSQL { engine.Dialect = mysql{} engine.AutoIncrement = "AUTO_INCREMENT" + } else { + return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName)) } - engine.QuoteIdentifier = "`" + /*engine.Pool = SimpleConnectPool{ + releasedSessions: make([]*sql.DB, 30), + usingSessions: map[*sql.DB]time.Time{}, + cur: -1, + maxWaitTimeOut: 14400, + mutex: &sync.Mutex{}, + }*/ + engine.Pool = NoneConnectPool{} - return engine + return engine, nil }