From e89f74daa0fd551ba07b96b2dfb02d5bdf59e43f Mon Sep 17 00:00:00 2001 From: evalphobia Date: Fri, 28 Aug 2015 16:54:19 +0900 Subject: [PATCH 1/3] Added feature for SELECT ... FOR UPDATE --- session.go | 6 ++++++ statement.go | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/session.go b/session.go index f0edb0ba..9b17958e 100644 --- a/session.go +++ b/session.go @@ -225,6 +225,12 @@ func (session *Session) Distinct(columns ...string) *Session { return session } +// Set Read/Write locking for UPDATE +func (session *Session) ForUpdate() *Session { + session.Statement.IsForUpdate = true + return session +} + // Only not use the paramters as select or update columns func (session *Session) Omit(columns ...string) *Session { session.Statement.Omit(columns...) diff --git a/statement.go b/statement.go index f4d0bcf5..063a6193 100644 --- a/statement.go +++ b/statement.go @@ -66,6 +66,7 @@ type Statement struct { UseCache bool UseAutoTime bool IsDistinct bool + IsForUpdate bool TableAlias string allUseBool bool checkVersion bool @@ -102,6 +103,7 @@ func (statement *Statement) Init() { statement.UseCache = true statement.UseAutoTime = true statement.IsDistinct = false + statement.IsForUpdate = false statement.TableAlias = "" statement.selectStr = "" statement.allUseBool = false @@ -802,6 +804,12 @@ func (statement *Statement) Distinct(columns ...string) *Statement { return statement } +// Generate "SELECT ... FOR UPDATE" statment +func (statement *Statement) ForUpdate() *Statement { + statement.IsForUpdate = true + return statement +} + // replace select func (s *Statement) Select(str string) *Statement { s.selectStr = str @@ -1274,6 +1282,9 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) } } + if statement.IsForUpdate { + a += " FOR UPDATE" + } return } From 693501fd607555f0d45740f017a547cf450e3e2a Mon Sep 17 00:00:00 2001 From: evalphobia Date: Sun, 30 Aug 2015 20:07:18 +0900 Subject: [PATCH 2/3] Fixed FOR UPDATE for each dialects #290 --- mssql_dialect.go | 4 ++++ sqlite3_dialect.go | 4 ++++ statement.go | 2 +- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mssql_dialect.go b/mssql_dialect.go index 0eef76d4..41989b91 100644 --- a/mssql_dialect.go +++ b/mssql_dialect.go @@ -509,6 +509,10 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars return sql } +func (db *mssql) ForUpdateSql(query string) string { + return query +} + func (db *mssql) Filters() []core.Filter { return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}} } diff --git a/sqlite3_dialect.go b/sqlite3_dialect.go index 94e7d6b3..80873dbd 100644 --- a/sqlite3_dialect.go +++ b/sqlite3_dialect.go @@ -250,6 +250,10 @@ func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string { return fmt.Sprintf("DROP INDEX %v", quote(idxName)) } +func (db *sqlite3) ForUpdateSql(query string) string { + return query +} + /*func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { args := []interface{}{tableName} sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" diff --git a/statement.go b/statement.go index 063a6193..61d42dfe 100644 --- a/statement.go +++ b/statement.go @@ -1283,7 +1283,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { } } if statement.IsForUpdate { - a += " FOR UPDATE" + a = statement.Engine.dialect.ForUpdateSql(a) } return From 65f413ecf31cc472e03f82f03fb0ae4f9c4d2709 Mon Sep 17 00:00:00 2001 From: evalphobia Date: Sun, 30 Aug 2015 20:23:46 +0900 Subject: [PATCH 3/3] refactored genSelectSql --- statement.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/statement.go b/statement.go index 61d42dfe..6064c542 100644 --- a/statement.go +++ b/statement.go @@ -1192,6 +1192,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { distinct = "DISTINCT " } + var dialect core.Dialect = statement.Engine.Dialect() var top string var mssqlCondi string /*var orderBy string @@ -1203,7 +1204,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { if statement.WhereStr != "" { whereStr = fmt.Sprintf(" WHERE %v", statement.WhereStr) if statement.ConditionStr != "" { - whereStr = fmt.Sprintf("%v %s %v", whereStr, statement.Engine.Dialect().AndStr(), + whereStr = fmt.Sprintf("%v %s %v", whereStr, dialect.AndStr(), statement.ConditionStr) } } else if statement.ConditionStr != "" { @@ -1211,7 +1212,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { } var fromStr string = " FROM " + statement.Engine.Quote(statement.TableName()) if statement.TableAlias != "" { - if statement.Engine.dialect.DBType() == core.ORACLE { + if dialect.DBType() == core.ORACLE { fromStr += " " + statement.Engine.Quote(statement.TableAlias) } else { fromStr += " AS " + statement.Engine.Quote(statement.TableAlias) @@ -1221,7 +1222,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) } - if statement.Engine.dialect.DBType() == core.MSSQL { + if dialect.DBType() == core.MSSQL { if statement.LimitN > 0 { top = fmt.Sprintf(" TOP %d ", statement.LimitN) } @@ -1271,19 +1272,19 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { if statement.OrderStr != "" { a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) } - if statement.Engine.dialect.DBType() != core.MSSQL && statement.Engine.dialect.DBType() != core.ORACLE { + if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { if statement.Start > 0 { a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) } else if statement.LimitN > 0 { a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) } - } else if statement.Engine.dialect.DBType() == core.ORACLE { + } else if dialect.DBType() == core.ORACLE { if statement.Start != 0 || statement.LimitN != 0 { a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) } } if statement.IsForUpdate { - a = statement.Engine.dialect.ForUpdateSql(a) + a = dialect.ForUpdateSql(a) } return