diff --git a/session_update.go b/session_update.go index 27e2deb0..73484a82 100644 --- a/session_update.go +++ b/session_update.go @@ -262,39 +262,40 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } cond = cond.And(builder.Eq{session.Engine.Quote(table.Version): verValue.Interface()}) - condSQL, condArgs, _ = builder.ToSQL(cond) - - if len(condSQL) > 0 { - condSQL = "WHERE " + condSQL - } - - if st.LimitN > 0 { - condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) - } - - sqlStr = fmt.Sprintf("UPDATE %v SET %v, %v %v", - session.Engine.Quote(session.Statement.TableName()), - strings.Join(colNames, ", "), - session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1", - condSQL) - + colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1") doIncVer = true - } else { - condSQL, condArgs, _ = builder.ToSQL(cond) - if len(condSQL) > 0 { - condSQL = "WHERE " + condSQL - } - - if st.LimitN > 0 { - condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) - } - - sqlStr = fmt.Sprintf("UPDATE %v SET %v %v", - session.Engine.Quote(session.Statement.TableName()), - strings.Join(colNames, ", "), - condSQL) } + condSQL, condArgs, _ = builder.ToSQL(cond) + if len(condSQL) > 0 { + condSQL = "WHERE " + condSQL + } + + if st.OrderStr != "" { + condSQL = condSQL + fmt.Sprintf(" ORDER BY %v", st.OrderStr) + } + + // TODO: Only Mysql support + // MSSQL: update top (100) table1 set field1 = 1 + if st.LimitN > 0 { + if st.Engine.dialect.DBType() == core.MYSQL { + condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + } else if st.Engine.dialect.DBType() == core.SQLITE { + tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", + session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...)) + condSQL, condArgs, _ = builder.ToSQL(cond) + if len(condSQL) > 0 { + condSQL = "WHERE " + condSQL + } + } + } + + sqlStr = fmt.Sprintf("UPDATE %v SET %v %v", + session.Engine.Quote(session.Statement.TableName()), + strings.Join(colNames, ", "), + condSQL) + res, err := session.exec(sqlStr, append(args, condArgs...)...) if err != nil { return 0, err diff --git a/session_update_test.go b/session_update_test.go index eb1b79df..9eeb6186 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -34,3 +34,41 @@ func TestUpdateMap(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) } + +func TestUpdateLimit(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UpdateTable struct { + Id int64 + Name string + Age int + } + + assert.NoError(t, testEngine.Sync2(new(UpdateTable))) + var tb = UpdateTable{ + Name: "test1", + Age: 35, + } + cnt, err := testEngine.Insert(&tb) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + tb.Name = "test2" + tb.Id = 0 + cnt, err = testEngine.Insert(&tb) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.OrderBy("name desc").Limit(1).Update(&UpdateTable{ + Age: 30, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var uts []UpdateTable + err = testEngine.Find(&uts) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(uts)) + assert.EqualValues(t, 35, uts[0].Age) + assert.EqualValues(t, 30, uts[1].Age) +} diff --git a/xorm_test.go b/xorm_test.go index df77f281..b124deec 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -16,15 +16,15 @@ var ( ) func prepareSqlite3Engine() error { - if testEngine == nil { - os.Remove("./test.db") - var err error - testEngine, err = NewEngine("sqlite3", "./test.db") - if err != nil { - return err - } - testEngine.ShowSQL(*showSQL) + //if testEngine == nil { + os.Remove("./test.db") + var err error + testEngine, err = NewEngine("sqlite3", "./test.db") + if err != nil { + return err } + testEngine.ShowSQL(*showSQL) + //} return nil }