fix FindAndCount bug with Limit (#851)

This commit is contained in:
Lunny Xiao 2018-02-22 20:24:40 -06:00 committed by GitHub
parent 2e295feace
commit cea778734c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 20 deletions

View File

@ -151,7 +151,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
args = append(session.statement.joinArgs, condArgs...) args = append(session.statement.joinArgs, condArgs...)
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL) sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true)
if err != nil { if err != nil {
return err return err
} }

View File

@ -555,4 +555,10 @@ func TestFindAndCountOneFunc(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
results = make([]FindAndCountStruct, 0, 1)
cnt, err = testEngine.Where("msg = ?", true).Limit(1).FindAndCount(&results)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 1, cnt)
} }

View File

@ -70,7 +70,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
} }
args := append(session.statement.joinArgs, condArgs...) args := append(session.statement.joinArgs, condArgs...)
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL) sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }

View File

@ -988,7 +988,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
return "", nil, err return "", nil, err
} }
sqlStr, err := statement.genSelectSQL(columnStr, condSQL) sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -1018,7 +1018,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
selectSQL = "count(*)" selectSQL = "count(*)"
} }
} }
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL) sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -1043,7 +1043,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return "", nil, err return "", nil, err
} }
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL) sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -1051,7 +1051,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return sqlStr, append(statement.joinArgs, condArgs...), nil return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) { func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bool) (a string, err error) {
var distinct string var distinct string
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT " distinct = "DISTINCT "
@ -1149,6 +1149,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e
if statement.OrderStr != "" { if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
} }
if needLimit {
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
if statement.Start > 0 { if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
@ -1160,6 +1161,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
} }
} }
}
if statement.IsForUpdate { if statement.IsForUpdate {
a = dialect.ForUpdateSql(a) a = dialect.ForUpdateSql(a)
} }