Only replace quotes when necessary
This commit is contained in:
parent
00b65c6d99
commit
7495ca7297
|
@ -7,8 +7,6 @@ package dialects
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"xorm.io/xorm/schemas"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Filter is an interface to filter SQL
|
// Filter is an interface to filter SQL
|
||||||
|
@ -16,48 +14,6 @@ type Filter interface {
|
||||||
Do(sql string) string
|
Do(sql string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// QuoteFilter filter SQL replace ` to database's own quote character
|
|
||||||
type QuoteFilter struct {
|
|
||||||
quoter schemas.Quoter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *QuoteFilter) Do(sql string) string {
|
|
||||||
if s.quoter.IsEmpty() {
|
|
||||||
return sql
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf strings.Builder
|
|
||||||
buf.Grow(len(sql))
|
|
||||||
|
|
||||||
var beginSingleQuote bool
|
|
||||||
for i := 0; i < len(sql); i++ {
|
|
||||||
if !beginSingleQuote && sql[i] == '`' {
|
|
||||||
var j = i + 1
|
|
||||||
for ; j < len(sql); j++ {
|
|
||||||
if sql[j] == '`' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
word := sql[i+1 : j]
|
|
||||||
isReserved := s.quoter.IsReserved(word)
|
|
||||||
if isReserved {
|
|
||||||
buf.WriteByte(s.quoter.Prefix)
|
|
||||||
}
|
|
||||||
buf.WriteString(word)
|
|
||||||
if isReserved {
|
|
||||||
buf.WriteByte(s.quoter.Suffix)
|
|
||||||
}
|
|
||||||
i = j
|
|
||||||
} else {
|
|
||||||
if sql[i] == '\'' {
|
|
||||||
beginSingleQuote = !beginSingleQuote
|
|
||||||
}
|
|
||||||
buf.WriteByte(sql[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return buf.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SeqFilter filter SQL replace ?, ? ... to $1, $2 ...
|
// SeqFilter filter SQL replace ?, ? ... to $1, $2 ...
|
||||||
type SeqFilter struct {
|
type SeqFilter struct {
|
||||||
Prefix string
|
Prefix string
|
||||||
|
|
|
@ -525,7 +525,7 @@ func (db *mssql) ForUpdateSQL(query string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mssql) Filters() []Filter {
|
func (db *mssql) Filters() []Filter {
|
||||||
return []Filter{&QuoteFilter{db.Quoter()}}
|
return []Filter{ /*&QuoteFilter{db.Quoter()}*/ }
|
||||||
}
|
}
|
||||||
|
|
||||||
type odbcDriver struct {
|
type odbcDriver struct {
|
||||||
|
|
|
@ -793,7 +793,7 @@ func (db *oracle) GetIndexes(ctx context.Context, tableName string) (map[string]
|
||||||
|
|
||||||
func (db *oracle) Filters() []Filter {
|
func (db *oracle) Filters() []Filter {
|
||||||
return []Filter{
|
return []Filter{
|
||||||
&QuoteFilter{db.Quoter()},
|
/*&QuoteFilter{db.Quoter()},*/
|
||||||
&SeqFilter{Prefix: ":", Start: 1},
|
&SeqFilter{Prefix: ":", Start: 1},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1231,7 +1231,7 @@ func (db *postgres) GetIndexes(ctx context.Context, tableName string) (map[strin
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) Filters() []Filter {
|
func (db *postgres) Filters() []Filter {
|
||||||
return []Filter{&QuoteFilter{db.Quoter()}, &SeqFilter{Prefix: "$", Start: 1}}
|
return []Filter{ /*&QuoteFilter{db.Quoter()}, */ &SeqFilter{Prefix: "$", Start: 1}}
|
||||||
}
|
}
|
||||||
|
|
||||||
type pqDriver struct {
|
type pqDriver struct {
|
||||||
|
|
|
@ -20,7 +20,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
|
||||||
}
|
}
|
||||||
|
|
||||||
if statement.RawSQL != "" {
|
if statement.RawSQL != "" {
|
||||||
return statement.RawSQL, statement.RawParams, nil
|
return statement.GenRawSQL(), statement.RawParams, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(statement.TableName()) <= 0 {
|
if len(statement.TableName()) <= 0 {
|
||||||
|
@ -74,7 +74,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
|
||||||
|
|
||||||
func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
|
func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
|
||||||
if statement.RawSQL != "" {
|
if statement.RawSQL != "" {
|
||||||
return statement.RawSQL, statement.RawParams, nil
|
return statement.GenRawSQL(), statement.RawParams, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
statement.SetRefBean(bean)
|
statement.SetRefBean(bean)
|
||||||
|
@ -153,7 +153,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
|
||||||
|
|
||||||
func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) {
|
func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) {
|
||||||
if statement.RawSQL != "" {
|
if statement.RawSQL != "" {
|
||||||
return statement.RawSQL, statement.RawParams, nil
|
return statement.GenRawSQL(), statement.RawParams, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var condArgs []interface{}
|
var condArgs []interface{}
|
||||||
|
@ -313,7 +313,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
|
||||||
|
|
||||||
func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) {
|
func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) {
|
||||||
if statement.RawSQL != "" {
|
if statement.RawSQL != "" {
|
||||||
return statement.RawSQL, statement.RawParams, nil
|
return statement.GenRawSQL(), statement.RawParams, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var sqlStr string
|
var sqlStr string
|
||||||
|
@ -382,7 +382,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
|
||||||
|
|
||||||
func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) {
|
func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) {
|
||||||
if statement.RawSQL != "" {
|
if statement.RawSQL != "" {
|
||||||
return statement.RawSQL, statement.RawParams, nil
|
return statement.GenRawSQL(), statement.RawParams, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var sqlStr string
|
var sqlStr string
|
||||||
|
|
|
@ -98,6 +98,15 @@ func (statement *Statement) omitStr() string {
|
||||||
return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,")
|
return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenRawSQL generates correct raw sql
|
||||||
|
func (statement *Statement) GenRawSQL() string {
|
||||||
|
if statement.RawSQL == "" || statement.dialect.URI().DBType == schemas.MYSQL ||
|
||||||
|
statement.dialect.URI().DBType == schemas.SQLITE {
|
||||||
|
return statement.RawSQL
|
||||||
|
}
|
||||||
|
return statement.dialect.Quoter().Replace(statement.RawSQL)
|
||||||
|
}
|
||||||
|
|
||||||
func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) {
|
func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) {
|
||||||
statement.Context = ctxCache
|
statement.Context = ctxCache
|
||||||
}
|
}
|
||||||
|
|
2
rows.go
2
rows.go
|
@ -80,7 +80,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
sqlStr = rows.session.statement.RawSQL
|
sqlStr = rows.session.statement.GenRawSQL()
|
||||||
args = rows.session.statement.RawParams
|
args = rows.session.statement.RawParams
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -196,3 +196,41 @@ func (q Quoter) Strings(s []string) []string {
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Replace replaces common quote(`) as the quotes on the sql
|
||||||
|
func (q Quoter) Replace(sql string) string {
|
||||||
|
if q.IsEmpty() {
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
buf.Grow(len(sql))
|
||||||
|
|
||||||
|
var beginSingleQuote bool
|
||||||
|
for i := 0; i < len(sql); i++ {
|
||||||
|
if !beginSingleQuote && sql[i] == CommanQuoteMark {
|
||||||
|
var j = i + 1
|
||||||
|
for ; j < len(sql); j++ {
|
||||||
|
if sql[j] == CommanQuoteMark {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
word := sql[i+1 : j]
|
||||||
|
isReserved := q.IsReserved(word)
|
||||||
|
if isReserved {
|
||||||
|
buf.WriteByte(q.Prefix)
|
||||||
|
}
|
||||||
|
buf.WriteString(word)
|
||||||
|
if isReserved {
|
||||||
|
buf.WriteByte(q.Suffix)
|
||||||
|
}
|
||||||
|
i = j
|
||||||
|
} else {
|
||||||
|
if sql[i] == '\'' {
|
||||||
|
beginSingleQuote = !beginSingleQuote
|
||||||
|
}
|
||||||
|
buf.WriteByte(sql[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
|
@ -59,7 +59,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
sqlStr = session.statement.RawSQL
|
sqlStr = session.statement.GenRawSQL()
|
||||||
args = session.statement.RawParams
|
args = session.statement.RawParams
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue