From b0fd84832d1959c9f7874dc628622c94916fddd2 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 18 Sep 2018 14:45:26 +0800 Subject: [PATCH] add context cache feature --- context.go | 11 +++++++++++ engine.go | 6 ++++++ session.go | 5 +++++ session_get.go | 15 +++++++++++---- session_get_test.go | 37 +++++++++++++++++++++++++++++++++++++ 5 files changed, 70 insertions(+), 4 deletions(-) diff --git a/context.go b/context.go index 074ba35a..1ca3664f 100644 --- a/context.go +++ b/context.go @@ -24,3 +24,14 @@ func (session *Session) PingContext(ctx context.Context) error { session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) return session.DB().PingContext(ctx) } + +// WithContext cooperate with ctx +func (session *Session) WithContext(ctx context.Context) *Session { + session.context = ctx + return session +} + +// WithContext cooperate session with ctx +func WithContext(sess *Session, ctx context.Context) *Session { + return sess.WithContext(ctx) +} diff --git a/engine.go b/engine.go index 89a96d9f..68dcbae3 100644 --- a/engine.go +++ b/engine.go @@ -45,6 +45,7 @@ type Engine struct { DatabaseTZ *time.Location // The timezone of the database disableGlobalCache bool + enableContextCache bool tagHandlers map[string]tagHandler @@ -313,6 +314,11 @@ func (engine *Engine) NewSession() *Session { return session } +// EnableContextCache will enable or disable context cache +func (engine *Engine) EnableContextCache(enabled bool) { + engine.enableContextCache = enabled +} + // Close the engine func (engine *Engine) Close() error { return engine.db.Close() diff --git a/session.go b/session.go index 966cd2ce..fcfcbbdf 100644 --- a/session.go +++ b/session.go @@ -84,6 +84,11 @@ func (session *Session) Init() { session.lastSQL = "" session.lastSQLArgs = []interface{}{} + if session.engine.enableContextCache { + session.context = context.Background() + } else { + session.context = nil + } } // Close release the connection from pool diff --git a/session_get.go b/session_get.go index 2808a2c4..1107f097 100644 --- a/session_get.go +++ b/session_get.go @@ -5,8 +5,10 @@ package xorm import ( + "context" "database/sql" "errors" + "fmt" "reflect" "strconv" @@ -67,9 +69,14 @@ func (session *Session) get(bean interface{}) (bool, error) { } if session.context != nil { - //res := session.context.Value(fmt.Sprintf("%v-%v", sql, args)) - //runtime.deepcopy() - //&res + res := session.context.Value(fmt.Sprintf("%v-%v", sqlStr, args)) + if res != nil { + structValue := reflect.Indirect(reflect.ValueOf(bean)) + structValue.Set(reflect.Indirect(reflect.ValueOf(res))) + session.lastSQL = "" + session.lastSQLArgs = nil + return true, nil + } } has, err := session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...) @@ -77,7 +84,7 @@ func (session *Session) get(bean interface{}) (bool, error) { return has, err } if session.context != nil { - //session.context. + session.context = context.WithValue(session.context, fmt.Sprintf("%v-%v", sqlStr, args), bean) } return true, nil diff --git a/session_get_test.go b/session_get_test.go index 4ec7cf02..dacb3d06 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -5,6 +5,7 @@ package xorm import ( + "context" "database/sql" "fmt" "testing" @@ -319,3 +320,39 @@ func TestGetStructId(t *testing.T) { assert.True(t, has) assert.EqualValues(t, 2, maxid.Id) } + +func TestContextGet(t *testing.T) { + type ContextGetStruct struct { + Id int64 + Name string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(ContextGetStruct)) + + _, err := testEngine.Insert(&ContextGetStruct{Name: "1"}) + assert.NoError(t, err) + + sess := WithContext(testEngine.NewSession(), context.Background()) + defer sess.Close() + + var c2 ContextGetStruct + has, err := sess.ID(1).Get(&c2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, c2.Id) + assert.EqualValues(t, "1", c2.Name) + sql, args := sess.LastSQL() + assert.True(t, len(sql) > 0) + assert.True(t, len(args) > 0) + + var c3 ContextGetStruct + has, err = sess.ID(1).Get(&c3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, c3.Id) + assert.EqualValues(t, "1", c3.Name) + sql, args = sess.LastSQL() + assert.True(t, len(sql) == 0) + assert.True(t, len(args) == 0) +}