diff --git a/dialects/dialect.go b/dialects/dialect.go index 8e512c4f..50e5c38b 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -87,6 +87,8 @@ type Dialect interface { Filters() []Filter SetParams(params map[string]string) + + IsRetryable(err error) (canRetry bool) } // Base represents a basic dialect and all real dialects could embed this struct @@ -247,6 +249,11 @@ func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *Base) SetParams(params map[string]string) { } +// check if an error is retryable +func (db *Base) IsRetryable(err error) bool { + return true +} + var dialects = map[string]func() Dialect{} // RegisterDialect register database dialect diff --git a/engine.go b/engine.go index 0cbfdede..02e7c8fb 100644 --- a/engine.go +++ b/engine.go @@ -24,6 +24,7 @@ import ( "xorm.io/xorm/internal/utils" "xorm.io/xorm/log" "xorm.io/xorm/names" + "xorm.io/xorm/retry" "xorm.io/xorm/schemas" "xorm.io/xorm/tags" ) @@ -1433,3 +1434,58 @@ func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interf return result, nil } + +// Do is a retryer of session +func (engine *Engine) Do(ctx context.Context, f func(context.Context, *Session) error, opts ...retry.RetryOption) error { + var ( + dialect = engine.Dialect() + attempts = 0 + ) + err := retry.Retry(ctx, dialect.IsRetryable, func(ctx context.Context) (err error) { + attempts++ + session := engine.NewSession().Context(ctx) + defer func() { + _ = session.Close() + }() + if err = f(ctx, session); err != nil { + return err + } + return nil + }, opts...) + if err != nil { + return fmt.Errorf("operation failed after %d attempts: %w", attempts, err) + } + return nil +} + +// DoTx is a retryer of session transactions +func (engine *Engine) DoTx(ctx context.Context, f func(context.Context, *Session) error, opts ...retry.RetryOption) error { + var ( + dialect = engine.Dialect() + attempts = 0 + ) + err := retry.Retry(ctx, dialect.IsRetryable, func(ctx context.Context) (err error) { + attempts++ + session := engine.NewSession().Context(ctx) + defer func() { + _ = session.Close() + }() + if err = session.Begin(); err != nil { + return err + } + defer func() { + _ = session.Rollback() + }() + if err = f(ctx, session); err != nil { + return err + } + if err = session.Commit(); err != nil { + return err + } + return nil + }, opts...) + if err != nil { + return fmt.Errorf("tx failed after %d attempts: %w", attempts, err) + } + return nil +} diff --git a/interface.go b/interface.go index 03dfd236..47579165 100644 --- a/interface.go +++ b/interface.go @@ -15,6 +15,7 @@ import ( "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" + "xorm.io/xorm/retry" "xorm.io/xorm/schemas" ) @@ -127,6 +128,9 @@ type EngineInterface interface { TableName(interface{}, ...bool) string UnMapType(reflect.Type) EnableSessionID(bool) + + Do(context.Context, func(context.Context, *Session) error, ...retry.RetryOption) error + DoTx(context.Context, func(context.Context, *Session) error, ...retry.RetryOption) error } var ( diff --git a/retry/backoff.go b/retry/backoff.go new file mode 100644 index 00000000..a30790e0 --- /dev/null +++ b/retry/backoff.go @@ -0,0 +1,73 @@ +// reference: https://aws.amazon.com/vi/blogs/architecture/exponential-backoff-and-jitter/ +package retry + +import ( + "math" + "math/rand" + "time" +) + +type BackoffInterface interface { + Wait(n int) <-chan time.Time + + Delay(i int) time.Duration +} + +type Backoff struct { + min time.Duration // default 5ms + max time.Duration // default 5s + jitter bool // default true +} + +func DefaultBackoff() *Backoff { + return &Backoff{ + min: 5 * time.Millisecond, + max: 5 * time.Second, + jitter: true, + } +} + +func NewBackoff(min, max time.Duration, jitter bool) *Backoff { + return &Backoff{ + min: min, + max: max, + jitter: jitter, + } +} + +func (b *Backoff) Wait(n int) <-chan time.Time { + return time.After(b.Delay(n)) +} + +// Decorrelated Jitter +func (b *Backoff) Delay(i int) time.Duration { + rand.New(rand.NewSource(time.Now().UnixNano())) + base := int64(b.min) + cap := int64(b.max) + + if base >= cap { + return time.Duration(cap) + } + + t := int(math.Log2(float64(cap)/float64(base))) + 1 + if i > t { + i = t + } + + bf := base * int64(1< cap { + bf = cap + } + + if !b.jitter { + return time.Duration(bf) + } + + w := (bf >> 1) + rand.Int63n((bf>>1)+1) + w = base + rand.Int63n(w*3-base+1) + if w > cap { + w = cap + } + + return time.Duration(w) +} diff --git a/retry/backoff_test.go b/retry/backoff_test.go new file mode 100644 index 00000000..3fdb0578 --- /dev/null +++ b/retry/backoff_test.go @@ -0,0 +1,74 @@ +package retry + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultBackoff(t *testing.T) { + bf := DefaultBackoff() + for i := 0; i < 64; i++ { + d := bf.Delay(i) + n := time.Now() + start := n.Add(bf.min) + end := n.Add(bf.max) + cur := n.Add(d) + assert.WithinRange(t, cur, start, end) + } +} + +func TestBackoff(t *testing.T) { + for _, v := range []struct { + min time.Duration + max time.Duration + jitter bool + attempts int + }{ + { + min: 5 * time.Microsecond, + max: 10 * time.Microsecond, + jitter: true, + attempts: 0, + }, + { + min: 10 * time.Millisecond, + max: 20 * time.Millisecond, + jitter: false, + attempts: 1, + }, + { + min: 20 * time.Microsecond, + max: 30 * time.Millisecond, + jitter: false, + attempts: 2, + }, + { + min: 30 * time.Second, + max: 40 * time.Second, + jitter: true, + attempts: 70, + }, + { + min: 10 * time.Millisecond, + max: 20 * time.Second, + jitter: true, + attempts: 10, + }, + { + min: 1 * time.Second, + max: 2 * time.Second, + jitter: false, + attempts: 30, + }, + } { + bf := NewBackoff(v.min, v.max, v.jitter) + d := bf.Delay(v.attempts) + n := time.Now() + start := n.Add(bf.min) + end := n.Add(bf.max) + cur := n.Add(d) + assert.WithinRange(t, cur, start, end) + } +} diff --git a/retry/retry.go b/retry/retry.go new file mode 100644 index 00000000..27202a26 --- /dev/null +++ b/retry/retry.go @@ -0,0 +1,123 @@ +// reference: https://github.com/ydb-platform/ydb-go-sdk/blob/master/retry/retry.go +package retry + +import ( + "context" + "errors" + "fmt" +) + +type retryOptions struct { + id string + idempotent bool + backoff BackoffInterface // default implement 'Decorrelated Jitter' algorithm + ctx context.Context +} + +var ( + ErrNonRetryable = errors.New("retry error: non-retryable operation") + ErrNonIdempotent = errors.New("retry error: non-idempotent operation") + ErrMaxRetriesLimitExceed = errors.New("retry error: max retries limit exceeded") +) + +// !datbeohbbh! This function can be dialect.IsRetryable(err) +// or your custom function that check if an error can be retried +type checkRetryable func(error) bool + +type retryOperation func(context.Context) error + +type RetryOption func(*retryOptions) + +type maxRetriesKey struct{} + +func WithMaxRetries(maxRetriesValue int) RetryOption { + return func(o *retryOptions) { + o.ctx = context.WithValue(o.ctx, maxRetriesKey{}, maxRetriesValue) + } +} + +func WithID(id string) RetryOption { + return func(o *retryOptions) { + o.id = id + } +} + +func WithIdempotent(idempotent bool) RetryOption { + return func(o *retryOptions) { + o.idempotent = idempotent + } +} + +func WithBackoff(backoff BackoffInterface) RetryOption { + return func(o *retryOptions) { + o.backoff = backoff + } +} + +func (opts *retryOptions) reachMaxRetries(attempts int) bool { + if mx, has := opts.ctx.Value(maxRetriesKey{}).(int); !has { + return false + } else { + return attempts > mx + } +} + +// !datbeohbbh! Retry provide the best effort fo retrying operation +// +// Retry implements internal busy loop until one of the following conditions is met: +// - context was canceled or deadlined +// - retry operation returned nil as error +// +// Warning: if deadline without deadline or cancellation func Retry will be worked infinite +func Retry(ctx context.Context, check checkRetryable, f retryOperation, opts ...RetryOption) error { + options := &retryOptions{ + ctx: ctx, + backoff: DefaultBackoff(), + } + for _, o := range opts { + if o != nil { + o(options) + } + } + + attempts := 0 + for !options.reachMaxRetries(attempts) { + attempts++ + select { + case <-ctx.Done(): + return ctx.Err() + default: + err := f(ctx) + if err == nil { + return nil + } + canRetry := check(err) + if !canRetry { + return fmt.Errorf("Retry process with id '%s': %w", + options.id, fmt.Errorf("%v: %w", err, ErrNonRetryable)) + } + if !options.idempotent { + return fmt.Errorf("Retry process with id '%s': %w", + options.id, fmt.Errorf("%v: %w", err, ErrNonIdempotent)) + } + if err = wait(ctx, options.backoff, attempts); err != nil { + return fmt.Errorf("Retry process with id '%s': %w", options.id, err) + } + } + } + return fmt.Errorf("Retry process with id '%s': %w", + options.id, + fmt.Errorf("%v: %w", + fmt.Errorf("max retries: %v", options.ctx.Value(maxRetriesKey{})), + ErrMaxRetriesLimitExceed, + )) +} + +func wait(ctx context.Context, backoff BackoffInterface, attempts int) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-backoff.Wait(attempts): + return nil + } +} diff --git a/retry/retry_test.go b/retry/retry_test.go new file mode 100644 index 00000000..ddb209eb --- /dev/null +++ b/retry/retry_test.go @@ -0,0 +1,161 @@ +package retry + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSetRetryOptions(t *testing.T) { + opts := []RetryOption{ + WithMaxRetries(10), + WithID("ut-test-retry"), + WithIdempotent(true), + WithBackoff(DefaultBackoff()), + } + + rt := &retryOptions{ + ctx: context.Background(), + } + for _, o := range opts { + if o != nil { + o(rt) + } + } + + val, ok := rt.ctx.Value(maxRetriesKey{}).(int) + assert.True(t, ok) + assert.EqualValues(t, 10, val) + + assert.Equal(t, "ut-test-retry", rt.id) + + assert.True(t, rt.idempotent) + + assert.EqualValues(t, DefaultBackoff(), rt.backoff) +} + +func TestMaxRetries(t *testing.T) { + const mxRetries int = 10 + + opts := []RetryOption{ + WithMaxRetries(mxRetries), + } + + rt := &retryOptions{ + ctx: context.Background(), + } + for _, o := range opts { + if o != nil { + o(rt) + } + } + + val, ok := rt.ctx.Value(maxRetriesKey{}).(int) + assert.True(t, ok) + assert.EqualValues(t, mxRetries, val) + + for i := 0; i < mxRetries; i++ { + assert.False(t, rt.reachMaxRetries(i)) + } + + assert.True(t, rt.reachMaxRetries(mxRetries+1)) +} + +func TestRetryTimeOut(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + err := Retry(ctx, func(err error) bool { + return true + }, func(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(5 * time.Millisecond): + return nil + } + }, WithIdempotent(true)) + + assert.True(t, errors.Is(err, context.DeadlineExceeded)) +} + +func TestRetryMaxRetriesExceeded(t *testing.T) { + ctx := context.Background() + + utErr := errors.New("ut-error") + + err := Retry(ctx, func(err error) bool { + return true + }, func(ctx context.Context) error { + return utErr + }, + WithMaxRetries(10), + WithIdempotent(true), + WithBackoff(NewBackoff(1*time.Millisecond, 2*time.Millisecond, true))) + + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrMaxRetriesLimitExceed)) +} + +func TestRetryNonRetryable(t *testing.T) { + ctx := context.Background() + + utErr := errors.New("ut-error") + + err := Retry(ctx, func(err error) bool { + return false + }, func(ctx context.Context) error { + return utErr + }, + WithBackoff(NewBackoff(1*time.Millisecond, 2*time.Millisecond, true))) + + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrNonRetryable)) +} + +func TestRetryIdempotent(t *testing.T) { + ctx := context.Background() + + utErr := errors.New("ut-error") + + err := Retry(ctx, func(err error) bool { + return true + }, func(ctx context.Context) error { + return utErr + }, + WithIdempotent(false), + WithBackoff(NewBackoff(1*time.Millisecond, 2*time.Millisecond, true))) + + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrNonIdempotent)) +} + +func TestRetryOk(t *testing.T) { + const mxRetries int = 10 + ctx := context.Background() + + utErr := errors.New("ut-error") + + var c int = 0 + + err := Retry(ctx, func(err error) bool { + return true + }, func(ctx context.Context) error { + defer func() { + c += 1 + }() + if c == mxRetries { + return nil + } + return utErr + }, + WithMaxRetries(mxRetries), + WithIdempotent(true), + WithBackoff(NewBackoff(1*time.Millisecond, 2*time.Millisecond, true))) + + assert.NoError(t, err) + assert.Greater(t, c, mxRetries) +} diff --git a/tests/retry_test.go b/tests/retry_test.go new file mode 100644 index 00000000..8c3566dc --- /dev/null +++ b/tests/retry_test.go @@ -0,0 +1,72 @@ +package tests + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "xorm.io/xorm" + "xorm.io/xorm/retry" +) + +func TestRetry(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestRetry struct { + Id int64 `xorm:"int(11) pk"` + Name string `xorm:"varchar(255)"` + } + + assert.NoError(t, testEngine.Sync(new(TestRetry))) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := testEngine.Do(ctx, func(ctx context.Context, session *xorm.Session) error { + num, err := insertMultiDatas(1, + append([]TestRetry{}, TestRetry{1, "test1"}, TestRetry{2, "test2"}, TestRetry{3, "test3"})) + + if err != nil { + return err + } + + assert.EqualValues(t, 3, num) + return nil + }, retry.WithID("test-retry")) + + assert.NoError(t, err) +} + +func TestRetryTx(t *testing.T) { + assert.NoError(t, PrepareEngine()) + assertSync(t, new(Userinfo)) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := testEngine.DoTx( + ctx, + func(ctx context.Context, session *xorm.Session) error { + user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} + if _, err := session.Insert(&user1); err != nil { + return err + } + + user2 := Userinfo{Username: "zzz"} + if _, err := session.Where("`id` = ?", 0).Update(&user2); err != nil { + return err + } + + if _, err := session.Exec("delete from "+testEngine.Quote(testEngine.TableName("userinfo", true))+" where `username` = ?", user2.Username); err != nil { + return err + } + + return nil + }, + retry.WithID("test-retry-tx"), + retry.WithMaxRetries(5), + ) + + assert.NoError(t, err) +}