diff --git a/integrations/session_find_test.go b/integrations/session_find_test.go index 1ecf454a..6701b1b5 100644 --- a/integrations/session_find_test.go +++ b/integrations/session_find_test.go @@ -249,7 +249,7 @@ func TestOrder(t *testing.T) { assert.NoError(t, err) users = make([]Userinfo, 0) - err = testEngine.OrderBy("case username like ? desc", "a").Find(&users) + err = testEngine.OrderBy("CASE WHEN username LIKE ? THEN 0 ELSE 1 END DESC", "a").Find(&users) assert.NoError(t, err) } diff --git a/internal/statements/query.go b/internal/statements/query.go index 31228dbb..137ad02f 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -11,6 +11,7 @@ import ( "strings" "xorm.io/builder" + "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -250,12 +251,13 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB distinct = "DISTINCT " } - condSQL, condArgs, err := statement.GenCondSQL(statement.cond) - if err != nil { + condWriter := builder.NewWriter() + if err := statement.cond.WriteTo(condWriter); err != nil { return "", nil, err } - if len(condSQL) > 0 { - whereStr = fmt.Sprintf(" WHERE %s", condSQL) + + if condWriter.Len() > 0 { + whereStr = " WHERE " } pLimitN := statement.LimitN @@ -297,11 +299,13 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB } } - if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s%s%s%s", - column, statement.Start, column, fromStr, whereStr, orderByWriter.String()); err != nil { + if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s%s%s", + column, statement.Start, column, fromStr, whereStr); err != nil { + return "", nil, err + } + if err := utils.WriteBuilder(mssqlCondi, condWriter, orderByWriter); err != nil { return "", nil, err } - mssqlCondi.Append(orderByWriter.Args()...) if err := statement.WriteGroupBy(mssqlCondi); err != nil { return "", nil, err @@ -315,14 +319,19 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB buf := builder.NewWriter() fmt.Fprintf(buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) + if err := utils.WriteBuilder(buf, condWriter); err != nil { + return "", nil, err + } if mssqlCondi.Len() > 0 { if len(whereStr) > 0 { fmt.Fprint(buf, " AND ") } else { fmt.Fprint(buf, " WHERE ") } - fmt.Fprint(buf, mssqlCondi.String()) - buf.Append(mssqlCondi.Args()...) + + if err := utils.WriteBuilder(buf, mssqlCondi); err != nil { + return "", nil, err + } } if err := statement.WriteGroupBy(buf); err != nil { @@ -361,10 +370,10 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB } } if statement.IsForUpdate { - return dialect.ForUpdateSQL(buf.String()), condArgs, nil + return dialect.ForUpdateSQL(buf.String()), buf.Args(), nil } - return buf.String(), condArgs, nil + return buf.String(), buf.Args(), nil } // GenExistSQL generates Exist SQL diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 9255b478..8250921e 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -455,6 +455,10 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement { return statement } +func (statement *Statement) HasOrderBy() bool { + return statement.OrderStr != "" +} + // ResetOrderBy reset ordery conditions func (statement *Statement) ResetOrderBy() { statement.OrderStr = "" diff --git a/internal/utils/builder.go b/internal/utils/builder.go new file mode 100644 index 00000000..73bbf87a --- /dev/null +++ b/internal/utils/builder.go @@ -0,0 +1,22 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "fmt" + + "xorm.io/builder" +) + +// WriteBuilder writes writers to one +func WriteBuilder(w *builder.BytesWriter, inputs ...*builder.BytesWriter) error { + for _, input := range inputs { + if _, err := fmt.Fprint(w, input.String()); err != nil { + return err + } + w.Append(input.Args()...) + } + return nil +} diff --git a/session_delete.go b/session_delete.go index 8d8c2555..cd502a60 100644 --- a/session_delete.go +++ b/session_delete.go @@ -11,6 +11,7 @@ import ( "xorm.io/builder" "xorm.io/xorm/caches" + "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -89,16 +90,6 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri return nil } -func writeBuilder(w *builder.BytesWriter, inputs ...*builder.BytesWriter) error { - for _, input := range inputs { - if _, err := fmt.Fprint(w, input.String()); err != nil { - return err - } - w.Append(input.Args()...) - } - return nil -} - // Delete records, bean's non-empty fields are conditions func (session *Session) Delete(beans ...interface{}) (int64, error) { if session.isAutoClose { @@ -194,7 +185,7 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) { copy(argsForCache, deleteSQLWriter.Args()) argsForCache = append(deleteSQLWriter.Args(), argsForCache...) if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled - if err := writeBuilder(realSQLWriter, deleteSQLWriter, orderCondWriter); err != nil { + if err := utils.WriteBuilder(realSQLWriter, deleteSQLWriter, orderCondWriter); err != nil { return 0, err } } else { @@ -212,7 +203,7 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) { realSQLWriter.Append(val) realSQLWriter.Append(condWriter.Args()...) - if err := writeBuilder(realSQLWriter, orderCondWriter); err != nil { + if err := utils.WriteBuilder(realSQLWriter, orderCondWriter); err != nil { return 0, err } diff --git a/session_update.go b/session_update.go index fefbee90..235fa5b0 100644 --- a/session_update.go +++ b/session_update.go @@ -60,7 +60,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = make([]schemas.PK, 0) for rows.Next() { - var res = make([]string, len(table.PrimaryKeys)) + res := make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil { return err @@ -176,8 +176,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 // -- var err error - var isMap = t.Kind() == reflect.Map - var isStruct = t.Kind() == reflect.Struct + isMap := t.Kind() == reflect.Map + isStruct := t.Kind() == reflect.Struct if isStruct { if err := session.statement.SetRefBean(bean); err != nil { return 0, err @@ -226,7 +226,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 args = append(args, val) } - var colName = col.Name + colName := col.Name if isStruct { session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) @@ -279,7 +279,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 condBeanIsStruct := false if len(condiBean) > 0 { if c, ok := condiBean[0].(map[string]interface{}); ok { - var eq = make(builder.Eq) + eq := make(builder.Eq) for k, v := range c { eq[session.engine.Quote(k)] = v } @@ -323,11 +323,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 st := session.statement var ( - sqlStr string - condArgs []interface{} - condSQL string cond = session.statement.Conds().And(autoCond) - doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion) verValue *reflect.Value ) @@ -347,70 +343,65 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, ErrNoColumnsTobeUpdated } - condSQL, condArgs, err = session.statement.GenCondSQL(cond) - if err != nil { + whereWriter := builder.NewWriter() + if cond.IsValid() { + fmt.Fprint(whereWriter, "WHERE ") + } + if err := cond.WriteTo(whereWriter); err != nil { + return 0, err + } + if err := st.WriteOrderBy(whereWriter); err != nil { return 0, err } - if len(condSQL) > 0 { - condSQL = "WHERE " + condSQL - } - - if st.OrderStr != "" { - condSQL += fmt.Sprintf(" ORDER BY %v", st.OrderStr) - } - - var tableName = session.statement.TableName() + tableName := session.statement.TableName() // TODO: Oracle support needed var top string if st.LimitN != nil { limitValue := *st.LimitN switch session.engine.dialect.URI().DBType { case schemas.MYSQL: - condSQL += fmt.Sprintf(" LIMIT %d", limitValue) + fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) case schemas.SQLITE: - tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) + fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) + cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", - session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = session.statement.GenCondSQL(cond) - if err != nil { + session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)) + + whereWriter = builder.NewWriter() + fmt.Fprint(whereWriter, "WHERE ") + if err := cond.WriteTo(whereWriter); err != nil { return 0, err } - if len(condSQL) > 0 { - condSQL = "WHERE " + condSQL - } case schemas.POSTGRES: - tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) + fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) + cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", - session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = session.statement.GenCondSQL(cond) - if err != nil { + session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)) + + whereWriter = builder.NewWriter() + fmt.Fprint(whereWriter, "WHERE ") + if err := cond.WriteTo(whereWriter); err != nil { return 0, err } - - if len(condSQL) > 0 { - condSQL = "WHERE " + condSQL - } case schemas.MSSQL: - if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 { + if st.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], - session.engine.Quote(tableName), condSQL), condArgs...) + session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...) - condSQL, condArgs, err = session.statement.GenCondSQL(cond) - if err != nil { + whereWriter = builder.NewWriter() + fmt.Fprint(whereWriter, "WHERE ") + if err := cond.WriteTo(whereWriter); err != nil { return 0, err } - if len(condSQL) > 0 { - condSQL = "WHERE " + condSQL - } } else { top = fmt.Sprintf("TOP (%d) ", limitValue) } } } - var tableAlias = session.engine.Quote(tableName) + tableAlias := session.engine.Quote(tableName) var fromSQL string if session.statement.TableAlias != "" { switch session.engine.dialect.URI().DBType { @@ -422,14 +413,18 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v%v", + updateWriter := builder.NewWriter() + if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v%v", top, tableAlias, strings.Join(colNames, ", "), fromSQL, - condSQL) + whereWriter.String()); err != nil { + return 0, err + } + updateWriter.Append(whereWriter.Args()...) - res, err := session.exec(sqlStr, append(args, condArgs...)...) + res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...) if err != nil { return 0, err } else if doIncVer { @@ -535,7 +530,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac } args = append(args, val) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t)