remove QuoteStr() usage (#1360)

This commit is contained in:
BetaCat 2019-07-24 09:41:06 +08:00 committed by Lunny Xiao
parent 674c9089df
commit 18b32486cf
8 changed files with 105 additions and 63 deletions

View File

@ -177,6 +177,7 @@ func (engine *Engine) SupportInsertMany() bool {
// QuoteStr Engine's database use which character as quote. // QuoteStr Engine's database use which character as quote.
// mysql, sqlite use ` and postgres use " // mysql, sqlite use ` and postgres use "
// Deprecated, use Quote() instead
func (engine *Engine) QuoteStr() string { func (engine *Engine) QuoteStr() string {
return engine.dialect.QuoteStr() return engine.dialect.QuoteStr()
} }
@ -196,13 +197,10 @@ func (engine *Engine) Quote(value string) string {
return value return value
} }
if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' { buf := builder.StringBuilder{}
return value engine.QuoteTo(&buf, value)
}
value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1) return buf.String()
return engine.dialect.QuoteStr() + value + engine.dialect.QuoteStr()
} }
// QuoteTo quotes string and writes into the buffer // QuoteTo quotes string and writes into the buffer
@ -216,20 +214,30 @@ func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) {
return return
} }
if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' { quotePair := engine.dialect.Quote("")
buf.WriteString(value)
if value[0] == '`' || len(quotePair) < 2 || value[0] == quotePair[0] { // no quote
_, _ = buf.WriteString(value)
return return
} else {
prefix, suffix := quotePair[0], quotePair[1]
_ = buf.WriteByte(prefix)
for i := 0; i < len(value); i++ {
if value[i] == '.' {
_ = buf.WriteByte(suffix)
_ = buf.WriteByte('.')
_ = buf.WriteByte(prefix)
} else {
_ = buf.WriteByte(value[i])
}
}
_ = buf.WriteByte(suffix)
} }
value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1)
buf.WriteString(engine.dialect.QuoteStr())
buf.WriteString(value)
buf.WriteString(engine.dialect.QuoteStr())
} }
func (engine *Engine) quote(sql string) string { func (engine *Engine) quote(sql string) string {
return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr() return engine.dialect.Quote(sql)
} }
// SqlType will be deprecated, please use SQLType instead // SqlType will be deprecated, please use SQLType instead

View File

@ -309,3 +309,24 @@ func sliceEq(left, right []string) bool {
func indexName(tableName, idxName string) string { func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName) return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
} }
func eraseAny(value string, strToErase ...string) string {
if len(strToErase) == 0 {
return value
}
var replaceSeq []string
for _, s := range strToErase {
replaceSeq = append(replaceSeq, s, "")
}
replacer := strings.NewReplacer(replaceSeq...)
return replacer.Replace(value)
}
func quoteColumns(cols []string, quoteFunc func(string) string, sep string) string {
for i := range cols {
cols[i] = quoteFunc(cols[i])
}
return strings.Join(cols, sep+" ")
}

View File

@ -4,7 +4,11 @@
package xorm package xorm
import "testing" import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSplitTag(t *testing.T) { func TestSplitTag(t *testing.T) {
var cases = []struct { var cases = []struct {
@ -24,3 +28,19 @@ func TestSplitTag(t *testing.T) {
} }
} }
} }
func TestEraseAny(t *testing.T) {
raw := "SELECT * FROM `table`.[table_name]"
assert.EqualValues(t, raw, eraseAny(raw))
assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`"))
assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]"))
}
func TestQuoteColumns(t *testing.T) {
cols := []string{"f1", "f2", "f3"}
quoteFunc := func(value string) string {
return "[" + value + "]"
}
assert.EqualValues(t, "[f1], [f2], [f3]", quoteColumns(cols, quoteFunc, ","))
}

View File

@ -242,23 +242,17 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
var sql string var sql string
if session.engine.dialect.DBType() == core.ORACLE { if session.engine.dialect.DBType() == core.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
session.engine.Quote(tableName), session.engine.Quote(tableName),
session.engine.QuoteStr(), quoteColumns(colNames, session.engine.Quote, ","))
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
session.engine.QuoteStr())
sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL",
session.engine.Quote(tableName), session.engine.Quote(tableName),
session.engine.QuoteStr(), quoteColumns(colNames, session.engine.Quote, ","),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
strings.Join(colMultiPlaces, temp)) strings.Join(colMultiPlaces, temp))
} else { } else {
sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
session.engine.Quote(tableName), session.engine.Quote(tableName),
session.engine.QuoteStr(), quoteColumns(colNames, session.engine.Quote, ","),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
strings.Join(colMultiPlaces, "),(")) strings.Join(colMultiPlaces, "),("))
} }
res, err := session.exec(sql, args...) res, err := session.exec(sql, args...)
@ -378,11 +372,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
} }
if len(colPlaces) > 0 { if len(colPlaces) > 0 {
sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v)%s VALUES (%v)", sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
session.engine.Quote(tableName), session.engine.Quote(tableName),
session.engine.QuoteStr(), quoteColumns(colNames, session.engine.Quote, ","),
strings.Join(colNames, session.engine.Quote(", ")),
session.engine.QuoteStr(),
output, output,
colPlaces) colPlaces)
} else { } else {

View File

@ -96,14 +96,15 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
return ErrCacheFailed return ErrCacheFailed
} }
kvs := strings.Split(strings.TrimSpace(sqls[1]), ",") kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")
for idx, kv := range kvs { for idx, kv := range kvs {
sps := strings.SplitN(kv, "=", 2) sps := strings.SplitN(kv, "=", 2)
sps2 := strings.Split(sps[0], ".") sps2 := strings.Split(sps[0], ".")
colName := sps2[len(sps2)-1] colName := sps2[len(sps2)-1]
if strings.Contains(colName, "`") { // treat quote prefix, suffix and '`' as quotes
colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1)) quotes := append(strings.Split(session.engine.Quote(""), ""), "`")
} else if strings.Contains(colName, session.engine.QuoteStr()) { if strings.ContainsAny(colName, strings.Join(quotes, "")) {
colName = strings.TrimSpace(strings.Replace(colName, session.engine.QuoteStr(), "", -1)) colName = strings.TrimSpace(eraseAny(colName, quotes...))
} else { } else {
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName) session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
return ErrCacheFailed return ErrCacheFailed

View File

@ -11,8 +11,8 @@ import (
"testing" "testing"
"time" "time"
"xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core"
) )
func TestUpdateMap(t *testing.T) { func TestUpdateMap(t *testing.T) {

View File

@ -6,7 +6,6 @@ package xorm
import ( import (
"database/sql/driver" "database/sql/driver"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -579,21 +578,9 @@ func (statement *Statement) getExpr() map[string]exprParam {
func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
newColumns := make([]string, 0) newColumns := make([]string, 0)
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
for _, col := range columns { for _, col := range columns {
col = strings.Replace(col, "`", "", -1) newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
fields := strings.Split(strings.TrimSpace(c), ".")
if len(fields) == 1 {
newColumns = append(newColumns, statement.Engine.quote(fields[0]))
} else if len(fields) == 2 {
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
statement.Engine.quote(fields[1]))
} else {
panic(errors.New("unwanted colnames"))
}
}
} }
return newColumns return newColumns
} }
@ -764,7 +751,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
return statement return statement
} }
tbs := strings.Split(tp.TableName(), ".") tbs := strings.Split(tp.TableName(), ".")
var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr()) quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...) statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder: case *builder.Builder:
@ -774,7 +763,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
return statement return statement
} }
tbs := strings.Split(tp.TableName(), ".") tbs := strings.Split(tp.TableName(), ".")
var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr()) quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...) statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default: default:

View File

@ -9,8 +9,8 @@ import (
"strings" "strings"
"testing" "testing"
"xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core"
) )
var colStrTests = []struct { var colStrTests = []struct {
@ -237,3 +237,12 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
testEngine.Update(record) testEngine.Update(record)
assertGetRecord() assertGetRecord()
} }
func TestCol2NewColsWithQuote(t *testing.T) {
cols := []string{"f1", "f2", "t3.f3"}
statement := createTestStatement()
quotedCols := statement.col2NewColsWithQuote(cols...)
assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols)
}