fix statement.Limit(0) will update or delete all data

This commit is contained in:
yifhao 2018-10-17 11:53:45 +08:00 committed by Lunny Xiao
parent 14a0c19a0c
commit 56efc798ed
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
4 changed files with 38 additions and 28 deletions

View File

@ -101,7 +101,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
if err != nil {
return 0, err
}
if len(condSQL) == 0 && session.statement.LimitN == 0 {
pLimitN := session.statement.LimitN
if len(condSQL) == 0 && pLimitN == nil && *pLimitN == 0 {
return 0, ErrNeedDeletedCond
}
@ -119,8 +120,9 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
if len(session.statement.OrderStr) > 0 {
orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr)
}
if session.statement.LimitN > 0 {
orderSQL += fmt.Sprintf(" LIMIT %d", session.statement.LimitN)
if pLimitN != nil && *pLimitN > 0 {
limitNValue := *pLimitN
orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue)
}
if len(orderSQL) > 0 {
@ -139,7 +141,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
} else {
deleteSQL += " WHERE " + inSQL
}
// TODO: how to handle delete limit on mssql?
// TODO: how to handle delete limit on mssql?
case core.MSSQL:
return 0, ErrNotImplemented
default:
@ -180,7 +182,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
} else {
realSQL += " WHERE " + inSQL
}
// TODO: how to handle delete limit on mssql?
// TODO: how to handle delete limit on mssql?
case core.MSSQL:
return 0, ErrNotImplemented
default:

View File

@ -63,9 +63,9 @@ func (session *Session) BufferSize(size int) *Session {
func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error {
var bufferSize = session.statement.bufferSize
var limit = session.statement.LimitN
if limit > 0 && bufferSize > limit {
bufferSize = limit
var pLimitN = session.statement.LimitN
if pLimitN != nil && bufferSize > *pLimitN {
bufferSize = *pLimitN
}
var start = session.statement.Start
v := rValue(bean)
@ -94,8 +94,8 @@ func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error {
}
start = start + slice.Elem().Len()
if limit > 0 && start+bufferSize > limit {
bufferSize = limit - start
if pLimitN != nil && start+bufferSize > *pLimitN {
bufferSize = *pLimitN - start
}
}

View File

@ -337,11 +337,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var tableName = session.statement.TableName()
// TODO: Oracle support needed
var top string
if st.LimitN > 0 {
if st.LimitN != nil {
limitValue := *st.LimitN
if st.Engine.dialect.DBType() == core.MYSQL {
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
} else if st.Engine.dialect.DBType() == core.SQLITE {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
condSQL, condArgs, err = builder.ToSQL(cond)
@ -352,7 +353,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condSQL = "WHERE " + condSQL
}
} else if st.Engine.dialect.DBType() == core.POSTGRES {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
condSQL, condArgs, err = builder.ToSQL(cond)
@ -367,7 +368,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL &&
table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
session.engine.Quote(tableName), condSQL), condArgs...)
condSQL, condArgs, err = builder.ToSQL(cond)
@ -378,7 +379,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condSQL = "WHERE " + condSQL
}
} else {
top = fmt.Sprintf("TOP (%d) ", st.LimitN)
top = fmt.Sprintf("TOP (%d) ", limitValue)
}
}
}

View File

@ -20,7 +20,7 @@ type Statement struct {
RefTable *core.Table
Engine *Engine
Start int
LimitN int
LimitN *int
idParam *core.PK
OrderStr string
JoinStr string
@ -65,7 +65,7 @@ type Statement struct {
func (statement *Statement) Init() {
statement.RefTable = nil
statement.Start = 0
statement.LimitN = 0
statement.LimitN = nil
statement.OrderStr = ""
statement.UseCascade = true
statement.JoinStr = ""
@ -671,7 +671,7 @@ func (statement *Statement) Top(limit int) *Statement {
// Limit generate LIMIT start, limit statement
func (statement *Statement) Limit(limit int, start ...int) *Statement {
statement.LimitN = limit
statement.LimitN = &limit
if len(start) > 0 {
statement.Start = start[0]
}
@ -1071,9 +1071,11 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
}
pLimitN := statement.LimitN
if dialect.DBType() == core.MSSQL {
if statement.LimitN > 0 {
top = fmt.Sprintf("TOP %d ", statement.LimitN)
if pLimitN != nil {
LimitNValue := *pLimitN
top = fmt.Sprintf("TOP %d ", LimitNValue)
}
if statement.Start > 0 {
var column string
@ -1134,12 +1136,16 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
if needLimit {
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
if statement.Start > 0 {
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
} else if statement.LimitN > 0 {
fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
if pLimitN != nil {
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
} else {
fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start)
}
} else if pLimitN != nil {
fmt.Fprint(&buf, " LIMIT ", *pLimitN)
}
} else if dialect.DBType() == core.ORACLE {
if statement.Start != 0 || statement.LimitN != 0 {
if statement.Start != 0 || pLimitN != nil {
oldString := buf.String()
buf.Reset()
rawColStr := columnStr
@ -1147,7 +1153,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
rawColStr = "at.*"
}
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, rawColStr, oldString, statement.Start+statement.LimitN, statement.Start)
columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start)
}
}
}
@ -1204,8 +1210,9 @@ func (statement *Statement) convertIDSQL(sqlStr string) string {
}
var top string
if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
top = fmt.Sprintf("TOP %d ", statement.LimitN)
pLimitN := statement.LimitN
if pLimitN != nil && statement.Engine.dialect.DBType() == core.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN)
}
newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])