diff --git a/session_cols_test.go b/session_cols_test.go index 5f5954c7..7deffa87 100644 --- a/session_cols_test.go +++ b/session_cols_test.go @@ -7,8 +7,9 @@ package xorm import ( "testing" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/builder" + "xorm.io/core" ) func TestSetExpr(t *testing.T) { @@ -34,6 +35,15 @@ func TestSetExpr(t *testing.T) { cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr)) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.SetExpr("show", + builder.Select("NOT show"). + From("user_expr"). + Where(builder.Eq{"id": 1})). + ID(1). + Update(new(UserExpr)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) } func TestCols(t *testing.T) { diff --git a/session_update.go b/session_update.go index 402470e5..c5c65a45 100644 --- a/session_update.go +++ b/session_update.go @@ -245,7 +245,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if err != nil { return 0, err } - colNames = append(colNames, session.engine.Quote(colName)+" = "+subQuery) + colNames = append(colNames, session.engine.Quote(colName)+" = ("+subQuery+")") args = append(args, subArgs...) } } diff --git a/statement_args.go b/statement_args.go index c6168db1..5353ae1a 100644 --- a/statement_args.go +++ b/statement_args.go @@ -17,9 +17,15 @@ func writeArg(w *builder.BytesWriter, arg interface{}) error { return err } case *builder.Builder: + if _, err := w.WriteString("("); err != nil { + return err + } if err := argv.WriteTo(w); err != nil { return err } + if _, err := w.WriteString(")"); err != nil { + return err + } default: if _, err := w.WriteString(fmt.Sprintf("%v", argv)); err != nil { return err diff --git a/statement_exprparam.go b/statement_exprparam.go index a72f0aea..0cddca02 100644 --- a/statement_exprparam.go +++ b/statement_exprparam.go @@ -60,9 +60,15 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { for _, expr := range exprs.args { switch arg := expr.(type) { case *builder.Builder: + if _, err := w.WriteString("("); err != nil { + return err + } if err := arg.WriteTo(w); err != nil { return err } + if _, err := w.WriteString(")"); err != nil { + return err + } default: if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil { return err @@ -83,9 +89,15 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { switch arg := exprs.args[i].(type) { case *builder.Builder: + if _, err := w.WriteString("("); err != nil { + return err + } if err := arg.WriteTo(w); err != nil { return err } + if _, err := w.WriteString("("); err != nil { + return err + } default: w.Append(exprs.args[i]) }