diff --git a/mssql_dialect.go b/mssql_dialect.go index 0eef76d4..41989b91 100644 --- a/mssql_dialect.go +++ b/mssql_dialect.go @@ -509,6 +509,10 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars return sql } +func (db *mssql) ForUpdateSql(query string) string { + return query +} + func (db *mssql) Filters() []core.Filter { return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}} } diff --git a/session.go b/session.go index 0a365aba..419a5876 100644 --- a/session.go +++ b/session.go @@ -225,6 +225,12 @@ func (session *Session) Distinct(columns ...string) *Session { return session } +// Set Read/Write locking for UPDATE +func (session *Session) ForUpdate() *Session { + session.Statement.IsForUpdate = true + return session +} + // Only not use the paramters as select or update columns func (session *Session) Omit(columns ...string) *Session { session.Statement.Omit(columns...) diff --git a/sqlite3_dialect.go b/sqlite3_dialect.go index 94e7d6b3..80873dbd 100644 --- a/sqlite3_dialect.go +++ b/sqlite3_dialect.go @@ -250,6 +250,10 @@ func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string { return fmt.Sprintf("DROP INDEX %v", quote(idxName)) } +func (db *sqlite3) ForUpdateSql(query string) string { + return query +} + /*func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { args := []interface{}{tableName} sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" diff --git a/statement.go b/statement.go index f4d0bcf5..6064c542 100644 --- a/statement.go +++ b/statement.go @@ -66,6 +66,7 @@ type Statement struct { UseCache bool UseAutoTime bool IsDistinct bool + IsForUpdate bool TableAlias string allUseBool bool checkVersion bool @@ -102,6 +103,7 @@ func (statement *Statement) Init() { statement.UseCache = true statement.UseAutoTime = true statement.IsDistinct = false + statement.IsForUpdate = false statement.TableAlias = "" statement.selectStr = "" statement.allUseBool = false @@ -802,6 +804,12 @@ func (statement *Statement) Distinct(columns ...string) *Statement { return statement } +// Generate "SELECT ... FOR UPDATE" statment +func (statement *Statement) ForUpdate() *Statement { + statement.IsForUpdate = true + return statement +} + // replace select func (s *Statement) Select(str string) *Statement { s.selectStr = str @@ -1184,6 +1192,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { distinct = "DISTINCT " } + var dialect core.Dialect = statement.Engine.Dialect() var top string var mssqlCondi string /*var orderBy string @@ -1195,7 +1204,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { if statement.WhereStr != "" { whereStr = fmt.Sprintf(" WHERE %v", statement.WhereStr) if statement.ConditionStr != "" { - whereStr = fmt.Sprintf("%v %s %v", whereStr, statement.Engine.Dialect().AndStr(), + whereStr = fmt.Sprintf("%v %s %v", whereStr, dialect.AndStr(), statement.ConditionStr) } } else if statement.ConditionStr != "" { @@ -1203,7 +1212,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { } var fromStr string = " FROM " + statement.Engine.Quote(statement.TableName()) if statement.TableAlias != "" { - if statement.Engine.dialect.DBType() == core.ORACLE { + if dialect.DBType() == core.ORACLE { fromStr += " " + statement.Engine.Quote(statement.TableAlias) } else { fromStr += " AS " + statement.Engine.Quote(statement.TableAlias) @@ -1213,7 +1222,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) } - if statement.Engine.dialect.DBType() == core.MSSQL { + if dialect.DBType() == core.MSSQL { if statement.LimitN > 0 { top = fmt.Sprintf(" TOP %d ", statement.LimitN) } @@ -1263,17 +1272,20 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { if statement.OrderStr != "" { a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) } - if statement.Engine.dialect.DBType() != core.MSSQL && statement.Engine.dialect.DBType() != core.ORACLE { + if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { if statement.Start > 0 { a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) } else if statement.LimitN > 0 { a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) } - } else if statement.Engine.dialect.DBType() == core.ORACLE { + } else if dialect.DBType() == core.ORACLE { if statement.Start != 0 || statement.LimitN != 0 { a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) } } + if statement.IsForUpdate { + a = dialect.ForUpdateSql(a) + } return }