diff --git a/session.go b/session.go index e8165a41..a983237b 100644 --- a/session.go +++ b/session.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "hash/crc32" "reflect" "strconv" "strings" @@ -33,6 +34,8 @@ type Session struct { beforeClosures []func(interface{}) afterClosures []func(interface{}) + + stmtCache map[uint32]*sql.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) } // Method Init reset the session as the init status. @@ -53,14 +56,17 @@ func (session *Session) Init() { // Method Close release the connection from pool func (session *Session) Close() { - defer func() { - if session.Db != nil { - session.Engine.Pool.ReleaseDB(session.Engine, session.Db) - session.Db = nil - session.Tx = nil - session.Init() - } - }() + for _, v := range session.stmtCache { + v.Close() + } + + if session.Db != nil { + session.Engine.Pool.ReleaseDB(session.Engine, session.Db) + session.Db = nil + session.Tx = nil + session.stmtCache = nil + session.Init() + } } // Method Sql provides raw sql input parameter. When you have a complex SQL statement @@ -256,6 +262,7 @@ func (session *Session) newDb() error { return err } session.Db = db + session.stmtCache = make(map[uint32]*sql.Stmt, 0) } return nil } @@ -394,13 +401,13 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b //Execute sql func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) { - rs, err := session.Db.Prepare(sqlStr) + stmt, err := session.doPrepare(sqlStr) if err != nil { return nil, err } - defer rs.Close() + //defer stmt.Close() - res, err := rs.Exec(args...) + res, err := stmt.Exec(args...) if err != nil { return nil, err } @@ -866,6 +873,21 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { return nil } +func (session *Session) doPrepare(sqlStr string) (stmt *sql.Stmt, err error) { + crc := crc32.ChecksumIEEE([]byte(sqlStr)) + // TODO try hash(sqlStr+len(sqlStr)) + var has bool + stmt, has = session.stmtCache[crc] + if !has { + stmt, err = session.Db.Prepare(sqlStr) + if err != nil { + return nil, err + } + session.stmtCache[crc] = stmt + } + return +} + // get retrieve one record from database, bean's non-empty fields // will be as conditions func (session *Session) Get(bean interface{}) (bool, error) { @@ -901,11 +923,11 @@ func (session *Session) Get(bean interface{}) (bool, error) { var rawRows *sql.Rows session.queryPreprocess(&sqlStr, args...) if session.IsAutoCommit { - stmt, err := session.Db.Prepare(sqlStr) + stmt, err := session.doPrepare(sqlStr) if err != nil { return false, err } - defer stmt.Close() + // defer stmt.Close() // !nashtsai! don't close due to stmt is cached and bounded to this session rawRows, err = stmt.Query(args...) } else { rawRows, err = session.Tx.Query(sqlStr, args...) @@ -1070,11 +1092,10 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) // defer rawRows.Close() if session.IsAutoCommit { - stmt, err = session.Db.Prepare(sqlStr) + stmt, err = session.doPrepare(sqlStr) if err != nil { return err } - defer stmt.Close() rawRows, err = stmt.Query(args...) } else { rawRows, err = session.Tx.Query(sqlStr, args...) @@ -1165,19 +1186,19 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return nil } -func (session *Session) queryRows(rawStmt **sql.Stmt, rawRows **sql.Rows, sqlStr string, args ...interface{}) error { - var err error - if session.IsAutoCommit { - *rawStmt, err = session.Db.Prepare(sqlStr) - if err != nil { - return err - } - *rawRows, err = (*rawStmt).Query(args...) - } else { - *rawRows, err = session.Tx.Query(sqlStr, args...) - } - return err -} +// func (session *Session) queryRows(rawStmt **sql.Stmt, rawRows **sql.Rows, sqlStr string, args ...interface{}) error { +// var err error +// if session.IsAutoCommit { +// *rawStmt, err = session.doPrepare(sqlStr) +// if err != nil { +// return err +// } +// *rawRows, err = (*rawStmt).Query(args...) +// } else { +// *rawRows, err = session.Tx.Query(sqlStr, args...) +// } +// return err +// } // Test if database is ok func (session *Session) Ping() error { diff --git a/tests/mysql_ddl.sql b/tests/mysql_ddl.sql index 92c95829..db20aa33 100644 --- a/tests/mysql_ddl.sql +++ b/tests/mysql_ddl.sql @@ -2,3 +2,4 @@ --DROP DATABASE xorm_test2; CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET utf8 COLLATE utf8_general_ci; CREATE DATABASE IF NOT EXISTS xorm_test2 CHARACTER SET utf8 COLLATE utf8_general_ci; +CREATE DATABASE IF NOT EXISTS xorm_test3 CHARACTER SET utf8 COLLATE utf8_general_ci;