From ccc5c0abd42df230609f0faa5b15aea358fb02db Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 19 Oct 2021 21:35:04 +0800 Subject: [PATCH] Also fix prepare with exec --- core/stmt.go | 2 +- integrations/session_get_test.go | 26 +++++++++++++++++++++++--- session.go | 8 ++++++++ session_raw.go | 14 ++++++++------ 4 files changed, 40 insertions(+), 10 deletions(-) diff --git a/core/stmt.go b/core/stmt.go index 260843d5..3247efed 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -93,7 +93,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result if err != nil { return nil, err } - res, err := s.Stmt.ExecContext(ctx, args) + res, err := s.Stmt.ExecContext(ctx, args...) hookCtx.End(ctx, res, err) if err := s.db.afterProcess(hookCtx); err != nil { return nil, err diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index c73f9ea6..5d1558f4 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -933,7 +933,7 @@ func TestGetWithPrepare(t *testing.T) { assert.NoError(t, err) var v1 GetVarsWithPrepare - has, err := testEngine.Prepare().Get(&v1) + has, err := testEngine.Prepare().ID(1).Get(&v1) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "xlw", v1.Name) @@ -943,16 +943,36 @@ func TestGetWithPrepare(t *testing.T) { defer sess.Close() var v2 GetVarsWithPrepare - has, err = sess.Prepare().Get(&v2) + has, err = sess.Prepare().ID(1).Get(&v2) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "xlw", v2.Name) assert.EqualValues(t, 42, v2.Age) var v3 GetVarsWithPrepare - has, err = sess.Prepare().Get(&v3) + has, err = sess.Prepare().ID(1).Get(&v3) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "xlw", v3.Name) assert.EqualValues(t, 42, v3.Age) + + err = sess.Begin() + assert.NoError(t, err) + + cnt, err := sess.Prepare().Insert(&GetVarsWithPrepare{ + Name: "xlw2", + Age: 12, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = sess.Prepare().Insert(&GetVarsWithPrepare{ + Name: "xlw3", + Age: 13, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + err = sess.Commit() + assert.NoError(t, err) } diff --git a/session.go b/session.go index da4576c8..2c916335 100644 --- a/session.go +++ b/session.go @@ -131,6 +131,7 @@ func newSession(engine *Engine) *Session { afterClosures: make([]func(interface{}), 0), afterProcessors: make([]executedProcessor, 0), stmtCache: make(map[uint32]*core.Stmt), + txStmtCache: make(map[uint32]*core.Stmt), lastSQL: "", lastSQLArgs: make([]interface{}, 0), @@ -151,6 +152,12 @@ func (session *Session) Close() error { } } + for _, v := range session.txStmtCache { + if err := v.Close(); err != nil { + return err + } + } + if !session.isClosed { // When Close be called, if session is a transaction and do not call // Commit or Rollback, then call Rollback. @@ -161,6 +168,7 @@ func (session *Session) Close() error { } session.tx = nil session.stmtCache = nil + session.txStmtCache = nil session.isClosed = true } return nil diff --git a/session_raw.go b/session_raw.go index bce1f575..acb106a5 100644 --- a/session_raw.go +++ b/session_raw.go @@ -157,6 +157,13 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er session.lastSQLArgs = args if !session.isAutoCommit { + if session.prepareStmt { + stmt, err := session.doPrepareTx(sqlStr) + if err != nil { + return nil, err + } + return stmt.ExecContext(session.ctx, args...) + } return session.tx.ExecContext(session.ctx, sqlStr, args...) } @@ -165,12 +172,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er if err != nil { return nil, err } - - res, err := stmt.ExecContext(session.ctx, args...) - if err != nil { - return nil, err - } - return res, nil + return stmt.ExecContext(session.ctx, args...) } return session.DB().ExecContext(session.ctx, sqlStr, args...)