diff --git a/core/tx.go b/core/tx.go index 9b2988af..a85a6874 100644 --- a/core/tx.go +++ b/core/tx.go @@ -18,7 +18,8 @@ var ( // Tx represents a transaction type Tx struct { *sql.Tx - db *DB + db *DB + ctx context.Context } func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { @@ -32,13 +33,41 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { if err := db.afterProcess(hookCtx); err != nil { return nil, err } - return &Tx{tx, db}, nil + return &Tx{tx, db, ctx}, nil } func (db *DB) Begin() (*Tx, error) { return db.BeginTx(context.Background(), nil) } +func (tx *Tx) Commit() error { + hookCtx := contexts.NewContextHook(tx.ctx, "COMMIT", nil) + ctx, err := tx.db.beforeProcess(hookCtx) + if err != nil { + return err + } + err = tx.Tx.Commit() + hookCtx.End(ctx, nil, err) + if err := tx.db.afterProcess(hookCtx); err != nil { + return err + } + return nil +} + +func (tx *Tx) Rollback() error { + hookCtx := contexts.NewContextHook(tx.ctx, "ROLLBACK", nil) + ctx, err := tx.db.beforeProcess(hookCtx) + if err != nil { + return err + } + err = tx.Tx.Rollback() + hookCtx.End(ctx, nil, err) + if err := tx.db.afterProcess(hookCtx); err != nil { + return err + } + return nil +} + func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { names := make(map[string]int) var i int