diff --git a/dialects/filter.go b/dialects/filter.go index 4795edb7..0f9b4107 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -13,20 +13,20 @@ import ( // Filter is an interface to filter SQL type Filter interface { - Do(sql string, dialect Dialect, table *schemas.Table) string + Do(sql string) string } // QuoteFilter filter SQL replace ` to database's own quote character type QuoteFilter struct { + quoter schemas.Quoter } -func (s *QuoteFilter) Do(sql string, dialect Dialect, table *schemas.Table) string { - quoter := dialect.Quoter() - if quoter.IsEmpty() { +func (s *QuoteFilter) Do(sql string) string { + if s.quoter.IsEmpty() { return sql } - prefix, suffix := quoter[0][0], quoter[1][0] + prefix, suffix := s.quoter[0][0], s.quoter[1][0] raw := []byte(sql) for i, cnt := 0, 0; i < len(raw); i = i + 1 { if raw[i] == '`' { @@ -66,6 +66,6 @@ func convertQuestionMark(sql, prefix string, start int) string { return buf.String() } -func (s *SeqFilter) Do(sql string, dialect Dialect, table *schemas.Table) string { +func (s *SeqFilter) Do(sql string) string { return convertQuestionMark(sql, s.Prefix, s.Start) } diff --git a/dialects/filter_test.go b/dialects/filter_test.go index e5430bab..ac110a69 100644 --- a/dialects/filter_test.go +++ b/dialects/filter_test.go @@ -3,21 +3,15 @@ package dialects import ( "testing" + "xorm.io/xorm/schemas" + "github.com/stretchr/testify/assert" ) -type quoterOnly struct { - Dialect -} - -func (q *quoterOnly) Quote(item string) string { - return "[" + item + "]" -} - func TestQuoteFilter_Do(t *testing.T) { - f := QuoteFilter{} + f := QuoteFilter{schemas.Quoter{"[", "]"}} sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" - res := f.Do(sql, new(quoterOnly), nil) + res := f.Do(sql) assert.EqualValues(t, "SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", res, diff --git a/dialects/mssql.go b/dialects/mssql.go index d473d975..74a3bb63 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -534,7 +534,7 @@ func (db *mssql) ForUpdateSQL(query string) string { } func (db *mssql) Filters() []Filter { - return []Filter{&QuoteFilter{}} + return []Filter{&QuoteFilter{db.Quoter()}} } type odbcDriver struct { diff --git a/dialects/oracle.go b/dialects/oracle.go index bf9ee2af..46f7aca2 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -848,7 +848,10 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*schemas.Index, error } func (db *oracle) Filters() []Filter { - return []Filter{&QuoteFilter{}, &SeqFilter{Prefix: ":", Start: 1}} + return []Filter{ + &QuoteFilter{db.Quoter()}, + &SeqFilter{Prefix: ":", Start: 1}, + } } type goracleDriver struct { diff --git a/dialects/postgres.go b/dialects/postgres.go index f161fdfa..cab7eaef 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1159,7 +1159,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*schemas.Index, err } func (db *postgres) Filters() []Filter { - return []Filter{&QuoteFilter{}, &SeqFilter{Prefix: "$", Start: 1}} + return []Filter{&QuoteFilter{db.Quoter()}, &SeqFilter{Prefix: "$", Start: 1}} } type pqDriver struct { diff --git a/session_delete.go b/session_delete.go index 94cf833d..a639d61f 100644 --- a/session_delete.go +++ b/session_delete.go @@ -20,7 +20,7 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.engine.dialect, table) + sqlStr = filter.Do(sqlStr) } newsql := session.statement.convertIDSQL(sqlStr) diff --git a/session_find.go b/session_find.go index cdf086d0..fd1d49b1 100644 --- a/session_find.go +++ b/session_find.go @@ -335,7 +335,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) + sqlStr = filter.Do(sqlStr) } newsql := session.statement.convertIDSQL(sqlStr) diff --git a/session_get.go b/session_get.go index fd66c438..bf91eacf 100644 --- a/session_get.go +++ b/session_get.go @@ -272,7 +272,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) + sqlStr = filter.Do(sqlStr) } newsql := session.statement.convertIDSQL(sqlStr) if newsql == "" { diff --git a/session_raw.go b/session_raw.go index 4a2e2777..51487779 100644 --- a/session_raw.go +++ b/session_raw.go @@ -15,7 +15,7 @@ import ( func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { for _, filter := range session.engine.dialect.Filters() { - *sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable) + *sqlStr = filter.Do(*sqlStr) } session.lastSQL = *sqlStr diff --git a/session_update.go b/session_update.go index c1f1e0bf..4330afae 100644 --- a/session_update.go +++ b/session_update.go @@ -28,7 +28,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri return ErrCacheFailed } for _, filter := range session.engine.dialect.Filters() { - newsql = filter.Do(newsql, session.engine.dialect, table) + newsql = filter.Do(newsql) } session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)