remove QuoteStr() usage (#1360)
This commit is contained in:
parent
674c9089df
commit
18b32486cf
40
engine.go
40
engine.go
|
@ -177,6 +177,7 @@ func (engine *Engine) SupportInsertMany() bool {
|
||||||
|
|
||||||
// QuoteStr Engine's database use which character as quote.
|
// QuoteStr Engine's database use which character as quote.
|
||||||
// mysql, sqlite use ` and postgres use "
|
// mysql, sqlite use ` and postgres use "
|
||||||
|
// Deprecated, use Quote() instead
|
||||||
func (engine *Engine) QuoteStr() string {
|
func (engine *Engine) QuoteStr() string {
|
||||||
return engine.dialect.QuoteStr()
|
return engine.dialect.QuoteStr()
|
||||||
}
|
}
|
||||||
|
@ -196,13 +197,10 @@ func (engine *Engine) Quote(value string) string {
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' {
|
buf := builder.StringBuilder{}
|
||||||
return value
|
engine.QuoteTo(&buf, value)
|
||||||
}
|
|
||||||
|
|
||||||
value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1)
|
return buf.String()
|
||||||
|
|
||||||
return engine.dialect.QuoteStr() + value + engine.dialect.QuoteStr()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// QuoteTo quotes string and writes into the buffer
|
// QuoteTo quotes string and writes into the buffer
|
||||||
|
@ -216,20 +214,30 @@ func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' {
|
quotePair := engine.dialect.Quote("")
|
||||||
buf.WriteString(value)
|
|
||||||
|
if value[0] == '`' || len(quotePair) < 2 || value[0] == quotePair[0] { // no quote
|
||||||
|
_, _ = buf.WriteString(value)
|
||||||
return
|
return
|
||||||
|
} else {
|
||||||
|
prefix, suffix := quotePair[0], quotePair[1]
|
||||||
|
|
||||||
|
_ = buf.WriteByte(prefix)
|
||||||
|
for i := 0; i < len(value); i++ {
|
||||||
|
if value[i] == '.' {
|
||||||
|
_ = buf.WriteByte(suffix)
|
||||||
|
_ = buf.WriteByte('.')
|
||||||
|
_ = buf.WriteByte(prefix)
|
||||||
|
} else {
|
||||||
|
_ = buf.WriteByte(value[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = buf.WriteByte(suffix)
|
||||||
}
|
}
|
||||||
|
|
||||||
value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1)
|
|
||||||
|
|
||||||
buf.WriteString(engine.dialect.QuoteStr())
|
|
||||||
buf.WriteString(value)
|
|
||||||
buf.WriteString(engine.dialect.QuoteStr())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (engine *Engine) quote(sql string) string {
|
func (engine *Engine) quote(sql string) string {
|
||||||
return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr()
|
return engine.dialect.Quote(sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SqlType will be deprecated, please use SQLType instead
|
// SqlType will be deprecated, please use SQLType instead
|
||||||
|
@ -1581,7 +1589,7 @@ func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{
|
||||||
func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) {
|
func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) {
|
||||||
switch sqlTypeName {
|
switch sqlTypeName {
|
||||||
case core.Time:
|
case core.Time:
|
||||||
s := t.Format("2006-01-02 15:04:05") //time.RFC3339
|
s := t.Format("2006-01-02 15:04:05") // time.RFC3339
|
||||||
v = s[11:19]
|
v = s[11:19]
|
||||||
case core.Date:
|
case core.Date:
|
||||||
v = t.Format("2006-01-02")
|
v = t.Format("2006-01-02")
|
||||||
|
|
23
helpers.go
23
helpers.go
|
@ -281,7 +281,7 @@ func rValue(bean interface{}) reflect.Value {
|
||||||
|
|
||||||
func rType(bean interface{}) reflect.Type {
|
func rType(bean interface{}) reflect.Type {
|
||||||
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
|
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
|
||||||
//return reflect.TypeOf(sliceValue.Interface())
|
// return reflect.TypeOf(sliceValue.Interface())
|
||||||
return sliceValue.Type()
|
return sliceValue.Type()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -309,3 +309,24 @@ func sliceEq(left, right []string) bool {
|
||||||
func indexName(tableName, idxName string) string {
|
func indexName(tableName, idxName string) string {
|
||||||
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
|
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func eraseAny(value string, strToErase ...string) string {
|
||||||
|
if len(strToErase) == 0 {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
var replaceSeq []string
|
||||||
|
for _, s := range strToErase {
|
||||||
|
replaceSeq = append(replaceSeq, s, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
replacer := strings.NewReplacer(replaceSeq...)
|
||||||
|
|
||||||
|
return replacer.Replace(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func quoteColumns(cols []string, quoteFunc func(string) string, sep string) string {
|
||||||
|
for i := range cols {
|
||||||
|
cols[i] = quoteFunc(cols[i])
|
||||||
|
}
|
||||||
|
return strings.Join(cols, sep+" ")
|
||||||
|
}
|
||||||
|
|
|
@ -4,7 +4,11 @@
|
||||||
|
|
||||||
package xorm
|
package xorm
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
func TestSplitTag(t *testing.T) {
|
func TestSplitTag(t *testing.T) {
|
||||||
var cases = []struct {
|
var cases = []struct {
|
||||||
|
@ -24,3 +28,19 @@ func TestSplitTag(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEraseAny(t *testing.T) {
|
||||||
|
raw := "SELECT * FROM `table`.[table_name]"
|
||||||
|
assert.EqualValues(t, raw, eraseAny(raw))
|
||||||
|
assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`"))
|
||||||
|
assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuoteColumns(t *testing.T) {
|
||||||
|
cols := []string{"f1", "f2", "f3"}
|
||||||
|
quoteFunc := func(value string) string {
|
||||||
|
return "[" + value + "]"
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.EqualValues(t, "[f1], [f2], [f3]", quoteColumns(cols, quoteFunc, ","))
|
||||||
|
}
|
||||||
|
|
|
@ -242,23 +242,17 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
|
||||||
|
|
||||||
var sql string
|
var sql string
|
||||||
if session.engine.dialect.DBType() == core.ORACLE {
|
if session.engine.dialect.DBType() == core.ORACLE {
|
||||||
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
|
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
|
||||||
session.engine.Quote(tableName),
|
session.engine.Quote(tableName),
|
||||||
session.engine.QuoteStr(),
|
quoteColumns(colNames, session.engine.Quote, ","))
|
||||||
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
|
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
|
||||||
session.engine.QuoteStr())
|
|
||||||
sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL",
|
|
||||||
session.engine.Quote(tableName),
|
session.engine.Quote(tableName),
|
||||||
session.engine.QuoteStr(),
|
quoteColumns(colNames, session.engine.Quote, ","),
|
||||||
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
|
|
||||||
session.engine.QuoteStr(),
|
|
||||||
strings.Join(colMultiPlaces, temp))
|
strings.Join(colMultiPlaces, temp))
|
||||||
} else {
|
} else {
|
||||||
sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
|
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
|
||||||
session.engine.Quote(tableName),
|
session.engine.Quote(tableName),
|
||||||
session.engine.QuoteStr(),
|
quoteColumns(colNames, session.engine.Quote, ","),
|
||||||
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
|
|
||||||
session.engine.QuoteStr(),
|
|
||||||
strings.Join(colMultiPlaces, "),("))
|
strings.Join(colMultiPlaces, "),("))
|
||||||
}
|
}
|
||||||
res, err := session.exec(sql, args...)
|
res, err := session.exec(sql, args...)
|
||||||
|
@ -378,11 +372,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
||||||
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
|
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
|
||||||
}
|
}
|
||||||
if len(colPlaces) > 0 {
|
if len(colPlaces) > 0 {
|
||||||
sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v)%s VALUES (%v)",
|
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
|
||||||
session.engine.Quote(tableName),
|
session.engine.Quote(tableName),
|
||||||
session.engine.QuoteStr(),
|
quoteColumns(colNames, session.engine.Quote, ","),
|
||||||
strings.Join(colNames, session.engine.Quote(", ")),
|
|
||||||
session.engine.QuoteStr(),
|
|
||||||
output,
|
output,
|
||||||
colPlaces)
|
colPlaces)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -96,14 +96,15 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
|
||||||
return ErrCacheFailed
|
return ErrCacheFailed
|
||||||
}
|
}
|
||||||
kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")
|
kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")
|
||||||
|
|
||||||
for idx, kv := range kvs {
|
for idx, kv := range kvs {
|
||||||
sps := strings.SplitN(kv, "=", 2)
|
sps := strings.SplitN(kv, "=", 2)
|
||||||
sps2 := strings.Split(sps[0], ".")
|
sps2 := strings.Split(sps[0], ".")
|
||||||
colName := sps2[len(sps2)-1]
|
colName := sps2[len(sps2)-1]
|
||||||
if strings.Contains(colName, "`") {
|
// treat quote prefix, suffix and '`' as quotes
|
||||||
colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1))
|
quotes := append(strings.Split(session.engine.Quote(""), ""), "`")
|
||||||
} else if strings.Contains(colName, session.engine.QuoteStr()) {
|
if strings.ContainsAny(colName, strings.Join(quotes, "")) {
|
||||||
colName = strings.TrimSpace(strings.Replace(colName, session.engine.QuoteStr(), "", -1))
|
colName = strings.TrimSpace(eraseAny(colName, quotes...))
|
||||||
} else {
|
} else {
|
||||||
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
|
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
|
||||||
return ErrCacheFailed
|
return ErrCacheFailed
|
||||||
|
@ -221,19 +222,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//for update action to like "column = column + ?"
|
// for update action to like "column = column + ?"
|
||||||
incColumns := session.statement.getInc()
|
incColumns := session.statement.getInc()
|
||||||
for _, v := range incColumns {
|
for _, v := range incColumns {
|
||||||
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?")
|
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?")
|
||||||
args = append(args, v.arg)
|
args = append(args, v.arg)
|
||||||
}
|
}
|
||||||
//for update action to like "column = column - ?"
|
// for update action to like "column = column - ?"
|
||||||
decColumns := session.statement.getDec()
|
decColumns := session.statement.getDec()
|
||||||
for _, v := range decColumns {
|
for _, v := range decColumns {
|
||||||
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?")
|
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?")
|
||||||
args = append(args, v.arg)
|
args = append(args, v.arg)
|
||||||
}
|
}
|
||||||
//for update action to like "column = expression"
|
// for update action to like "column = expression"
|
||||||
exprColumns := session.statement.getExpr()
|
exprColumns := session.statement.getExpr()
|
||||||
for _, v := range exprColumns {
|
for _, v := range exprColumns {
|
||||||
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr)
|
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr)
|
||||||
|
@ -382,7 +383,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
|
|
||||||
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
|
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
|
||||||
//session.cacheUpdate(table, tableName, sqlStr, args...)
|
// session.cacheUpdate(table, tableName, sqlStr, args...)
|
||||||
session.engine.logger.Debug("[cacheUpdate] clear table ", tableName)
|
session.engine.logger.Debug("[cacheUpdate] clear table ", tableName)
|
||||||
cacher.ClearIds(tableName)
|
cacher.ClearIds(tableName)
|
||||||
cacher.ClearBeans(tableName)
|
cacher.ClearBeans(tableName)
|
||||||
|
|
|
@ -11,8 +11,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"xorm.io/core"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"xorm.io/core"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUpdateMap(t *testing.T) {
|
func TestUpdateMap(t *testing.T) {
|
||||||
|
|
29
statement.go
29
statement.go
|
@ -6,7 +6,6 @@ package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -398,7 +397,7 @@ func (statement *Statement) buildUpdates(bean interface{},
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
//TODO: how to handler?
|
// TODO: how to handler?
|
||||||
panic("not supported")
|
panic("not supported")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -579,21 +578,9 @@ func (statement *Statement) getExpr() map[string]exprParam {
|
||||||
|
|
||||||
func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
|
func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
|
||||||
newColumns := make([]string, 0)
|
newColumns := make([]string, 0)
|
||||||
|
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
col = strings.Replace(col, "`", "", -1)
|
newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
|
||||||
col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
|
|
||||||
ccols := strings.Split(col, ",")
|
|
||||||
for _, c := range ccols {
|
|
||||||
fields := strings.Split(strings.TrimSpace(c), ".")
|
|
||||||
if len(fields) == 1 {
|
|
||||||
newColumns = append(newColumns, statement.Engine.quote(fields[0]))
|
|
||||||
} else if len(fields) == 2 {
|
|
||||||
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
|
|
||||||
statement.Engine.quote(fields[1]))
|
|
||||||
} else {
|
|
||||||
panic(errors.New("unwanted colnames"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return newColumns
|
return newColumns
|
||||||
}
|
}
|
||||||
|
@ -764,7 +751,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
tbs := strings.Split(tp.TableName(), ".")
|
tbs := strings.Split(tp.TableName(), ".")
|
||||||
var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
|
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
|
||||||
|
|
||||||
|
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
|
||||||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
||||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||||
case *builder.Builder:
|
case *builder.Builder:
|
||||||
|
@ -774,7 +763,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
tbs := strings.Split(tp.TableName(), ".")
|
tbs := strings.Split(tp.TableName(), ".")
|
||||||
var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
|
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
|
||||||
|
|
||||||
|
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
|
||||||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
||||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||||
default:
|
default:
|
||||||
|
@ -1246,7 +1237,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
|
||||||
|
|
||||||
var whereStr = sqls[1]
|
var whereStr = sqls[1]
|
||||||
|
|
||||||
//TODO: for postgres only, if any other database?
|
// TODO: for postgres only, if any other database?
|
||||||
var paraStr string
|
var paraStr string
|
||||||
if statement.Engine.dialect.DBType() == core.POSTGRES {
|
if statement.Engine.dialect.DBType() == core.POSTGRES {
|
||||||
paraStr = "$"
|
paraStr = "$"
|
||||||
|
|
|
@ -9,8 +9,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"xorm.io/core"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"xorm.io/core"
|
||||||
)
|
)
|
||||||
|
|
||||||
var colStrTests = []struct {
|
var colStrTests = []struct {
|
||||||
|
@ -237,3 +237,12 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
|
||||||
testEngine.Update(record)
|
testEngine.Update(record)
|
||||||
assertGetRecord()
|
assertGetRecord()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCol2NewColsWithQuote(t *testing.T) {
|
||||||
|
cols := []string{"f1", "f2", "t3.f3"}
|
||||||
|
|
||||||
|
statement := createTestStatement()
|
||||||
|
|
||||||
|
quotedCols := statement.col2NewColsWithQuote(cols...)
|
||||||
|
assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue