Some refactors (#2348)

Reviewed-on: https://gitea.com/xorm/xorm/pulls/2348
This commit is contained in:
Lunny Xiao 2023-10-27 09:16:46 +00:00
parent 6ef0a7798f
commit 8eeb1ef8ac
3 changed files with 62 additions and 53 deletions

View File

@ -10,11 +10,12 @@ import (
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
) )
func (statement *Statement) writePagination(bw *builder.BytesWriter) error { func (statement *Statement) writePagination(bw *builder.BytesWriter) error {
dbType := statement.dialect.URI().DBType dbType := statement.dialect.URI().DBType
if dbType == "mssql" || dbType == "oracle" { if dbType == schemas.MSSQL || dbType == schemas.ORACLE {
return statement.writeOffsetFetch(bw) return statement.writeOffsetFetch(bw)
} }
return statement.writeLimitOffset(bw) return statement.writeLimitOffset(bw)
@ -50,15 +51,15 @@ func (statement *Statement) writeOffsetFetch(w builder.Writer) error {
} }
func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error { func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error {
if !statement.cond.IsValid() { if statement.cond.IsValid() {
return statement.writeMssqlPaginationCond(w) if _, err := fmt.Fprint(w, " WHERE "); err != nil {
} return err
if _, err := fmt.Fprint(w, " WHERE "); err != nil { }
return err if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil {
} return err
if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil { }
return err
} }
return statement.writeMssqlPaginationCond(w) return statement.writeMssqlPaginationCond(w)
} }
@ -115,15 +116,8 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err
if _, err := fmt.Fprint(subWriter, "))"); err != nil { if _, err := fmt.Fprint(subWriter, "))"); err != nil {
return err return err
} }
if err := statement.writeWhereOrAnd(w, statement.cond.IsValid()); err != nil {
if statement.cond.IsValid() { return err
if _, err := fmt.Fprint(w, " AND "); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
} }
return utils.WriteBuilder(w, subWriter) return utils.WriteBuilder(w, subWriter)

View File

@ -350,6 +350,15 @@ func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) e
return err return err
} }
func (statement *Statement) writeWhereOrAnd(updateWriter *builder.BytesWriter, hasConditions bool) error {
if hasConditions {
_, err := fmt.Fprint(updateWriter, " AND ")
return err
}
_, err := fmt.Fprint(updateWriter, " WHERE ")
return err
}
func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, cond builder.Cond) error { func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, cond builder.Cond) error {
if statement.LimitN == nil { if statement.LimitN == nil {
return nil return nil
@ -364,14 +373,8 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter,
_, err := fmt.Fprintf(updateWriter, " LIMIT %d", limitValue) _, err := fmt.Fprintf(updateWriter, " LIMIT %d", limitValue)
return err return err
case schemas.SQLITE: case schemas.SQLITE:
if cond.IsValid() { if err := statement.writeWhereOrAnd(updateWriter, cond.IsValid()); err != nil {
if _, err := fmt.Fprint(updateWriter, " AND "); err != nil { return err
return err
}
} else {
if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil {
return err
}
} }
if _, err := fmt.Fprint(updateWriter, "rowid IN (SELECT rowid FROM ", statement.quote(tableName)); err != nil { if _, err := fmt.Fprint(updateWriter, "rowid IN (SELECT rowid FROM ", statement.quote(tableName)); err != nil {
return err return err
@ -385,14 +388,8 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter,
_, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue) _, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue)
return err return err
case schemas.POSTGRES: case schemas.POSTGRES:
if cond.IsValid() { if err := statement.writeWhereOrAnd(updateWriter, cond.IsValid()); err != nil {
if _, err := fmt.Fprint(updateWriter, " AND "); err != nil { return err
return err
}
} else {
if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil {
return err
}
} }
if _, err := fmt.Fprint(updateWriter, "CTID IN (SELECT CTID FROM ", statement.quote(tableName)); err != nil { if _, err := fmt.Fprint(updateWriter, "CTID IN (SELECT CTID FROM ", statement.quote(tableName)); err != nil {
return err return err
@ -477,9 +474,9 @@ func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Valu
return nil return nil
} }
func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error { func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSets bool) error {
for i, expr := range statement.IncrColumns { for i, expr := range statement.IncrColumns {
if i > 0 || hasPreviousSet { if i > 0 || hasPreviousSets {
if _, err := fmt.Fprint(w, ", "); err != nil { if _, err := fmt.Fprint(w, ", "); err != nil {
return err return err
} }
@ -492,10 +489,10 @@ func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool)
return nil return nil
} }
func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error { func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSets bool) error {
// for update action to like "column = column - ?" // for update action to like "column = column - ?"
for i, expr := range statement.DecrColumns { for i, expr := range statement.DecrColumns {
if i > 0 || hasPreviousSet { if i > 0 || hasPreviousSets {
if _, err := fmt.Fprint(w, ", "); err != nil { if _, err := fmt.Fprint(w, ", "); err != nil {
return err return err
} }
@ -508,10 +505,10 @@ func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool)
return nil return nil
} }
func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error { func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSets bool) error {
// for update action to like "column = expression" // for update action to like "column = expression"
for i, expr := range statement.ExprColumns { for i, expr := range statement.ExprColumns {
if i > 0 || hasPreviousSet { if i > 0 || hasPreviousSets {
if _, err := fmt.Fprint(w, ", "); err != nil { if _, err := fmt.Fprint(w, ", "); err != nil {
return err return err
} }
@ -544,33 +541,51 @@ func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet
return nil return nil
} }
func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error { func (statement *Statement) writeSetColumns(colNames []string, args []interface{}) func(w *builder.BytesWriter) error {
previousLen := w.Len() return func(w *builder.BytesWriter) error {
for i, colName := range colNames { if len(colNames) == 0 {
if i > 0 { return nil
if _, err := fmt.Fprint(w, ", "); err != nil { }
if len(colNames) != len(args) {
return fmt.Errorf("columns elements %d but args elements %d", len(colNames), len(args))
}
for i, colName := range colNames {
if i > 0 {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, colName); err != nil {
return err return err
} }
} }
if _, err := fmt.Fprint(w, colName); err != nil { w.Append(args...)
return err return nil
}
} }
w.Append(args...) }
if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil { func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error {
if err := statement.writeSetColumns(colNames, args)(w); err != nil {
return err return err
} }
if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil { setNumber := len(colNames)
if err := statement.writeIncrSets(w, setNumber > 0); err != nil {
return err return err
} }
if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil { setNumber += len(statement.IncrColumns)
if err := statement.writeDecrSets(w, setNumber > 0); err != nil {
return err return err
} }
if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil { setNumber += len(statement.DecrColumns)
if err := statement.writeExprSets(w, setNumber > 0); err != nil {
return err
}
setNumber += len(statement.ExprColumns)
if err := statement.writeVersionIncrSet(w, v, setNumber > 0); err != nil {
return err return err
} }
return nil return nil