Refactor write update (#2310)

Reviewed-on: https://gitea.com/xorm/xorm/pulls/2310
This commit is contained in:
Lunny Xiao 2023-07-25 10:49:55 +00:00
parent 9aab1f689c
commit cb4f310151
2 changed files with 167 additions and 58 deletions

View File

@ -427,7 +427,158 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter,
} }
} }
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string, args []interface{}) error { func (statement *Statement) GenConditionsFromMap(m interface{}) ([]builder.Cond, error) {
switch t := m.(type) {
case map[string]interface{}:
conds := []builder.Cond{}
for k, v := range t {
conds = append(conds, builder.Eq{k: v})
}
return conds, nil
case map[string]string:
conds := []builder.Cond{}
for k, v := range t {
conds = append(conds, builder.Eq{k: v})
}
return conds, nil
default:
return nil, fmt.Errorf("unsupported condition map type %v", t)
}
}
func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Value, hasPreviousSet bool) error {
if v.Type().Kind() != reflect.Struct {
return nil
}
table := statement.RefTable
if !(statement.RefTable != nil && table.Version != "" && statement.CheckVersion) {
return nil
}
verValue, err := table.VersionColumn().ValueOfV(&v)
if err != nil {
return err
}
if verValue == nil {
return nil
}
if hasPreviousSet {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, statement.quote(table.Version), " = ", statement.quote(table.Version), " + 1"); err != nil {
return err
}
return nil
}
func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error {
for i, expr := range statement.IncrColumns {
if i > 0 || hasPreviousSet {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " + ?"); err != nil {
return err
}
w.Append(expr.Arg)
}
return nil
}
func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error {
// for update action to like "column = column - ?"
for i, expr := range statement.DecrColumns {
if i > 0 || hasPreviousSet {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " - ?"); err != nil {
return err
}
w.Append(expr.Arg)
}
return nil
}
func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error {
// for update action to like "column = expression"
for i, expr := range statement.ExprColumns {
if i > 0 || hasPreviousSet {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
switch tp := expr.Arg.(type) {
case string:
if len(tp) == 0 {
tp = "''"
}
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", tp); err != nil {
return err
}
case *builder.Builder:
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ("); err != nil {
return err
}
if err := tp.WriteTo(statement.QuoteReplacer(w)); err != nil {
return err
}
if _, err := fmt.Fprint(w, ")"); err != nil {
return err
}
default:
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ?"); err != nil {
return err
}
w.Append(expr.Arg)
}
}
return nil
}
func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error {
previousLen := w.Len()
for i, colName := range colNames {
if i > 0 {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, colName); err != nil {
return err
}
}
w.Append(args...)
if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil {
return err
}
if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil {
return err
}
if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil {
return err
}
if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil {
return err
}
return nil
}
var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error {
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil { if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
return err return err
} }
@ -444,17 +595,16 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil { if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
return err return err
} }
for i, colName := range colNames { previousLen := updateWriter.Len()
if i > 0 {
if _, err := fmt.Fprint(updateWriter, ", "); err != nil { if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
return err return err
} }
}
if _, err := fmt.Fprint(updateWriter, colName); err != nil { // if no columns to be updated, return error
return err if previousLen == updateWriter.Len() {
} return ErrNoColumnsTobeUpdated
} }
updateWriter.Append(args...)
// write from // write from
if err := statement.writeUpdateFrom(updateWriter); err != nil { if err := statement.writeUpdateFrom(updateWriter); err != nil {

View File

@ -5,17 +5,17 @@
package xorm package xorm
import ( import (
"errors"
"reflect" "reflect"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/internal/statements"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
// enumerated all errors // enumerated all errors
var ( var (
ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") ErrNoColumnsTobeUpdated = statements.ErrNoColumnsTobeUpdated
) )
func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) { func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
@ -74,9 +74,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
v := utils.ReflectValue(bean) v := utils.ReflectValue(bean)
t := v.Type() t := v.Type()
var colNames []string
var args []interface{}
// handle before update processors // handle before update processors
for _, closure := range session.beforeClosures { for _, closure := range session.beforeClosures {
closure(bean) closure(bean)
@ -87,6 +84,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
// -- // --
var colNames []string
var args []interface{}
var err error var err error
isMap := t.Kind() == reflect.Map isMap := t.Kind() == reflect.Map
isStruct := t.Kind() == reflect.Struct isStruct := t.Kind() == reflect.Struct
@ -148,41 +147,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
// for update action to like "column = column + ?"
incColumns := session.statement.IncrColumns
for _, expr := range incColumns {
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?")
args = append(args, expr.Arg)
}
// for update action to like "column = column - ?"
decColumns := session.statement.DecrColumns
for _, expr := range decColumns {
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?")
args = append(args, expr.Arg)
}
// for update action to like "column = expression"
exprColumns := session.statement.ExprColumns
for _, expr := range exprColumns {
switch tp := expr.Arg.(type) {
case string:
if len(tp) == 0 {
tp = "''"
}
colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp)
case *builder.Builder:
subQuery, subArgs, err := builder.ToSQL(tp)
if err != nil {
return 0, err
}
subQuery = session.statement.ReplaceQuote(subQuery)
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")")
args = append(args, subArgs...)
default:
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?")
args = append(args, expr.Arg)
}
}
if err = session.statement.ProcessIDParam(); err != nil { if err = session.statement.ProcessIDParam(); err != nil {
return 0, err return 0, err
} }
@ -211,23 +175,18 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
verValue *reflect.Value verValue *reflect.Value
) )
if doIncVer { if doIncVer {
verValue, err = table.VersionColumn().ValueOf(bean) verValue, err = table.VersionColumn().ValueOfV(&v)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if verValue != nil { if verValue != nil {
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()}) cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
} }
} }
if len(colNames) == 0 {
return 0, ErrNoColumnsTobeUpdated
}
updateWriter := builder.NewWriter() updateWriter := builder.NewWriter()
if err := session.statement.WriteUpdate(updateWriter, cond, colNames, args); err != nil { if err := session.statement.WriteUpdate(updateWriter, cond, v, colNames, args); err != nil {
return 0, err return 0, err
} }