add context cache feature

This commit is contained in:
Lunny Xiao 2018-09-18 14:45:26 +08:00
parent 3542b3a933
commit b0fd84832d
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
5 changed files with 70 additions and 4 deletions

View File

@ -24,3 +24,14 @@ func (session *Session) PingContext(ctx context.Context) error {
session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
return session.DB().PingContext(ctx) return session.DB().PingContext(ctx)
} }
// WithContext cooperate with ctx
func (session *Session) WithContext(ctx context.Context) *Session {
session.context = ctx
return session
}
// WithContext cooperate session with ctx
func WithContext(sess *Session, ctx context.Context) *Session {
return sess.WithContext(ctx)
}

View File

@ -45,6 +45,7 @@ type Engine struct {
DatabaseTZ *time.Location // The timezone of the database DatabaseTZ *time.Location // The timezone of the database
disableGlobalCache bool disableGlobalCache bool
enableContextCache bool
tagHandlers map[string]tagHandler tagHandlers map[string]tagHandler
@ -313,6 +314,11 @@ func (engine *Engine) NewSession() *Session {
return session return session
} }
// EnableContextCache will enable or disable context cache
func (engine *Engine) EnableContextCache(enabled bool) {
engine.enableContextCache = enabled
}
// Close the engine // Close the engine
func (engine *Engine) Close() error { func (engine *Engine) Close() error {
return engine.db.Close() return engine.db.Close()

View File

@ -84,6 +84,11 @@ func (session *Session) Init() {
session.lastSQL = "" session.lastSQL = ""
session.lastSQLArgs = []interface{}{} session.lastSQLArgs = []interface{}{}
if session.engine.enableContextCache {
session.context = context.Background()
} else {
session.context = nil
}
} }
// Close release the connection from pool // Close release the connection from pool

View File

@ -5,8 +5,10 @@
package xorm package xorm
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"reflect" "reflect"
"strconv" "strconv"
@ -67,9 +69,14 @@ func (session *Session) get(bean interface{}) (bool, error) {
} }
if session.context != nil { if session.context != nil {
//res := session.context.Value(fmt.Sprintf("%v-%v", sql, args)) res := session.context.Value(fmt.Sprintf("%v-%v", sqlStr, args))
//runtime.deepcopy() if res != nil {
//&res structValue := reflect.Indirect(reflect.ValueOf(bean))
structValue.Set(reflect.Indirect(reflect.ValueOf(res)))
session.lastSQL = ""
session.lastSQLArgs = nil
return true, nil
}
} }
has, err := session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...) has, err := session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...)
@ -77,7 +84,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
return has, err return has, err
} }
if session.context != nil { if session.context != nil {
//session.context. session.context = context.WithValue(session.context, fmt.Sprintf("%v-%v", sqlStr, args), bean)
} }
return true, nil return true, nil

View File

@ -5,6 +5,7 @@
package xorm package xorm
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"testing" "testing"
@ -319,3 +320,39 @@ func TestGetStructId(t *testing.T) {
assert.True(t, has) assert.True(t, has)
assert.EqualValues(t, 2, maxid.Id) assert.EqualValues(t, 2, maxid.Id)
} }
func TestContextGet(t *testing.T) {
type ContextGetStruct struct {
Id int64
Name string
}
assert.NoError(t, prepareEngine())
assertSync(t, new(ContextGetStruct))
_, err := testEngine.Insert(&ContextGetStruct{Name: "1"})
assert.NoError(t, err)
sess := WithContext(testEngine.NewSession(), context.Background())
defer sess.Close()
var c2 ContextGetStruct
has, err := sess.ID(1).Get(&c2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, c2.Id)
assert.EqualValues(t, "1", c2.Name)
sql, args := sess.LastSQL()
assert.True(t, len(sql) > 0)
assert.True(t, len(args) > 0)
var c3 ContextGetStruct
has, err = sess.ID(1).Get(&c3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, c3.Id)
assert.EqualValues(t, "1", c3.Name)
sql, args = sess.LastSQL()
assert.True(t, len(sql) == 0)
assert.True(t, len(args) == 0)
}