diff --git a/session_delete.go b/session_delete.go index 675d4d8c..7b0a0641 100644 --- a/session_delete.go +++ b/session_delete.go @@ -101,7 +101,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) { if err != nil { return 0, err } - if len(condSQL) == 0 && session.statement.LimitN == 0 { + pLimitN := session.statement.LimitN + if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { return 0, ErrNeedDeletedCond } @@ -119,8 +120,9 @@ func (session *Session) Delete(bean interface{}) (int64, error) { if len(session.statement.OrderStr) > 0 { orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr) } - if session.statement.LimitN > 0 { - orderSQL += fmt.Sprintf(" LIMIT %d", session.statement.LimitN) + if pLimitN != nil && *pLimitN > 0 { + limitNValue := *pLimitN + orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue) } if len(orderSQL) > 0 { @@ -139,7 +141,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { } else { deleteSQL += " WHERE " + inSQL } - // TODO: how to handle delete limit on mssql? + // TODO: how to handle delete limit on mssql? case core.MSSQL: return 0, ErrNotImplemented default: @@ -180,7 +182,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { } else { realSQL += " WHERE " + inSQL } - // TODO: how to handle delete limit on mssql? + // TODO: how to handle delete limit on mssql? case core.MSSQL: return 0, ErrNotImplemented default: diff --git a/session_iterate.go b/session_iterate.go index a1642bb3..4a3cc083 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -63,9 +63,9 @@ func (session *Session) BufferSize(size int) *Session { func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error { var bufferSize = session.statement.bufferSize - var limit = session.statement.LimitN - if limit > 0 && bufferSize > limit { - bufferSize = limit + var pLimitN = session.statement.LimitN + if pLimitN != nil && bufferSize > *pLimitN { + bufferSize = *pLimitN } var start = session.statement.Start v := rValue(bean) @@ -94,8 +94,8 @@ func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error { } start = start + slice.Elem().Len() - if limit > 0 && start+bufferSize > limit { - bufferSize = limit - start + if pLimitN != nil && start+bufferSize > *pLimitN { + bufferSize = *pLimitN - start } } diff --git a/session_update.go b/session_update.go index 22d516e7..a26b15a6 100644 --- a/session_update.go +++ b/session_update.go @@ -337,11 +337,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var tableName = session.statement.TableName() // TODO: Oracle support needed var top string - if st.LimitN > 0 { + if st.LimitN != nil { + limitValue := *st.LimitN if st.Engine.dialect.DBType() == core.MYSQL { - condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) } else if st.Engine.dialect.DBType() == core.SQLITE { - tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) condSQL, condArgs, err = builder.ToSQL(cond) @@ -352,7 +353,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 condSQL = "WHERE " + condSQL } } else if st.Engine.dialect.DBType() == core.POSTGRES { - tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) condSQL, condArgs, err = builder.ToSQL(cond) @@ -367,7 +368,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL && table != nil && len(table.PrimaryKeys) == 1 { cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", - table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0], + table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], session.engine.Quote(tableName), condSQL), condArgs...) condSQL, condArgs, err = builder.ToSQL(cond) @@ -378,7 +379,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 condSQL = "WHERE " + condSQL } } else { - top = fmt.Sprintf("TOP (%d) ", st.LimitN) + top = fmt.Sprintf("TOP (%d) ", limitValue) } } } diff --git a/statement.go b/statement.go index 67e35213..dc251d30 100644 --- a/statement.go +++ b/statement.go @@ -20,7 +20,7 @@ type Statement struct { RefTable *core.Table Engine *Engine Start int - LimitN int + LimitN *int idParam *core.PK OrderStr string JoinStr string @@ -65,7 +65,7 @@ type Statement struct { func (statement *Statement) Init() { statement.RefTable = nil statement.Start = 0 - statement.LimitN = 0 + statement.LimitN = nil statement.OrderStr = "" statement.UseCascade = true statement.JoinStr = "" @@ -671,7 +671,7 @@ func (statement *Statement) Top(limit int) *Statement { // Limit generate LIMIT start, limit statement func (statement *Statement) Limit(limit int, start ...int) *Statement { - statement.LimitN = limit + statement.LimitN = &limit if len(start) > 0 { statement.Start = start[0] } @@ -1071,9 +1071,11 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) } + pLimitN := statement.LimitN if dialect.DBType() == core.MSSQL { - if statement.LimitN > 0 { - top = fmt.Sprintf("TOP %d ", statement.LimitN) + if pLimitN != nil { + LimitNValue := *pLimitN + top = fmt.Sprintf("TOP %d ", LimitNValue) } if statement.Start > 0 { var column string @@ -1134,12 +1136,16 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n if needLimit { if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { if statement.Start > 0 { - fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start) - } else if statement.LimitN > 0 { - fmt.Fprint(&buf, " LIMIT ", statement.LimitN) + if pLimitN != nil { + fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start) + } else { + fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start) + } + } else if pLimitN != nil { + fmt.Fprint(&buf, " LIMIT ", *pLimitN) } } else if dialect.DBType() == core.ORACLE { - if statement.Start != 0 || statement.LimitN != 0 { + if statement.Start != 0 || pLimitN != nil { oldString := buf.String() buf.Reset() rawColStr := columnStr @@ -1147,7 +1153,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n rawColStr = "at.*" } fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", - columnStr, rawColStr, oldString, statement.Start+statement.LimitN, statement.Start) + columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start) } } } @@ -1204,8 +1210,9 @@ func (statement *Statement) convertIDSQL(sqlStr string) string { } var top string - if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL { - top = fmt.Sprintf("TOP %d ", statement.LimitN) + pLimitN := statement.LimitN + if pLimitN != nil && statement.Engine.dialect.DBType() == core.MSSQL { + top = fmt.Sprintf("TOP %d ", *pLimitN) } newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) diff --git a/tag_id_test.go b/tag_id_test.go index f1c5a6bc..dce5f688 100644 --- a/tag_id_test.go +++ b/tag_id_test.go @@ -7,8 +7,8 @@ package xorm import ( "testing" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/core" ) type IDGonicMapper struct { @@ -76,7 +76,7 @@ func TestSameMapperID(t *testing.T) { for _, tb := range tables { if tb.Name == "IDSameMapper" { if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "ID" { - t.Fatal(tb) + t.Fatalf("tb %s tb.PKColumns() is %d not 1, tb.PKColumns()[0].Name is %s not ID", tb.Name, len(tb.PKColumns()), tb.PKColumns()[0].Name) } return }