diff --git a/internal/statements/delete.go b/internal/statements/delete.go index fccebbed..a77cf862 100644 --- a/internal/statements/delete.go +++ b/internal/statements/delete.go @@ -14,8 +14,8 @@ import ( "xorm.io/xorm/schemas" ) -func (statement *Statement) writeDeleteOrder(w builder.Writer) error { - if err := statement.WriteOrderBy(w); err != nil { +func (statement *Statement) writeDeleteOrder(w *builder.BytesWriter) error { + if err := statement.writeOrderBys(w); err != nil { return err } diff --git a/internal/statements/order_by.go b/internal/statements/order_by.go index b7eeeb87..595c0430 100644 --- a/internal/statements/order_by.go +++ b/internal/statements/order_by.go @@ -6,85 +6,91 @@ package statements import ( "fmt" - "strings" "xorm.io/builder" ) +type orderBy struct { + orderStr interface{} + orderArgs []interface{} + direction string // ASC, DESC or "", "" means raw orderStr +} + func (statement *Statement) HasOrderBy() bool { - return statement.orderStr != "" + return len(statement.orderBy) > 0 } // ResetOrderBy reset ordery conditions func (statement *Statement) ResetOrderBy() { - statement.orderStr = "" - statement.orderArgs = nil + statement.orderBy = []orderBy{} +} + +func (statement *Statement) writeOrderBy(w *builder.BytesWriter, orderBy orderBy) error { + switch t := orderBy.orderStr.(type) { + case (*builder.Expression): + if _, err := fmt.Fprint(w.Builder, statement.dialect.Quoter().Replace(t.Content())); err != nil { + return err + } + w.Append(t.Args()...) + return nil + case string: + if orderBy.direction == "" { + if _, err := fmt.Fprint(w.Builder, statement.dialect.Quoter().Replace(t)); err != nil { + return err + } + w.Append(orderBy.orderArgs...) + return nil + } + if err := statement.dialect.Quoter().QuoteTo(w.Builder, t); err != nil { + return err + } + _, err := fmt.Fprint(w, " ", orderBy.direction) + return err + default: + return ErrUnSupportedSQLType + } } // WriteOrderBy write order by to writer -func (statement *Statement) WriteOrderBy(w builder.Writer) error { - if len(statement.orderStr) > 0 { - if _, err := fmt.Fprint(w, " ORDER BY ", statement.orderStr); err != nil { +func (statement *Statement) writeOrderBys(w *builder.BytesWriter) error { + if len(statement.orderBy) == 0 { + return nil + } + + if _, err := fmt.Fprint(w, " ORDER BY "); err != nil { + return err + } + for i, ob := range statement.orderBy { + if err := statement.writeOrderBy(w, ob); err != nil { return err } - w.Append(statement.orderArgs...) + if i < len(statement.orderBy)-1 { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } } return nil } // OrderBy generate "Order By order" statement func (statement *Statement) OrderBy(order interface{}, args ...interface{}) *Statement { - if len(statement.orderStr) > 0 { - statement.orderStr += ", " - } - var rawOrder string - switch t := order.(type) { - case (*builder.Expression): - rawOrder = t.Content() - args = t.Args() - case string: - rawOrder = t - default: - statement.LastError = ErrUnSupportedSQLType - return statement - } - statement.orderStr += statement.ReplaceQuote(rawOrder) - if len(args) > 0 { - statement.orderArgs = append(statement.orderArgs, args...) - } + statement.orderBy = append(statement.orderBy, orderBy{order, args, ""}) return statement } // Desc generate `ORDER BY xx DESC` func (statement *Statement) Desc(colNames ...string) *Statement { - var buf strings.Builder - if len(statement.orderStr) > 0 { - fmt.Fprint(&buf, statement.orderStr, ", ") + for _, colName := range colNames { + statement.orderBy = append(statement.orderBy, orderBy{colName, nil, "DESC"}) } - for i, col := range colNames { - if i > 0 { - fmt.Fprint(&buf, ", ") - } - _ = statement.dialect.Quoter().QuoteTo(&buf, col) - fmt.Fprint(&buf, " DESC") - } - statement.orderStr = buf.String() return statement } // Asc provide asc order by query condition, the input parameters are columns. func (statement *Statement) Asc(colNames ...string) *Statement { - var buf strings.Builder - if len(statement.orderStr) > 0 { - fmt.Fprint(&buf, statement.orderStr, ", ") + for _, colName := range colNames { + statement.orderBy = append(statement.orderBy, orderBy{colName, nil, "ASC"}) } - for i, col := range colNames { - if i > 0 { - fmt.Fprint(&buf, ", ") - } - _ = statement.dialect.Quoter().QuoteTo(&buf, col) - fmt.Fprint(&buf, " ASC") - } - statement.orderStr = buf.String() return statement } diff --git a/internal/statements/query.go b/internal/statements/query.go index 2e38f0fe..63e079e7 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -320,7 +320,7 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err if err := statement.writeWhere(subWriter); err != nil { return err } - if err := statement.WriteOrderBy(subWriter); err != nil { + if err := statement.writeOrderBys(subWriter); err != nil { return err } if err := statement.writeGroupBy(subWriter); err != nil { @@ -375,7 +375,7 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri if err := statement.writeHaving(buf); err != nil { return err } - if err := statement.WriteOrderBy(buf); err != nil { + if err := statement.writeOrderBys(buf); err != nil { return err } diff --git a/internal/statements/statement.go b/internal/statements/statement.go index afc38a2e..017f40a5 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -50,8 +50,7 @@ type Statement struct { Start int LimitN *int idParam schemas.PK - orderStr string - orderArgs []interface{} + orderBy []orderBy joins []join GroupByStr string HavingStr string @@ -163,15 +162,15 @@ func (statement *Statement) Reset() { // SQL adds raw sql statement func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { - switch query.(type) { + switch t := query.(type) { case (*builder.Builder): var err error - statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL() + statement.RawSQL, statement.RawParams, err = t.ToSQL() if err != nil { statement.LastError = err } case string: - statement.RawSQL = query.(string) + statement.RawSQL = t statement.RawParams = args default: statement.LastError = ErrUnSupportedSQLType diff --git a/internal/statements/update.go b/internal/statements/update.go index 16ab5676..f0914b0b 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -319,7 +319,7 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { return err } - if err := statement.WriteOrderBy(whereWriter); err != nil { + if err := statement.writeOrderBys(whereWriter); err != nil { return err }