Only replace quotes when necessary (#1584)
fix test improve code improve code improve code improve code Fix replace quote fix test Only replace quotes when necessary Reviewed-on: https://gitea.com/xorm/xorm/pulls/1584
This commit is contained in:
parent
00b65c6d99
commit
3617ee736f
|
@ -7,8 +7,6 @@ package dialects
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"xorm.io/xorm/schemas"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Filter is an interface to filter SQL
|
// Filter is an interface to filter SQL
|
||||||
|
@ -16,48 +14,6 @@ type Filter interface {
|
||||||
Do(sql string) string
|
Do(sql string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// QuoteFilter filter SQL replace ` to database's own quote character
|
|
||||||
type QuoteFilter struct {
|
|
||||||
quoter schemas.Quoter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *QuoteFilter) Do(sql string) string {
|
|
||||||
if s.quoter.IsEmpty() {
|
|
||||||
return sql
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf strings.Builder
|
|
||||||
buf.Grow(len(sql))
|
|
||||||
|
|
||||||
var beginSingleQuote bool
|
|
||||||
for i := 0; i < len(sql); i++ {
|
|
||||||
if !beginSingleQuote && sql[i] == '`' {
|
|
||||||
var j = i + 1
|
|
||||||
for ; j < len(sql); j++ {
|
|
||||||
if sql[j] == '`' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
word := sql[i+1 : j]
|
|
||||||
isReserved := s.quoter.IsReserved(word)
|
|
||||||
if isReserved {
|
|
||||||
buf.WriteByte(s.quoter.Prefix)
|
|
||||||
}
|
|
||||||
buf.WriteString(word)
|
|
||||||
if isReserved {
|
|
||||||
buf.WriteByte(s.quoter.Suffix)
|
|
||||||
}
|
|
||||||
i = j
|
|
||||||
} else {
|
|
||||||
if sql[i] == '\'' {
|
|
||||||
beginSingleQuote = !beginSingleQuote
|
|
||||||
}
|
|
||||||
buf.WriteByte(sql[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return buf.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SeqFilter filter SQL replace ?, ? ... to $1, $2 ...
|
// SeqFilter filter SQL replace ?, ? ... to $1, $2 ...
|
||||||
type SeqFilter struct {
|
type SeqFilter struct {
|
||||||
Prefix string
|
Prefix string
|
||||||
|
|
|
@ -3,38 +3,9 @@ package dialects
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"xorm.io/xorm/schemas"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestQuoteFilter_Do(t *testing.T) {
|
|
||||||
f := QuoteFilter{schemas.Quoter{'[', ']', schemas.AlwaysReserve}}
|
|
||||||
var kases = []struct {
|
|
||||||
source string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?",
|
|
||||||
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"SELECT 'abc```test```''', `a` FROM b",
|
|
||||||
"SELECT 'abc```test```''', [a] FROM b",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"UPDATE table SET `a` = ~ `a`, `b`='abc`'",
|
|
||||||
"UPDATE table SET [a] = ~ [a], [b]='abc`'",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, kase := range kases {
|
|
||||||
t.Run(kase.source, func(t *testing.T) {
|
|
||||||
assert.EqualValues(t, kase.expected, f.Do(kase.source))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSeqFilter(t *testing.T) {
|
func TestSeqFilter(t *testing.T) {
|
||||||
var kases = map[string]string{
|
var kases = map[string]string{
|
||||||
"SELECT * FROM TABLE1 WHERE a=? AND b=?": "SELECT * FROM TABLE1 WHERE a=$1 AND b=$2",
|
"SELECT * FROM TABLE1 WHERE a=? AND b=?": "SELECT * FROM TABLE1 WHERE a=$1 AND b=$2",
|
||||||
|
|
|
@ -525,7 +525,7 @@ func (db *mssql) ForUpdateSQL(query string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mssql) Filters() []Filter {
|
func (db *mssql) Filters() []Filter {
|
||||||
return []Filter{&QuoteFilter{db.Quoter()}}
|
return []Filter{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type odbcDriver struct {
|
type odbcDriver struct {
|
||||||
|
|
|
@ -793,7 +793,6 @@ func (db *oracle) GetIndexes(ctx context.Context, tableName string) (map[string]
|
||||||
|
|
||||||
func (db *oracle) Filters() []Filter {
|
func (db *oracle) Filters() []Filter {
|
||||||
return []Filter{
|
return []Filter{
|
||||||
&QuoteFilter{db.Quoter()},
|
|
||||||
&SeqFilter{Prefix: ":", Start: 1},
|
&SeqFilter{Prefix: ":", Start: 1},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1231,7 +1231,7 @@ func (db *postgres) GetIndexes(ctx context.Context, tableName string) (map[strin
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) Filters() []Filter {
|
func (db *postgres) Filters() []Filter {
|
||||||
return []Filter{&QuoteFilter{db.Quoter()}, &SeqFilter{Prefix: "$", Start: 1}}
|
return []Filter{&SeqFilter{Prefix: "$", Start: 1}}
|
||||||
}
|
}
|
||||||
|
|
||||||
type pqDriver struct {
|
type pqDriver struct {
|
||||||
|
|
|
@ -16,11 +16,11 @@ import (
|
||||||
|
|
||||||
func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) {
|
func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) {
|
||||||
if len(sqlOrArgs) > 0 {
|
if len(sqlOrArgs) > 0 {
|
||||||
return ConvertSQLOrArgs(sqlOrArgs...)
|
return statement.ConvertSQLOrArgs(sqlOrArgs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if statement.RawSQL != "" {
|
if statement.RawSQL != "" {
|
||||||
return statement.RawSQL, statement.RawParams, nil
|
return statement.GenRawSQL(), statement.RawParams, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(statement.TableName()) <= 0 {
|
if len(statement.TableName()) <= 0 {
|
||||||
|
@ -74,7 +74,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
|
||||||
|
|
||||||
func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
|
func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
|
||||||
if statement.RawSQL != "" {
|
if statement.RawSQL != "" {
|
||||||
return statement.RawSQL, statement.RawParams, nil
|
return statement.GenRawSQL(), statement.RawParams, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
statement.SetRefBean(bean)
|
statement.SetRefBean(bean)
|
||||||
|
@ -83,6 +83,8 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
|
||||||
for _, colName := range columns {
|
for _, colName := range columns {
|
||||||
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
|
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
|
||||||
colName = statement.quote(colName)
|
colName = statement.quote(colName)
|
||||||
|
} else {
|
||||||
|
colName = statement.ReplaceQuote(colName)
|
||||||
}
|
}
|
||||||
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
|
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
|
||||||
}
|
}
|
||||||
|
@ -153,7 +155,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
|
||||||
|
|
||||||
func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) {
|
func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) {
|
||||||
if statement.RawSQL != "" {
|
if statement.RawSQL != "" {
|
||||||
return statement.RawSQL, statement.RawParams, nil
|
return statement.GenRawSQL(), statement.RawParams, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var condArgs []interface{}
|
var condArgs []interface{}
|
||||||
|
@ -193,7 +195,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
|
||||||
distinct = "DISTINCT "
|
distinct = "DISTINCT "
|
||||||
}
|
}
|
||||||
|
|
||||||
condSQL, condArgs, err := builder.ToSQL(statement.cond)
|
condSQL, condArgs, err := statement.GenCondSQL(statement.cond)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
@ -313,7 +315,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
|
||||||
|
|
||||||
func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) {
|
func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) {
|
||||||
if statement.RawSQL != "" {
|
if statement.RawSQL != "" {
|
||||||
return statement.RawSQL, statement.RawParams, nil
|
return statement.GenRawSQL(), statement.RawParams, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var sqlStr string
|
var sqlStr string
|
||||||
|
@ -332,7 +334,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
|
||||||
}
|
}
|
||||||
|
|
||||||
if statement.Conds().IsValid() {
|
if statement.Conds().IsValid() {
|
||||||
condSQL, condArgs, err := builder.ToSQL(statement.Conds())
|
condSQL, condArgs, err := statement.GenCondSQL(statement.Conds())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
@ -382,7 +384,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
|
||||||
|
|
||||||
func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) {
|
func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) {
|
||||||
if statement.RawSQL != "" {
|
if statement.RawSQL != "" {
|
||||||
return statement.RawSQL, statement.RawParams, nil
|
return statement.GenRawSQL(), statement.RawParams, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var sqlStr string
|
var sqlStr string
|
||||||
|
|
|
@ -98,6 +98,27 @@ func (statement *Statement) omitStr() string {
|
||||||
return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,")
|
return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenRawSQL generates correct raw sql
|
||||||
|
func (statement *Statement) GenRawSQL() string {
|
||||||
|
return statement.ReplaceQuote(statement.RawSQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) {
|
||||||
|
condSQL, condArgs, err := builder.ToSQL(condOrBuilder)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
return statement.ReplaceQuote(condSQL), condArgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) ReplaceQuote(sql string) string {
|
||||||
|
if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL ||
|
||||||
|
statement.dialect.URI().DBType == schemas.SQLITE {
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
return statement.dialect.Quoter().Replace(sql)
|
||||||
|
}
|
||||||
|
|
||||||
func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) {
|
func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) {
|
||||||
statement.Context = ctxCache
|
statement.Context = ctxCache
|
||||||
}
|
}
|
||||||
|
@ -348,7 +369,11 @@ func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
|
||||||
|
|
||||||
// SetExpr Generate "Update ... Set column = {expression}" statement
|
// SetExpr Generate "Update ... Set column = {expression}" statement
|
||||||
func (statement *Statement) SetExpr(column string, expression interface{}) *Statement {
|
func (statement *Statement) SetExpr(column string, expression interface{}) *Statement {
|
||||||
|
if e, ok := expression.(string); ok {
|
||||||
|
statement.ExprColumns.addParam(column, statement.dialect.Quoter().Replace(e))
|
||||||
|
} else {
|
||||||
statement.ExprColumns.addParam(column, expression)
|
statement.ExprColumns.addParam(column, expression)
|
||||||
|
}
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -367,7 +392,7 @@ func (statement *Statement) ForUpdate() *Statement {
|
||||||
|
|
||||||
// Select replace select
|
// Select replace select
|
||||||
func (statement *Statement) Select(str string) *Statement {
|
func (statement *Statement) Select(str string) *Statement {
|
||||||
statement.SelectStr = str
|
statement.SelectStr = statement.ReplaceQuote(str)
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -458,7 +483,7 @@ func (statement *Statement) OrderBy(order string) *Statement {
|
||||||
if len(statement.OrderStr) > 0 {
|
if len(statement.OrderStr) > 0 {
|
||||||
statement.OrderStr += ", "
|
statement.OrderStr += ", "
|
||||||
}
|
}
|
||||||
statement.OrderStr += order
|
statement.OrderStr += statement.ReplaceQuote(order)
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -537,7 +562,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
||||||
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
|
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
|
||||||
aliasName = schemas.CommonQuoter.Trim(aliasName)
|
aliasName = schemas.CommonQuoter.Trim(aliasName)
|
||||||
|
|
||||||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), aliasName, statement.ReplaceQuote(condition))
|
||||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||||
case *builder.Builder:
|
case *builder.Builder:
|
||||||
subSQL, subQueryArgs, err := tp.ToSQL()
|
subSQL, subQueryArgs, err := tp.ToSQL()
|
||||||
|
@ -550,7 +575,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
||||||
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
|
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
|
||||||
aliasName = schemas.CommonQuoter.Trim(aliasName)
|
aliasName = schemas.CommonQuoter.Trim(aliasName)
|
||||||
|
|
||||||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), aliasName, statement.ReplaceQuote(condition))
|
||||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||||
default:
|
default:
|
||||||
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
|
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
|
||||||
|
@ -559,7 +584,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
||||||
statement.dialect.Quoter().QuoteTo(&buf, tbName)
|
statement.dialect.Quoter().QuoteTo(&buf, tbName)
|
||||||
tbName = buf.String()
|
tbName = buf.String()
|
||||||
}
|
}
|
||||||
fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
|
fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition))
|
||||||
}
|
}
|
||||||
|
|
||||||
statement.JoinStr = buf.String()
|
statement.JoinStr = buf.String()
|
||||||
|
@ -578,13 +603,13 @@ func (statement *Statement) tbNameNoSchema(table *schemas.Table) string {
|
||||||
|
|
||||||
// GroupBy generate "Group By keys" statement
|
// GroupBy generate "Group By keys" statement
|
||||||
func (statement *Statement) GroupBy(keys string) *Statement {
|
func (statement *Statement) GroupBy(keys string) *Statement {
|
||||||
statement.GroupByStr = keys
|
statement.GroupByStr = statement.ReplaceQuote(keys)
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
|
||||||
// Having generate "Having conditions" statement
|
// Having generate "Having conditions" statement
|
||||||
func (statement *Statement) Having(conditions string) *Statement {
|
func (statement *Statement) Having(conditions string) *Statement {
|
||||||
statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
|
statement.HavingStr = fmt.Sprintf("HAVING %v", statement.ReplaceQuote(conditions))
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -926,7 +951,7 @@ func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, e
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return builder.ToSQL(statement.cond)
|
return statement.GenCondSQL(statement.cond)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) quoteColumnStr(columnStr string) string {
|
func (statement *Statement) quoteColumnStr(columnStr string) string {
|
||||||
|
@ -934,7 +959,15 @@ func (statement *Statement) quoteColumnStr(columnStr string) string {
|
||||||
return statement.dialect.Quoter().Join(columns, ",")
|
return statement.dialect.Quoter().Join(columns, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) {
|
func (statement *Statement) ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) {
|
||||||
|
sql, args, err := convertSQLOrArgs(sqlOrArgs...)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
return statement.ReplaceQuote(sql), args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) {
|
||||||
switch sqlOrArgs[0].(type) {
|
switch sqlOrArgs[0].(type) {
|
||||||
case string:
|
case string:
|
||||||
return sqlOrArgs[0].(string), sqlOrArgs[1:], nil
|
return sqlOrArgs[0].(string), sqlOrArgs[1:], nil
|
||||||
|
|
2
rows.go
2
rows.go
|
@ -80,7 +80,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
sqlStr = rows.session.statement.RawSQL
|
sqlStr = rows.session.statement.GenRawSQL()
|
||||||
args = rows.session.statement.RawParams
|
args = rows.session.statement.RawParams
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -196,3 +196,41 @@ func (q Quoter) Strings(s []string) []string {
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Replace replaces common quote(`) as the quotes on the sql
|
||||||
|
func (q Quoter) Replace(sql string) string {
|
||||||
|
if q.IsEmpty() {
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
buf.Grow(len(sql))
|
||||||
|
|
||||||
|
var beginSingleQuote bool
|
||||||
|
for i := 0; i < len(sql); i++ {
|
||||||
|
if !beginSingleQuote && sql[i] == CommanQuoteMark {
|
||||||
|
var j = i + 1
|
||||||
|
for ; j < len(sql); j++ {
|
||||||
|
if sql[j] == CommanQuoteMark {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
word := sql[i+1 : j]
|
||||||
|
isReserved := q.IsReserved(word)
|
||||||
|
if isReserved {
|
||||||
|
buf.WriteByte(q.Prefix)
|
||||||
|
}
|
||||||
|
buf.WriteString(word)
|
||||||
|
if isReserved {
|
||||||
|
buf.WriteByte(q.Suffix)
|
||||||
|
}
|
||||||
|
i = j
|
||||||
|
} else {
|
||||||
|
if sql[i] == '\'' {
|
||||||
|
beginSingleQuote = !beginSingleQuote
|
||||||
|
}
|
||||||
|
buf.WriteByte(sql[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
|
@ -146,3 +146,30 @@ func TestTrim(t *testing.T) {
|
||||||
assert.EqualValues(t, dst, Quoter{'[', ']', AlwaysReserve}.Trim(src))
|
assert.EqualValues(t, dst, Quoter{'[', ']', AlwaysReserve}.Trim(src))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReplace(t *testing.T) {
|
||||||
|
q := Quoter{'[', ']', AlwaysReserve}
|
||||||
|
var kases = []struct {
|
||||||
|
source string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?",
|
||||||
|
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SELECT 'abc```test```''', `a` FROM b",
|
||||||
|
"SELECT 'abc```test```''', [a] FROM b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"UPDATE table SET `a` = ~ `a`, `b`='abc`'",
|
||||||
|
"UPDATE table SET [a] = ~ [a], [b]='abc`'",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, kase := range kases {
|
||||||
|
t.Run(kase.source, func(t *testing.T) {
|
||||||
|
assert.EqualValues(t, kase.expected, q.Replace(kase.source))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -59,7 +59,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
sqlStr = session.statement.RawSQL
|
sqlStr = session.statement.GenRawSQL()
|
||||||
args = session.statement.RawParams
|
args = session.statement.RawParams
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"xorm.io/xorm/core"
|
"xorm.io/xorm/core"
|
||||||
"xorm.io/xorm/internal/statements"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
|
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
|
||||||
|
@ -172,7 +171,7 @@ func (session *Session) Exec(sqlOrArgs ...interface{}) (sql.Result, error) {
|
||||||
return nil, ErrUnSupportedType
|
return nil, ErrUnSupportedType
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlStr, args, err := statements.ConvertSQLOrArgs(sqlOrArgs...)
|
sqlStr, args, err := session.statement.ConvertSQLOrArgs(sqlOrArgs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -240,7 +240,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
colNames = append(colNames, session.engine.Quote(colName)+"="+tp)
|
colNames = append(colNames, session.engine.Quote(colName)+"="+tp)
|
||||||
case *builder.Builder:
|
case *builder.Builder:
|
||||||
subQuery, subArgs, err := builder.ToSQL(tp)
|
subQuery, subArgs, err := session.statement.GenCondSQL(tp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -317,7 +317,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
condSQL, condArgs, err = builder.ToSQL(cond)
|
if len(colNames) <= 0 {
|
||||||
|
return 0, errors.New("No content found to be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -335,24 +339,25 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
var top string
|
var top string
|
||||||
if st.LimitN != nil {
|
if st.LimitN != nil {
|
||||||
limitValue := *st.LimitN
|
limitValue := *st.LimitN
|
||||||
if session.engine.dialect.URI().DBType == schemas.MYSQL {
|
switch session.engine.dialect.URI().DBType {
|
||||||
|
case schemas.MYSQL:
|
||||||
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
|
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
|
||||||
} else if session.engine.dialect.URI().DBType == schemas.SQLITE {
|
case schemas.SQLITE:
|
||||||
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
|
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
|
||||||
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
||||||
session.engine.Quote(tableName), tempCondSQL), condArgs...))
|
session.engine.Quote(tableName), tempCondSQL), condArgs...))
|
||||||
condSQL, condArgs, err = builder.ToSQL(cond)
|
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if len(condSQL) > 0 {
|
if len(condSQL) > 0 {
|
||||||
condSQL = "WHERE " + condSQL
|
condSQL = "WHERE " + condSQL
|
||||||
}
|
}
|
||||||
} else if session.engine.dialect.URI().DBType == schemas.POSTGRES {
|
case schemas.POSTGRES:
|
||||||
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
|
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
|
||||||
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
||||||
session.engine.Quote(tableName), tempCondSQL), condArgs...))
|
session.engine.Quote(tableName), tempCondSQL), condArgs...))
|
||||||
condSQL, condArgs, err = builder.ToSQL(cond)
|
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -360,14 +365,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
if len(condSQL) > 0 {
|
if len(condSQL) > 0 {
|
||||||
condSQL = "WHERE " + condSQL
|
condSQL = "WHERE " + condSQL
|
||||||
}
|
}
|
||||||
} else if session.engine.dialect.URI().DBType == schemas.MSSQL {
|
case schemas.MSSQL:
|
||||||
if st.OrderStr != "" && session.engine.dialect.URI().DBType == schemas.MSSQL &&
|
if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 {
|
||||||
table != nil && len(table.PrimaryKeys) == 1 {
|
|
||||||
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
|
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
|
||||||
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
|
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
|
||||||
session.engine.Quote(tableName), condSQL), condArgs...)
|
session.engine.Quote(tableName), condSQL), condArgs...)
|
||||||
|
|
||||||
condSQL, condArgs, err = builder.ToSQL(cond)
|
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -380,10 +384,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(colNames) <= 0 {
|
|
||||||
return 0, errors.New("No content found to be updated")
|
|
||||||
}
|
|
||||||
|
|
||||||
var tableAlias = session.engine.Quote(tableName)
|
var tableAlias = session.engine.Quote(tableName)
|
||||||
var fromSQL string
|
var fromSQL string
|
||||||
if session.statement.TableAlias != "" {
|
if session.statement.TableAlias != "" {
|
||||||
|
|
Loading…
Reference in New Issue