Some refactors (#2348)

Reviewed-on: https://gitea.com/xorm/xorm/pulls/2348
This commit is contained in:
Lunny Xiao 2023-10-27 09:16:46 +00:00
parent 6ef0a7798f
commit 8eeb1ef8ac
3 changed files with 62 additions and 53 deletions

View File

@ -10,11 +10,12 @@ import (
"xorm.io/builder"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
func (statement *Statement) writePagination(bw *builder.BytesWriter) error {
dbType := statement.dialect.URI().DBType
if dbType == "mssql" || dbType == "oracle" {
if dbType == schemas.MSSQL || dbType == schemas.ORACLE {
return statement.writeOffsetFetch(bw)
}
return statement.writeLimitOffset(bw)
@ -50,15 +51,15 @@ func (statement *Statement) writeOffsetFetch(w builder.Writer) error {
}
func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error {
if !statement.cond.IsValid() {
return statement.writeMssqlPaginationCond(w)
}
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil {
return err
if statement.cond.IsValid() {
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil {
return err
}
}
return statement.writeMssqlPaginationCond(w)
}
@ -115,15 +116,8 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err
if _, err := fmt.Fprint(subWriter, "))"); err != nil {
return err
}
if statement.cond.IsValid() {
if _, err := fmt.Fprint(w, " AND "); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
if err := statement.writeWhereOrAnd(w, statement.cond.IsValid()); err != nil {
return err
}
return utils.WriteBuilder(w, subWriter)

View File

@ -350,6 +350,15 @@ func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) e
return err
}
func (statement *Statement) writeWhereOrAnd(updateWriter *builder.BytesWriter, hasConditions bool) error {
if hasConditions {
_, err := fmt.Fprint(updateWriter, " AND ")
return err
}
_, err := fmt.Fprint(updateWriter, " WHERE ")
return err
}
func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, cond builder.Cond) error {
if statement.LimitN == nil {
return nil
@ -364,14 +373,8 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter,
_, err := fmt.Fprintf(updateWriter, " LIMIT %d", limitValue)
return err
case schemas.SQLITE:
if cond.IsValid() {
if _, err := fmt.Fprint(updateWriter, " AND "); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil {
return err
}
if err := statement.writeWhereOrAnd(updateWriter, cond.IsValid()); err != nil {
return err
}
if _, err := fmt.Fprint(updateWriter, "rowid IN (SELECT rowid FROM ", statement.quote(tableName)); err != nil {
return err
@ -385,14 +388,8 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter,
_, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue)
return err
case schemas.POSTGRES:
if cond.IsValid() {
if _, err := fmt.Fprint(updateWriter, " AND "); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil {
return err
}
if err := statement.writeWhereOrAnd(updateWriter, cond.IsValid()); err != nil {
return err
}
if _, err := fmt.Fprint(updateWriter, "CTID IN (SELECT CTID FROM ", statement.quote(tableName)); err != nil {
return err
@ -477,9 +474,9 @@ func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Valu
return nil
}
func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error {
func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSets bool) error {
for i, expr := range statement.IncrColumns {
if i > 0 || hasPreviousSet {
if i > 0 || hasPreviousSets {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
@ -492,10 +489,10 @@ func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool)
return nil
}
func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error {
func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSets bool) error {
// for update action to like "column = column - ?"
for i, expr := range statement.DecrColumns {
if i > 0 || hasPreviousSet {
if i > 0 || hasPreviousSets {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
@ -508,10 +505,10 @@ func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool)
return nil
}
func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error {
func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSets bool) error {
// for update action to like "column = expression"
for i, expr := range statement.ExprColumns {
if i > 0 || hasPreviousSet {
if i > 0 || hasPreviousSets {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
@ -544,33 +541,51 @@ func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet
return nil
}
func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error {
previousLen := w.Len()
for i, colName := range colNames {
if i > 0 {
if _, err := fmt.Fprint(w, ", "); err != nil {
func (statement *Statement) writeSetColumns(colNames []string, args []interface{}) func(w *builder.BytesWriter) error {
return func(w *builder.BytesWriter) error {
if len(colNames) == 0 {
return nil
}
if len(colNames) != len(args) {
return fmt.Errorf("columns elements %d but args elements %d", len(colNames), len(args))
}
for i, colName := range colNames {
if i > 0 {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, colName); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, colName); err != nil {
return err
}
w.Append(args...)
return nil
}
w.Append(args...)
}
if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil {
func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error {
if err := statement.writeSetColumns(colNames, args)(w); err != nil {
return err
}
if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil {
setNumber := len(colNames)
if err := statement.writeIncrSets(w, setNumber > 0); err != nil {
return err
}
if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil {
setNumber += len(statement.IncrColumns)
if err := statement.writeDecrSets(w, setNumber > 0); err != nil {
return err
}
if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil {
setNumber += len(statement.DecrColumns)
if err := statement.writeExprSets(w, setNumber > 0); err != nil {
return err
}
setNumber += len(statement.ExprColumns)
if err := statement.writeVersionIncrSet(w, v, setNumber > 0); err != nil {
return err
}
return nil