Improve statement (#1555)

Fix test

Improve statement

Reviewed-on: https://gitea.com/xorm/xorm/pulls/1555
This commit is contained in:
Lunny Xiao 2020-02-27 05:49:43 +00:00
parent e2f9100419
commit 7bf9a7a73c
4 changed files with 54 additions and 45 deletions

View File

@ -66,13 +66,18 @@ func (q Quoter) Trim(s string) string {
return s return s
} }
if s[0:1] == q[0] { var buf strings.Builder
s = s[1:] for i := 0; i < len(s); i++ {
switch {
case i == 0 && s[i:i+1] == q[0]:
case i == len(s)-1 && s[i:i+1] == q[1]:
case s[i:i+1] == q[1] && s[i+1] == '.':
case s[i:i+1] == q[0] && s[i-1] == '.':
default:
buf.WriteByte(s[i])
} }
if len(s) > 0 && s[len(s)-1:] == q[1] {
return s[:len(s)-1]
} }
return s return buf.String()
} }
func (q Quoter) Join(a []string, sep string) string { func (q Quoter) Join(a []string, sep string) string {

View File

@ -65,7 +65,13 @@ func TestStrings(t *testing.T) {
} }
func TestTrim(t *testing.T) { func TestTrim(t *testing.T) {
raw := "[table_name]" var kases = map[string]string{
assert.EqualValues(t, raw, CommonQuoter.Trim(raw)) "[table_name]": "table_name",
assert.EqualValues(t, "table_name", Quoter{"[", "]"}.Trim(raw)) "[schema].[table_name]": "schema.table_name",
}
for src, dst := range kases {
assert.EqualValues(t, src, CommonQuoter.Trim(src))
assert.EqualValues(t, dst, Quoter{"[", "]"}.Trim(src))
}
} }

View File

@ -615,7 +615,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
} }
func (statement *Statement) columnStr() string { func (statement *Statement) columnStr() string {
return statement.Engine.dialect.Quoter().Join(statement.columnMap, ", ") return statement.dialect.Quoter().Join(statement.columnMap, ", ")
} }
// AllCols update use only: update all columns // AllCols update use only: update all columns
@ -750,10 +750,11 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
statement.lastError = err statement.lastError = err
return statement return statement
} }
tbs := strings.Split(tp.TableName(), ".")
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, "")) fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...) statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder: case *builder.Builder:
@ -762,17 +763,18 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
statement.lastError = err statement.lastError = err
return statement return statement
} }
tbs := strings.Split(tp.TableName(), ".")
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, "")) fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...) statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default: default:
tbName := statement.Engine.TableName(tablename, true) tbName := statement.Engine.TableName(tablename, true)
if !isSubQuery(tbName) { if !isSubQuery(tbName) {
var buf strings.Builder var buf strings.Builder
statement.Engine.QuoteTo(&buf, tbName) statement.dialect.Quoter().QuoteTo(&buf, tbName)
tbName = buf.String() tbName = buf.String()
} }
fmt.Fprintf(&buf, "%s ON %v", tbName, condition) fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
@ -836,14 +838,14 @@ func (statement *Statement) genColumnStr() string {
buf.WriteString(".") buf.WriteString(".")
} }
statement.Engine.QuoteTo(&buf, col.Name) statement.dialect.Quoter().QuoteTo(&buf, col.Name)
} }
return buf.String() return buf.String()
} }
func (statement *Statement) genCreateTableSQL() string { func (statement *Statement) genCreateTableSQL() string {
return statement.Engine.dialect.CreateTableSQL(statement.RefTable, statement.TableName(), return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName(),
statement.StoreEngine, statement.Charset) statement.StoreEngine, statement.Charset)
} }
@ -852,11 +854,7 @@ func (statement *Statement) genIndexSQL() []string {
tbName := statement.TableName() tbName := statement.TableName()
for _, index := range statement.RefTable.Indexes { for _, index := range statement.RefTable.Indexes {
if index.Type == schemas.IndexType { if index.Type == schemas.IndexType {
sql := statement.Engine.dialect.CreateIndexSQL(tbName, index) sql := statement.dialect.CreateIndexSQL(tbName, index)
/*idxTBName := strings.Replace(tbName, ".", "_", -1)
idxTBName = strings.Replace(idxTBName, `"`, "", -1)
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }
} }
@ -872,7 +870,7 @@ func (statement *Statement) genUniqueSQL() []string {
tbName := statement.TableName() tbName := statement.TableName()
for _, index := range statement.RefTable.Indexes { for _, index := range statement.RefTable.Indexes {
if index.Type == schemas.UniqueType { if index.Type == schemas.UniqueType {
sql := statement.Engine.dialect.CreateIndexSQL(tbName, index) sql := statement.dialect.CreateIndexSQL(tbName, index)
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }
} }
@ -895,9 +893,9 @@ func (statement *Statement) genDelIndexSQL() []string {
} else if index.Type == schemas.IndexType { } else if index.Type == schemas.IndexType {
rIdxName = indexName(idxPrefixName, idxName) rIdxName = indexName(idxPrefixName, idxName)
} }
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true))) sql := fmt.Sprintf("DROP INDEX %v", statement.quote(statement.Engine.TableName(rIdxName, true)))
if statement.Engine.dialect.IndexOnTable() { if statement.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) sql += fmt.Sprintf(" ON %v", statement.quote(tbName))
} }
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }
@ -905,10 +903,10 @@ func (statement *Statement) genDelIndexSQL() []string {
} }
func (statement *Statement) genAddColumnStr(col *schemas.Column) (string, []interface{}) { func (statement *Statement) genAddColumnStr(col *schemas.Column) (string, []interface{}) {
quote := statement.Engine.Quote quote := statement.quote
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()), sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
dialects.String(statement.Engine.dialect, col)) dialects.String(statement.dialect, col))
if statement.Engine.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 { if statement.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'" sql += " COMMENT '" + col.Comment + "'"
} }
sql += ";" sql += ";"
@ -946,7 +944,7 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e
func (statement *Statement) quoteColumnStr(columnStr string) string { func (statement *Statement) quoteColumnStr(columnStr string) string {
columns := strings.Split(columnStr, ",") columns := strings.Split(columnStr, ",")
return statement.Engine.dialect.Quoter().Join(columns, ",") return statement.dialect.Quoter().Join(columns, ",")
} }
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) { func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
@ -1040,7 +1038,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
var sumStrs = make([]string, 0, len(columns)) var sumStrs = make([]string, 0, len(columns))
for _, colName := range columns { for _, colName := range columns {
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
colName = statement.Engine.Quote(colName) colName = statement.quote(colName)
} }
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
} }
@ -1062,8 +1060,8 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
var ( var (
distinct string distinct string
dialect = statement.Engine.Dialect() dialect = statement.dialect
quote = statement.Engine.Quote quote = statement.quote
fromStr = " FROM " fromStr = " FROM "
top, mssqlCondi, whereStr string top, mssqlCondi, whereStr string
) )
@ -1207,10 +1205,10 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName
var colnames = make([]string, len(cols)) var colnames = make([]string, len(cols))
for i, col := range cols { for i, col := range cols {
if includeTableName { if includeTableName {
colnames[i] = statement.Engine.Quote(statement.TableName()) + colnames[i] = statement.quote(statement.TableName()) +
"." + statement.Engine.Quote(col.Name) "." + statement.quote(col.Name)
} else { } else {
colnames[i] = statement.Engine.Quote(col.Name) colnames[i] = statement.quote(col.Name)
} }
} }
return strings.Join(colnames, ", ") return strings.Join(colnames, ", ")
@ -1231,7 +1229,7 @@ func (statement *Statement) convertIDSQL(sqlStr string) string {
var top string var top string
pLimitN := statement.LimitN pLimitN := statement.LimitN
if pLimitN != nil && statement.Engine.dialect.DBType() == schemas.MSSQL { if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN) top = fmt.Sprintf("TOP %d ", *pLimitN)
} }
@ -1251,7 +1249,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
if len(sqls) != 2 { if len(sqls) != 2 {
if len(sqls) == 1 { if len(sqls) == 1 {
return sqls[0], fmt.Sprintf("SELECT %v FROM %v", return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
colstrs, statement.Engine.Quote(statement.TableName())) colstrs, statement.quote(statement.TableName()))
} }
return "", "" return "", ""
} }
@ -1260,9 +1258,9 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
// TODO: for postgres only, if any other database? // TODO: for postgres only, if any other database?
var paraStr string var paraStr string
if statement.Engine.dialect.DBType() == schemas.POSTGRES { if statement.dialect.DBType() == schemas.POSTGRES {
paraStr = "$" paraStr = "$"
} else if statement.Engine.dialect.DBType() == schemas.MSSQL { } else if statement.dialect.DBType() == schemas.MSSQL {
paraStr = ":" paraStr = ":"
} }
@ -1278,6 +1276,6 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
} }
return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
colstrs, statement.Engine.Quote(statement.TableName()), colstrs, statement.quote(statement.TableName()),
whereStr) whereStr)
} }

View File

@ -80,7 +80,7 @@ const insertSelectPlaceHolder = true
func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error { func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error {
switch argv := arg.(type) { switch argv := arg.(type) {
case bool: case bool:
if statement.Engine.dialect.DBType() == schemas.MSSQL { if statement.dialect.DBType() == schemas.MSSQL {
if argv { if argv {
if _, err := w.WriteString("1"); err != nil { if _, err := w.WriteString("1"); err != nil {
return err return err
@ -119,7 +119,7 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er
w.Append(arg) w.Append(arg)
} else { } else {
var convertFunc = convertStringSingleQuote var convertFunc = convertStringSingleQuote
if statement.Engine.dialect.DBType() == schemas.MYSQL { if statement.dialect.DBType() == schemas.MYSQL {
convertFunc = convertString convertFunc = convertString
} }
if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil { if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil {