From c855ca4e5988c65aecc766bb014a9cd483dd66e2 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 27 Oct 2023 14:01:14 +0000 Subject: [PATCH] Some refactors (#2348) (#2352) backport #2348 Reviewed-on: https://gitea.com/xorm/xorm/pulls/2348 Reviewed-on: https://gitea.com/xorm/xorm/pulls/2352 --- .../statements/{statement_args.go => args.go} | 0 internal/statements/pagination.go | 30 +++---- internal/statements/update.go | 85 +++++++++++-------- 3 files changed, 62 insertions(+), 53 deletions(-) rename internal/statements/{statement_args.go => args.go} (100%) diff --git a/internal/statements/statement_args.go b/internal/statements/args.go similarity index 100% rename from internal/statements/statement_args.go rename to internal/statements/args.go diff --git a/internal/statements/pagination.go b/internal/statements/pagination.go index 3c7a3913..24a9d203 100644 --- a/internal/statements/pagination.go +++ b/internal/statements/pagination.go @@ -10,11 +10,12 @@ import ( "xorm.io/builder" "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" ) func (statement *Statement) writePagination(bw *builder.BytesWriter) error { dbType := statement.dialect.URI().DBType - if dbType == "mssql" || dbType == "oracle" { + if dbType == schemas.MSSQL || dbType == schemas.ORACLE { return statement.writeOffsetFetch(bw) } return statement.writeLimitOffset(bw) @@ -50,15 +51,15 @@ func (statement *Statement) writeOffsetFetch(w builder.Writer) error { } func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error { - if !statement.cond.IsValid() { - return statement.writeMssqlPaginationCond(w) - } - if _, err := fmt.Fprint(w, " WHERE "); err != nil { - return err - } - if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil { - return err + if statement.cond.IsValid() { + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil { + return err + } } + return statement.writeMssqlPaginationCond(w) } @@ -115,15 +116,8 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err if _, err := fmt.Fprint(subWriter, "))"); err != nil { return err } - - if statement.cond.IsValid() { - if _, err := fmt.Fprint(w, " AND "); err != nil { - return err - } - } else { - if _, err := fmt.Fprint(w, " WHERE "); err != nil { - return err - } + if err := statement.writeWhereOrAnd(w, statement.cond.IsValid()); err != nil { + return err } return utils.WriteBuilder(w, subWriter) diff --git a/internal/statements/update.go b/internal/statements/update.go index 5d71f34d..20e95447 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -350,6 +350,15 @@ func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) e return err } +func (statement *Statement) writeWhereOrAnd(updateWriter *builder.BytesWriter, hasConditions bool) error { + if hasConditions { + _, err := fmt.Fprint(updateWriter, " AND ") + return err + } + _, err := fmt.Fprint(updateWriter, " WHERE ") + return err +} + func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, cond builder.Cond) error { if statement.LimitN == nil { return nil @@ -364,14 +373,8 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, _, err := fmt.Fprintf(updateWriter, " LIMIT %d", limitValue) return err case schemas.SQLITE: - if cond.IsValid() { - if _, err := fmt.Fprint(updateWriter, " AND "); err != nil { - return err - } - } else { - if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil { - return err - } + if err := statement.writeWhereOrAnd(updateWriter, cond.IsValid()); err != nil { + return err } if _, err := fmt.Fprint(updateWriter, "rowid IN (SELECT rowid FROM ", statement.quote(tableName)); err != nil { return err @@ -385,14 +388,8 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, _, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue) return err case schemas.POSTGRES: - if cond.IsValid() { - if _, err := fmt.Fprint(updateWriter, " AND "); err != nil { - return err - } - } else { - if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil { - return err - } + if err := statement.writeWhereOrAnd(updateWriter, cond.IsValid()); err != nil { + return err } if _, err := fmt.Fprint(updateWriter, "CTID IN (SELECT CTID FROM ", statement.quote(tableName)); err != nil { return err @@ -477,9 +474,9 @@ func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Valu return nil } -func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error { +func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSets bool) error { for i, expr := range statement.IncrColumns { - if i > 0 || hasPreviousSet { + if i > 0 || hasPreviousSets { if _, err := fmt.Fprint(w, ", "); err != nil { return err } @@ -492,10 +489,10 @@ func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) return nil } -func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error { +func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSets bool) error { // for update action to like "column = column - ?" for i, expr := range statement.DecrColumns { - if i > 0 || hasPreviousSet { + if i > 0 || hasPreviousSets { if _, err := fmt.Fprint(w, ", "); err != nil { return err } @@ -508,10 +505,10 @@ func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) return nil } -func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error { +func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSets bool) error { // for update action to like "column = expression" for i, expr := range statement.ExprColumns { - if i > 0 || hasPreviousSet { + if i > 0 || hasPreviousSets { if _, err := fmt.Fprint(w, ", "); err != nil { return err } @@ -544,33 +541,51 @@ func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet return nil } -func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error { - previousLen := w.Len() - for i, colName := range colNames { - if i > 0 { - if _, err := fmt.Fprint(w, ", "); err != nil { +func (statement *Statement) writeSetColumns(colNames []string, args []interface{}) func(w *builder.BytesWriter) error { + return func(w *builder.BytesWriter) error { + if len(colNames) == 0 { + return nil + } + if len(colNames) != len(args) { + return fmt.Errorf("columns elements %d but args elements %d", len(colNames), len(args)) + } + for i, colName := range colNames { + if i > 0 { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + if _, err := fmt.Fprint(w, colName); err != nil { return err } } - if _, err := fmt.Fprint(w, colName); err != nil { - return err - } + w.Append(args...) + return nil } - w.Append(args...) +} - if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil { +func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error { + if err := statement.writeSetColumns(colNames, args)(w); err != nil { return err } - if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil { + setNumber := len(colNames) + if err := statement.writeIncrSets(w, setNumber > 0); err != nil { return err } - if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil { + setNumber += len(statement.IncrColumns) + if err := statement.writeDecrSets(w, setNumber > 0); err != nil { return err } - if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil { + setNumber += len(statement.DecrColumns) + if err := statement.writeExprSets(w, setNumber > 0); err != nil { + return err + } + + setNumber += len(statement.ExprColumns) + if err := statement.writeVersionIncrSet(w, v, setNumber > 0); err != nil { return err } return nil