Improve some codes (#1551)

Improve code

Reviewed-on: https://gitea.com/xorm/xorm/pulls/1551
This commit is contained in:
Lunny Xiao 2020-02-27 01:30:06 +00:00
parent 5a5375a170
commit a421331cf9
8 changed files with 18 additions and 50 deletions

View File

@ -200,17 +200,3 @@ func sliceEq(left, right []string) bool {
func indexName(tableName, idxName string) string { func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName) return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
} }
func eraseAny(value string, strToErase ...string) string {
if len(strToErase) == 0 {
return value
}
var replaceSeq []string
for _, s := range strToErase {
replaceSeq = append(replaceSeq, s, "")
}
replacer := strings.NewReplacer(replaceSeq...)
return replacer.Replace(value)
}

View File

@ -1,18 +0,0 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestEraseAny(t *testing.T) {
raw := "SELECT * FROM `table`.[table_name]"
assert.EqualValues(t, raw, eraseAny(raw))
assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`"))
assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]"))
}

View File

@ -69,7 +69,7 @@ func (q Quoter) Trim(s string) string {
if s[0:1] == q[0] { if s[0:1] == q[0] {
s = s[1:] s = s[1:]
} }
if len(s) > 0 && s[len(s)-1:] == q[0] { if len(s) > 0 && s[len(s)-1:] == q[1] {
return s[:len(s)-1] return s[:len(s)-1]
} }
return s return s

View File

@ -63,3 +63,9 @@ func TestStrings(t *testing.T) {
quotedCols := quoter.Strings(cols) quotedCols := quoter.Strings(cols)
assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols) assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols)
} }
func TestTrim(t *testing.T) {
raw := "[table_name]"
assert.EqualValues(t, raw, CommonQuoter.Trim(raw))
assert.EqualValues(t, "table_name", Quoter{"[", "]"}.Trim(raw))
}

View File

@ -135,7 +135,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
return ErrTableNotFound return ErrTableNotFound
} }
var columnStr = session.statement.ColumnStr var columnStr = session.statement.columnStr()
if len(session.statement.selectStr) > 0 { if len(session.statement.selectStr) > 0 {
columnStr = session.statement.selectStr columnStr = session.statement.selectStr
} else { } else {

View File

@ -29,7 +29,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
var columnStr = session.statement.ColumnStr var columnStr = session.statement.columnStr()
if len(session.statement.selectStr) > 0 { if len(session.statement.selectStr) > 0 {
columnStr = session.statement.selectStr columnStr = session.statement.selectStr
} else { } else {

View File

@ -103,14 +103,8 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
sps := strings.SplitN(kv, "=", 2) sps := strings.SplitN(kv, "=", 2)
sps2 := strings.Split(sps[0], ".") sps2 := strings.Split(sps[0], ".")
colName := sps2[len(sps2)-1] colName := sps2[len(sps2)-1]
// treat quote prefix, suffix and '`' as quotes colName = session.engine.dialect.Quoter().Trim(colName)
quotes := append(strings.Split(session.engine.Quote(""), ""), "`") colName = schemas.CommonQuoter.Trim(colName)
if strings.ContainsAny(colName, strings.Join(quotes, "")) {
colName = strings.TrimSpace(eraseAny(colName, quotes...))
} else {
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
return ErrCacheFailed
}
if col := table.GetColumn(colName); col != nil { if col := table.GetColumn(colName); col != nil {
fieldValue, err := col.ValueOf(bean) fieldValue, err := col.ValueOf(bean)
@ -182,7 +176,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
if session.statement.ColumnStr == "" { if session.statement.columnStr() == "" {
colNames, args = session.statement.buildUpdates(bean, false, false, colNames, args = session.statement.buildUpdates(bean, false, false,
false, false, true) false, false, true)
} else { } else {

View File

@ -30,7 +30,6 @@ type Statement struct {
joinArgs []interface{} joinArgs []interface{}
GroupByStr string GroupByStr string
HavingStr string HavingStr string
ColumnStr string
selectStr string selectStr string
useAllCols bool useAllCols bool
AltTableName string AltTableName string
@ -86,7 +85,6 @@ func (statement *Statement) Reset() {
statement.joinArgs = make([]interface{}, 0) statement.joinArgs = make([]interface{}, 0)
statement.GroupByStr = "" statement.GroupByStr = ""
statement.HavingStr = "" statement.HavingStr = ""
statement.ColumnStr = ""
statement.columnMap = columnMap{} statement.columnMap = columnMap{}
statement.omitColumnMap = columnMap{} statement.omitColumnMap = columnMap{}
statement.AltTableName = "" statement.AltTableName = ""
@ -612,11 +610,13 @@ func (statement *Statement) Cols(columns ...string) *Statement {
for _, nc := range cols { for _, nc := range cols {
statement.columnMap.add(nc) statement.columnMap.add(nc)
} }
statement.ColumnStr = statement.dialect.Quoter().Join(statement.columnMap, ", ")
return statement return statement
} }
func (statement *Statement) columnStr() string {
return statement.Engine.dialect.Quoter().Join(statement.columnMap, ", ")
}
// AllCols update use only: update all columns // AllCols update use only: update all columns
func (statement *Statement) AllCols() *Statement { func (statement *Statement) AllCols() *Statement {
statement.useAllCols = true statement.useAllCols = true
@ -955,7 +955,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
statement.setRefBean(bean) statement.setRefBean(bean)
} }
var columnStr = statement.ColumnStr var columnStr = statement.columnStr()
if len(statement.selectStr) > 0 { if len(statement.selectStr) > 0 {
columnStr = statement.selectStr columnStr = statement.selectStr
} else { } else {
@ -1020,7 +1020,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
var selectSQL = statement.selectStr var selectSQL = statement.selectStr
if len(selectSQL) <= 0 { if len(selectSQL) <= 0 {
if statement.IsDistinct { if statement.IsDistinct {
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr) selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.columnStr())
} else { } else {
selectSQL = "count(*)" selectSQL = "count(*)"
} }