diff --git a/core/db.go b/core/db.go index 671d1dc2..931b062a 100644 --- a/core/db.go +++ b/core/db.go @@ -88,6 +88,7 @@ type DB struct { reflectCache map[reflect.Type]*cacheStruct reflectCacheMutex sync.RWMutex Logger log.ContextLogger + hooks []Hook } // Open opens a database @@ -139,27 +140,22 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value { } // QueryContext overwrites sql.DB.QueryContext -func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { +func (db *DB) QueryContext(parentCtx context.Context, query string, args ...interface{}) (*Rows, error) { + logCtx := log.LogContext{ + Ctx: parentCtx, + SQL: query, + Args: args, + } start := time.Now() - showSQL := db.NeedLogSQL(ctx) - if showSQL { - db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - }) - } - rows, err := db.DB.QueryContext(ctx, query, args...) - if showSQL { - db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } + ctx, err := db.beforeProcess(logCtx) if err != nil { + return nil, err + } + logCtx.Ctx = ctx + rows, err := db.DB.QueryContext(ctx, query, args...) + logCtx.ExecuteTime = time.Now().Sub(start) + logCtx.Err = err + if err := db.afterProcess(logCtx); err != nil { if rows != nil { rows.Close() } @@ -262,29 +258,60 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{ return db.ExecContext(ctx, query, args...) } -func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { +func (db *DB) ExecContext(parentCtx context.Context, query string, args ...interface{}) (sql.Result, error) { + logCtx := log.LogContext{ + Ctx: parentCtx, + SQL: query, + Args: args, + } start := time.Now() - showSQL := db.NeedLogSQL(ctx) - if showSQL { - db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - }) + ctx, err := db.beforeProcess(logCtx) + if err != nil { + return nil, err } + logCtx.Ctx = ctx res, err := db.DB.ExecContext(ctx, query, args...) - if showSQL { - db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) + logCtx.Err = err + logCtx.ExecuteTime = time.Now().Sub(start) + if err := db.afterProcess(logCtx); err != nil { + return nil, err } - return res, err + return res, nil } func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { return db.ExecStructContext(context.Background(), query, st) } + +func (db *DB) beforeProcess(logCtx log.LogContext) (context.Context, error) { + ctx := logCtx.Ctx + if db.NeedLogSQL(ctx) { + db.Logger.BeforeSQL(logCtx) + } + for _, h := range db.hooks { + var err error + ctx, err = h.BeforeProcess(ctx, logCtx.SQL, logCtx.Args...) + if err != nil { + return nil, err + } + } + return ctx, nil +} + +func (db *DB) afterProcess(logCtx log.LogContext) error { + firstErr := logCtx.Err + for _, h := range db.hooks { + err := h.AfterProcess(&logCtx) + if err != nil && firstErr == nil { + firstErr = err + } + } + if db.NeedLogSQL(logCtx.Ctx) { + db.Logger.AfterSQL(logCtx) + } + return firstErr +} + +func (db *DB) AddHook(hook Hook) { + db.hooks = append(db.hooks, hook) +} diff --git a/core/interface.go b/core/interface.go index a5c8e4e2..73ca2d2b 100644 --- a/core/interface.go +++ b/core/interface.go @@ -3,8 +3,15 @@ package core import ( "context" "database/sql" + + "xorm.io/xorm/log" ) +type Hook interface { + BeforeProcess(ctx context.Context, query string, args ...interface{}) (context.Context, error) + AfterProcess(logContext *log.LogContext) error +} + // Queryer represents an interface to query a SQL to get data from database type Queryer interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) diff --git a/core/stmt.go b/core/stmt.go index ebf2af73..754a8f89 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -30,28 +30,21 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { i++ return "?" }) - + logCtx := log.LogContext{ + Ctx: ctx, + SQL: "PREPARE", + } start := time.Now() - showSQL := db.NeedLogSQL(ctx) - if showSQL { - db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - }) - } - stmt, err := db.DB.PrepareContext(ctx, query) - if showSQL { - db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } + ctx, err := db.beforeProcess(logCtx) if err != nil { return nil, err } - + stmt, err := db.DB.PrepareContext(ctx, query) + logCtx.Err = err + logCtx.ExecuteTime = time.Now().Sub(start) + if err := db.afterProcess(logCtx); err != nil { + return nil, err + } return &Stmt{stmt, db, names, query}, nil } @@ -94,49 +87,40 @@ func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { } func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + logCtx := log.LogContext{ + Ctx: ctx, + SQL: s.query, + Args: args, + } start := time.Now() - showSQL := s.db.NeedLogSQL(ctx) - if showSQL { - s.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - }) + ctx, err := s.db.beforeProcess(logCtx) + if err != nil { + return nil, err } res, err := s.Stmt.ExecContext(ctx, args) - if showSQL { - s.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) + logCtx.ExecuteTime = time.Now().Sub(start) + logCtx.Err = err + if err := s.db.afterProcess(logCtx); err != nil { + return nil, err } - return res, err + return res, nil } func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { + logCtx := log.LogContext{ + Ctx: ctx, + SQL: s.query, + Args: args, + } start := time.Now() - showSQL := s.db.NeedLogSQL(ctx) - if showSQL { - s.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - }) + ctx, err := s.db.beforeProcess(logCtx) + if err != nil { + return nil, err } rows, err := s.Stmt.QueryContext(ctx, args...) - if showSQL { - s.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } - if err != nil { + logCtx.ExecuteTime = time.Now().Sub(start) + logCtx.Err = err + if err := s.db.afterProcess(logCtx); err != nil { return nil, err } return &Rows{rows, s.db}, nil diff --git a/core/tx.go b/core/tx.go index 99a8097d..0b491659 100644 --- a/core/tx.go +++ b/core/tx.go @@ -58,25 +58,19 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { i++ return "?" }) - + logCtx := log.LogContext{ + Ctx: ctx, + SQL: "PREPARE", + } start := time.Now() - showSQL := tx.db.NeedLogSQL(ctx) - if showSQL { - tx.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - }) + ctx, err := tx.db.beforeProcess(logCtx) + if err != nil { + return nil, err } stmt, err := tx.Tx.PrepareContext(ctx, query) - if showSQL { - tx.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } - if err != nil { + logCtx.Err = err + logCtx.ExecuteTime = time.Now().Sub(start) + if err := tx.db.afterProcess(logCtx); err != nil { return nil, err } return &Stmt{stmt, tx.db, names, query}, nil @@ -117,23 +111,20 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { start := time.Now() - showSQL := tx.db.NeedLogSQL(ctx) - if showSQL { - tx.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - }) + logCtx := log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + } + ctx, err := tx.db.beforeProcess(logCtx) + if err != nil { + return nil, err } res, err := tx.Tx.ExecContext(ctx, query, args...) - if showSQL { - tx.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) + logCtx.ExecuteTime = time.Now().Sub(start) + logCtx.Err = err + if err := tx.db.afterProcess(logCtx); err != nil { + return nil, err } return res, err } @@ -143,26 +134,20 @@ func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { } func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + logCtx := log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + } start := time.Now() - showSQL := tx.db.NeedLogSQL(ctx) - if showSQL { - tx.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - }) + ctx, err := tx.db.beforeProcess(logCtx) + if err != nil { + return nil, err } rows, err := tx.Tx.QueryContext(ctx, query, args...) - if showSQL { - tx.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } - if err != nil { + logCtx.Err = err + logCtx.ExecuteTime = time.Now().Sub(start) + if err := tx.db.afterProcess(logCtx); err != nil { if rows != nil { rows.Close() } diff --git a/engine.go b/engine.go index d99e15db..7d196249 100644 --- a/engine.go +++ b/engine.go @@ -1287,6 +1287,10 @@ func (engine *Engine) SetSchema(schema string) { engine.dialect.URI().SetSchema(schema) } +func (engine *Engine) AddHook(hook core.Hook) { + engine.db.AddHook(hook) +} + // Unscoped always disable struct tag "deleted" func (engine *Engine) Unscoped() *Session { session := engine.NewSession() diff --git a/engine_group.go b/engine_group.go index 02a57ab4..a4a1f6a6 100644 --- a/engine_group.go +++ b/engine_group.go @@ -9,6 +9,7 @@ import ( "time" "xorm.io/xorm/caches" + "xorm.io/xorm/core" "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -143,6 +144,13 @@ func (eg *EngineGroup) SetLogger(logger interface{}) { } } +func (eg *EngineGroup) AddHook(hook core.Hook) { + eg.Engine.AddHook(hook) + for i := 0; i < len(eg.slaves); i++ { + eg.slaves[i].AddHook(hook) + } +} + // SetLogLevel sets the logger level func (eg *EngineGroup) SetLogLevel(level log.LogLevel) { eg.Engine.SetLogLevel(level) diff --git a/interface.go b/interface.go index 262a2cfe..5f8c4b2e 100644 --- a/interface.go +++ b/interface.go @@ -11,6 +11,7 @@ import ( "time" "xorm.io/xorm/caches" + "xorm.io/xorm/core" "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -111,6 +112,7 @@ type EngineInterface interface { SetTableMapper(names.Mapper) SetTZDatabase(tz *time.Location) SetTZLocation(tz *time.Location) + AddHook(hook core.Hook) ShowSQL(show ...bool) Sync(...interface{}) error Sync2(...interface{}) error diff --git a/log/logger_context.go b/log/logger_context.go index faed26d6..715f86c6 100644 --- a/log/logger_context.go +++ b/log/logger_context.go @@ -6,6 +6,7 @@ package log import ( "context" + "database/sql" "fmt" "time" ) @@ -15,6 +16,7 @@ type LogContext struct { Ctx context.Context SQL string // log content or SQL Args []interface{} // if it's a SQL, it's the arguments + Result sql.Result ExecuteTime time.Duration Err error // SQL executed error }