From cea778734ccc9da409c84c12d2914b6e54654ff3 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 22 Feb 2018 20:24:40 -0600 Subject: [PATCH] fix FindAndCount bug with Limit (#851) --- session_find.go | 2 +- session_find_test.go | 16 +++++++++++----- session_query.go | 2 +- statement.go | 28 +++++++++++++++------------- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/session_find.go b/session_find.go index 68880d97..44eae714 100644 --- a/session_find.go +++ b/session_find.go @@ -151,7 +151,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } args = append(session.statement.joinArgs, condArgs...) - sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL) + sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true) if err != nil { return err } diff --git a/session_find_test.go b/session_find_test.go index b2e8f5a3..4088e05e 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -523,9 +523,9 @@ func TestFindMark(t *testing.T) { func TestFindAndCountOneFunc(t *testing.T) { type FindAndCountStruct struct { - Id int64 + Id int64 Content string - Msg bool `xorm:"bit"` + Msg bool `xorm:"bit"` } assert.NoError(t, prepareEngine()) @@ -534,11 +534,11 @@ func TestFindAndCountOneFunc(t *testing.T) { cnt, err := testEngine.Insert([]FindAndCountStruct{ { Content: "111", - Msg: false, + Msg: false, }, { Content: "222", - Msg: true, + Msg: true, }, }) assert.NoError(t, err) @@ -555,4 +555,10 @@ func TestFindAndCountOneFunc(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, cnt) -} \ No newline at end of file + + results = make([]FindAndCountStruct, 0, 1) + cnt, err = testEngine.Where("msg = ?", true).Limit(1).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 1, cnt) +} diff --git a/session_query.go b/session_query.go index f8098f84..e8fbd8d3 100644 --- a/session_query.go +++ b/session_query.go @@ -70,7 +70,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa } args := append(session.statement.joinArgs, condArgs...) - sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL) + sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true) if err != nil { return "", nil, err } diff --git a/statement.go b/statement.go index 6400425b..35c4a472 100644 --- a/statement.go +++ b/statement.go @@ -988,7 +988,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, return "", nil, err } - sqlStr, err := statement.genSelectSQL(columnStr, condSQL) + sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true) if err != nil { return "", nil, err } @@ -1018,7 +1018,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa selectSQL = "count(*)" } } - sqlStr, err := statement.genSelectSQL(selectSQL, condSQL) + sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false) if err != nil { return "", nil, err } @@ -1043,7 +1043,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri return "", nil, err } - sqlStr, err := statement.genSelectSQL(sumSelect, condSQL) + sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true) if err != nil { return "", nil, err } @@ -1051,7 +1051,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) { +func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bool) (a string, err error) { var distinct string if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { distinct = "DISTINCT " @@ -1149,15 +1149,17 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e if statement.OrderStr != "" { a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) } - if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { - if statement.Start > 0 { - a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) - } else if statement.LimitN > 0 { - a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) - } - } else if dialect.DBType() == core.ORACLE { - if statement.Start != 0 || statement.LimitN != 0 { - a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) + if needLimit { + if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { + if statement.Start > 0 { + a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) + } else if statement.LimitN > 0 { + a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) + } + } else if dialect.DBType() == core.ORACLE { + if statement.Start != 0 || statement.LimitN != 0 { + a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) + } } } if statement.IsForUpdate {