diff --git a/session_update.go b/session_update.go index 73484a82..0f2d1b5c 100644 --- a/session_update.go +++ b/session_update.go @@ -253,9 +253,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var condSQL string cond := session.Statement.cond.And(autoCond) - doIncVer := false + var doIncVer = (table != nil && table.Version != "" && session.Statement.checkVersion) var verValue *reflect.Value - if table != nil && table.Version != "" && session.Statement.checkVersion { + if doIncVer { verValue, err = table.VersionColumn().ValueOf(bean) if err != nil { return 0, err @@ -263,7 +263,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 cond = cond.And(builder.Eq{session.Engine.Quote(table.Version): verValue.Interface()}) colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1") - doIncVer = true } condSQL, condArgs, _ = builder.ToSQL(cond) @@ -275,8 +274,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 condSQL = condSQL + fmt.Sprintf(" ORDER BY %v", st.OrderStr) } - // TODO: Only Mysql support - // MSSQL: update top (100) table1 set field1 = 1 + // TODO: Oracle support needed + var top string if st.LimitN > 0 { if st.Engine.dialect.DBType() == core.MYSQL { condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) @@ -288,10 +287,21 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } + } else if st.Engine.dialect.DBType() == core.POSTGRES { + tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", + session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...)) + condSQL, condArgs, _ = builder.ToSQL(cond) + if len(condSQL) > 0 { + condSQL = "WHERE " + condSQL + } + } else if st.Engine.dialect.DBType() == core.MSSQL { + top = fmt.Sprintf("top (%d) ", st.LimitN) } } - sqlStr = fmt.Sprintf("UPDATE %v SET %v %v", + sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v", + top, session.Engine.Quote(session.Statement.TableName()), strings.Join(colNames, ", "), condSQL)