fix FindAndCount bug with Limit (#851)

This commit is contained in:
Lunny Xiao 2018-02-22 20:24:40 -06:00 committed by GitHub
parent 2e295feace
commit cea778734c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 20 deletions

View File

@ -151,7 +151,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
}
args = append(session.statement.joinArgs, condArgs...)
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL)
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true)
if err != nil {
return err
}

View File

@ -523,9 +523,9 @@ func TestFindMark(t *testing.T) {
func TestFindAndCountOneFunc(t *testing.T) {
type FindAndCountStruct struct {
Id int64
Id int64
Content string
Msg bool `xorm:"bit"`
Msg bool `xorm:"bit"`
}
assert.NoError(t, prepareEngine())
@ -534,11 +534,11 @@ func TestFindAndCountOneFunc(t *testing.T) {
cnt, err := testEngine.Insert([]FindAndCountStruct{
{
Content: "111",
Msg: false,
Msg: false,
},
{
Content: "222",
Msg: true,
Msg: true,
},
})
assert.NoError(t, err)
@ -555,4 +555,10 @@ func TestFindAndCountOneFunc(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 1, cnt)
results = make([]FindAndCountStruct, 0, 1)
cnt, err = testEngine.Where("msg = ?", true).Limit(1).FindAndCount(&results)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 1, cnt)
}

View File

@ -70,7 +70,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
}
args := append(session.statement.joinArgs, condArgs...)
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL)
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true)
if err != nil {
return "", nil, err
}

View File

@ -988,7 +988,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
return "", nil, err
}
sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true)
if err != nil {
return "", nil, err
}
@ -1018,7 +1018,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
selectSQL = "count(*)"
}
}
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false)
if err != nil {
return "", nil, err
}
@ -1043,7 +1043,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return "", nil, err
}
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true)
if err != nil {
return "", nil, err
}
@ -1051,7 +1051,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bool) (a string, err error) {
var distinct string
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT "
@ -1149,15 +1149,17 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e
if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
}
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
} else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
}
} else if dialect.DBType() == core.ORACLE {
if statement.Start != 0 || statement.LimitN != 0 {
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
if needLimit {
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
} else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
}
} else if dialect.DBType() == core.ORACLE {
if statement.Start != 0 || statement.LimitN != 0 {
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
}
}
}
if statement.IsForUpdate {