This commit is contained in:
Lunny Xiao 2022-05-30 00:38:08 +08:00
parent ee2b5ef320
commit a0f42c421a
6 changed files with 93 additions and 72 deletions

View File

@ -249,7 +249,7 @@ func TestOrder(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
users = make([]Userinfo, 0) users = make([]Userinfo, 0)
err = testEngine.OrderBy("case username like ? desc", "a").Find(&users) err = testEngine.OrderBy("CASE WHEN username LIKE ? THEN 0 ELSE 1 END DESC", "a").Find(&users)
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -250,12 +251,13 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
distinct = "DISTINCT " distinct = "DISTINCT "
} }
condSQL, condArgs, err := statement.GenCondSQL(statement.cond) condWriter := builder.NewWriter()
if err != nil { if err := statement.cond.WriteTo(condWriter); err != nil {
return "", nil, err return "", nil, err
} }
if len(condSQL) > 0 {
whereStr = fmt.Sprintf(" WHERE %s", condSQL) if condWriter.Len() > 0 {
whereStr = " WHERE "
} }
pLimitN := statement.LimitN pLimitN := statement.LimitN
@ -297,11 +299,13 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} }
} }
if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s%s%s%s", if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s%s%s",
column, statement.Start, column, fromStr, whereStr, orderByWriter.String()); err != nil { column, statement.Start, column, fromStr, whereStr); err != nil {
return "", nil, err
}
if err := utils.WriteBuilder(mssqlCondi, condWriter, orderByWriter); err != nil {
return "", nil, err return "", nil, err
} }
mssqlCondi.Append(orderByWriter.Args()...)
if err := statement.WriteGroupBy(mssqlCondi); err != nil { if err := statement.WriteGroupBy(mssqlCondi); err != nil {
return "", nil, err return "", nil, err
@ -315,14 +319,19 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
buf := builder.NewWriter() buf := builder.NewWriter()
fmt.Fprintf(buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) fmt.Fprintf(buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
if err := utils.WriteBuilder(buf, condWriter); err != nil {
return "", nil, err
}
if mssqlCondi.Len() > 0 { if mssqlCondi.Len() > 0 {
if len(whereStr) > 0 { if len(whereStr) > 0 {
fmt.Fprint(buf, " AND ") fmt.Fprint(buf, " AND ")
} else { } else {
fmt.Fprint(buf, " WHERE ") fmt.Fprint(buf, " WHERE ")
} }
fmt.Fprint(buf, mssqlCondi.String())
buf.Append(mssqlCondi.Args()...) if err := utils.WriteBuilder(buf, mssqlCondi); err != nil {
return "", nil, err
}
} }
if err := statement.WriteGroupBy(buf); err != nil { if err := statement.WriteGroupBy(buf); err != nil {
@ -361,10 +370,10 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} }
} }
if statement.IsForUpdate { if statement.IsForUpdate {
return dialect.ForUpdateSQL(buf.String()), condArgs, nil return dialect.ForUpdateSQL(buf.String()), buf.Args(), nil
} }
return buf.String(), condArgs, nil return buf.String(), buf.Args(), nil
} }
// GenExistSQL generates Exist SQL // GenExistSQL generates Exist SQL

View File

@ -455,6 +455,10 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement {
return statement return statement
} }
func (statement *Statement) HasOrderBy() bool {
return statement.OrderStr != ""
}
// ResetOrderBy reset ordery conditions // ResetOrderBy reset ordery conditions
func (statement *Statement) ResetOrderBy() { func (statement *Statement) ResetOrderBy() {
statement.OrderStr = "" statement.OrderStr = ""

22
internal/utils/builder.go Normal file
View File

@ -0,0 +1,22 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package utils
import (
"fmt"
"xorm.io/builder"
)
// WriteBuilder writes writers to one
func WriteBuilder(w *builder.BytesWriter, inputs ...*builder.BytesWriter) error {
for _, input := range inputs {
if _, err := fmt.Fprint(w, input.String()); err != nil {
return err
}
w.Append(input.Args()...)
}
return nil
}

View File

@ -11,6 +11,7 @@ import (
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/caches" "xorm.io/xorm/caches"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -89,16 +90,6 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri
return nil return nil
} }
func writeBuilder(w *builder.BytesWriter, inputs ...*builder.BytesWriter) error {
for _, input := range inputs {
if _, err := fmt.Fprint(w, input.String()); err != nil {
return err
}
w.Append(input.Args()...)
}
return nil
}
// Delete records, bean's non-empty fields are conditions // Delete records, bean's non-empty fields are conditions
func (session *Session) Delete(beans ...interface{}) (int64, error) { func (session *Session) Delete(beans ...interface{}) (int64, error) {
if session.isAutoClose { if session.isAutoClose {
@ -194,7 +185,7 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
copy(argsForCache, deleteSQLWriter.Args()) copy(argsForCache, deleteSQLWriter.Args())
argsForCache = append(deleteSQLWriter.Args(), argsForCache...) argsForCache = append(deleteSQLWriter.Args(), argsForCache...)
if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled
if err := writeBuilder(realSQLWriter, deleteSQLWriter, orderCondWriter); err != nil { if err := utils.WriteBuilder(realSQLWriter, deleteSQLWriter, orderCondWriter); err != nil {
return 0, err return 0, err
} }
} else { } else {
@ -212,7 +203,7 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
realSQLWriter.Append(val) realSQLWriter.Append(val)
realSQLWriter.Append(condWriter.Args()...) realSQLWriter.Append(condWriter.Args()...)
if err := writeBuilder(realSQLWriter, orderCondWriter); err != nil { if err := utils.WriteBuilder(realSQLWriter, orderCondWriter); err != nil {
return 0, err return 0, err
} }

View File

@ -60,7 +60,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
ids = make([]schemas.PK, 0) ids = make([]schemas.PK, 0)
for rows.Next() { for rows.Next() {
var res = make([]string, len(table.PrimaryKeys)) res := make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res) err = rows.ScanSlice(&res)
if err != nil { if err != nil {
return err return err
@ -176,8 +176,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
// -- // --
var err error var err error
var isMap = t.Kind() == reflect.Map isMap := t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct isStruct := t.Kind() == reflect.Struct
if isStruct { if isStruct {
if err := session.statement.SetRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return 0, err return 0, err
@ -226,7 +226,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
args = append(args, val) args = append(args, val)
} }
var colName = col.Name colName := col.Name
if isStruct { if isStruct {
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
@ -279,7 +279,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condBeanIsStruct := false condBeanIsStruct := false
if len(condiBean) > 0 { if len(condiBean) > 0 {
if c, ok := condiBean[0].(map[string]interface{}); ok { if c, ok := condiBean[0].(map[string]interface{}); ok {
var eq = make(builder.Eq) eq := make(builder.Eq)
for k, v := range c { for k, v := range c {
eq[session.engine.Quote(k)] = v eq[session.engine.Quote(k)] = v
} }
@ -323,11 +323,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
st := session.statement st := session.statement
var ( var (
sqlStr string
condArgs []interface{}
condSQL string
cond = session.statement.Conds().And(autoCond) cond = session.statement.Conds().And(autoCond)
doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion) doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion)
verValue *reflect.Value verValue *reflect.Value
) )
@ -347,70 +343,65 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, ErrNoColumnsTobeUpdated return 0, ErrNoColumnsTobeUpdated
} }
condSQL, condArgs, err = session.statement.GenCondSQL(cond) whereWriter := builder.NewWriter()
if err != nil { if cond.IsValid() {
fmt.Fprint(whereWriter, "WHERE ")
}
if err := cond.WriteTo(whereWriter); err != nil {
return 0, err
}
if err := st.WriteOrderBy(whereWriter); err != nil {
return 0, err return 0, err
} }
if len(condSQL) > 0 { tableName := session.statement.TableName()
condSQL = "WHERE " + condSQL
}
if st.OrderStr != "" {
condSQL += fmt.Sprintf(" ORDER BY %v", st.OrderStr)
}
var tableName = session.statement.TableName()
// TODO: Oracle support needed // TODO: Oracle support needed
var top string var top string
if st.LimitN != nil { if st.LimitN != nil {
limitValue := *st.LimitN limitValue := *st.LimitN
switch session.engine.dialect.URI().DBType { switch session.engine.dialect.URI().DBType {
case schemas.MYSQL: case schemas.MYSQL:
condSQL += fmt.Sprintf(" LIMIT %d", limitValue) fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
case schemas.SQLITE: case schemas.SQLITE:
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil { whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
return 0, err return 0, err
} }
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
case schemas.POSTGRES: case schemas.POSTGRES:
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil { whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
return 0, err return 0, err
} }
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
case schemas.MSSQL: case schemas.MSSQL:
if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 { if st.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
session.engine.Quote(tableName), condSQL), condArgs...) session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)
condSQL, condArgs, err = session.statement.GenCondSQL(cond) whereWriter = builder.NewWriter()
if err != nil { fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
return 0, err return 0, err
} }
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else { } else {
top = fmt.Sprintf("TOP (%d) ", limitValue) top = fmt.Sprintf("TOP (%d) ", limitValue)
} }
} }
} }
var tableAlias = session.engine.Quote(tableName) tableAlias := session.engine.Quote(tableName)
var fromSQL string var fromSQL string
if session.statement.TableAlias != "" { if session.statement.TableAlias != "" {
switch session.engine.dialect.URI().DBType { switch session.engine.dialect.URI().DBType {
@ -422,14 +413,18 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v%v", updateWriter := builder.NewWriter()
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v%v",
top, top,
tableAlias, tableAlias,
strings.Join(colNames, ", "), strings.Join(colNames, ", "),
fromSQL, fromSQL,
condSQL) whereWriter.String()); err != nil {
return 0, err
}
updateWriter.Append(whereWriter.Args()...)
res, err := session.exec(sqlStr, append(args, condArgs...)...) res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...)
if err != nil { if err != nil {
return 0, err return 0, err
} else if doIncVer { } else if doIncVer {
@ -535,7 +530,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
} }
args = append(args, val) args = append(args, val)
var colName = col.Name colName := col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)