Improve statement (#1555)
Fix test Improve statement Reviewed-on: https://gitea.com/xorm/xorm/pulls/1555
This commit is contained in:
parent
e2f9100419
commit
7bf9a7a73c
|
@ -66,13 +66,18 @@ func (q Quoter) Trim(s string) string {
|
|||
return s
|
||||
}
|
||||
|
||||
if s[0:1] == q[0] {
|
||||
s = s[1:]
|
||||
var buf strings.Builder
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch {
|
||||
case i == 0 && s[i:i+1] == q[0]:
|
||||
case i == len(s)-1 && s[i:i+1] == q[1]:
|
||||
case s[i:i+1] == q[1] && s[i+1] == '.':
|
||||
case s[i:i+1] == q[0] && s[i-1] == '.':
|
||||
default:
|
||||
buf.WriteByte(s[i])
|
||||
}
|
||||
}
|
||||
if len(s) > 0 && s[len(s)-1:] == q[1] {
|
||||
return s[:len(s)-1]
|
||||
}
|
||||
return s
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (q Quoter) Join(a []string, sep string) string {
|
||||
|
|
|
@ -65,7 +65,13 @@ func TestStrings(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTrim(t *testing.T) {
|
||||
raw := "[table_name]"
|
||||
assert.EqualValues(t, raw, CommonQuoter.Trim(raw))
|
||||
assert.EqualValues(t, "table_name", Quoter{"[", "]"}.Trim(raw))
|
||||
var kases = map[string]string{
|
||||
"[table_name]": "table_name",
|
||||
"[schema].[table_name]": "schema.table_name",
|
||||
}
|
||||
|
||||
for src, dst := range kases {
|
||||
assert.EqualValues(t, src, CommonQuoter.Trim(src))
|
||||
assert.EqualValues(t, dst, Quoter{"[", "]"}.Trim(src))
|
||||
}
|
||||
}
|
||||
|
|
66
statement.go
66
statement.go
|
@ -615,7 +615,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
|
|||
}
|
||||
|
||||
func (statement *Statement) columnStr() string {
|
||||
return statement.Engine.dialect.Quoter().Join(statement.columnMap, ", ")
|
||||
return statement.dialect.Quoter().Join(statement.columnMap, ", ")
|
||||
}
|
||||
|
||||
// AllCols update use only: update all columns
|
||||
|
@ -750,10 +750,11 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
|||
statement.lastError = err
|
||||
return statement
|
||||
}
|
||||
tbs := strings.Split(tp.TableName(), ".")
|
||||
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
|
||||
|
||||
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
|
||||
fields := strings.Split(tp.TableName(), ".")
|
||||
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
|
||||
aliasName = schemas.CommonQuoter.Trim(aliasName)
|
||||
|
||||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||
case *builder.Builder:
|
||||
|
@ -762,17 +763,18 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
|||
statement.lastError = err
|
||||
return statement
|
||||
}
|
||||
tbs := strings.Split(tp.TableName(), ".")
|
||||
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
|
||||
|
||||
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
|
||||
fields := strings.Split(tp.TableName(), ".")
|
||||
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
|
||||
aliasName = schemas.CommonQuoter.Trim(aliasName)
|
||||
|
||||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||
default:
|
||||
tbName := statement.Engine.TableName(tablename, true)
|
||||
if !isSubQuery(tbName) {
|
||||
var buf strings.Builder
|
||||
statement.Engine.QuoteTo(&buf, tbName)
|
||||
statement.dialect.Quoter().QuoteTo(&buf, tbName)
|
||||
tbName = buf.String()
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
|
||||
|
@ -836,14 +838,14 @@ func (statement *Statement) genColumnStr() string {
|
|||
buf.WriteString(".")
|
||||
}
|
||||
|
||||
statement.Engine.QuoteTo(&buf, col.Name)
|
||||
statement.dialect.Quoter().QuoteTo(&buf, col.Name)
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (statement *Statement) genCreateTableSQL() string {
|
||||
return statement.Engine.dialect.CreateTableSQL(statement.RefTable, statement.TableName(),
|
||||
return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName(),
|
||||
statement.StoreEngine, statement.Charset)
|
||||
}
|
||||
|
||||
|
@ -852,11 +854,7 @@ func (statement *Statement) genIndexSQL() []string {
|
|||
tbName := statement.TableName()
|
||||
for _, index := range statement.RefTable.Indexes {
|
||||
if index.Type == schemas.IndexType {
|
||||
sql := statement.Engine.dialect.CreateIndexSQL(tbName, index)
|
||||
/*idxTBName := strings.Replace(tbName, ".", "_", -1)
|
||||
idxTBName = strings.Replace(idxTBName, `"`, "", -1)
|
||||
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
|
||||
quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
|
||||
sql := statement.dialect.CreateIndexSQL(tbName, index)
|
||||
sqls = append(sqls, sql)
|
||||
}
|
||||
}
|
||||
|
@ -872,7 +870,7 @@ func (statement *Statement) genUniqueSQL() []string {
|
|||
tbName := statement.TableName()
|
||||
for _, index := range statement.RefTable.Indexes {
|
||||
if index.Type == schemas.UniqueType {
|
||||
sql := statement.Engine.dialect.CreateIndexSQL(tbName, index)
|
||||
sql := statement.dialect.CreateIndexSQL(tbName, index)
|
||||
sqls = append(sqls, sql)
|
||||
}
|
||||
}
|
||||
|
@ -895,9 +893,9 @@ func (statement *Statement) genDelIndexSQL() []string {
|
|||
} else if index.Type == schemas.IndexType {
|
||||
rIdxName = indexName(idxPrefixName, idxName)
|
||||
}
|
||||
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
|
||||
if statement.Engine.dialect.IndexOnTable() {
|
||||
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
|
||||
sql := fmt.Sprintf("DROP INDEX %v", statement.quote(statement.Engine.TableName(rIdxName, true)))
|
||||
if statement.dialect.IndexOnTable() {
|
||||
sql += fmt.Sprintf(" ON %v", statement.quote(tbName))
|
||||
}
|
||||
sqls = append(sqls, sql)
|
||||
}
|
||||
|
@ -905,10 +903,10 @@ func (statement *Statement) genDelIndexSQL() []string {
|
|||
}
|
||||
|
||||
func (statement *Statement) genAddColumnStr(col *schemas.Column) (string, []interface{}) {
|
||||
quote := statement.Engine.Quote
|
||||
quote := statement.quote
|
||||
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
|
||||
dialects.String(statement.Engine.dialect, col))
|
||||
if statement.Engine.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
|
||||
dialects.String(statement.dialect, col))
|
||||
if statement.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
|
||||
sql += " COMMENT '" + col.Comment + "'"
|
||||
}
|
||||
sql += ";"
|
||||
|
@ -946,7 +944,7 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e
|
|||
|
||||
func (statement *Statement) quoteColumnStr(columnStr string) string {
|
||||
columns := strings.Split(columnStr, ",")
|
||||
return statement.Engine.dialect.Quoter().Join(columns, ",")
|
||||
return statement.dialect.Quoter().Join(columns, ",")
|
||||
}
|
||||
|
||||
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
|
||||
|
@ -1040,7 +1038,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
|
|||
var sumStrs = make([]string, 0, len(columns))
|
||||
for _, colName := range columns {
|
||||
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
|
||||
colName = statement.Engine.Quote(colName)
|
||||
colName = statement.quote(colName)
|
||||
}
|
||||
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
|
||||
}
|
||||
|
@ -1062,8 +1060,8 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
|
|||
func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
|
||||
var (
|
||||
distinct string
|
||||
dialect = statement.Engine.Dialect()
|
||||
quote = statement.Engine.Quote
|
||||
dialect = statement.dialect
|
||||
quote = statement.quote
|
||||
fromStr = " FROM "
|
||||
top, mssqlCondi, whereStr string
|
||||
)
|
||||
|
@ -1207,10 +1205,10 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName
|
|||
var colnames = make([]string, len(cols))
|
||||
for i, col := range cols {
|
||||
if includeTableName {
|
||||
colnames[i] = statement.Engine.Quote(statement.TableName()) +
|
||||
"." + statement.Engine.Quote(col.Name)
|
||||
colnames[i] = statement.quote(statement.TableName()) +
|
||||
"." + statement.quote(col.Name)
|
||||
} else {
|
||||
colnames[i] = statement.Engine.Quote(col.Name)
|
||||
colnames[i] = statement.quote(col.Name)
|
||||
}
|
||||
}
|
||||
return strings.Join(colnames, ", ")
|
||||
|
@ -1231,7 +1229,7 @@ func (statement *Statement) convertIDSQL(sqlStr string) string {
|
|||
|
||||
var top string
|
||||
pLimitN := statement.LimitN
|
||||
if pLimitN != nil && statement.Engine.dialect.DBType() == schemas.MSSQL {
|
||||
if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL {
|
||||
top = fmt.Sprintf("TOP %d ", *pLimitN)
|
||||
}
|
||||
|
||||
|
@ -1251,7 +1249,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
|
|||
if len(sqls) != 2 {
|
||||
if len(sqls) == 1 {
|
||||
return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
|
||||
colstrs, statement.Engine.Quote(statement.TableName()))
|
||||
colstrs, statement.quote(statement.TableName()))
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
@ -1260,9 +1258,9 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
|
|||
|
||||
// TODO: for postgres only, if any other database?
|
||||
var paraStr string
|
||||
if statement.Engine.dialect.DBType() == schemas.POSTGRES {
|
||||
if statement.dialect.DBType() == schemas.POSTGRES {
|
||||
paraStr = "$"
|
||||
} else if statement.Engine.dialect.DBType() == schemas.MSSQL {
|
||||
} else if statement.dialect.DBType() == schemas.MSSQL {
|
||||
paraStr = ":"
|
||||
}
|
||||
|
||||
|
@ -1278,6 +1276,6 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
|
|||
}
|
||||
|
||||
return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
|
||||
colstrs, statement.Engine.Quote(statement.TableName()),
|
||||
colstrs, statement.quote(statement.TableName()),
|
||||
whereStr)
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ const insertSelectPlaceHolder = true
|
|||
func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error {
|
||||
switch argv := arg.(type) {
|
||||
case bool:
|
||||
if statement.Engine.dialect.DBType() == schemas.MSSQL {
|
||||
if statement.dialect.DBType() == schemas.MSSQL {
|
||||
if argv {
|
||||
if _, err := w.WriteString("1"); err != nil {
|
||||
return err
|
||||
|
@ -119,7 +119,7 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er
|
|||
w.Append(arg)
|
||||
} else {
|
||||
var convertFunc = convertStringSingleQuote
|
||||
if statement.Engine.dialect.DBType() == schemas.MYSQL {
|
||||
if statement.dialect.DBType() == schemas.MYSQL {
|
||||
convertFunc = convertString
|
||||
}
|
||||
if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil {
|
||||
|
|
Loading…
Reference in New Issue