Improve some codes (#1551)
Improve code Reviewed-on: https://gitea.com/xorm/xorm/pulls/1551
This commit is contained in:
parent
5a5375a170
commit
a421331cf9
14
helpers.go
14
helpers.go
|
@ -200,17 +200,3 @@ func sliceEq(left, right []string) bool {
|
||||||
func indexName(tableName, idxName string) string {
|
func indexName(tableName, idxName string) string {
|
||||||
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
|
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func eraseAny(value string, strToErase ...string) string {
|
|
||||||
if len(strToErase) == 0 {
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
var replaceSeq []string
|
|
||||||
for _, s := range strToErase {
|
|
||||||
replaceSeq = append(replaceSeq, s, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
replacer := strings.NewReplacer(replaceSeq...)
|
|
||||||
|
|
||||||
return replacer.Replace(value)
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,18 +0,0 @@
|
||||||
// Copyright 2017 The Xorm Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package xorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestEraseAny(t *testing.T) {
|
|
||||||
raw := "SELECT * FROM `table`.[table_name]"
|
|
||||||
assert.EqualValues(t, raw, eraseAny(raw))
|
|
||||||
assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`"))
|
|
||||||
assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]"))
|
|
||||||
}
|
|
|
@ -69,7 +69,7 @@ func (q Quoter) Trim(s string) string {
|
||||||
if s[0:1] == q[0] {
|
if s[0:1] == q[0] {
|
||||||
s = s[1:]
|
s = s[1:]
|
||||||
}
|
}
|
||||||
if len(s) > 0 && s[len(s)-1:] == q[0] {
|
if len(s) > 0 && s[len(s)-1:] == q[1] {
|
||||||
return s[:len(s)-1]
|
return s[:len(s)-1]
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
|
|
|
@ -63,3 +63,9 @@ func TestStrings(t *testing.T) {
|
||||||
quotedCols := quoter.Strings(cols)
|
quotedCols := quoter.Strings(cols)
|
||||||
assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols)
|
assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTrim(t *testing.T) {
|
||||||
|
raw := "[table_name]"
|
||||||
|
assert.EqualValues(t, raw, CommonQuoter.Trim(raw))
|
||||||
|
assert.EqualValues(t, "table_name", Quoter{"[", "]"}.Trim(raw))
|
||||||
|
}
|
||||||
|
|
|
@ -135,7 +135,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
|
||||||
return ErrTableNotFound
|
return ErrTableNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
var columnStr = session.statement.ColumnStr
|
var columnStr = session.statement.columnStr()
|
||||||
if len(session.statement.selectStr) > 0 {
|
if len(session.statement.selectStr) > 0 {
|
||||||
columnStr = session.statement.selectStr
|
columnStr = session.statement.selectStr
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -29,7 +29,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa
|
||||||
return "", nil, ErrTableNotFound
|
return "", nil, ErrTableNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
var columnStr = session.statement.ColumnStr
|
var columnStr = session.statement.columnStr()
|
||||||
if len(session.statement.selectStr) > 0 {
|
if len(session.statement.selectStr) > 0 {
|
||||||
columnStr = session.statement.selectStr
|
columnStr = session.statement.selectStr
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -103,14 +103,8 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
|
||||||
sps := strings.SplitN(kv, "=", 2)
|
sps := strings.SplitN(kv, "=", 2)
|
||||||
sps2 := strings.Split(sps[0], ".")
|
sps2 := strings.Split(sps[0], ".")
|
||||||
colName := sps2[len(sps2)-1]
|
colName := sps2[len(sps2)-1]
|
||||||
// treat quote prefix, suffix and '`' as quotes
|
colName = session.engine.dialect.Quoter().Trim(colName)
|
||||||
quotes := append(strings.Split(session.engine.Quote(""), ""), "`")
|
colName = schemas.CommonQuoter.Trim(colName)
|
||||||
if strings.ContainsAny(colName, strings.Join(quotes, "")) {
|
|
||||||
colName = strings.TrimSpace(eraseAny(colName, quotes...))
|
|
||||||
} else {
|
|
||||||
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
|
|
||||||
return ErrCacheFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
if col := table.GetColumn(colName); col != nil {
|
if col := table.GetColumn(colName); col != nil {
|
||||||
fieldValue, err := col.ValueOf(bean)
|
fieldValue, err := col.ValueOf(bean)
|
||||||
|
@ -182,7 +176,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
return 0, ErrTableNotFound
|
return 0, ErrTableNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
if session.statement.ColumnStr == "" {
|
if session.statement.columnStr() == "" {
|
||||||
colNames, args = session.statement.buildUpdates(bean, false, false,
|
colNames, args = session.statement.buildUpdates(bean, false, false,
|
||||||
false, false, true)
|
false, false, true)
|
||||||
} else {
|
} else {
|
||||||
|
|
12
statement.go
12
statement.go
|
@ -30,7 +30,6 @@ type Statement struct {
|
||||||
joinArgs []interface{}
|
joinArgs []interface{}
|
||||||
GroupByStr string
|
GroupByStr string
|
||||||
HavingStr string
|
HavingStr string
|
||||||
ColumnStr string
|
|
||||||
selectStr string
|
selectStr string
|
||||||
useAllCols bool
|
useAllCols bool
|
||||||
AltTableName string
|
AltTableName string
|
||||||
|
@ -86,7 +85,6 @@ func (statement *Statement) Reset() {
|
||||||
statement.joinArgs = make([]interface{}, 0)
|
statement.joinArgs = make([]interface{}, 0)
|
||||||
statement.GroupByStr = ""
|
statement.GroupByStr = ""
|
||||||
statement.HavingStr = ""
|
statement.HavingStr = ""
|
||||||
statement.ColumnStr = ""
|
|
||||||
statement.columnMap = columnMap{}
|
statement.columnMap = columnMap{}
|
||||||
statement.omitColumnMap = columnMap{}
|
statement.omitColumnMap = columnMap{}
|
||||||
statement.AltTableName = ""
|
statement.AltTableName = ""
|
||||||
|
@ -612,11 +610,13 @@ func (statement *Statement) Cols(columns ...string) *Statement {
|
||||||
for _, nc := range cols {
|
for _, nc := range cols {
|
||||||
statement.columnMap.add(nc)
|
statement.columnMap.add(nc)
|
||||||
}
|
}
|
||||||
|
|
||||||
statement.ColumnStr = statement.dialect.Quoter().Join(statement.columnMap, ", ")
|
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) columnStr() string {
|
||||||
|
return statement.Engine.dialect.Quoter().Join(statement.columnMap, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
// AllCols update use only: update all columns
|
// AllCols update use only: update all columns
|
||||||
func (statement *Statement) AllCols() *Statement {
|
func (statement *Statement) AllCols() *Statement {
|
||||||
statement.useAllCols = true
|
statement.useAllCols = true
|
||||||
|
@ -955,7 +955,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
|
||||||
statement.setRefBean(bean)
|
statement.setRefBean(bean)
|
||||||
}
|
}
|
||||||
|
|
||||||
var columnStr = statement.ColumnStr
|
var columnStr = statement.columnStr()
|
||||||
if len(statement.selectStr) > 0 {
|
if len(statement.selectStr) > 0 {
|
||||||
columnStr = statement.selectStr
|
columnStr = statement.selectStr
|
||||||
} else {
|
} else {
|
||||||
|
@ -1020,7 +1020,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
|
||||||
var selectSQL = statement.selectStr
|
var selectSQL = statement.selectStr
|
||||||
if len(selectSQL) <= 0 {
|
if len(selectSQL) <= 0 {
|
||||||
if statement.IsDistinct {
|
if statement.IsDistinct {
|
||||||
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
|
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.columnStr())
|
||||||
} else {
|
} else {
|
||||||
selectSQL = "count(*)"
|
selectSQL = "count(*)"
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue