diff --git a/session_stats.go b/session_stats.go index c2cac830..88286e50 100644 --- a/session_stats.go +++ b/session_stats.go @@ -30,6 +30,10 @@ func (session *Session) Count(bean ...interface{}) (int64, error) { args = session.statement.RawParams } + if len(session.statement.selectStr) > 0 { + sqlStr = "SELECT COUNT(*) FROM ("+sqlStr+") _TEMP_" + } + var total int64 err = session.queryRow(sqlStr, args...).Scan(&total) if err == sql.ErrNoRows || err == nil { diff --git a/session_stats_test.go b/session_stats_test.go index d66a7e1f..7b9afcfc 100644 --- a/session_stats_test.go +++ b/session_stats_test.go @@ -182,6 +182,10 @@ func TestCount(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + total, err = testEngine.Select(colName).Where(cond).Count(new(UserinfoCount)) + assert.NoError(t, err) + assert.EqualValues(t, 1, total) + total, err = testEngine.Where(cond).Count(new(UserinfoCount)) assert.NoError(t, err) assert.EqualValues(t, 1, total)