Improve statement (#1555)
Fix test Improve statement Reviewed-on: https://gitea.com/xorm/xorm/pulls/1555
This commit is contained in:
parent
e2f9100419
commit
7bf9a7a73c
|
@ -66,13 +66,18 @@ func (q Quoter) Trim(s string) string {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
if s[0:1] == q[0] {
|
var buf strings.Builder
|
||||||
s = s[1:]
|
for i := 0; i < len(s); i++ {
|
||||||
|
switch {
|
||||||
|
case i == 0 && s[i:i+1] == q[0]:
|
||||||
|
case i == len(s)-1 && s[i:i+1] == q[1]:
|
||||||
|
case s[i:i+1] == q[1] && s[i+1] == '.':
|
||||||
|
case s[i:i+1] == q[0] && s[i-1] == '.':
|
||||||
|
default:
|
||||||
|
buf.WriteByte(s[i])
|
||||||
}
|
}
|
||||||
if len(s) > 0 && s[len(s)-1:] == q[1] {
|
|
||||||
return s[:len(s)-1]
|
|
||||||
}
|
}
|
||||||
return s
|
return buf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q Quoter) Join(a []string, sep string) string {
|
func (q Quoter) Join(a []string, sep string) string {
|
||||||
|
|
|
@ -65,7 +65,13 @@ func TestStrings(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTrim(t *testing.T) {
|
func TestTrim(t *testing.T) {
|
||||||
raw := "[table_name]"
|
var kases = map[string]string{
|
||||||
assert.EqualValues(t, raw, CommonQuoter.Trim(raw))
|
"[table_name]": "table_name",
|
||||||
assert.EqualValues(t, "table_name", Quoter{"[", "]"}.Trim(raw))
|
"[schema].[table_name]": "schema.table_name",
|
||||||
|
}
|
||||||
|
|
||||||
|
for src, dst := range kases {
|
||||||
|
assert.EqualValues(t, src, CommonQuoter.Trim(src))
|
||||||
|
assert.EqualValues(t, dst, Quoter{"[", "]"}.Trim(src))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
66
statement.go
66
statement.go
|
@ -615,7 +615,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) columnStr() string {
|
func (statement *Statement) columnStr() string {
|
||||||
return statement.Engine.dialect.Quoter().Join(statement.columnMap, ", ")
|
return statement.dialect.Quoter().Join(statement.columnMap, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllCols update use only: update all columns
|
// AllCols update use only: update all columns
|
||||||
|
@ -750,10 +750,11 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
||||||
statement.lastError = err
|
statement.lastError = err
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
tbs := strings.Split(tp.TableName(), ".")
|
|
||||||
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
|
|
||||||
|
|
||||||
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
|
fields := strings.Split(tp.TableName(), ".")
|
||||||
|
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
|
||||||
|
aliasName = schemas.CommonQuoter.Trim(aliasName)
|
||||||
|
|
||||||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
||||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||||
case *builder.Builder:
|
case *builder.Builder:
|
||||||
|
@ -762,17 +763,18 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
||||||
statement.lastError = err
|
statement.lastError = err
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
tbs := strings.Split(tp.TableName(), ".")
|
|
||||||
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
|
|
||||||
|
|
||||||
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
|
fields := strings.Split(tp.TableName(), ".")
|
||||||
|
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
|
||||||
|
aliasName = schemas.CommonQuoter.Trim(aliasName)
|
||||||
|
|
||||||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
||||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||||
default:
|
default:
|
||||||
tbName := statement.Engine.TableName(tablename, true)
|
tbName := statement.Engine.TableName(tablename, true)
|
||||||
if !isSubQuery(tbName) {
|
if !isSubQuery(tbName) {
|
||||||
var buf strings.Builder
|
var buf strings.Builder
|
||||||
statement.Engine.QuoteTo(&buf, tbName)
|
statement.dialect.Quoter().QuoteTo(&buf, tbName)
|
||||||
tbName = buf.String()
|
tbName = buf.String()
|
||||||
}
|
}
|
||||||
fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
|
fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
|
||||||
|
@ -836,14 +838,14 @@ func (statement *Statement) genColumnStr() string {
|
||||||
buf.WriteString(".")
|
buf.WriteString(".")
|
||||||
}
|
}
|
||||||
|
|
||||||
statement.Engine.QuoteTo(&buf, col.Name)
|
statement.dialect.Quoter().QuoteTo(&buf, col.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return buf.String()
|
return buf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) genCreateTableSQL() string {
|
func (statement *Statement) genCreateTableSQL() string {
|
||||||
return statement.Engine.dialect.CreateTableSQL(statement.RefTable, statement.TableName(),
|
return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName(),
|
||||||
statement.StoreEngine, statement.Charset)
|
statement.StoreEngine, statement.Charset)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -852,11 +854,7 @@ func (statement *Statement) genIndexSQL() []string {
|
||||||
tbName := statement.TableName()
|
tbName := statement.TableName()
|
||||||
for _, index := range statement.RefTable.Indexes {
|
for _, index := range statement.RefTable.Indexes {
|
||||||
if index.Type == schemas.IndexType {
|
if index.Type == schemas.IndexType {
|
||||||
sql := statement.Engine.dialect.CreateIndexSQL(tbName, index)
|
sql := statement.dialect.CreateIndexSQL(tbName, index)
|
||||||
/*idxTBName := strings.Replace(tbName, ".", "_", -1)
|
|
||||||
idxTBName = strings.Replace(idxTBName, `"`, "", -1)
|
|
||||||
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
|
|
||||||
quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
|
|
||||||
sqls = append(sqls, sql)
|
sqls = append(sqls, sql)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -872,7 +870,7 @@ func (statement *Statement) genUniqueSQL() []string {
|
||||||
tbName := statement.TableName()
|
tbName := statement.TableName()
|
||||||
for _, index := range statement.RefTable.Indexes {
|
for _, index := range statement.RefTable.Indexes {
|
||||||
if index.Type == schemas.UniqueType {
|
if index.Type == schemas.UniqueType {
|
||||||
sql := statement.Engine.dialect.CreateIndexSQL(tbName, index)
|
sql := statement.dialect.CreateIndexSQL(tbName, index)
|
||||||
sqls = append(sqls, sql)
|
sqls = append(sqls, sql)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -895,9 +893,9 @@ func (statement *Statement) genDelIndexSQL() []string {
|
||||||
} else if index.Type == schemas.IndexType {
|
} else if index.Type == schemas.IndexType {
|
||||||
rIdxName = indexName(idxPrefixName, idxName)
|
rIdxName = indexName(idxPrefixName, idxName)
|
||||||
}
|
}
|
||||||
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
|
sql := fmt.Sprintf("DROP INDEX %v", statement.quote(statement.Engine.TableName(rIdxName, true)))
|
||||||
if statement.Engine.dialect.IndexOnTable() {
|
if statement.dialect.IndexOnTable() {
|
||||||
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
|
sql += fmt.Sprintf(" ON %v", statement.quote(tbName))
|
||||||
}
|
}
|
||||||
sqls = append(sqls, sql)
|
sqls = append(sqls, sql)
|
||||||
}
|
}
|
||||||
|
@ -905,10 +903,10 @@ func (statement *Statement) genDelIndexSQL() []string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) genAddColumnStr(col *schemas.Column) (string, []interface{}) {
|
func (statement *Statement) genAddColumnStr(col *schemas.Column) (string, []interface{}) {
|
||||||
quote := statement.Engine.Quote
|
quote := statement.quote
|
||||||
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
|
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
|
||||||
dialects.String(statement.Engine.dialect, col))
|
dialects.String(statement.dialect, col))
|
||||||
if statement.Engine.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
|
if statement.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
|
||||||
sql += " COMMENT '" + col.Comment + "'"
|
sql += " COMMENT '" + col.Comment + "'"
|
||||||
}
|
}
|
||||||
sql += ";"
|
sql += ";"
|
||||||
|
@ -946,7 +944,7 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e
|
||||||
|
|
||||||
func (statement *Statement) quoteColumnStr(columnStr string) string {
|
func (statement *Statement) quoteColumnStr(columnStr string) string {
|
||||||
columns := strings.Split(columnStr, ",")
|
columns := strings.Split(columnStr, ",")
|
||||||
return statement.Engine.dialect.Quoter().Join(columns, ",")
|
return statement.dialect.Quoter().Join(columns, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
|
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
|
||||||
|
@ -1040,7 +1038,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
|
||||||
var sumStrs = make([]string, 0, len(columns))
|
var sumStrs = make([]string, 0, len(columns))
|
||||||
for _, colName := range columns {
|
for _, colName := range columns {
|
||||||
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
|
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
|
||||||
colName = statement.Engine.Quote(colName)
|
colName = statement.quote(colName)
|
||||||
}
|
}
|
||||||
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
|
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
|
||||||
}
|
}
|
||||||
|
@ -1062,8 +1060,8 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
|
||||||
func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
|
func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
|
||||||
var (
|
var (
|
||||||
distinct string
|
distinct string
|
||||||
dialect = statement.Engine.Dialect()
|
dialect = statement.dialect
|
||||||
quote = statement.Engine.Quote
|
quote = statement.quote
|
||||||
fromStr = " FROM "
|
fromStr = " FROM "
|
||||||
top, mssqlCondi, whereStr string
|
top, mssqlCondi, whereStr string
|
||||||
)
|
)
|
||||||
|
@ -1207,10 +1205,10 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName
|
||||||
var colnames = make([]string, len(cols))
|
var colnames = make([]string, len(cols))
|
||||||
for i, col := range cols {
|
for i, col := range cols {
|
||||||
if includeTableName {
|
if includeTableName {
|
||||||
colnames[i] = statement.Engine.Quote(statement.TableName()) +
|
colnames[i] = statement.quote(statement.TableName()) +
|
||||||
"." + statement.Engine.Quote(col.Name)
|
"." + statement.quote(col.Name)
|
||||||
} else {
|
} else {
|
||||||
colnames[i] = statement.Engine.Quote(col.Name)
|
colnames[i] = statement.quote(col.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return strings.Join(colnames, ", ")
|
return strings.Join(colnames, ", ")
|
||||||
|
@ -1231,7 +1229,7 @@ func (statement *Statement) convertIDSQL(sqlStr string) string {
|
||||||
|
|
||||||
var top string
|
var top string
|
||||||
pLimitN := statement.LimitN
|
pLimitN := statement.LimitN
|
||||||
if pLimitN != nil && statement.Engine.dialect.DBType() == schemas.MSSQL {
|
if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL {
|
||||||
top = fmt.Sprintf("TOP %d ", *pLimitN)
|
top = fmt.Sprintf("TOP %d ", *pLimitN)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1251,7 +1249,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
|
||||||
if len(sqls) != 2 {
|
if len(sqls) != 2 {
|
||||||
if len(sqls) == 1 {
|
if len(sqls) == 1 {
|
||||||
return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
|
return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
|
||||||
colstrs, statement.Engine.Quote(statement.TableName()))
|
colstrs, statement.quote(statement.TableName()))
|
||||||
}
|
}
|
||||||
return "", ""
|
return "", ""
|
||||||
}
|
}
|
||||||
|
@ -1260,9 +1258,9 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
|
||||||
|
|
||||||
// TODO: for postgres only, if any other database?
|
// TODO: for postgres only, if any other database?
|
||||||
var paraStr string
|
var paraStr string
|
||||||
if statement.Engine.dialect.DBType() == schemas.POSTGRES {
|
if statement.dialect.DBType() == schemas.POSTGRES {
|
||||||
paraStr = "$"
|
paraStr = "$"
|
||||||
} else if statement.Engine.dialect.DBType() == schemas.MSSQL {
|
} else if statement.dialect.DBType() == schemas.MSSQL {
|
||||||
paraStr = ":"
|
paraStr = ":"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1278,6 +1276,6 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
|
return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
|
||||||
colstrs, statement.Engine.Quote(statement.TableName()),
|
colstrs, statement.quote(statement.TableName()),
|
||||||
whereStr)
|
whereStr)
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,7 +80,7 @@ const insertSelectPlaceHolder = true
|
||||||
func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error {
|
func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error {
|
||||||
switch argv := arg.(type) {
|
switch argv := arg.(type) {
|
||||||
case bool:
|
case bool:
|
||||||
if statement.Engine.dialect.DBType() == schemas.MSSQL {
|
if statement.dialect.DBType() == schemas.MSSQL {
|
||||||
if argv {
|
if argv {
|
||||||
if _, err := w.WriteString("1"); err != nil {
|
if _, err := w.WriteString("1"); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -119,7 +119,7 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er
|
||||||
w.Append(arg)
|
w.Append(arg)
|
||||||
} else {
|
} else {
|
||||||
var convertFunc = convertStringSingleQuote
|
var convertFunc = convertStringSingleQuote
|
||||||
if statement.Engine.dialect.DBType() == schemas.MYSQL {
|
if statement.dialect.DBType() == schemas.MYSQL {
|
||||||
convertFunc = convertString
|
convertFunc = convertString
|
||||||
}
|
}
|
||||||
if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil {
|
if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil {
|
||||||
|
|
Loading…
Reference in New Issue