fix FindAndCount bug with Limit (#851)
This commit is contained in:
parent
2e295feace
commit
cea778734c
|
@ -151,7 +151,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
args = append(session.statement.joinArgs, condArgs...)
|
args = append(session.statement.joinArgs, condArgs...)
|
||||||
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL)
|
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -523,9 +523,9 @@ func TestFindMark(t *testing.T) {
|
||||||
|
|
||||||
func TestFindAndCountOneFunc(t *testing.T) {
|
func TestFindAndCountOneFunc(t *testing.T) {
|
||||||
type FindAndCountStruct struct {
|
type FindAndCountStruct struct {
|
||||||
Id int64
|
Id int64
|
||||||
Content string
|
Content string
|
||||||
Msg bool `xorm:"bit"`
|
Msg bool `xorm:"bit"`
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NoError(t, prepareEngine())
|
assert.NoError(t, prepareEngine())
|
||||||
|
@ -534,11 +534,11 @@ func TestFindAndCountOneFunc(t *testing.T) {
|
||||||
cnt, err := testEngine.Insert([]FindAndCountStruct{
|
cnt, err := testEngine.Insert([]FindAndCountStruct{
|
||||||
{
|
{
|
||||||
Content: "111",
|
Content: "111",
|
||||||
Msg: false,
|
Msg: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Content: "222",
|
Content: "222",
|
||||||
Msg: true,
|
Msg: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -555,4 +555,10 @@ func TestFindAndCountOneFunc(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, 1, len(results))
|
assert.EqualValues(t, 1, len(results))
|
||||||
assert.EqualValues(t, 1, cnt)
|
assert.EqualValues(t, 1, cnt)
|
||||||
|
|
||||||
|
results = make([]FindAndCountStruct, 0, 1)
|
||||||
|
cnt, err = testEngine.Where("msg = ?", true).Limit(1).FindAndCount(&results)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, len(results))
|
||||||
|
assert.EqualValues(t, 1, cnt)
|
||||||
}
|
}
|
|
@ -70,7 +70,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
|
||||||
}
|
}
|
||||||
|
|
||||||
args := append(session.statement.joinArgs, condArgs...)
|
args := append(session.statement.joinArgs, condArgs...)
|
||||||
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL)
|
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
28
statement.go
28
statement.go
|
@ -988,7 +988,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
|
sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
@ -1018,7 +1018,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
|
||||||
selectSQL = "count(*)"
|
selectSQL = "count(*)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
|
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
@ -1043,7 +1043,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
|
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
@ -1051,7 +1051,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
|
||||||
return sqlStr, append(statement.joinArgs, condArgs...), nil
|
return sqlStr, append(statement.joinArgs, condArgs...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
|
func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bool) (a string, err error) {
|
||||||
var distinct string
|
var distinct string
|
||||||
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
|
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
|
||||||
distinct = "DISTINCT "
|
distinct = "DISTINCT "
|
||||||
|
@ -1149,15 +1149,17 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e
|
||||||
if statement.OrderStr != "" {
|
if statement.OrderStr != "" {
|
||||||
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
|
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
|
||||||
}
|
}
|
||||||
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
|
if needLimit {
|
||||||
if statement.Start > 0 {
|
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
|
||||||
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
|
if statement.Start > 0 {
|
||||||
} else if statement.LimitN > 0 {
|
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
|
||||||
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
|
} else if statement.LimitN > 0 {
|
||||||
}
|
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
|
||||||
} else if dialect.DBType() == core.ORACLE {
|
}
|
||||||
if statement.Start != 0 || statement.LimitN != 0 {
|
} else if dialect.DBType() == core.ORACLE {
|
||||||
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
|
if statement.Start != 0 || statement.LimitN != 0 {
|
||||||
|
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if statement.IsForUpdate {
|
if statement.IsForUpdate {
|
||||||
|
|
Loading…
Reference in New Issue