diff --git a/engine.go b/engine.go index 80baaa4f..09fcd774 100644 --- a/engine.go +++ b/engine.go @@ -232,9 +232,7 @@ func (engine *Engine) Dialect() dialects.Dialect { // NewSession New a session func (engine *Engine) NewSession() *Session { - session := &Session{engine: engine} - session.Init() - return session + return newSession(engine) } // Close the engine diff --git a/error.go b/error.go index 3708ce4f..cfa5c819 100644 --- a/error.go +++ b/error.go @@ -9,6 +9,8 @@ import ( ) var ( + // ErrPtrSliceType represents a type error + ErrPtrSliceType = errors.New("A point to a slice is needed") // ErrParamsType params error ErrParamsType = errors.New("Params type error") // ErrTableNotFound table not found error diff --git a/session.go b/session.go index 13ebfe27..81afcc03 100644 --- a/session.go +++ b/session.go @@ -47,24 +47,24 @@ func (e ErrFieldIsNotValid) Error() string { return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) } -type sessionType int +type sessionType bool const ( - engineSession sessionType = iota - groupSession + engineSession sessionType = false + groupSession sessionType = true ) // Session keep a pointer to sql.DB and provides all execution of all // kind of database operations. type Session struct { - db *core.DB engine *Engine tx *core.Tx statement *statements.Statement isAutoCommit bool isCommitedOrRollbacked bool isAutoClose bool - + isClosed bool + prepareStmt bool // Automatically reset the statement after operations that execute a SQL // query such as Count(), Find(), Get(), ... autoResetStatement bool @@ -75,28 +75,19 @@ type Session struct { afterDeleteBeans map[interface{}]*[]func(interface{}) // -- - beforeClosures []func(interface{}) - afterClosures []func(interface{}) - + beforeClosures []func(interface{}) + afterClosures []func(interface{}) afterProcessors []executedProcessor - prepareStmt bool - stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) + stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) lastSQL string lastSQLArgs []interface{} - showSQL bool ctx context.Context sessionType sessionType } -// Clone copy all the session's content and return a new session -func (session *Session) Clone() *Session { - var sess = *session - return &sess -} - func newSessionID() string { hash := sha256.New() _, err := io.CopyN(hash, rand.Reader, 50) @@ -108,63 +99,77 @@ func newSessionID() string { return mdStr[0:20] } -// Init reset the session as the init status. -func (session *Session) Init() { - session.statement = statements.NewStatement( - session.engine.dialect, - session.engine.tagParser, - session.engine.DatabaseTZ, - ) - session.db = session.engine.db - session.isAutoCommit = true - session.isCommitedOrRollbacked = false - session.isAutoClose = false - session.autoResetStatement = true - session.prepareStmt = false - - // !nashtsai! is lazy init better? - session.afterInsertBeans = make(map[interface{}]*[]func(interface{}), 0) - session.afterUpdateBeans = make(map[interface{}]*[]func(interface{}), 0) - session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0) - session.beforeClosures = make([]func(interface{}), 0) - session.afterClosures = make([]func(interface{}), 0) - session.stmtCache = make(map[uint32]*core.Stmt) - - session.afterProcessors = make([]executedProcessor, 0) - - session.lastSQL = "" - session.lastSQLArgs = []interface{}{} - - if session.engine.logSessionID { - session.ctx = context.WithValue(session.engine.defaultContext, log.SessionIDKey, newSessionID()) +func newSession(engine *Engine) *Session { + var ctx context.Context + if engine.logSessionID { + ctx = context.WithValue(engine.defaultContext, log.SessionIDKey, newSessionID()) } else { - session.ctx = session.engine.defaultContext + ctx = engine.defaultContext + } + + return &Session{ + ctx: ctx, + engine: engine, + tx: nil, + statement: statements.NewStatement( + engine.dialect, + engine.tagParser, + engine.DatabaseTZ, + ), + isClosed: false, + isAutoCommit: true, + isCommitedOrRollbacked: false, + isAutoClose: false, + autoResetStatement: true, + prepareStmt: false, + + afterInsertBeans: make(map[interface{}]*[]func(interface{}), 0), + afterUpdateBeans: make(map[interface{}]*[]func(interface{}), 0), + afterDeleteBeans: make(map[interface{}]*[]func(interface{}), 0), + beforeClosures: make([]func(interface{}), 0), + afterClosures: make([]func(interface{}), 0), + afterProcessors: make([]executedProcessor, 0), + stmtCache: make(map[uint32]*core.Stmt), + + lastSQL: "", + lastSQLArgs: make([]interface{}, 0), + + sessionType: engineSession, } } // Close release the connection from pool -func (session *Session) Close() { +func (session *Session) Close() error { for _, v := range session.stmtCache { - v.Close() + if err := v.Close(); err != nil { + return err + } } - if session.db != nil { + if !session.isClosed { // When Close be called, if session is a transaction and do not call // Commit or Rollback, then call Rollback. if session.tx != nil && !session.isCommitedOrRollbacked { - session.Rollback() + if err := session.Rollback(); err != nil { + return err + } } session.tx = nil session.stmtCache = nil - session.db = nil + session.isClosed = true } + return nil +} + +func (session *Session) db() *core.DB { + return session.engine.db } func (session *Session) getQueryer() core.Queryer { if session.tx != nil { return session.tx } - return session.db + return session.db() } // ContextCache enable context cache or not @@ -175,7 +180,7 @@ func (session *Session) ContextCache(context contexts.ContextCache) *Session { // IsClosed returns if session is closed func (session *Session) IsClosed() bool { - return session.db == nil + return session.isClosed } func (session *Session) resetStatement() { @@ -320,11 +325,7 @@ func (session *Session) Having(conditions string) *Session { // DB db return the wrapper of sql.DB func (session *Session) DB() *core.DB { - if session.db == nil { - session.db = session.engine.DB() - session.stmtCache = make(map[uint32]*core.Stmt, 0) - } - return session.db + return session.db() } func cleanupProcessorsClosures(slices *[]func(interface{})) { diff --git a/session_get.go b/session_get.go index e56ef2d7..afedcd1f 100644 --- a/session_get.go +++ b/session_get.go @@ -242,7 +242,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, if err != nil { return false, err } - // close it before covert data + // close it before convert data rows.Close() dataStruct := utils.ReflectValue(bean) diff --git a/session_insert.go b/session_insert.go index 3a0a7066..1270d5db 100644 --- a/session_insert.go +++ b/session_insert.go @@ -112,13 +112,14 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, ErrTableNotFound } - table := session.statement.RefTable - size := sliceValue.Len() - - var colNames []string - var colMultiPlaces []string - var args []interface{} - var cols []*schemas.Column + var ( + table = session.statement.RefTable + size = sliceValue.Len() + colNames []string + colMultiPlaces []string + args []interface{} + cols []*schemas.Column + ) for i := 0; i < size; i++ { v := sliceValue.Index(i) @@ -265,12 +266,11 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice { - return 0, ErrParamsType - + return 0, ErrPtrSliceType } if sliceValue.Len() <= 0 { - return 0, nil + return 0, ErrNoElementsOnSlice } return session.innerInsertMulti(rowsSlicePtr) @@ -483,7 +483,7 @@ func (session *Session) cacheInsert(table string) error { if cacher == nil { return nil } - session.engine.logger.Debugf("[cache] clear sql: %v", table) + session.engine.logger.Debugf("[cache] clear SQL: %v", table) cacher.ClearIds(table) return nil }