diff --git a/session.go b/session.go index f0edb0ba..a4ad2916 100644 --- a/session.go +++ b/session.go @@ -30,6 +30,7 @@ type Session struct { IsCommitedOrRollbacked bool TransType string IsAutoClose bool + LockRead bool //set sad lock for accepting ro read dirty data // Automatically reset the statement after operations that execute a SQL // query such as Count(), Find(), Get(), ... @@ -60,6 +61,7 @@ func (session *Session) Init() { session.IsCommitedOrRollbacked = false session.IsAutoClose = false session.AutoResetStatement = true + session.LockRead = false // !nashtsai! is lazy init better? session.afterInsertBeans = make(map[interface{}]*[]func(interface{}), 0) @@ -104,6 +106,16 @@ func (session *Session) Sql(querystring string, args ...interface{}) *Session { return session } +//set read lock +func (session *Session) SetLockRead(lr ...bool) *Session { + if 0 == len(lr) { + session.LockRead = true + } else { + session.LockRead = lr[0] + } + return session +} + // Method Where provides custom query condition. func (session *Session) Where(querystring string, args ...interface{}) *Session { session.Statement.Where(querystring, args...) @@ -1015,6 +1027,10 @@ func (session *Session) Get(bean interface{}) (bool, error) { if session.Statement.RawSQL == "" { sqlStr, args = session.Statement.genGetSql(bean) + //加入悲观锁 FOR oracle & pg & mysql + if session.LockRead { + sqlStr += " FOR UPDATE " + } } else { sqlStr = session.Statement.RawSQL args = session.Statement.RawParams diff --git a/statement.go b/statement.go index f4d0bcf5..0618552a 100644 --- a/statement.go +++ b/statement.go @@ -1172,14 +1172,14 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} return statement.genSelectSql(fmt.Sprintf("count(%v)", id)), append(statement.Params, statement.BeanArgs...) } -func (statement *Statement) genSelectSql(columnStr string) (a string) { +func (statement *Statement) genSelectSql(columnStr string) string { /*if statement.GroupByStr != "" { if columnStr == "" { columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) } //statement.GroupByStr = columnStr }*/ - var distinct string + var distinct, a string if statement.IsDistinct { distinct = "DISTINCT " } @@ -1274,8 +1274,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) } } - - return + return a } func (statement *Statement) processIdParam() {