From 7bf9a7a73c5e3f93b2abdbf1486b638041b07c91 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 27 Feb 2020 05:49:43 +0000 Subject: [PATCH] Improve statement (#1555) Fix test Improve statement Reviewed-on: https://gitea.com/xorm/xorm/pulls/1555 --- schemas/quote.go | 17 +++++++---- schemas/quote_test.go | 12 ++++++-- statement.go | 66 +++++++++++++++++++++---------------------- statement_args.go | 4 +-- 4 files changed, 54 insertions(+), 45 deletions(-) diff --git a/schemas/quote.go b/schemas/quote.go index d10a5dc8..0e022240 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -66,13 +66,18 @@ func (q Quoter) Trim(s string) string { return s } - if s[0:1] == q[0] { - s = s[1:] + var buf strings.Builder + for i := 0; i < len(s); i++ { + switch { + case i == 0 && s[i:i+1] == q[0]: + case i == len(s)-1 && s[i:i+1] == q[1]: + case s[i:i+1] == q[1] && s[i+1] == '.': + case s[i:i+1] == q[0] && s[i-1] == '.': + default: + buf.WriteByte(s[i]) + } } - if len(s) > 0 && s[len(s)-1:] == q[1] { - return s[:len(s)-1] - } - return s + return buf.String() } func (q Quoter) Join(a []string, sep string) string { diff --git a/schemas/quote_test.go b/schemas/quote_test.go index 174d1a0d..24739377 100644 --- a/schemas/quote_test.go +++ b/schemas/quote_test.go @@ -65,7 +65,13 @@ func TestStrings(t *testing.T) { } func TestTrim(t *testing.T) { - raw := "[table_name]" - assert.EqualValues(t, raw, CommonQuoter.Trim(raw)) - assert.EqualValues(t, "table_name", Quoter{"[", "]"}.Trim(raw)) + var kases = map[string]string{ + "[table_name]": "table_name", + "[schema].[table_name]": "schema.table_name", + } + + for src, dst := range kases { + assert.EqualValues(t, src, CommonQuoter.Trim(src)) + assert.EqualValues(t, dst, Quoter{"[", "]"}.Trim(src)) + } } diff --git a/statement.go b/statement.go index d3048601..c07ddfe9 100644 --- a/statement.go +++ b/statement.go @@ -615,7 +615,7 @@ func (statement *Statement) Cols(columns ...string) *Statement { } func (statement *Statement) columnStr() string { - return statement.Engine.dialect.Quoter().Join(statement.columnMap, ", ") + return statement.dialect.Quoter().Join(statement.columnMap, ", ") } // AllCols update use only: update all columns @@ -750,10 +750,11 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition statement.lastError = err return statement } - tbs := strings.Split(tp.TableName(), ".") - quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") - var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, "")) + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) case *builder.Builder: @@ -762,17 +763,18 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition statement.lastError = err return statement } - tbs := strings.Split(tp.TableName(), ".") - quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") - var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, "")) + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) default: tbName := statement.Engine.TableName(tablename, true) if !isSubQuery(tbName) { var buf strings.Builder - statement.Engine.QuoteTo(&buf, tbName) + statement.dialect.Quoter().QuoteTo(&buf, tbName) tbName = buf.String() } fmt.Fprintf(&buf, "%s ON %v", tbName, condition) @@ -836,14 +838,14 @@ func (statement *Statement) genColumnStr() string { buf.WriteString(".") } - statement.Engine.QuoteTo(&buf, col.Name) + statement.dialect.Quoter().QuoteTo(&buf, col.Name) } return buf.String() } func (statement *Statement) genCreateTableSQL() string { - return statement.Engine.dialect.CreateTableSQL(statement.RefTable, statement.TableName(), + return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName(), statement.StoreEngine, statement.Charset) } @@ -852,11 +854,7 @@ func (statement *Statement) genIndexSQL() []string { tbName := statement.TableName() for _, index := range statement.RefTable.Indexes { if index.Type == schemas.IndexType { - sql := statement.Engine.dialect.CreateIndexSQL(tbName, index) - /*idxTBName := strings.Replace(tbName, ".", "_", -1) - idxTBName = strings.Replace(idxTBName, `"`, "", -1) - sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)), - quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/ + sql := statement.dialect.CreateIndexSQL(tbName, index) sqls = append(sqls, sql) } } @@ -872,7 +870,7 @@ func (statement *Statement) genUniqueSQL() []string { tbName := statement.TableName() for _, index := range statement.RefTable.Indexes { if index.Type == schemas.UniqueType { - sql := statement.Engine.dialect.CreateIndexSQL(tbName, index) + sql := statement.dialect.CreateIndexSQL(tbName, index) sqls = append(sqls, sql) } } @@ -895,9 +893,9 @@ func (statement *Statement) genDelIndexSQL() []string { } else if index.Type == schemas.IndexType { rIdxName = indexName(idxPrefixName, idxName) } - sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true))) - if statement.Engine.dialect.IndexOnTable() { - sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) + sql := fmt.Sprintf("DROP INDEX %v", statement.quote(statement.Engine.TableName(rIdxName, true))) + if statement.dialect.IndexOnTable() { + sql += fmt.Sprintf(" ON %v", statement.quote(tbName)) } sqls = append(sqls, sql) } @@ -905,10 +903,10 @@ func (statement *Statement) genDelIndexSQL() []string { } func (statement *Statement) genAddColumnStr(col *schemas.Column) (string, []interface{}) { - quote := statement.Engine.Quote + quote := statement.quote sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()), - dialects.String(statement.Engine.dialect, col)) - if statement.Engine.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 { + dialects.String(statement.dialect, col)) + if statement.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 { sql += " COMMENT '" + col.Comment + "'" } sql += ";" @@ -946,7 +944,7 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e func (statement *Statement) quoteColumnStr(columnStr string) string { columns := strings.Split(columnStr, ",") - return statement.Engine.dialect.Quoter().Join(columns, ",") + return statement.dialect.Quoter().Join(columns, ",") } func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) { @@ -1040,7 +1038,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri var sumStrs = make([]string, 0, len(columns)) for _, colName := range columns { if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { - colName = statement.Engine.Quote(colName) + colName = statement.quote(colName) } sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) } @@ -1062,8 +1060,8 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { var ( distinct string - dialect = statement.Engine.Dialect() - quote = statement.Engine.Quote + dialect = statement.dialect + quote = statement.quote fromStr = " FROM " top, mssqlCondi, whereStr string ) @@ -1207,10 +1205,10 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName var colnames = make([]string, len(cols)) for i, col := range cols { if includeTableName { - colnames[i] = statement.Engine.Quote(statement.TableName()) + - "." + statement.Engine.Quote(col.Name) + colnames[i] = statement.quote(statement.TableName()) + + "." + statement.quote(col.Name) } else { - colnames[i] = statement.Engine.Quote(col.Name) + colnames[i] = statement.quote(col.Name) } } return strings.Join(colnames, ", ") @@ -1231,7 +1229,7 @@ func (statement *Statement) convertIDSQL(sqlStr string) string { var top string pLimitN := statement.LimitN - if pLimitN != nil && statement.Engine.dialect.DBType() == schemas.MSSQL { + if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL { top = fmt.Sprintf("TOP %d ", *pLimitN) } @@ -1251,7 +1249,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { if len(sqls) != 2 { if len(sqls) == 1 { return sqls[0], fmt.Sprintf("SELECT %v FROM %v", - colstrs, statement.Engine.Quote(statement.TableName())) + colstrs, statement.quote(statement.TableName())) } return "", "" } @@ -1260,9 +1258,9 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { // TODO: for postgres only, if any other database? var paraStr string - if statement.Engine.dialect.DBType() == schemas.POSTGRES { + if statement.dialect.DBType() == schemas.POSTGRES { paraStr = "$" - } else if statement.Engine.dialect.DBType() == schemas.MSSQL { + } else if statement.dialect.DBType() == schemas.MSSQL { paraStr = ":" } @@ -1278,6 +1276,6 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { } return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", - colstrs, statement.Engine.Quote(statement.TableName()), + colstrs, statement.quote(statement.TableName()), whereStr) } diff --git a/statement_args.go b/statement_args.go index 4f35ce6e..22bfeb7b 100644 --- a/statement_args.go +++ b/statement_args.go @@ -80,7 +80,7 @@ const insertSelectPlaceHolder = true func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error { switch argv := arg.(type) { case bool: - if statement.Engine.dialect.DBType() == schemas.MSSQL { + if statement.dialect.DBType() == schemas.MSSQL { if argv { if _, err := w.WriteString("1"); err != nil { return err @@ -119,7 +119,7 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er w.Append(arg) } else { var convertFunc = convertStringSingleQuote - if statement.Engine.dialect.DBType() == schemas.MYSQL { + if statement.dialect.DBType() == schemas.MYSQL { convertFunc = convertString } if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil {