diff --git a/integrations/session_update_test.go b/integrations/session_update_test.go index cc1042b6..586fce06 100644 --- a/integrations/session_update_test.go +++ b/integrations/session_update_test.go @@ -6,6 +6,7 @@ package integrations import ( "fmt" + "strings" "sync" "testing" "time" @@ -1458,3 +1459,43 @@ func TestNilFromDB(t *testing.T) { assert.NotNil(t, tt4.Field1) assert.NotNil(t, tt4.Field1.cb) } + +func TestUpdateSetExprs2(t *testing.T) { + type TblPassiveHost struct { + Id int64 `xorm:"not null pk autoincr INT(11)"` + Host string `xorm:"varchar(255) not null unique"` + Backend string `xorm:"text"` + Status uint32 `xorm:"not null tinyint default 0"` + + RecentCount int64 `xorm:"bigint default 0"` + TotalCount int64 `xorm:"bigint default 0"` + UrlCount int64 `xorm:"bigint default 0"` //number of urls + VulnCount int64 `xorm:"bigint default 0"` //number of vulns + + Atime time.Time `xorm:"DATETIME created"` + Utime time.Time `xorm:"DATETIME updated"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TblPassiveHost)) + + var b = TblPassiveHost{ + Host: "192.168.1.1", + Backend: "linux", + Status: 1, + RecentCount: 2, + TotalCount: 3, + UrlCount: 4, + VulnCount: 5, + } + cnt, err := testEngine.Insert(&b) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var passiveHost = TblPassiveHost{} + _, err = testEngine.Where(`host=?`, b.Host). + Incr(`total_count`, 2). + SetExpr(`recent_count`, 2). + SetExpr(`backend`, strings.Join([]string{"192.168.2.1", "192.168.3.1"}, ",")).Update(&passiveHost) + assert.NoError(t, err) +} diff --git a/session_update.go b/session_update.go index 4f8e6961..52ccbb72 100644 --- a/session_update.go +++ b/session_update.go @@ -251,8 +251,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 case string: if len(tp) == 0 { tp = "''" + } else { + tp = strings.Replace(tp, "'", "''", -1) } - colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp) + colNames = append(colNames, session.engine.Quote(expr.ColName)+"='"+tp+"'") case *builder.Builder: subQuery, subArgs, err := session.statement.GenCondSQL(tp) if err != nil {