From b081b9196dc1b0a56024d0e069f9464283ba408a Mon Sep 17 00:00:00 2001 From: datbeohbbh Date: Sun, 5 Mar 2023 22:00:18 +0300 Subject: [PATCH] add `(*Engine) TransactionContext(...)` --- engine.go | 27 ++++++++++++++++++++++++ integrations/engine_test.go | 42 +++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/engine.go b/engine.go index fb19176e..88a1d010 100644 --- a/engine.go +++ b/engine.go @@ -1416,6 +1416,11 @@ func (engine *Engine) SetDefaultContext(ctx context.Context) { engine.defaultContext = ctx } +// GetDefaultContext get the default context +func (engine *Engine) GetDefaultContext() context.Context { + return engine.defaultContext +} + // PingContext tests if database is alive func (engine *Engine) PingContext(ctx context.Context) error { session := engine.NewSession() @@ -1443,3 +1448,25 @@ func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interf return result, nil } + +// !datbeohbbh! Transaction Execute sql wrapped in a transaction with provided context +func (engine *Engine) TransactionContext(ctx context.Context, f func(context.Context, *Session) (interface{}, error)) (interface{}, error) { + session := engine.NewSession().Context(ctx) + defer session.Close() + + if err := session.Begin(); err != nil { + return nil, err + } + defer session.Rollback() + + result, err := f(ctx, session) + if err != nil { + return nil, err + } + + if err := session.Commit(); err != nil { + return nil, err + } + + return result, nil +} diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 730a424e..009c0915 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -84,6 +84,48 @@ func TestAutoTransaction(t *testing.T) { assert.EqualValues(t, false, has) } +func TestTransactionContext(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestTxContext struct { + Id int64 `xorm:"autoincr pk"` + Msg string `xorm:"varchar(255)"` + Created time.Time `xorm:"created"` + } + + assert.NoError(t, testEngine.Sync(&TestTxContext{})) + + engine := testEngine.(*xorm.Engine) + ctx, cancel := context.WithTimeout(engine.GetDefaultContext(), 5*time.Second) + defer cancel() + + // will success + _, err := engine.TransactionContext(ctx, func(ctx context.Context, session *xorm.Session) (interface{}, error) { + _, err := session.Insert(TestTxContext{Msg: "hi"}) + assert.NoError(t, err) + + return nil, nil + }) + assert.NoError(t, err) + + has, err := engine.Exist(&TestTxContext{Msg: "hi"}) + assert.NoError(t, err) + assert.True(t, has) + + // will rollback + _, err = engine.TransactionContext(ctx, func(ctx context.Context, session *xorm.Session) (interface{}, error) { + _, err := session.Insert(TestTxContext{Msg: "hello"}) + assert.NoError(t, err) + + return nil, fmt.Errorf("rollback") + }) + assert.Error(t, err) + + has, err = engine.Exist(&TestTxContext{Msg: "hello"}) + assert.NoError(t, err) + assert.False(t, has) +} + func assertSync(t *testing.T, beans ...interface{}) { for _, bean := range beans { t.Run(testEngine.TableName(bean, true), func(t *testing.T) {