diff --git a/contexts/hook.go b/contexts/hook.go new file mode 100644 index 00000000..71ad8e87 --- /dev/null +++ b/contexts/hook.go @@ -0,0 +1,75 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package contexts + +import ( + "context" + "database/sql" + "time" +) + +// ContextHook represents a hook context +type ContextHook struct { + start time.Time + Ctx context.Context + SQL string // log content or SQL + Args []interface{} // if it's a SQL, it's the arguments + Result sql.Result + ExecuteTime time.Duration + Err error // SQL executed error +} + +// NewContextHook return context for hook +func NewContextHook(ctx context.Context, sql string, args []interface{}) *ContextHook { + return &ContextHook{ + start: time.Now(), + Ctx: ctx, + SQL: sql, + Args: args, + } +} + +func (c *ContextHook) End(ctx context.Context, result sql.Result, err error) { + c.Ctx = ctx + c.Result = result + c.Err = err + c.ExecuteTime = time.Now().Sub(c.start) +} + +type Hook interface { + BeforeProcess(c *ContextHook) (context.Context, error) + AfterProcess(c *ContextHook) error +} + +type Hooks struct { + hooks []Hook +} + +func (h *Hooks) AddHook(hooks ...Hook) { + h.hooks = append(h.hooks, hooks...) +} + +func (h *Hooks) BeforeProcess(c *ContextHook) (context.Context, error) { + ctx := c.Ctx + for _, h := range h.hooks { + var err error + ctx, err = h.BeforeProcess(c) + if err != nil { + return nil, err + } + } + return ctx, nil +} + +func (h *Hooks) AfterProcess(c *ContextHook) error { + firstErr := c.Err + for _, h := range h.hooks { + err := h.AfterProcess(c) + if err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} diff --git a/contexts/hook_test.go b/contexts/hook_test.go new file mode 100644 index 00000000..96c54e33 --- /dev/null +++ b/contexts/hook_test.go @@ -0,0 +1,140 @@ +package contexts + +import ( + "context" + "errors" + "testing" +) + +type testHook struct { + before func(c *ContextHook) (context.Context, error) + after func(c *ContextHook) error +} + +func (h *testHook) BeforeProcess(c *ContextHook) (context.Context, error) { + if h.before != nil { + return h.before(c) + } + return c.Ctx, nil +} + +func (h *testHook) AfterProcess(c *ContextHook) error { + if h.after != nil { + return h.after(c) + } + return c.Err +} + +var _ Hook = &testHook{} + +func TestBeforeProcess(t *testing.T) { + expectErr := errors.New("before error") + tests := []struct { + msg string + hooks []Hook + expect error + }{ + { + msg: "first hook return err", + hooks: []Hook{ + &testHook{ + before: func(c *ContextHook) (ctx context.Context, err error) { + return c.Ctx, expectErr + }, + }, + &testHook{ + before: func(c *ContextHook) (ctx context.Context, err error) { + return c.Ctx, nil + }, + }, + }, + expect: expectErr, + }, + { + msg: "second hook return err", + hooks: []Hook{ + &testHook{ + before: func(c *ContextHook) (ctx context.Context, err error) { + return c.Ctx, nil + }, + }, + &testHook{ + before: func(c *ContextHook) (ctx context.Context, err error) { + return c.Ctx, expectErr + }, + }, + }, + expect: expectErr, + }, + } + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + hooks := Hooks{} + hooks.AddHook(tt.hooks...) + _, err := hooks.BeforeProcess(&ContextHook{ + Ctx: context.Background(), + }) + if err != tt.expect { + t.Errorf("got %v, expect %v", err, tt.expect) + } + }) + } +} + +func TestAfterProcess(t *testing.T) { + expectErr := errors.New("expect err") + tests := []struct { + msg string + ctx *ContextHook + hooks []Hook + expect error + }{ + { + msg: "context has err", + ctx: &ContextHook{ + Ctx: context.Background(), + Err: expectErr, + }, + hooks: []Hook{ + &testHook{ + after: func(c *ContextHook) error { + return errors.New("hook err") + }, + }, + }, + expect: expectErr, + }, + { + msg: "last hook has err", + ctx: &ContextHook{ + Ctx: context.Background(), + Err: nil, + }, + hooks: []Hook{ + &testHook{ + after: func(c *ContextHook) error { + return nil + }, + }, + &testHook{ + after: func(c *ContextHook) error { + return expectErr + }, + }, + }, + expect: expectErr, + }, + } + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + hooks := Hooks{} + hooks.AddHook(tt.hooks...) + err := hooks.AfterProcess(tt.ctx) + if err != tt.expect { + t.Errorf("got %v, expect %v", err, tt.expect) + } + }) + } +} diff --git a/core/db.go b/core/db.go index 671d1dc2..50c64c6f 100644 --- a/core/db.go +++ b/core/db.go @@ -12,8 +12,8 @@ import ( "reflect" "regexp" "sync" - "time" + "xorm.io/xorm/contexts" "xorm.io/xorm/log" "xorm.io/xorm/names" ) @@ -88,6 +88,7 @@ type DB struct { reflectCache map[reflect.Type]*cacheStruct reflectCacheMutex sync.RWMutex Logger log.ContextLogger + hooks contexts.Hooks } // Open opens a database @@ -140,26 +141,14 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value { // QueryContext overwrites sql.DB.QueryContext func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { - start := time.Now() - showSQL := db.NeedLogSQL(ctx) - if showSQL { - db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - }) + hookCtx := contexts.NewContextHook(ctx, query, args) + ctx, err := db.beforeProcess(hookCtx) + if err != nil { + return nil, err } rows, err := db.DB.QueryContext(ctx, query, args...) - if showSQL { - db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } - if err != nil { + hookCtx.End(ctx, nil, err) + if err := db.afterProcess(hookCtx); err != nil { if rows != nil { rows.Close() } @@ -239,7 +228,7 @@ var ( re = regexp.MustCompile(`[?](\w+)`) ) -// ExecMapContext exec map with context.Context +// ExecMapContext exec map with context.ContextHook // insert into (name) values (?) // insert into (name) values (?name) func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) { @@ -263,28 +252,42 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{ } func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - start := time.Now() - showSQL := db.NeedLogSQL(ctx) - if showSQL { - db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - }) + hookCtx := contexts.NewContextHook(ctx, query, args) + ctx, err := db.beforeProcess(hookCtx) + if err != nil { + return nil, err } res, err := db.DB.ExecContext(ctx, query, args...) - if showSQL { - db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) + hookCtx.End(ctx, res, err) + if err := db.afterProcess(hookCtx); err != nil { + return nil, err } - return res, err + return res, nil } func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { return db.ExecStructContext(context.Background(), query, st) } + +func (db *DB) beforeProcess(c *contexts.ContextHook) (context.Context, error) { + if db.NeedLogSQL(c.Ctx) { + db.Logger.BeforeSQL(log.LogContext(*c)) + } + ctx, err := db.hooks.BeforeProcess(c) + if err != nil { + return nil, err + } + return ctx, nil +} + +func (db *DB) afterProcess(c *contexts.ContextHook) error { + err := db.hooks.AfterProcess(c) + if db.NeedLogSQL(c.Ctx) { + db.Logger.AfterSQL(log.LogContext(*c)) + } + return err +} + +func (db *DB) AddHook(h ...contexts.Hook) { + db.hooks.AddHook(h...) +} diff --git a/core/stmt.go b/core/stmt.go index ebf2af73..d46ac9c6 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -9,9 +9,8 @@ import ( "database/sql" "errors" "reflect" - "time" - "xorm.io/xorm/log" + "xorm.io/xorm/contexts" ) // Stmt reprents a stmt objects @@ -30,28 +29,16 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { i++ return "?" }) - - start := time.Now() - showSQL := db.NeedLogSQL(ctx) - if showSQL { - db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - }) - } - stmt, err := db.DB.PrepareContext(ctx, query) - if showSQL { - db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } + hookCtx := contexts.NewContextHook(ctx, "PREPARE", nil) + ctx, err := db.beforeProcess(hookCtx) if err != nil { return nil, err } - + stmt, err := db.DB.PrepareContext(ctx, query) + hookCtx.End(ctx, nil, err) + if err := db.afterProcess(hookCtx); err != nil { + return nil, err + } return &Stmt{stmt, db, names, query}, nil } @@ -94,49 +81,28 @@ func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { } func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { - start := time.Now() - showSQL := s.db.NeedLogSQL(ctx) - if showSQL { - s.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - }) + hookCtx := contexts.NewContextHook(ctx, s.query, args) + ctx, err := s.db.beforeProcess(hookCtx) + if err != nil { + return nil, err } res, err := s.Stmt.ExecContext(ctx, args) - if showSQL { - s.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) + hookCtx.End(ctx, res, err) + if err := s.db.afterProcess(hookCtx); err != nil { + return nil, err } - return res, err + return res, nil } func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { - start := time.Now() - showSQL := s.db.NeedLogSQL(ctx) - if showSQL { - s.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - }) + hookCtx := contexts.NewContextHook(ctx, s.query, args) + ctx, err := s.db.beforeProcess(hookCtx) + if err != nil { + return nil, err } rows, err := s.Stmt.QueryContext(ctx, args...) - if showSQL { - s.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } - if err != nil { + hookCtx.End(ctx, nil, err) + if err := s.db.afterProcess(hookCtx); err != nil { return nil, err } return &Rows{rows, s.db}, nil diff --git a/core/tx.go b/core/tx.go index 99a8097d..9b2988af 100644 --- a/core/tx.go +++ b/core/tx.go @@ -7,9 +7,8 @@ package core import ( "context" "database/sql" - "time" - "xorm.io/xorm/log" + "xorm.io/xorm/contexts" ) var ( @@ -23,24 +22,14 @@ type Tx struct { } func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { - start := time.Now() - showSQL := db.NeedLogSQL(ctx) - if showSQL { - db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: "BEGIN TRANSACTION", - }) + hookCtx := contexts.NewContextHook(ctx, "BEGIN TRANSACTION", nil) + ctx, err := db.beforeProcess(hookCtx) + if err != nil { + return nil, err } tx, err := db.DB.BeginTx(ctx, opts) - if showSQL { - db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: "BEGIN TRANSACTION", - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } - if err != nil { + hookCtx.End(ctx, nil, err) + if err := db.afterProcess(hookCtx); err != nil { return nil, err } return &Tx{tx, db}, nil @@ -58,25 +47,14 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { i++ return "?" }) - - start := time.Now() - showSQL := tx.db.NeedLogSQL(ctx) - if showSQL { - tx.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - }) + hookCtx := contexts.NewContextHook(ctx, "PREPARE", nil) + ctx, err := tx.db.beforeProcess(hookCtx) + if err != nil { + return nil, err } stmt, err := tx.Tx.PrepareContext(ctx, query) - if showSQL { - tx.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } - if err != nil { + hookCtx.End(ctx, nil, err) + if err := tx.db.afterProcess(hookCtx); err != nil { return nil, err } return &Stmt{stmt, tx.db, names, query}, nil @@ -116,24 +94,15 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{ } func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - start := time.Now() - showSQL := tx.db.NeedLogSQL(ctx) - if showSQL { - tx.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - }) + hookCtx := contexts.NewContextHook(ctx, query, args) + ctx, err := tx.db.beforeProcess(hookCtx) + if err != nil { + return nil, err } res, err := tx.Tx.ExecContext(ctx, query, args...) - if showSQL { - tx.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) + hookCtx.End(ctx, res, err) + if err := tx.db.afterProcess(hookCtx); err != nil { + return nil, err } return res, err } @@ -143,26 +112,14 @@ func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { } func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { - start := time.Now() - showSQL := tx.db.NeedLogSQL(ctx) - if showSQL { - tx.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - }) + hookCtx := contexts.NewContextHook(ctx, query, args) + ctx, err := tx.db.beforeProcess(hookCtx) + if err != nil { + return nil, err } rows, err := tx.Tx.QueryContext(ctx, query, args...) - if showSQL { - tx.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } - if err != nil { + hookCtx.End(ctx, nil, err) + if err := tx.db.afterProcess(hookCtx); err != nil { if rows != nil { rows.Close() } diff --git a/engine.go b/engine.go index d99e15db..7399f41a 100644 --- a/engine.go +++ b/engine.go @@ -18,6 +18,7 @@ import ( "time" "xorm.io/xorm/caches" + "xorm.io/xorm/contexts" "xorm.io/xorm/core" "xorm.io/xorm/dialects" "xorm.io/xorm/internal/utils" @@ -1287,6 +1288,10 @@ func (engine *Engine) SetSchema(schema string) { engine.dialect.URI().SetSchema(schema) } +func (engine *Engine) AddHook(hook contexts.Hook) { + engine.db.AddHook(hook) +} + // Unscoped always disable struct tag "deleted" func (engine *Engine) Unscoped() *Session { session := engine.NewSession() @@ -1298,7 +1303,7 @@ func (engine *Engine) tbNameWithSchema(v string) string { return dialects.TableNameWithSchema(engine.dialect, v) } -// Context creates a session with the context +// ContextHook creates a session with the context func (engine *Engine) Context(ctx context.Context) *Session { session := engine.NewSession() session.isAutoClose = true diff --git a/engine_group.go b/engine_group.go index 02a57ab4..cdd9dd44 100644 --- a/engine_group.go +++ b/engine_group.go @@ -9,6 +9,7 @@ import ( "time" "xorm.io/xorm/caches" + "xorm.io/xorm/contexts" "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -78,7 +79,7 @@ func (eg *EngineGroup) Close() error { return nil } -// Context returned a group session +// ContextHook returned a group session func (eg *EngineGroup) Context(ctx context.Context) *Session { sess := eg.NewSession() sess.isAutoClose = true @@ -143,6 +144,13 @@ func (eg *EngineGroup) SetLogger(logger interface{}) { } } +func (eg *EngineGroup) AddHook(hook contexts.Hook) { + eg.Engine.AddHook(hook) + for i := 0; i < len(eg.slaves); i++ { + eg.slaves[i].AddHook(hook) + } +} + // SetLogLevel sets the logger level func (eg *EngineGroup) SetLogLevel(level log.LogLevel) { eg.Engine.SetLogLevel(level) diff --git a/interface.go b/interface.go index 262a2cfe..6aac4ae8 100644 --- a/interface.go +++ b/interface.go @@ -11,6 +11,7 @@ import ( "time" "xorm.io/xorm/caches" + "xorm.io/xorm/contexts" "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -111,6 +112,7 @@ type EngineInterface interface { SetTableMapper(names.Mapper) SetTZDatabase(tz *time.Location) SetTZLocation(tz *time.Location) + AddHook(hook contexts.Hook) ShowSQL(show ...bool) Sync(...interface{}) error Sync2(...interface{}) error diff --git a/log/logger_context.go b/log/logger_context.go index faed26d6..6b7252ef 100644 --- a/log/logger_context.go +++ b/log/logger_context.go @@ -5,19 +5,13 @@ package log import ( - "context" "fmt" - "time" + + "xorm.io/xorm/contexts" ) // LogContext represents a log context -type LogContext struct { - Ctx context.Context - SQL string // log content or SQL - Args []interface{} // if it's a SQL, it's the arguments - ExecuteTime time.Duration - Err error // SQL executed error -} +type LogContext contexts.ContextHook // SQLLogger represents an interface to log SQL type SQLLogger interface { diff --git a/session.go b/session.go index 9f47d9b4..761b1415 100644 --- a/session.go +++ b/session.go @@ -887,7 +887,7 @@ func (session *Session) incrVersionFieldValue(fieldValue *reflect.Value) { } } -// Context sets the context on this session +// ContextHook sets the context on this session func (session *Session) Context(ctx context.Context) *Session { session.ctx = ctx return session