From d00e239c3c06b431ff3e56b2faa1a59a1e93cc0a Mon Sep 17 00:00:00 2001 From: LinkinStars Date: Tue, 30 May 2023 18:04:57 +0800 Subject: [PATCH] Filter support passing context (#2200) --- dialects/filter.go | 5 +++-- session_delete.go | 2 +- session_find.go | 2 +- session_get.go | 2 +- session_raw.go | 2 +- session_update.go | 2 +- 6 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dialects/filter.go b/dialects/filter.go index bfe2e93e..ff9a44e8 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -5,13 +5,14 @@ package dialects import ( + "context" "fmt" "strings" ) // Filter is an interface to filter SQL type Filter interface { - Do(sql string) string + Do(ctx context.Context, sql string) string } // SeqFilter filter SQL replace ?, ? ... to $1, $2 ... @@ -71,6 +72,6 @@ func convertQuestionMark(sql, prefix string, start int) string { } // Do implements Filter -func (s *SeqFilter) Do(sql string) string { +func (s *SeqFilter) Do(ctx context.Context, sql string) string { return convertQuestionMark(sql, s.Prefix, s.Start) } diff --git a/session_delete.go b/session_delete.go index d36b9e52..6c63d9b0 100644 --- a/session_delete.go +++ b/session_delete.go @@ -30,7 +30,7 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr) + sqlStr = filter.Do(session.ctx, sqlStr) } newsql := session.statement.ConvertIDSQL(sqlStr) diff --git a/session_find.go b/session_find.go index 2270454b..3341eafe 100644 --- a/session_find.go +++ b/session_find.go @@ -283,7 +283,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr) + sqlStr = filter.Do(session.ctx, sqlStr) } newsql := session.statement.ConvertIDSQL(sqlStr) diff --git a/session_get.go b/session_get.go index 9bb92a8b..96e362e9 100644 --- a/session_get.go +++ b/session_get.go @@ -278,7 +278,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr) + sqlStr = filter.Do(session.ctx, sqlStr) } newsql := session.statement.ConvertIDSQL(sqlStr) if newsql == "" { diff --git a/session_raw.go b/session_raw.go index add584d0..99f6be99 100644 --- a/session_raw.go +++ b/session_raw.go @@ -13,7 +13,7 @@ import ( func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { for _, filter := range session.engine.dialect.Filters() { - *sqlStr = filter.Do(*sqlStr) + *sqlStr = filter.Do(session.ctx, *sqlStr) } session.lastSQL = *sqlStr diff --git a/session_update.go b/session_update.go index e7104710..1f80e70f 100644 --- a/session_update.go +++ b/session_update.go @@ -34,7 +34,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri return ErrCacheFailed } for _, filter := range session.engine.dialect.Filters() { - newsql = filter.Do(newsql) + newsql = filter.Do(session.ctx, newsql) } session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql)