Change filters

This commit is contained in:
Lunny Xiao 2020-02-27 10:35:01 +08:00
parent db72bb1f1b
commit 37b5dc2384
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
10 changed files with 21 additions and 24 deletions

View File

@ -13,20 +13,20 @@ import (
// Filter is an interface to filter SQL // Filter is an interface to filter SQL
type Filter interface { type Filter interface {
Do(sql string, dialect Dialect, table *schemas.Table) string Do(sql string) string
} }
// QuoteFilter filter SQL replace ` to database's own quote character // QuoteFilter filter SQL replace ` to database's own quote character
type QuoteFilter struct { type QuoteFilter struct {
quoter schemas.Quoter
} }
func (s *QuoteFilter) Do(sql string, dialect Dialect, table *schemas.Table) string { func (s *QuoteFilter) Do(sql string) string {
quoter := dialect.Quoter() if s.quoter.IsEmpty() {
if quoter.IsEmpty() {
return sql return sql
} }
prefix, suffix := quoter[0][0], quoter[1][0] prefix, suffix := s.quoter[0][0], s.quoter[1][0]
raw := []byte(sql) raw := []byte(sql)
for i, cnt := 0, 0; i < len(raw); i = i + 1 { for i, cnt := 0, 0; i < len(raw); i = i + 1 {
if raw[i] == '`' { if raw[i] == '`' {
@ -66,6 +66,6 @@ func convertQuestionMark(sql, prefix string, start int) string {
return buf.String() return buf.String()
} }
func (s *SeqFilter) Do(sql string, dialect Dialect, table *schemas.Table) string { func (s *SeqFilter) Do(sql string) string {
return convertQuestionMark(sql, s.Prefix, s.Start) return convertQuestionMark(sql, s.Prefix, s.Start)
} }

View File

@ -3,21 +3,15 @@ package dialects
import ( import (
"testing" "testing"
"xorm.io/xorm/schemas"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
type quoterOnly struct {
Dialect
}
func (q *quoterOnly) Quote(item string) string {
return "[" + item + "]"
}
func TestQuoteFilter_Do(t *testing.T) { func TestQuoteFilter_Do(t *testing.T) {
f := QuoteFilter{} f := QuoteFilter{schemas.Quoter{"[", "]"}}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
res := f.Do(sql, new(quoterOnly), nil) res := f.Do(sql)
assert.EqualValues(t, assert.EqualValues(t,
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", "SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
res, res,

View File

@ -534,7 +534,7 @@ func (db *mssql) ForUpdateSQL(query string) string {
} }
func (db *mssql) Filters() []Filter { func (db *mssql) Filters() []Filter {
return []Filter{&QuoteFilter{}} return []Filter{&QuoteFilter{db.Quoter()}}
} }
type odbcDriver struct { type odbcDriver struct {

View File

@ -848,7 +848,10 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*schemas.Index, error
} }
func (db *oracle) Filters() []Filter { func (db *oracle) Filters() []Filter {
return []Filter{&QuoteFilter{}, &SeqFilter{Prefix: ":", Start: 1}} return []Filter{
&QuoteFilter{db.Quoter()},
&SeqFilter{Prefix: ":", Start: 1},
}
} }
type goracleDriver struct { type goracleDriver struct {

View File

@ -1159,7 +1159,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*schemas.Index, err
} }
func (db *postgres) Filters() []Filter { func (db *postgres) Filters() []Filter {
return []Filter{&QuoteFilter{}, &SeqFilter{Prefix: "$", Start: 1}} return []Filter{&QuoteFilter{db.Quoter()}, &SeqFilter{Prefix: "$", Start: 1}}
} }
type pqDriver struct { type pqDriver struct {

View File

@ -20,7 +20,7 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri
} }
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, table) sqlStr = filter.Do(sqlStr)
} }
newsql := session.statement.convertIDSQL(sqlStr) newsql := session.statement.convertIDSQL(sqlStr)

View File

@ -335,7 +335,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} }
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) sqlStr = filter.Do(sqlStr)
} }
newsql := session.statement.convertIDSQL(sqlStr) newsql := session.statement.convertIDSQL(sqlStr)

View File

@ -272,7 +272,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
} }
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) sqlStr = filter.Do(sqlStr)
} }
newsql := session.statement.convertIDSQL(sqlStr) newsql := session.statement.convertIDSQL(sqlStr)
if newsql == "" { if newsql == "" {

View File

@ -15,7 +15,7 @@ import (
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
*sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable) *sqlStr = filter.Do(*sqlStr)
} }
session.lastSQL = *sqlStr session.lastSQL = *sqlStr

View File

@ -28,7 +28,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
return ErrCacheFailed return ErrCacheFailed
} }
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
newsql = filter.Do(newsql, session.engine.dialect, table) newsql = filter.Do(newsql)
} }
session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql) session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)