New Prepare useage (#2061)

Fix #2060, Three ways to use the `Prepare`.

The first

```go
engine.Prepare().Where().Get()
```

The second

```go
sess := engine.NewSession()
defer sess.Close()

sess.Prepare().Where().Get()

sess.Prepare().Where().Get()
```

The third
```go
sess := engine.NewSession()
defer sess.Close()

sess.Begin()

sess.Prepare().Where().Get()

sess.Prepare().Where().Get()

sess.Commit()
```

Or

```go
sess := engine.NewSession()
defer sess.Close()

sess.Begin()

sess.Prepare().Insert()

sess.Prepare().Insert()

sess.Commit()
```

Reviewed-on: https://gitea.com/xorm/xorm/pulls/2061
Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
Co-committed-by: Lunny Xiao <xiaolunwen@gmail.com>
This commit is contained in:
Lunny Xiao 2021-10-20 08:53:30 +08:00
parent b350c289f8
commit 40a135948b
5 changed files with 106 additions and 20 deletions

View File

@ -93,7 +93,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result
if err != nil { if err != nil {
return nil, err return nil, err
} }
res, err := s.Stmt.ExecContext(ctx, args) res, err := s.Stmt.ExecContext(ctx, args...)
hookCtx.End(ctx, res, err) hookCtx.End(ctx, res, err)
if err := s.db.afterProcess(hookCtx); err != nil { if err := s.db.afterProcess(hookCtx); err != nil {
return nil, err return nil, err

View File

@ -915,3 +915,64 @@ func TestGetVars(t *testing.T) {
assert.EqualValues(t, "xlw", name) assert.EqualValues(t, "xlw", name)
assert.EqualValues(t, 42, age) assert.EqualValues(t, 42, age)
} }
func TestGetWithPrepare(t *testing.T) {
type GetVarsWithPrepare struct {
Id int64
Name string
Age int
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(GetVarsWithPrepare))
_, err := testEngine.Insert(&GetVarsWithPrepare{
Name: "xlw",
Age: 42,
})
assert.NoError(t, err)
var v1 GetVarsWithPrepare
has, err := testEngine.Prepare().ID(1).Get(&v1)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "xlw", v1.Name)
assert.EqualValues(t, 42, v1.Age)
sess := testEngine.NewSession()
defer sess.Close()
var v2 GetVarsWithPrepare
has, err = sess.Prepare().ID(1).Get(&v2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "xlw", v2.Name)
assert.EqualValues(t, 42, v2.Age)
var v3 GetVarsWithPrepare
has, err = sess.Prepare().ID(1).Get(&v3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "xlw", v3.Name)
assert.EqualValues(t, 42, v3.Age)
err = sess.Begin()
assert.NoError(t, err)
cnt, err := sess.Prepare().Insert(&GetVarsWithPrepare{
Name: "xlw2",
Age: 12,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
cnt, err = sess.Prepare().Insert(&GetVarsWithPrepare{
Name: "xlw3",
Age: 13,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
err = sess.Commit()
assert.NoError(t, err)
}

View File

@ -99,6 +99,7 @@ type EngineInterface interface {
MapCacher(interface{}, caches.Cacher) error MapCacher(interface{}, caches.Cacher) error
NewSession() *Session NewSession() *Session
NoAutoTime() *Session NoAutoTime() *Session
Prepare() *Session
Quote(string) string Quote(string) string
SetCacher(string, caches.Cacher) SetCacher(string, caches.Cacher)
SetConnMaxLifetime(time.Duration) SetConnMaxLifetime(time.Duration)

View File

@ -79,7 +79,8 @@ type Session struct {
afterClosures []func(interface{}) afterClosures []func(interface{})
afterProcessors []executedProcessor afterProcessors []executedProcessor
stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr))
txStmtCache map[uint32]*core.Stmt // for tx statement
lastSQL string lastSQL string
lastSQLArgs []interface{} lastSQLArgs []interface{}
@ -130,6 +131,7 @@ func newSession(engine *Engine) *Session {
afterClosures: make([]func(interface{}), 0), afterClosures: make([]func(interface{}), 0),
afterProcessors: make([]executedProcessor, 0), afterProcessors: make([]executedProcessor, 0),
stmtCache: make(map[uint32]*core.Stmt), stmtCache: make(map[uint32]*core.Stmt),
txStmtCache: make(map[uint32]*core.Stmt),
lastSQL: "", lastSQL: "",
lastSQLArgs: make([]interface{}, 0), lastSQLArgs: make([]interface{}, 0),
@ -150,6 +152,12 @@ func (session *Session) Close() error {
} }
} }
for _, v := range session.txStmtCache {
if err := v.Close(); err != nil {
return err
}
}
if !session.isClosed { if !session.isClosed {
// When Close be called, if session is a transaction and do not call // When Close be called, if session is a transaction and do not call
// Commit or Rollback, then call Rollback. // Commit or Rollback, then call Rollback.
@ -160,6 +168,7 @@ func (session *Session) Close() error {
} }
session.tx = nil session.tx = nil
session.stmtCache = nil session.stmtCache = nil
session.txStmtCache = nil
session.isClosed = true session.isClosed = true
} }
return nil return nil
@ -200,6 +209,7 @@ func (session *Session) IsClosed() bool {
func (session *Session) resetStatement() { func (session *Session) resetStatement() {
if session.autoResetStatement { if session.autoResetStatement {
session.statement.Reset() session.statement.Reset()
session.prepareStmt = false
} }
} }
@ -370,6 +380,21 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt,
return return
} }
func (session *Session) doPrepareTx(sqlStr string) (stmt *core.Stmt, err error) {
crc := crc32.ChecksumIEEE([]byte(sqlStr))
// TODO try hash(sqlStr+len(sqlStr))
var has bool
stmt, has = session.txStmtCache[crc]
if !has {
stmt, err = session.tx.PrepareContext(session.ctx, sqlStr)
if err != nil {
return nil, err
}
session.txStmtCache[crc] = stmt
}
return
}
func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) { func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) {
var col = table.GetColumnIdx(colName, idx) var col = table.GetColumnIdx(colName, idx)
if col == nil { if col == nil {

View File

@ -46,25 +46,22 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row
return nil, err return nil, err
} }
rows, err := stmt.QueryContext(session.ctx, args...) return stmt.QueryContext(session.ctx, args...)
if err != nil {
return nil, err
}
return rows, nil
} }
rows, err := db.QueryContext(session.ctx, sqlStr, args...) return db.QueryContext(session.ctx, sqlStr, args...)
}
if session.prepareStmt {
stmt, err := session.doPrepareTx(sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return rows, nil
return stmt.QueryContext(session.ctx, args...)
} }
rows, err := session.tx.QueryContext(session.ctx, sqlStr, args...) return session.tx.QueryContext(session.ctx, sqlStr, args...)
if err != nil {
return nil, err
}
return rows, nil
} }
func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row { func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row {
@ -160,6 +157,13 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
session.lastSQLArgs = args session.lastSQLArgs = args
if !session.isAutoCommit { if !session.isAutoCommit {
if session.prepareStmt {
stmt, err := session.doPrepareTx(sqlStr)
if err != nil {
return nil, err
}
return stmt.ExecContext(session.ctx, args...)
}
return session.tx.ExecContext(session.ctx, sqlStr, args...) return session.tx.ExecContext(session.ctx, sqlStr, args...)
} }
@ -168,12 +172,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
if err != nil { if err != nil {
return nil, err return nil, err
} }
return stmt.ExecContext(session.ctx, args...)
res, err := stmt.ExecContext(session.ctx, args...)
if err != nil {
return nil, err
}
return res, nil
} }
return session.DB().ExecContext(session.ctx, sqlStr, args...) return session.DB().ExecContext(session.ctx, sqlStr, args...)