Refactor write update (#2310)
Reviewed-on: https://gitea.com/xorm/xorm/pulls/2310
This commit is contained in:
parent
9aab1f689c
commit
cb4f310151
|
@ -427,7 +427,158 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string, args []interface{}) error {
|
func (statement *Statement) GenConditionsFromMap(m interface{}) ([]builder.Cond, error) {
|
||||||
|
switch t := m.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
conds := []builder.Cond{}
|
||||||
|
for k, v := range t {
|
||||||
|
conds = append(conds, builder.Eq{k: v})
|
||||||
|
}
|
||||||
|
return conds, nil
|
||||||
|
case map[string]string:
|
||||||
|
conds := []builder.Cond{}
|
||||||
|
for k, v := range t {
|
||||||
|
conds = append(conds, builder.Eq{k: v})
|
||||||
|
}
|
||||||
|
return conds, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported condition map type %v", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Value, hasPreviousSet bool) error {
|
||||||
|
if v.Type().Kind() != reflect.Struct {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
table := statement.RefTable
|
||||||
|
if !(statement.RefTable != nil && table.Version != "" && statement.CheckVersion) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
verValue, err := table.VersionColumn().ValueOfV(&v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if verValue == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasPreviousSet {
|
||||||
|
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := fmt.Fprint(w, statement.quote(table.Version), " = ", statement.quote(table.Version), " + 1"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error {
|
||||||
|
for i, expr := range statement.IncrColumns {
|
||||||
|
if i > 0 || hasPreviousSet {
|
||||||
|
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " + ?"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.Append(expr.Arg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error {
|
||||||
|
// for update action to like "column = column - ?"
|
||||||
|
for i, expr := range statement.DecrColumns {
|
||||||
|
if i > 0 || hasPreviousSet {
|
||||||
|
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " - ?"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.Append(expr.Arg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error {
|
||||||
|
// for update action to like "column = expression"
|
||||||
|
for i, expr := range statement.ExprColumns {
|
||||||
|
if i > 0 || hasPreviousSet {
|
||||||
|
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch tp := expr.Arg.(type) {
|
||||||
|
case string:
|
||||||
|
if len(tp) == 0 {
|
||||||
|
tp = "''"
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", tp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case *builder.Builder:
|
||||||
|
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ("); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tp.WriteTo(statement.QuoteReplacer(w)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(w, ")"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ?"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.Append(expr.Arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error {
|
||||||
|
previousLen := w.Len()
|
||||||
|
for i, colName := range colNames {
|
||||||
|
if i > 0 {
|
||||||
|
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(w, colName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.Append(args...)
|
||||||
|
|
||||||
|
if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
|
||||||
|
|
||||||
|
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error {
|
||||||
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
|
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -444,17 +595,16 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond
|
||||||
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
|
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for i, colName := range colNames {
|
previousLen := updateWriter.Len()
|
||||||
if i > 0 {
|
|
||||||
if _, err := fmt.Fprint(updateWriter, ", "); err != nil {
|
if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if no columns to be updated, return error
|
||||||
|
if previousLen == updateWriter.Len() {
|
||||||
|
return ErrNoColumnsTobeUpdated
|
||||||
}
|
}
|
||||||
if _, err := fmt.Fprint(updateWriter, colName); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
updateWriter.Append(args...)
|
|
||||||
|
|
||||||
// write from
|
// write from
|
||||||
if err := statement.writeUpdateFrom(updateWriter); err != nil {
|
if err := statement.writeUpdateFrom(updateWriter); err != nil {
|
||||||
|
|
|
@ -5,17 +5,17 @@
|
||||||
package xorm
|
package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"xorm.io/builder"
|
"xorm.io/builder"
|
||||||
|
"xorm.io/xorm/internal/statements"
|
||||||
"xorm.io/xorm/internal/utils"
|
"xorm.io/xorm/internal/utils"
|
||||||
"xorm.io/xorm/schemas"
|
"xorm.io/xorm/schemas"
|
||||||
)
|
)
|
||||||
|
|
||||||
// enumerated all errors
|
// enumerated all errors
|
||||||
var (
|
var (
|
||||||
ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
|
ErrNoColumnsTobeUpdated = statements.ErrNoColumnsTobeUpdated
|
||||||
)
|
)
|
||||||
|
|
||||||
func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
|
func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
|
||||||
|
@ -74,9 +74,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
v := utils.ReflectValue(bean)
|
v := utils.ReflectValue(bean)
|
||||||
t := v.Type()
|
t := v.Type()
|
||||||
|
|
||||||
var colNames []string
|
|
||||||
var args []interface{}
|
|
||||||
|
|
||||||
// handle before update processors
|
// handle before update processors
|
||||||
for _, closure := range session.beforeClosures {
|
for _, closure := range session.beforeClosures {
|
||||||
closure(bean)
|
closure(bean)
|
||||||
|
@ -87,6 +84,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
// --
|
// --
|
||||||
|
|
||||||
|
var colNames []string
|
||||||
|
var args []interface{}
|
||||||
var err error
|
var err error
|
||||||
isMap := t.Kind() == reflect.Map
|
isMap := t.Kind() == reflect.Map
|
||||||
isStruct := t.Kind() == reflect.Struct
|
isStruct := t.Kind() == reflect.Struct
|
||||||
|
@ -148,41 +147,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// for update action to like "column = column + ?"
|
|
||||||
incColumns := session.statement.IncrColumns
|
|
||||||
for _, expr := range incColumns {
|
|
||||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?")
|
|
||||||
args = append(args, expr.Arg)
|
|
||||||
}
|
|
||||||
// for update action to like "column = column - ?"
|
|
||||||
decColumns := session.statement.DecrColumns
|
|
||||||
for _, expr := range decColumns {
|
|
||||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?")
|
|
||||||
args = append(args, expr.Arg)
|
|
||||||
}
|
|
||||||
// for update action to like "column = expression"
|
|
||||||
exprColumns := session.statement.ExprColumns
|
|
||||||
for _, expr := range exprColumns {
|
|
||||||
switch tp := expr.Arg.(type) {
|
|
||||||
case string:
|
|
||||||
if len(tp) == 0 {
|
|
||||||
tp = "''"
|
|
||||||
}
|
|
||||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp)
|
|
||||||
case *builder.Builder:
|
|
||||||
subQuery, subArgs, err := builder.ToSQL(tp)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
subQuery = session.statement.ReplaceQuote(subQuery)
|
|
||||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")")
|
|
||||||
args = append(args, subArgs...)
|
|
||||||
default:
|
|
||||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?")
|
|
||||||
args = append(args, expr.Arg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = session.statement.ProcessIDParam(); err != nil {
|
if err = session.statement.ProcessIDParam(); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -211,23 +175,18 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
verValue *reflect.Value
|
verValue *reflect.Value
|
||||||
)
|
)
|
||||||
if doIncVer {
|
if doIncVer {
|
||||||
verValue, err = table.VersionColumn().ValueOf(bean)
|
verValue, err = table.VersionColumn().ValueOfV(&v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if verValue != nil {
|
if verValue != nil {
|
||||||
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
|
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
|
||||||
colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(colNames) == 0 {
|
|
||||||
return 0, ErrNoColumnsTobeUpdated
|
|
||||||
}
|
|
||||||
|
|
||||||
updateWriter := builder.NewWriter()
|
updateWriter := builder.NewWriter()
|
||||||
if err := session.statement.WriteUpdate(updateWriter, cond, colNames, args); err != nil {
|
if err := session.statement.WriteUpdate(updateWriter, cond, v, colNames, args); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue