diff --git a/dialects/filter.go b/dialects/filter.go index add8cc7d..6968b6ce 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -7,8 +7,6 @@ package dialects import ( "fmt" "strings" - - "xorm.io/xorm/schemas" ) // Filter is an interface to filter SQL @@ -16,48 +14,6 @@ type Filter interface { Do(sql string) string } -// QuoteFilter filter SQL replace ` to database's own quote character -type QuoteFilter struct { - quoter schemas.Quoter -} - -func (s *QuoteFilter) Do(sql string) string { - if s.quoter.IsEmpty() { - return sql - } - - var buf strings.Builder - buf.Grow(len(sql)) - - var beginSingleQuote bool - for i := 0; i < len(sql); i++ { - if !beginSingleQuote && sql[i] == '`' { - var j = i + 1 - for ; j < len(sql); j++ { - if sql[j] == '`' { - break - } - } - word := sql[i+1 : j] - isReserved := s.quoter.IsReserved(word) - if isReserved { - buf.WriteByte(s.quoter.Prefix) - } - buf.WriteString(word) - if isReserved { - buf.WriteByte(s.quoter.Suffix) - } - i = j - } else { - if sql[i] == '\'' { - beginSingleQuote = !beginSingleQuote - } - buf.WriteByte(sql[i]) - } - } - return buf.String() -} - // SeqFilter filter SQL replace ?, ? ... to $1, $2 ... type SeqFilter struct { Prefix string diff --git a/dialects/mssql.go b/dialects/mssql.go index 6ba2cd97..af903a7a 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -525,7 +525,7 @@ func (db *mssql) ForUpdateSQL(query string) string { } func (db *mssql) Filters() []Filter { - return []Filter{&QuoteFilter{db.Quoter()}} + return []Filter{ /*&QuoteFilter{db.Quoter()}*/ } } type odbcDriver struct { diff --git a/dialects/oracle.go b/dialects/oracle.go index 045ad99b..48ead247 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -793,7 +793,7 @@ func (db *oracle) GetIndexes(ctx context.Context, tableName string) (map[string] func (db *oracle) Filters() []Filter { return []Filter{ - &QuoteFilter{db.Quoter()}, + /*&QuoteFilter{db.Quoter()},*/ &SeqFilter{Prefix: ":", Start: 1}, } } diff --git a/dialects/postgres.go b/dialects/postgres.go index e393452f..736d66a8 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1231,7 +1231,7 @@ func (db *postgres) GetIndexes(ctx context.Context, tableName string) (map[strin } func (db *postgres) Filters() []Filter { - return []Filter{&QuoteFilter{db.Quoter()}, &SeqFilter{Prefix: "$", Start: 1}} + return []Filter{ /*&QuoteFilter{db.Quoter()}, */ &SeqFilter{Prefix: "$", Start: 1}} } type pqDriver struct { diff --git a/internal/statements/query.go b/internal/statements/query.go index 8f7aeebb..31fb0f96 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -20,7 +20,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int } if statement.RawSQL != "" { - return statement.RawSQL, statement.RawParams, nil + return statement.GenRawSQL(), statement.RawParams, nil } if len(statement.TableName()) <= 0 { @@ -74,7 +74,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { if statement.RawSQL != "" { - return statement.RawSQL, statement.RawParams, nil + return statement.GenRawSQL(), statement.RawParams, nil } statement.SetRefBean(bean) @@ -153,7 +153,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) { if statement.RawSQL != "" { - return statement.RawSQL, statement.RawParams, nil + return statement.GenRawSQL(), statement.RawParams, nil } var condArgs []interface{} @@ -313,7 +313,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) { if statement.RawSQL != "" { - return statement.RawSQL, statement.RawParams, nil + return statement.GenRawSQL(), statement.RawParams, nil } var sqlStr string @@ -382,7 +382,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) { if statement.RawSQL != "" { - return statement.RawSQL, statement.RawParams, nil + return statement.GenRawSQL(), statement.RawParams, nil } var sqlStr string diff --git a/internal/statements/statement.go b/internal/statements/statement.go index f8cb6f2d..99a99c58 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -98,6 +98,15 @@ func (statement *Statement) omitStr() string { return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,") } +// GenRawSQL generates correct raw sql +func (statement *Statement) GenRawSQL() string { + if statement.RawSQL == "" || statement.dialect.URI().DBType == schemas.MYSQL || + statement.dialect.URI().DBType == schemas.SQLITE { + return statement.RawSQL + } + return statement.dialect.Quoter().Replace(statement.RawSQL) +} + func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) { statement.Context = ctxCache } diff --git a/rows.go b/rows.go index aa5e66e3..a56ea1c9 100644 --- a/rows.go +++ b/rows.go @@ -80,7 +80,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { return nil, err } } else { - sqlStr = rows.session.statement.RawSQL + sqlStr = rows.session.statement.GenRawSQL() args = rows.session.statement.RawParams } diff --git a/schemas/quote.go b/schemas/quote.go index 10436270..2a03152e 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -196,3 +196,41 @@ func (q Quoter) Strings(s []string) []string { } return res } + +// Replace replaces common quote(`) as the quotes on the sql +func (q Quoter) Replace(sql string) string { + if q.IsEmpty() { + return sql + } + + var buf strings.Builder + buf.Grow(len(sql)) + + var beginSingleQuote bool + for i := 0; i < len(sql); i++ { + if !beginSingleQuote && sql[i] == CommanQuoteMark { + var j = i + 1 + for ; j < len(sql); j++ { + if sql[j] == CommanQuoteMark { + break + } + } + word := sql[i+1 : j] + isReserved := q.IsReserved(word) + if isReserved { + buf.WriteByte(q.Prefix) + } + buf.WriteString(word) + if isReserved { + buf.WriteByte(q.Suffix) + } + i = j + } else { + if sql[i] == '\'' { + beginSingleQuote = !beginSingleQuote + } + buf.WriteByte(sql[i]) + } + } + return buf.String() +} diff --git a/session_get.go b/session_get.go index 76918194..e56ef2d7 100644 --- a/session_get.go +++ b/session_get.go @@ -59,7 +59,7 @@ func (session *Session) get(bean interface{}) (bool, error) { return false, err } } else { - sqlStr = session.statement.RawSQL + sqlStr = session.statement.GenRawSQL() args = session.statement.RawParams }