diff --git a/internal/statements/update.go b/internal/statements/update.go index 1eb431a8..40ffa273 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -427,7 +427,133 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, } } -func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string, args []interface{}) error { +func (statement *Statement) GenConditionsFromMap(m interface{}) ([]builder.Cond, error) { + switch t := m.(type) { + case map[string]interface{}: + conds := []builder.Cond{} + for k, v := range t { + conds = append(conds, builder.Eq{k: v}) + } + return conds, nil + case map[string]string: + conds := []builder.Cond{} + for k, v := range t { + conds = append(conds, builder.Eq{k: v}) + } + return conds, nil + default: + return nil, fmt.Errorf("unsupported condition map type %v", t) + } +} + +func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Value) error { + if v.Type().Kind() != reflect.Struct { + return nil + } + + table := statement.RefTable + if !(statement.RefTable != nil && table.Version != "" && statement.CheckVersion) { + return nil + } + + verValue, err := table.VersionColumn().ValueOfV(&v) + if err != nil { + return err + } + + if verValue == nil { + return nil + } + + if _, err := fmt.Fprint(w, statement.quote(table.Version), " = ", statement.quote(table.Version), " + 1"); err != nil { + return err + } + return nil +} + +func (statement *Statement) writeIncrSets(w builder.Writer) error { + for i, expr := range statement.IncrColumns { + if i > 0 { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " + ?"); err != nil { + return err + } + w.Append(expr.Arg) + } + return nil +} + +func (statement *Statement) writeDecrSets(w builder.Writer) error { + // for update action to like "column = column - ?" + for i, expr := range statement.DecrColumns { + if i > 0 { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " - ?"); err != nil { + return err + } + w.Append(expr.Arg) + } + return nil +} + +func (statement *Statement) writeExprSets(w *builder.BytesWriter) error { + // for update action to like "column = expression" + for i, expr := range statement.ExprColumns { + if i > 0 { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + switch tp := expr.Arg.(type) { + case string: + if len(tp) == 0 { + tp = "''" + } + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", tp); err != nil { + return err + } + case *builder.Builder: + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ("); err != nil { + return err + } + if err := tp.WriteTo(statement.QuoteReplacer(w)); err != nil { + return err + } + if _, err := fmt.Fprint(w, ")"); err != nil { + return err + } + default: + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ?"); err != nil { + return err + } + w.Append(expr.Arg) + } + } + return nil +} + +func (statement *Statement) writeUpdateSets(w *builder.BytesWriter) error { + if err := statement.writeIncrSets(w); err != nil { + return err + } + if err := statement.writeDecrSets(w); err != nil { + return err + } + if err := statement.writeExprSets(w); err != nil { + return err + } + return nil +} + +var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") + +func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error { if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil { return err } @@ -444,6 +570,7 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond if _, err := fmt.Fprint(updateWriter, " SET "); err != nil { return err } + previousLen := updateWriter.Len() for i, colName := range colNames { if i > 0 { if _, err := fmt.Fprint(updateWriter, ", "); err != nil { @@ -456,6 +583,19 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond } updateWriter.Append(args...) + if err := statement.writeUpdateSets(updateWriter); err != nil { + return err + } + + if err := statement.writeVersionIncrSet(updateWriter, v); err != nil { + return err + } + + // if no columns to be updated, return error + if previousLen == updateWriter.Len() { + return ErrNoColumnsTobeUpdated + } + // write from if err := statement.writeUpdateFrom(updateWriter); err != nil { return err diff --git a/session_update.go b/session_update.go index 97c55eb4..b3640ad2 100644 --- a/session_update.go +++ b/session_update.go @@ -5,17 +5,17 @@ package xorm import ( - "errors" "reflect" "xorm.io/builder" + "xorm.io/xorm/internal/statements" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) // enumerated all errors var ( - ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") + ErrNoColumnsTobeUpdated = statements.ErrNoColumnsTobeUpdated ) func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) { @@ -74,9 +74,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 v := utils.ReflectValue(bean) t := v.Type() - var colNames []string - var args []interface{} - // handle before update processors for _, closure := range session.beforeClosures { closure(bean) @@ -87,6 +84,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } // -- + var colNames []string + var args []interface{} var err error isMap := t.Kind() == reflect.Map isStruct := t.Kind() == reflect.Struct @@ -148,41 +147,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - // for update action to like "column = column + ?" - incColumns := session.statement.IncrColumns - for _, expr := range incColumns { - colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?") - args = append(args, expr.Arg) - } - // for update action to like "column = column - ?" - decColumns := session.statement.DecrColumns - for _, expr := range decColumns { - colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?") - args = append(args, expr.Arg) - } - // for update action to like "column = expression" - exprColumns := session.statement.ExprColumns - for _, expr := range exprColumns { - switch tp := expr.Arg.(type) { - case string: - if len(tp) == 0 { - tp = "''" - } - colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp) - case *builder.Builder: - subQuery, subArgs, err := builder.ToSQL(tp) - if err != nil { - return 0, err - } - subQuery = session.statement.ReplaceQuote(subQuery) - colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")") - args = append(args, subArgs...) - default: - colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?") - args = append(args, expr.Arg) - } - } - if err = session.statement.ProcessIDParam(); err != nil { return 0, err } @@ -211,23 +175,18 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 verValue *reflect.Value ) if doIncVer { - verValue, err = table.VersionColumn().ValueOf(bean) + verValue, err = table.VersionColumn().ValueOfV(&v) if err != nil { return 0, err } if verValue != nil { cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()}) - colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1") } } - if len(colNames) == 0 { - return 0, ErrNoColumnsTobeUpdated - } - updateWriter := builder.NewWriter() - if err := session.statement.WriteUpdate(updateWriter, cond, colNames, args); err != nil { + if err := session.statement.WriteUpdate(updateWriter, cond, v, colNames, args); err != nil { return 0, err }