diff --git a/engine.go b/engine.go index 2a0d7d46..8ed6903c 100644 --- a/engine.go +++ b/engine.go @@ -287,8 +287,8 @@ func (e *Engine) DropAll() error { func (e *Engine) CreateTables(beans ...interface{}) error { session := e.NewSession() - defer session.Close() err := session.Begin() + defer session.Close() if err != nil { return err } diff --git a/examples/goroutine.go b/examples/goroutine.go index 4db301a2..5251d71e 100644 --- a/examples/goroutine.go +++ b/examples/goroutine.go @@ -7,11 +7,12 @@ import ( _ "github.com/mattn/go-sqlite3" "os" //"time" + "sync/atomic" xorm "xorm" ) type User struct { - Id int + Id int64 Name string } @@ -24,18 +25,10 @@ func mysqlEngine() (*xorm.Engine, error) { return xorm.NewEngine("mysql", "root:123@/test?charset=utf8") } -func main() { - engine, err := sqliteEngine() - // engine, err := mysqlEngine() +var u *User = &User{} - if err != nil { - fmt.Println(err) - return - } - - u := &User{} - - err = engine.CreateTables(u) +func test(engine *xorm.Engine) { + err := engine.CreateTables(u) if err != nil { fmt.Println(err) return @@ -84,5 +77,24 @@ func main() { for i := 0; i < size; i++ { <-queue } + + conns := atomic.LoadInt32(&xorm.ConnectionNum) + fmt.Println("connection number:", conns) fmt.Println("end") } + +func main() { + engine, err := sqliteEngine() + if err != nil { + fmt.Println(err) + return + } + test(engine) + + engine, err = mysqlEngine() + if err != nil { + fmt.Println(err) + return + } + test(engine) +} diff --git a/pool.go b/pool.go index 34e51b65..e270bee3 100644 --- a/pool.go +++ b/pool.go @@ -2,8 +2,9 @@ package xorm import ( "database/sql" - //"fmt" + "fmt" //"sync" + "sync/atomic" //"time" ) @@ -15,21 +16,22 @@ type IConnectionPool interface { type NoneConnectPool struct { } +var ConnectionNum int32 = 0 + func (p NoneConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { + atomic.AddInt32(&ConnectionNum, 1) db, err = engine.OpenDB() + fmt.Printf("--open a connection--%x\n", &db) return } func (p NoneConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { + atomic.AddInt32(&ConnectionNum, -1) + fmt.Printf("--close a connection--%x\n", &db) db.Close() } -/* -var ( - total int = 0 -) - -type SimpleConnectPool struct { +/*type SimpleConnectPool struct { releasedSessions []*sql.DB cur int usingSessions map[*sql.DB]time.Time @@ -44,8 +46,8 @@ func (p SimpleConnectPool) RetrieveDB(engine *Engine) (*sql.DB, error) { var err error = nil fmt.Printf("%x, rbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions)) if p.cur < 0 { - total = total + 1 - fmt.Printf("new %v\n", total) + ConnectionNum = ConnectionNum + 1 + fmt.Printf("new %v\n", ConnectionNum) db, err = engine.OpenDB() if err != nil { return nil, err @@ -68,6 +70,7 @@ func (p SimpleConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { defer p.mutex.Unlock() fmt.Printf("%x, lbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingSessions)) if p.cur >= 29 { + ConnectionNum = ConnectionNum - 1 db.Close() } else { p.cur = p.cur + 1 diff --git a/session.go b/session.go index 7f7d465f..4f9e5d0c 100644 --- a/session.go +++ b/session.go @@ -221,19 +221,16 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b } if x != 0 { structInter := reflect.New(structField.Type()) - st := session.Statement - session.Statement.Init() - has, err := session.Id(x).Get(structInter.Interface()) + newsession := session.Engine.NewSession() + defer newsession.Close() + has, err := newsession.Id(x).Get(structInter.Interface()) if err != nil { - session.Statement = st return err } if has { v = structInter.Elem().Interface() - session.Statement = st } else { fmt.Println("cascade obj is not exist!") - session.Statement = st continue } } else { @@ -273,6 +270,9 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { err := session.newDb() + if session.IsAutoCommit { + defer session.Close() + } if err != nil { return nil, err } @@ -457,6 +457,10 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice return nil, err } + if session.IsAutoCommit { + defer session.Close() + } + if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1) } @@ -538,7 +542,11 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { isInTransaction := !session.IsAutoCommit if !isInTransaction { - session.Begin() + err = session.Begin() + defer session.Close() + if err != nil { + return 0, err + } } for _, bean := range beans { @@ -548,7 +556,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { lastId, err = session.InsertMulti(bean) if err != nil { if !isInTransaction { - session.Rollback() + err = session.Rollback() } return lastId, err } @@ -558,7 +566,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { lastId, err = session.InsertOne(sliceValue.Index(i).Interface()) if err != nil { if !isInTransaction { - session.Rollback() + err = session.Rollback() } return lastId, err } @@ -568,7 +576,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { lastId, err = session.InsertOne(bean) if err != nil { if !isInTransaction { - session.Rollback() + err = session.Rollback() } return lastId, err } @@ -707,7 +715,17 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { pkValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(table.PKColumn().FieldName) if pkValue.CanSet() { var v interface{} = id - pkValue.Set(reflect.ValueOf(v)) + switch pkValue.Type().Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32: + v = int(id) + pkValue.Set(reflect.ValueOf(v)) + case reflect.Int64: + pkValue.Set(reflect.ValueOf(v)) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v = uint(id) + pkValue.Set(reflect.ValueOf(v)) + } + } } diff --git a/xorm.go b/xorm.go index 75d7fa6f..ba46aa89 100644 --- a/xorm.go +++ b/xorm.go @@ -29,14 +29,14 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName)) } - /*engine.Pool = SimpleConnectPool{ + /*engine.Pool = &SimpleConnectPool{ releasedSessions: make([]*sql.DB, 30), usingSessions: map[*sql.DB]time.Time{}, cur: -1, maxWaitTimeOut: 14400, mutex: &sync.Mutex{}, }*/ - engine.Pool = NoneConnectPool{} + engine.Pool = &NoneConnectPool{} return engine, nil }