Change filters
This commit is contained in:
parent
db72bb1f1b
commit
37b5dc2384
|
@ -13,20 +13,20 @@ import (
|
||||||
|
|
||||||
// Filter is an interface to filter SQL
|
// Filter is an interface to filter SQL
|
||||||
type Filter interface {
|
type Filter interface {
|
||||||
Do(sql string, dialect Dialect, table *schemas.Table) string
|
Do(sql string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// QuoteFilter filter SQL replace ` to database's own quote character
|
// QuoteFilter filter SQL replace ` to database's own quote character
|
||||||
type QuoteFilter struct {
|
type QuoteFilter struct {
|
||||||
|
quoter schemas.Quoter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *QuoteFilter) Do(sql string, dialect Dialect, table *schemas.Table) string {
|
func (s *QuoteFilter) Do(sql string) string {
|
||||||
quoter := dialect.Quoter()
|
if s.quoter.IsEmpty() {
|
||||||
if quoter.IsEmpty() {
|
|
||||||
return sql
|
return sql
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix, suffix := quoter[0][0], quoter[1][0]
|
prefix, suffix := s.quoter[0][0], s.quoter[1][0]
|
||||||
raw := []byte(sql)
|
raw := []byte(sql)
|
||||||
for i, cnt := 0, 0; i < len(raw); i = i + 1 {
|
for i, cnt := 0, 0; i < len(raw); i = i + 1 {
|
||||||
if raw[i] == '`' {
|
if raw[i] == '`' {
|
||||||
|
@ -66,6 +66,6 @@ func convertQuestionMark(sql, prefix string, start int) string {
|
||||||
return buf.String()
|
return buf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SeqFilter) Do(sql string, dialect Dialect, table *schemas.Table) string {
|
func (s *SeqFilter) Do(sql string) string {
|
||||||
return convertQuestionMark(sql, s.Prefix, s.Start)
|
return convertQuestionMark(sql, s.Prefix, s.Start)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,21 +3,15 @@ package dialects
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"xorm.io/xorm/schemas"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
type quoterOnly struct {
|
|
||||||
Dialect
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *quoterOnly) Quote(item string) string {
|
|
||||||
return "[" + item + "]"
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQuoteFilter_Do(t *testing.T) {
|
func TestQuoteFilter_Do(t *testing.T) {
|
||||||
f := QuoteFilter{}
|
f := QuoteFilter{schemas.Quoter{"[", "]"}}
|
||||||
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
|
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
|
||||||
res := f.Do(sql, new(quoterOnly), nil)
|
res := f.Do(sql)
|
||||||
assert.EqualValues(t,
|
assert.EqualValues(t,
|
||||||
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
|
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
|
||||||
res,
|
res,
|
||||||
|
|
|
@ -534,7 +534,7 @@ func (db *mssql) ForUpdateSQL(query string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mssql) Filters() []Filter {
|
func (db *mssql) Filters() []Filter {
|
||||||
return []Filter{&QuoteFilter{}}
|
return []Filter{&QuoteFilter{db.Quoter()}}
|
||||||
}
|
}
|
||||||
|
|
||||||
type odbcDriver struct {
|
type odbcDriver struct {
|
||||||
|
|
|
@ -848,7 +848,10 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*schemas.Index, error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *oracle) Filters() []Filter {
|
func (db *oracle) Filters() []Filter {
|
||||||
return []Filter{&QuoteFilter{}, &SeqFilter{Prefix: ":", Start: 1}}
|
return []Filter{
|
||||||
|
&QuoteFilter{db.Quoter()},
|
||||||
|
&SeqFilter{Prefix: ":", Start: 1},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type goracleDriver struct {
|
type goracleDriver struct {
|
||||||
|
|
|
@ -1159,7 +1159,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*schemas.Index, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) Filters() []Filter {
|
func (db *postgres) Filters() []Filter {
|
||||||
return []Filter{&QuoteFilter{}, &SeqFilter{Prefix: "$", Start: 1}}
|
return []Filter{&QuoteFilter{db.Quoter()}, &SeqFilter{Prefix: "$", Start: 1}}
|
||||||
}
|
}
|
||||||
|
|
||||||
type pqDriver struct {
|
type pqDriver struct {
|
||||||
|
|
|
@ -20,7 +20,7 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, filter := range session.engine.dialect.Filters() {
|
for _, filter := range session.engine.dialect.Filters() {
|
||||||
sqlStr = filter.Do(sqlStr, session.engine.dialect, table)
|
sqlStr = filter.Do(sqlStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
newsql := session.statement.convertIDSQL(sqlStr)
|
newsql := session.statement.convertIDSQL(sqlStr)
|
||||||
|
|
|
@ -335,7 +335,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, filter := range session.engine.dialect.Filters() {
|
for _, filter := range session.engine.dialect.Filters() {
|
||||||
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
|
sqlStr = filter.Do(sqlStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
newsql := session.statement.convertIDSQL(sqlStr)
|
newsql := session.statement.convertIDSQL(sqlStr)
|
||||||
|
|
|
@ -272,7 +272,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, filter := range session.engine.dialect.Filters() {
|
for _, filter := range session.engine.dialect.Filters() {
|
||||||
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
|
sqlStr = filter.Do(sqlStr)
|
||||||
}
|
}
|
||||||
newsql := session.statement.convertIDSQL(sqlStr)
|
newsql := session.statement.convertIDSQL(sqlStr)
|
||||||
if newsql == "" {
|
if newsql == "" {
|
||||||
|
|
|
@ -15,7 +15,7 @@ import (
|
||||||
|
|
||||||
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
|
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
|
||||||
for _, filter := range session.engine.dialect.Filters() {
|
for _, filter := range session.engine.dialect.Filters() {
|
||||||
*sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable)
|
*sqlStr = filter.Do(*sqlStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
session.lastSQL = *sqlStr
|
session.lastSQL = *sqlStr
|
||||||
|
|
|
@ -28,7 +28,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
|
||||||
return ErrCacheFailed
|
return ErrCacheFailed
|
||||||
}
|
}
|
||||||
for _, filter := range session.engine.dialect.Filters() {
|
for _, filter := range session.engine.dialect.Filters() {
|
||||||
newsql = filter.Do(newsql, session.engine.dialect, table)
|
newsql = filter.Do(newsql)
|
||||||
}
|
}
|
||||||
session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)
|
session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue