add hook for engine

This commit is contained in:
yuxiao.lu 2020-04-02 15:58:04 +08:00
parent cfa88b908c
commit 7e4dc9cc57
8 changed files with 154 additions and 135 deletions

View File

@ -88,6 +88,7 @@ type DB struct {
reflectCache map[reflect.Type]*cacheStruct
reflectCacheMutex sync.RWMutex
Logger log.ContextLogger
hooks []Hook
}
// Open opens a database
@ -139,27 +140,22 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value {
}
// QueryContext overwrites sql.DB.QueryContext
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
func (db *DB) QueryContext(parentCtx context.Context, query string, args ...interface{}) (*Rows, error) {
logCtx := log.LogContext{
Ctx: parentCtx,
SQL: query,
Args: args,
}
start := time.Now()
showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
}
rows, err := db.DB.QueryContext(ctx, query, args...)
if showSQL {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
ctx, err := db.beforeProcess(logCtx)
if err != nil {
return nil, err
}
logCtx.Ctx = ctx
rows, err := db.DB.QueryContext(ctx, query, args...)
logCtx.ExecuteTime = time.Now().Sub(start)
logCtx.Err = err
if err := db.afterProcess(logCtx); err != nil {
if rows != nil {
rows.Close()
}
@ -262,29 +258,60 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{
return db.ExecContext(ctx, query, args...)
}
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
func (db *DB) ExecContext(parentCtx context.Context, query string, args ...interface{}) (sql.Result, error) {
logCtx := log.LogContext{
Ctx: parentCtx,
SQL: query,
Args: args,
}
start := time.Now()
showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
ctx, err := db.beforeProcess(logCtx)
if err != nil {
return nil, err
}
logCtx.Ctx = ctx
res, err := db.DB.ExecContext(ctx, query, args...)
if showSQL {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
logCtx.Err = err
logCtx.ExecuteTime = time.Now().Sub(start)
if err := db.afterProcess(logCtx); err != nil {
return nil, err
}
return res, err
return res, nil
}
func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
return db.ExecStructContext(context.Background(), query, st)
}
func (db *DB) beforeProcess(logCtx log.LogContext) (context.Context, error) {
ctx := logCtx.Ctx
if db.NeedLogSQL(ctx) {
db.Logger.BeforeSQL(logCtx)
}
for _, h := range db.hooks {
var err error
ctx, err = h.BeforeProcess(ctx, logCtx.SQL, logCtx.Args...)
if err != nil {
return nil, err
}
}
return ctx, nil
}
func (db *DB) afterProcess(logCtx log.LogContext) error {
firstErr := logCtx.Err
for _, h := range db.hooks {
err := h.AfterProcess(&logCtx)
if err != nil && firstErr == nil {
firstErr = err
}
}
if db.NeedLogSQL(logCtx.Ctx) {
db.Logger.AfterSQL(logCtx)
}
return firstErr
}
func (db *DB) AddHook(hook Hook) {
db.hooks = append(db.hooks, hook)
}

View File

@ -3,8 +3,15 @@ package core
import (
"context"
"database/sql"
"xorm.io/xorm/log"
)
type Hook interface {
BeforeProcess(ctx context.Context, query string, args ...interface{}) (context.Context, error)
AfterProcess(logContext *log.LogContext) error
}
// Queryer represents an interface to query a SQL to get data from database
type Queryer interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error)

View File

@ -30,28 +30,21 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
i++
return "?"
})
logCtx := log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
}
start := time.Now()
showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
})
}
stmt, err := db.DB.PrepareContext(ctx, query)
if showSQL {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
ctx, err := db.beforeProcess(logCtx)
if err != nil {
return nil, err
}
stmt, err := db.DB.PrepareContext(ctx, query)
logCtx.Err = err
logCtx.ExecuteTime = time.Now().Sub(start)
if err := db.afterProcess(logCtx); err != nil {
return nil, err
}
return &Stmt{stmt, db, names, query}, nil
}
@ -94,49 +87,40 @@ func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
}
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
logCtx := log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
}
start := time.Now()
showSQL := s.db.NeedLogSQL(ctx)
if showSQL {
s.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
})
ctx, err := s.db.beforeProcess(logCtx)
if err != nil {
return nil, err
}
res, err := s.Stmt.ExecContext(ctx, args)
if showSQL {
s.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
logCtx.ExecuteTime = time.Now().Sub(start)
logCtx.Err = err
if err := s.db.afterProcess(logCtx); err != nil {
return nil, err
}
return res, err
return res, nil
}
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
logCtx := log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
}
start := time.Now()
showSQL := s.db.NeedLogSQL(ctx)
if showSQL {
s.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
})
ctx, err := s.db.beforeProcess(logCtx)
if err != nil {
return nil, err
}
rows, err := s.Stmt.QueryContext(ctx, args...)
if showSQL {
s.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
logCtx.ExecuteTime = time.Now().Sub(start)
logCtx.Err = err
if err := s.db.afterProcess(logCtx); err != nil {
return nil, err
}
return &Rows{rows, s.db}, nil

View File

@ -58,25 +58,19 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
i++
return "?"
})
logCtx := log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
}
start := time.Now()
showSQL := tx.db.NeedLogSQL(ctx)
if showSQL {
tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
})
ctx, err := tx.db.beforeProcess(logCtx)
if err != nil {
return nil, err
}
stmt, err := tx.Tx.PrepareContext(ctx, query)
if showSQL {
tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
logCtx.Err = err
logCtx.ExecuteTime = time.Now().Sub(start)
if err := tx.db.afterProcess(logCtx); err != nil {
return nil, err
}
return &Stmt{stmt, tx.db, names, query}, nil
@ -117,23 +111,20 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
start := time.Now()
showSQL := tx.db.NeedLogSQL(ctx)
if showSQL {
tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
logCtx := log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
}
ctx, err := tx.db.beforeProcess(logCtx)
if err != nil {
return nil, err
}
res, err := tx.Tx.ExecContext(ctx, query, args...)
if showSQL {
tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
logCtx.ExecuteTime = time.Now().Sub(start)
logCtx.Err = err
if err := tx.db.afterProcess(logCtx); err != nil {
return nil, err
}
return res, err
}
@ -143,26 +134,20 @@ func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
}
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
logCtx := log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
}
start := time.Now()
showSQL := tx.db.NeedLogSQL(ctx)
if showSQL {
tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
ctx, err := tx.db.beforeProcess(logCtx)
if err != nil {
return nil, err
}
rows, err := tx.Tx.QueryContext(ctx, query, args...)
if showSQL {
tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
logCtx.Err = err
logCtx.ExecuteTime = time.Now().Sub(start)
if err := tx.db.afterProcess(logCtx); err != nil {
if rows != nil {
rows.Close()
}

View File

@ -1287,6 +1287,10 @@ func (engine *Engine) SetSchema(schema string) {
engine.dialect.URI().SetSchema(schema)
}
func (engine *Engine) AddHook(hook core.Hook) {
engine.db.AddHook(hook)
}
// Unscoped always disable struct tag "deleted"
func (engine *Engine) Unscoped() *Session {
session := engine.NewSession()

View File

@ -9,6 +9,7 @@ import (
"time"
"xorm.io/xorm/caches"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
@ -143,6 +144,13 @@ func (eg *EngineGroup) SetLogger(logger interface{}) {
}
}
func (eg *EngineGroup) AddHook(hook core.Hook) {
eg.Engine.AddHook(hook)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].AddHook(hook)
}
}
// SetLogLevel sets the logger level
func (eg *EngineGroup) SetLogLevel(level log.LogLevel) {
eg.Engine.SetLogLevel(level)

View File

@ -11,6 +11,7 @@ import (
"time"
"xorm.io/xorm/caches"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
@ -111,6 +112,7 @@ type EngineInterface interface {
SetTableMapper(names.Mapper)
SetTZDatabase(tz *time.Location)
SetTZLocation(tz *time.Location)
AddHook(hook core.Hook)
ShowSQL(show ...bool)
Sync(...interface{}) error
Sync2(...interface{}) error

View File

@ -6,6 +6,7 @@ package log
import (
"context"
"database/sql"
"fmt"
"time"
)
@ -15,6 +16,7 @@ type LogContext struct {
Ctx context.Context
SQL string // log content or SQL
Args []interface{} // if it's a SQL, it's the arguments
Result sql.Result
ExecuteTime time.Duration
Err error // SQL executed error
}