From 5750e3f90a96167c5f72053755d0f9e550350d97 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 20 Jan 2019 11:01:14 +0800 Subject: [PATCH] Add context support (#1193) * add context support * improve pingcontext tests --- engine.go | 3 ++ engine_context.go | 28 ++++++++++++++++++ context_test.go => engine_context_test.go | 5 ++-- interface.go | 2 ++ session.go | 6 +++- context.go => session_context.go | 13 ++++---- session_context_test.go | 36 +++++++++++++++++++++++ session_raw.go | 12 ++++---- session_schema.go | 2 +- session_tx.go | 2 +- xorm.go | 18 +++++++----- 11 files changed, 100 insertions(+), 27 deletions(-) create mode 100644 engine_context.go rename context_test.go => engine_context_test.go (72%) rename context.go => session_context.go (60%) create mode 100644 session_context_test.go diff --git a/engine.go b/engine.go index c1bf06e1..07649df7 100644 --- a/engine.go +++ b/engine.go @@ -7,6 +7,7 @@ package xorm import ( "bufio" "bytes" + "context" "database/sql" "encoding/gob" "errors" @@ -52,6 +53,8 @@ type Engine struct { cachers map[string]core.Cacher cacherLock sync.RWMutex + + defaultContext context.Context } func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { diff --git a/engine_context.go b/engine_context.go new file mode 100644 index 00000000..c6cbb76c --- /dev/null +++ b/engine_context.go @@ -0,0 +1,28 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.8 + +package xorm + +import "context" + +// Context creates a session with the context +func (engine *Engine) Context(ctx context.Context) *Session { + session := engine.NewSession() + session.isAutoClose = true + return session.Context(ctx) +} + +// SetDefaultContext set the default context +func (engine *Engine) SetDefaultContext(ctx context.Context) { + engine.defaultContext = ctx +} + +// PingContext tests if database is alive +func (engine *Engine) PingContext(ctx context.Context) error { + session := engine.NewSession() + defer session.Close() + return session.PingContext(ctx) +} diff --git a/context_test.go b/engine_context_test.go similarity index 72% rename from context_test.go rename to engine_context_test.go index 17437af5..cc564694 100644 --- a/context_test.go +++ b/engine_context_test.go @@ -17,9 +17,10 @@ import ( func TestPingContext(t *testing.T) { assert.NoError(t, prepareEngine()) - ctx, canceled := context.WithTimeout(context.Background(), 10*time.Second) + ctx, canceled := context.WithTimeout(context.Background(), time.Nanosecond) defer canceled() err := testEngine.(*Engine).PingContext(ctx) - assert.NoError(t, err) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") } diff --git a/interface.go b/interface.go index 33d2078e..4f084421 100644 --- a/interface.go +++ b/interface.go @@ -5,6 +5,7 @@ package xorm import ( + "context" "database/sql" "reflect" "time" @@ -73,6 +74,7 @@ type EngineInterface interface { Before(func(interface{})) *Session Charset(charset string) *Session ClearCache(...interface{}) error + Context(context.Context) *Session CreateTables(...interface{}) error DBMetas() ([]*core.Table, error) Dialect() core.Dialect diff --git a/session.go b/session.go index e3437b91..e3f7b989 100644 --- a/session.go +++ b/session.go @@ -5,6 +5,7 @@ package xorm import ( + "context" "database/sql" "encoding/json" "errors" @@ -52,6 +53,7 @@ type Session struct { lastSQLArgs []interface{} err error + ctx context.Context } // Clone copy all the session's content and return a new session @@ -82,6 +84,8 @@ func (session *Session) Init() { session.lastSQL = "" session.lastSQLArgs = []interface{}{} + + session.ctx = session.engine.defaultContext } // Close release the connection from pool @@ -275,7 +279,7 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, var has bool stmt, has = session.stmtCache[crc] if !has { - stmt, err = db.Prepare(sqlStr) + stmt, err = db.PrepareContext(session.ctx, sqlStr) if err != nil { return nil, err } diff --git a/context.go b/session_context.go similarity index 60% rename from context.go rename to session_context.go index 074ba35a..915f0568 100644 --- a/context.go +++ b/session_context.go @@ -1,18 +1,15 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. +// Copyright 2019 The Xorm Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build go1.8 - package xorm import "context" -// PingContext tests if database is alive -func (engine *Engine) PingContext(ctx context.Context) error { - session := engine.NewSession() - defer session.Close() - return session.PingContext(ctx) +// Context sets the context on this session +func (session *Session) Context(ctx context.Context) *Session { + session.ctx = ctx + return session } // PingContext test if database is ok diff --git a/session_context_test.go b/session_context_test.go new file mode 100644 index 00000000..3dec24ac --- /dev/null +++ b/session_context_test.go @@ -0,0 +1,36 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestQueryContext(t *testing.T) { + type ContextQueryStruct struct { + Id int64 + Name string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(ContextQueryStruct)) + + _, err := testEngine.Insert(&ContextQueryStruct{Name: "1"}) + assert.NoError(t, err) + + sess := testEngine.NewSession() + defer sess.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) + defer cancel() + has, err := testEngine.Context(ctx).Exist(&ContextQueryStruct{Name: "1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") + assert.False(t, has) +} diff --git a/session_raw.go b/session_raw.go index 47823d67..23ef0a16 100644 --- a/session_raw.go +++ b/session_raw.go @@ -62,21 +62,21 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row return nil, err } - rows, err := stmt.Query(args...) + rows, err := stmt.QueryContext(session.ctx, args...) if err != nil { return nil, err } return rows, nil } - rows, err := db.Query(sqlStr, args...) + rows, err := db.QueryContext(session.ctx, sqlStr, args...) if err != nil { return nil, err } return rows, nil } - rows, err := session.tx.Query(sqlStr, args...) + rows, err := session.tx.QueryContext(session.ctx, sqlStr, args...) if err != nil { return nil, err } @@ -175,7 +175,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er } if !session.isAutoCommit { - return session.tx.Exec(sqlStr, args...) + return session.tx.ExecContext(session.ctx, sqlStr, args...) } if session.prepareStmt { @@ -184,14 +184,14 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er return nil, err } - res, err := stmt.Exec(args...) + res, err := stmt.ExecContext(session.ctx, args...) if err != nil { return nil, err } return res, nil } - return session.DB().Exec(sqlStr, args...) + return session.DB().ExecContext(session.ctx, sqlStr, args...) } func convertSQLOrArgs(sqlorArgs ...interface{}) (string, []interface{}, error) { diff --git a/session_schema.go b/session_schema.go index 369ec72a..7629906f 100644 --- a/session_schema.go +++ b/session_schema.go @@ -19,7 +19,7 @@ func (session *Session) Ping() error { } session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) - return session.DB().Ping() + return session.DB().PingContext(session.ctx) } // CreateTable create a table according a bean diff --git a/session_tx.go b/session_tx.go index c8d759a3..ee3d473f 100644 --- a/session_tx.go +++ b/session_tx.go @@ -7,7 +7,7 @@ package xorm // Begin a transaction func (session *Session) Begin() error { if session.isAutoCommit { - tx, err := session.DB().Begin() + tx, err := session.DB().BeginTx(session.ctx, nil) if err != nil { return err } diff --git a/xorm.go b/xorm.go index 739de8d4..157c9d34 100644 --- a/xorm.go +++ b/xorm.go @@ -7,6 +7,7 @@ package xorm import ( + "context" "fmt" "os" "reflect" @@ -85,14 +86,15 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { } engine := &Engine{ - db: db, - dialect: dialect, - Tables: make(map[reflect.Type]*core.Table), - mutex: &sync.RWMutex{}, - TagIdentifier: "xorm", - TZLocation: time.Local, - tagHandlers: defaultTagHandlers, - cachers: make(map[string]core.Cacher), + db: db, + dialect: dialect, + Tables: make(map[reflect.Type]*core.Table), + mutex: &sync.RWMutex{}, + TagIdentifier: "xorm", + TZLocation: time.Local, + tagHandlers: defaultTagHandlers, + cachers: make(map[string]core.Cacher), + defaultContext: context.Background(), } if uri.DbType == core.SQLITE {