diff --git a/core/stmt.go b/core/stmt.go index 260843d5..3247efed 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -93,7 +93,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result if err != nil { return nil, err } - res, err := s.Stmt.ExecContext(ctx, args) + res, err := s.Stmt.ExecContext(ctx, args...) hookCtx.End(ctx, res, err) if err := s.db.afterProcess(hookCtx); err != nil { return nil, err diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 601d4a26..5d1558f4 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -915,3 +915,64 @@ func TestGetVars(t *testing.T) { assert.EqualValues(t, "xlw", name) assert.EqualValues(t, 42, age) } + +func TestGetWithPrepare(t *testing.T) { + type GetVarsWithPrepare struct { + Id int64 + Name string + Age int + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetVarsWithPrepare)) + + _, err := testEngine.Insert(&GetVarsWithPrepare{ + Name: "xlw", + Age: 42, + }) + assert.NoError(t, err) + + var v1 GetVarsWithPrepare + has, err := testEngine.Prepare().ID(1).Get(&v1) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "xlw", v1.Name) + assert.EqualValues(t, 42, v1.Age) + + sess := testEngine.NewSession() + defer sess.Close() + + var v2 GetVarsWithPrepare + has, err = sess.Prepare().ID(1).Get(&v2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "xlw", v2.Name) + assert.EqualValues(t, 42, v2.Age) + + var v3 GetVarsWithPrepare + has, err = sess.Prepare().ID(1).Get(&v3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "xlw", v3.Name) + assert.EqualValues(t, 42, v3.Age) + + err = sess.Begin() + assert.NoError(t, err) + + cnt, err := sess.Prepare().Insert(&GetVarsWithPrepare{ + Name: "xlw2", + Age: 12, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = sess.Prepare().Insert(&GetVarsWithPrepare{ + Name: "xlw3", + Age: 13, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + err = sess.Commit() + assert.NoError(t, err) +} diff --git a/interface.go b/interface.go index 42dc9a0a..b9e88505 100644 --- a/interface.go +++ b/interface.go @@ -99,6 +99,7 @@ type EngineInterface interface { MapCacher(interface{}, caches.Cacher) error NewSession() *Session NoAutoTime() *Session + Prepare() *Session Quote(string) string SetCacher(string, caches.Cacher) SetConnMaxLifetime(time.Duration) diff --git a/session.go b/session.go index f51fd41b..2c916335 100644 --- a/session.go +++ b/session.go @@ -79,7 +79,8 @@ type Session struct { afterClosures []func(interface{}) afterProcessors []executedProcessor - stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) + stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) + txStmtCache map[uint32]*core.Stmt // for tx statement lastSQL string lastSQLArgs []interface{} @@ -130,6 +131,7 @@ func newSession(engine *Engine) *Session { afterClosures: make([]func(interface{}), 0), afterProcessors: make([]executedProcessor, 0), stmtCache: make(map[uint32]*core.Stmt), + txStmtCache: make(map[uint32]*core.Stmt), lastSQL: "", lastSQLArgs: make([]interface{}, 0), @@ -150,6 +152,12 @@ func (session *Session) Close() error { } } + for _, v := range session.txStmtCache { + if err := v.Close(); err != nil { + return err + } + } + if !session.isClosed { // When Close be called, if session is a transaction and do not call // Commit or Rollback, then call Rollback. @@ -160,6 +168,7 @@ func (session *Session) Close() error { } session.tx = nil session.stmtCache = nil + session.txStmtCache = nil session.isClosed = true } return nil @@ -200,6 +209,7 @@ func (session *Session) IsClosed() bool { func (session *Session) resetStatement() { if session.autoResetStatement { session.statement.Reset() + session.prepareStmt = false } } @@ -370,6 +380,21 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, return } +func (session *Session) doPrepareTx(sqlStr string) (stmt *core.Stmt, err error) { + crc := crc32.ChecksumIEEE([]byte(sqlStr)) + // TODO try hash(sqlStr+len(sqlStr)) + var has bool + stmt, has = session.txStmtCache[crc] + if !has { + stmt, err = session.tx.PrepareContext(session.ctx, sqlStr) + if err != nil { + return nil, err + } + session.txStmtCache[crc] = stmt + } + return +} + func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) { var col = table.GetColumnIdx(colName, idx) if col == nil { diff --git a/session_raw.go b/session_raw.go index cee29fc7..acb106a5 100644 --- a/session_raw.go +++ b/session_raw.go @@ -46,25 +46,22 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row return nil, err } - rows, err := stmt.QueryContext(session.ctx, args...) - if err != nil { - return nil, err - } - return rows, nil + return stmt.QueryContext(session.ctx, args...) } - rows, err := db.QueryContext(session.ctx, sqlStr, args...) + return db.QueryContext(session.ctx, sqlStr, args...) + } + + if session.prepareStmt { + stmt, err := session.doPrepareTx(sqlStr) if err != nil { return nil, err } - return rows, nil + + return stmt.QueryContext(session.ctx, args...) } - rows, err := session.tx.QueryContext(session.ctx, sqlStr, args...) - if err != nil { - return nil, err - } - return rows, nil + return session.tx.QueryContext(session.ctx, sqlStr, args...) } func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row { @@ -160,6 +157,13 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er session.lastSQLArgs = args if !session.isAutoCommit { + if session.prepareStmt { + stmt, err := session.doPrepareTx(sqlStr) + if err != nil { + return nil, err + } + return stmt.ExecContext(session.ctx, args...) + } return session.tx.ExecContext(session.ctx, sqlStr, args...) } @@ -168,12 +172,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er if err != nil { return nil, err } - - res, err := stmt.ExecContext(session.ctx, args...) - if err != nil { - return nil, err - } - return res, nil + return stmt.ExecContext(session.ctx, args...) } return session.DB().ExecContext(session.ctx, sqlStr, args...)