Add context support (#1193)

* add context support

* improve pingcontext tests
This commit is contained in:
Lunny Xiao 2019-01-20 11:01:14 +08:00 committed by GitHub
parent 229c3aaf04
commit 5750e3f90a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 100 additions and 27 deletions

View File

@ -7,6 +7,7 @@ package xorm
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"database/sql" "database/sql"
"encoding/gob" "encoding/gob"
"errors" "errors"
@ -52,6 +53,8 @@ type Engine struct {
cachers map[string]core.Cacher cachers map[string]core.Cacher
cacherLock sync.RWMutex cacherLock sync.RWMutex
defaultContext context.Context
} }
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {

28
engine_context.go Normal file
View File

@ -0,0 +1,28 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package xorm
import "context"
// Context creates a session with the context
func (engine *Engine) Context(ctx context.Context) *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.Context(ctx)
}
// SetDefaultContext set the default context
func (engine *Engine) SetDefaultContext(ctx context.Context) {
engine.defaultContext = ctx
}
// PingContext tests if database is alive
func (engine *Engine) PingContext(ctx context.Context) error {
session := engine.NewSession()
defer session.Close()
return session.PingContext(ctx)
}

View File

@ -17,9 +17,10 @@ import (
func TestPingContext(t *testing.T) { func TestPingContext(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
ctx, canceled := context.WithTimeout(context.Background(), 10*time.Second) ctx, canceled := context.WithTimeout(context.Background(), time.Nanosecond)
defer canceled() defer canceled()
err := testEngine.(*Engine).PingContext(ctx) err := testEngine.(*Engine).PingContext(ctx)
assert.NoError(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "context deadline exceeded")
} }

View File

@ -5,6 +5,7 @@
package xorm package xorm
import ( import (
"context"
"database/sql" "database/sql"
"reflect" "reflect"
"time" "time"
@ -73,6 +74,7 @@ type EngineInterface interface {
Before(func(interface{})) *Session Before(func(interface{})) *Session
Charset(charset string) *Session Charset(charset string) *Session
ClearCache(...interface{}) error ClearCache(...interface{}) error
Context(context.Context) *Session
CreateTables(...interface{}) error CreateTables(...interface{}) error
DBMetas() ([]*core.Table, error) DBMetas() ([]*core.Table, error)
Dialect() core.Dialect Dialect() core.Dialect

View File

@ -5,6 +5,7 @@
package xorm package xorm
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
@ -52,6 +53,7 @@ type Session struct {
lastSQLArgs []interface{} lastSQLArgs []interface{}
err error err error
ctx context.Context
} }
// Clone copy all the session's content and return a new session // Clone copy all the session's content and return a new session
@ -82,6 +84,8 @@ func (session *Session) Init() {
session.lastSQL = "" session.lastSQL = ""
session.lastSQLArgs = []interface{}{} session.lastSQLArgs = []interface{}{}
session.ctx = session.engine.defaultContext
} }
// Close release the connection from pool // Close release the connection from pool
@ -275,7 +279,7 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt,
var has bool var has bool
stmt, has = session.stmtCache[crc] stmt, has = session.stmtCache[crc]
if !has { if !has {
stmt, err = db.Prepare(sqlStr) stmt, err = db.PrepareContext(session.ctx, sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,18 +1,15 @@
// Copyright 2017 The Xorm Authors. All rights reserved. // Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build go1.8
package xorm package xorm
import "context" import "context"
// PingContext tests if database is alive // Context sets the context on this session
func (engine *Engine) PingContext(ctx context.Context) error { func (session *Session) Context(ctx context.Context) *Session {
session := engine.NewSession() session.ctx = ctx
defer session.Close() return session
return session.PingContext(ctx)
} }
// PingContext test if database is ok // PingContext test if database is ok

36
session_context_test.go Normal file
View File

@ -0,0 +1,36 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestQueryContext(t *testing.T) {
type ContextQueryStruct struct {
Id int64
Name string
}
assert.NoError(t, prepareEngine())
assertSync(t, new(ContextQueryStruct))
_, err := testEngine.Insert(&ContextQueryStruct{Name: "1"})
assert.NoError(t, err)
sess := testEngine.NewSession()
defer sess.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
has, err := testEngine.Context(ctx).Exist(&ContextQueryStruct{Name: "1"})
assert.Error(t, err)
assert.Contains(t, err.Error(), "context deadline exceeded")
assert.False(t, has)
}

View File

@ -62,21 +62,21 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row
return nil, err return nil, err
} }
rows, err := stmt.Query(args...) rows, err := stmt.QueryContext(session.ctx, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return rows, nil return rows, nil
} }
rows, err := db.Query(sqlStr, args...) rows, err := db.QueryContext(session.ctx, sqlStr, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return rows, nil return rows, nil
} }
rows, err := session.tx.Query(sqlStr, args...) rows, err := session.tx.QueryContext(session.ctx, sqlStr, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -175,7 +175,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
} }
if !session.isAutoCommit { if !session.isAutoCommit {
return session.tx.Exec(sqlStr, args...) return session.tx.ExecContext(session.ctx, sqlStr, args...)
} }
if session.prepareStmt { if session.prepareStmt {
@ -184,14 +184,14 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
return nil, err return nil, err
} }
res, err := stmt.Exec(args...) res, err := stmt.ExecContext(session.ctx, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return res, nil return res, nil
} }
return session.DB().Exec(sqlStr, args...) return session.DB().ExecContext(session.ctx, sqlStr, args...)
} }
func convertSQLOrArgs(sqlorArgs ...interface{}) (string, []interface{}, error) { func convertSQLOrArgs(sqlorArgs ...interface{}) (string, []interface{}, error) {

View File

@ -19,7 +19,7 @@ func (session *Session) Ping() error {
} }
session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
return session.DB().Ping() return session.DB().PingContext(session.ctx)
} }
// CreateTable create a table according a bean // CreateTable create a table according a bean

View File

@ -7,7 +7,7 @@ package xorm
// Begin a transaction // Begin a transaction
func (session *Session) Begin() error { func (session *Session) Begin() error {
if session.isAutoCommit { if session.isAutoCommit {
tx, err := session.DB().Begin() tx, err := session.DB().BeginTx(session.ctx, nil)
if err != nil { if err != nil {
return err return err
} }

18
xorm.go
View File

@ -7,6 +7,7 @@
package xorm package xorm
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
@ -85,14 +86,15 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
} }
engine := &Engine{ engine := &Engine{
db: db, db: db,
dialect: dialect, dialect: dialect,
Tables: make(map[reflect.Type]*core.Table), Tables: make(map[reflect.Type]*core.Table),
mutex: &sync.RWMutex{}, mutex: &sync.RWMutex{},
TagIdentifier: "xorm", TagIdentifier: "xorm",
TZLocation: time.Local, TZLocation: time.Local,
tagHandlers: defaultTagHandlers, tagHandlers: defaultTagHandlers,
cachers: make(map[string]core.Cacher), cachers: make(map[string]core.Cacher),
defaultContext: context.Background(),
} }
if uri.DbType == core.SQLITE { if uri.DbType == core.SQLITE {