Improve statement (#1555)

Fix test

Improve statement

Reviewed-on: https://gitea.com/xorm/xorm/pulls/1555
This commit is contained in:
Lunny Xiao 2020-02-27 05:49:43 +00:00
parent e2f9100419
commit 7bf9a7a73c
4 changed files with 54 additions and 45 deletions

View File

@ -66,13 +66,18 @@ func (q Quoter) Trim(s string) string {
return s
}
if s[0:1] == q[0] {
s = s[1:]
var buf strings.Builder
for i := 0; i < len(s); i++ {
switch {
case i == 0 && s[i:i+1] == q[0]:
case i == len(s)-1 && s[i:i+1] == q[1]:
case s[i:i+1] == q[1] && s[i+1] == '.':
case s[i:i+1] == q[0] && s[i-1] == '.':
default:
buf.WriteByte(s[i])
}
}
if len(s) > 0 && s[len(s)-1:] == q[1] {
return s[:len(s)-1]
}
return s
return buf.String()
}
func (q Quoter) Join(a []string, sep string) string {

View File

@ -65,7 +65,13 @@ func TestStrings(t *testing.T) {
}
func TestTrim(t *testing.T) {
raw := "[table_name]"
assert.EqualValues(t, raw, CommonQuoter.Trim(raw))
assert.EqualValues(t, "table_name", Quoter{"[", "]"}.Trim(raw))
var kases = map[string]string{
"[table_name]": "table_name",
"[schema].[table_name]": "schema.table_name",
}
for src, dst := range kases {
assert.EqualValues(t, src, CommonQuoter.Trim(src))
assert.EqualValues(t, dst, Quoter{"[", "]"}.Trim(src))
}
}

View File

@ -615,7 +615,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
}
func (statement *Statement) columnStr() string {
return statement.Engine.dialect.Quoter().Join(statement.columnMap, ", ")
return statement.dialect.Quoter().Join(statement.columnMap, ", ")
}
// AllCols update use only: update all columns
@ -750,10 +750,11 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
statement.lastError = err
return statement
}
tbs := strings.Split(tp.TableName(), ".")
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder:
@ -762,17 +763,18 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
statement.lastError = err
return statement
}
tbs := strings.Split(tp.TableName(), ".")
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
tbName := statement.Engine.TableName(tablename, true)
if !isSubQuery(tbName) {
var buf strings.Builder
statement.Engine.QuoteTo(&buf, tbName)
statement.dialect.Quoter().QuoteTo(&buf, tbName)
tbName = buf.String()
}
fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
@ -836,14 +838,14 @@ func (statement *Statement) genColumnStr() string {
buf.WriteString(".")
}
statement.Engine.QuoteTo(&buf, col.Name)
statement.dialect.Quoter().QuoteTo(&buf, col.Name)
}
return buf.String()
}
func (statement *Statement) genCreateTableSQL() string {
return statement.Engine.dialect.CreateTableSQL(statement.RefTable, statement.TableName(),
return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName(),
statement.StoreEngine, statement.Charset)
}
@ -852,11 +854,7 @@ func (statement *Statement) genIndexSQL() []string {
tbName := statement.TableName()
for _, index := range statement.RefTable.Indexes {
if index.Type == schemas.IndexType {
sql := statement.Engine.dialect.CreateIndexSQL(tbName, index)
/*idxTBName := strings.Replace(tbName, ".", "_", -1)
idxTBName = strings.Replace(idxTBName, `"`, "", -1)
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
sql := statement.dialect.CreateIndexSQL(tbName, index)
sqls = append(sqls, sql)
}
}
@ -872,7 +870,7 @@ func (statement *Statement) genUniqueSQL() []string {
tbName := statement.TableName()
for _, index := range statement.RefTable.Indexes {
if index.Type == schemas.UniqueType {
sql := statement.Engine.dialect.CreateIndexSQL(tbName, index)
sql := statement.dialect.CreateIndexSQL(tbName, index)
sqls = append(sqls, sql)
}
}
@ -895,9 +893,9 @@ func (statement *Statement) genDelIndexSQL() []string {
} else if index.Type == schemas.IndexType {
rIdxName = indexName(idxPrefixName, idxName)
}
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
if statement.Engine.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
sql := fmt.Sprintf("DROP INDEX %v", statement.quote(statement.Engine.TableName(rIdxName, true)))
if statement.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.quote(tbName))
}
sqls = append(sqls, sql)
}
@ -905,10 +903,10 @@ func (statement *Statement) genDelIndexSQL() []string {
}
func (statement *Statement) genAddColumnStr(col *schemas.Column) (string, []interface{}) {
quote := statement.Engine.Quote
quote := statement.quote
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
dialects.String(statement.Engine.dialect, col))
if statement.Engine.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
dialects.String(statement.dialect, col))
if statement.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ";"
@ -946,7 +944,7 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e
func (statement *Statement) quoteColumnStr(columnStr string) string {
columns := strings.Split(columnStr, ",")
return statement.Engine.dialect.Quoter().Join(columns, ",")
return statement.dialect.Quoter().Join(columns, ",")
}
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
@ -1040,7 +1038,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
var sumStrs = make([]string, 0, len(columns))
for _, colName := range columns {
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
colName = statement.Engine.Quote(colName)
colName = statement.quote(colName)
}
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
}
@ -1062,8 +1060,8 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
var (
distinct string
dialect = statement.Engine.Dialect()
quote = statement.Engine.Quote
dialect = statement.dialect
quote = statement.quote
fromStr = " FROM "
top, mssqlCondi, whereStr string
)
@ -1207,10 +1205,10 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName
var colnames = make([]string, len(cols))
for i, col := range cols {
if includeTableName {
colnames[i] = statement.Engine.Quote(statement.TableName()) +
"." + statement.Engine.Quote(col.Name)
colnames[i] = statement.quote(statement.TableName()) +
"." + statement.quote(col.Name)
} else {
colnames[i] = statement.Engine.Quote(col.Name)
colnames[i] = statement.quote(col.Name)
}
}
return strings.Join(colnames, ", ")
@ -1231,7 +1229,7 @@ func (statement *Statement) convertIDSQL(sqlStr string) string {
var top string
pLimitN := statement.LimitN
if pLimitN != nil && statement.Engine.dialect.DBType() == schemas.MSSQL {
if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN)
}
@ -1251,7 +1249,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
if len(sqls) != 2 {
if len(sqls) == 1 {
return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
colstrs, statement.Engine.Quote(statement.TableName()))
colstrs, statement.quote(statement.TableName()))
}
return "", ""
}
@ -1260,9 +1258,9 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
// TODO: for postgres only, if any other database?
var paraStr string
if statement.Engine.dialect.DBType() == schemas.POSTGRES {
if statement.dialect.DBType() == schemas.POSTGRES {
paraStr = "$"
} else if statement.Engine.dialect.DBType() == schemas.MSSQL {
} else if statement.dialect.DBType() == schemas.MSSQL {
paraStr = ":"
}
@ -1278,6 +1276,6 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
}
return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
colstrs, statement.Engine.Quote(statement.TableName()),
colstrs, statement.quote(statement.TableName()),
whereStr)
}

View File

@ -80,7 +80,7 @@ const insertSelectPlaceHolder = true
func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error {
switch argv := arg.(type) {
case bool:
if statement.Engine.dialect.DBType() == schemas.MSSQL {
if statement.dialect.DBType() == schemas.MSSQL {
if argv {
if _, err := w.WriteString("1"); err != nil {
return err
@ -119,7 +119,7 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er
w.Append(arg)
} else {
var convertFunc = convertStringSingleQuote
if statement.Engine.dialect.DBType() == schemas.MYSQL {
if statement.dialect.DBType() == schemas.MYSQL {
convertFunc = convertString
}
if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil {