diff --git a/integrations/session_count_test.go b/integrations/session_count_test.go new file mode 100644 index 00000000..1517dede --- /dev/null +++ b/integrations/session_count_test.go @@ -0,0 +1,172 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "xorm.io/builder" +) + +func TestCount(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoCount struct { + Departname string + } + assert.NoError(t, testEngine.Sync2(new(UserinfoCount))) + + colName := testEngine.GetColumnMapper().Obj2Table("Departname") + var cond builder.Cond = builder.Eq{ + "`" + colName + "`": "dev", + } + + total, err := testEngine.Where(cond).Count(new(UserinfoCount)) + assert.NoError(t, err) + assert.EqualValues(t, 0, total) + + cnt, err := testEngine.Insert(&UserinfoCount{ + Departname: "dev", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + total, err = testEngine.Where(cond).Count(new(UserinfoCount)) + assert.NoError(t, err) + assert.EqualValues(t, 1, total) + + total, err = testEngine.Where(cond).Table("userinfo_count").Count() + assert.NoError(t, err) + assert.EqualValues(t, 1, total) + + total, err = testEngine.Table("userinfo_count").Count() + assert.NoError(t, err) + assert.EqualValues(t, 1, total) +} + +func TestSQLCount(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoCount2 struct { + Id int64 + Departname string + } + + type UserinfoBooks struct { + Id int64 + Pid int64 + IsOpen bool + } + + assertSync(t, new(UserinfoCount2), new(UserinfoBooks)) + + total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableName("userinfo_count2", true)). + Count() + assert.NoError(t, err) + assert.EqualValues(t, 0, total) +} + +func TestCountWithOthers(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type CountWithOthers struct { + Id int64 + Name string + } + + assertSync(t, new(CountWithOthers)) + + _, err := testEngine.Insert(&CountWithOthers{ + Name: "orderby", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(&CountWithOthers{ + Name: "limit", + }) + assert.NoError(t, err) + + total, err := testEngine.OrderBy("id desc").Limit(1).Count(new(CountWithOthers)) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) +} + +type CountWithTableName struct { + Id int64 + Name string +} + +func (CountWithTableName) TableName() string { + return "count_with_table_name1" +} + +func TestWithTableName(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "orderby", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "limit", + }) + assert.NoError(t, err) + + total, err := testEngine.OrderBy("id desc").Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) + + total, err = testEngine.OrderBy("id desc").Count(CountWithTableName{}) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) +} + +func TestCountWithSelectCols(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "orderby", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "limit", + }) + assert.NoError(t, err) + + total, err := testEngine.Cols("id").Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) + + total, err = testEngine.Select("count(id)").Count(CountWithTableName{}) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) +} + +func TestCountWithGroupBy(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "1", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "2", + }) + assert.NoError(t, err) + + cnt, err := testEngine.GroupBy("name").Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) +} diff --git a/integrations/session_find_test.go b/integrations/session_find_test.go index f90b5317..becb1494 100644 --- a/integrations/session_find_test.go +++ b/integrations/session_find_test.go @@ -678,6 +678,36 @@ func TestFindAndCountWithTableName(t *testing.T) { assert.EqualValues(t, 1, cnt) } +func TestFindAndCountWithGroupBy(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type FindAndCountWithGroupBy struct { + Id int64 + Age int `xorm:"index"` + Name string + } + + assert.NoError(t, testEngine.Sync2(new(FindAndCountWithGroupBy))) + + _, err := testEngine.Insert([]FindAndCountWithGroupBy{ + { + Name: "test1", + Age: 10, + }, + { + Name: "test2", + Age: 20, + }, + }) + assert.NoError(t, err) + + var results []FindAndCountWithGroupBy + cnt, err := testEngine.GroupBy("age").FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + assert.EqualValues(t, 2, len(results)) +} + type FindMapDevice struct { Deviceid string `xorm:"pk"` Status int diff --git a/integrations/session_stats_test.go b/integrations/session_sum_test.go similarity index 53% rename from integrations/session_stats_test.go rename to integrations/session_sum_test.go index 47a64076..b447c699 100644 --- a/integrations/session_stats_test.go +++ b/integrations/session_sum_test.go @@ -10,7 +10,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "xorm.io/builder" ) func isFloatEq(i, j float64, precision int) bool { @@ -158,143 +157,3 @@ func TestSumCustomColumn(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 3, int(sumInt)) } - -func TestCount(t *testing.T) { - assert.NoError(t, PrepareEngine()) - - type UserinfoCount struct { - Departname string - } - assert.NoError(t, testEngine.Sync2(new(UserinfoCount))) - - colName := testEngine.GetColumnMapper().Obj2Table("Departname") - var cond builder.Cond = builder.Eq{ - "`" + colName + "`": "dev", - } - - total, err := testEngine.Where(cond).Count(new(UserinfoCount)) - assert.NoError(t, err) - assert.EqualValues(t, 0, total) - - cnt, err := testEngine.Insert(&UserinfoCount{ - Departname: "dev", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - total, err = testEngine.Where(cond).Count(new(UserinfoCount)) - assert.NoError(t, err) - assert.EqualValues(t, 1, total) - - total, err = testEngine.Where(cond).Table("userinfo_count").Count() - assert.NoError(t, err) - assert.EqualValues(t, 1, total) - - total, err = testEngine.Table("userinfo_count").Count() - assert.NoError(t, err) - assert.EqualValues(t, 1, total) -} - -func TestSQLCount(t *testing.T) { - assert.NoError(t, PrepareEngine()) - - type UserinfoCount2 struct { - Id int64 - Departname string - } - - type UserinfoBooks struct { - Id int64 - Pid int64 - IsOpen bool - } - - assertSync(t, new(UserinfoCount2), new(UserinfoBooks)) - - total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableName("userinfo_count2", true)). - Count() - assert.NoError(t, err) - assert.EqualValues(t, 0, total) -} - -func TestCountWithOthers(t *testing.T) { - assert.NoError(t, PrepareEngine()) - - type CountWithOthers struct { - Id int64 - Name string - } - - assertSync(t, new(CountWithOthers)) - - _, err := testEngine.Insert(&CountWithOthers{ - Name: "orderby", - }) - assert.NoError(t, err) - - _, err = testEngine.Insert(&CountWithOthers{ - Name: "limit", - }) - assert.NoError(t, err) - - total, err := testEngine.OrderBy("id desc").Limit(1).Count(new(CountWithOthers)) - assert.NoError(t, err) - assert.EqualValues(t, 2, total) -} - -type CountWithTableName struct { - Id int64 - Name string -} - -func (CountWithTableName) TableName() string { - return "count_with_table_name1" -} - -func TestWithTableName(t *testing.T) { - assert.NoError(t, PrepareEngine()) - - assertSync(t, new(CountWithTableName)) - - _, err := testEngine.Insert(&CountWithTableName{ - Name: "orderby", - }) - assert.NoError(t, err) - - _, err = testEngine.Insert(CountWithTableName{ - Name: "limit", - }) - assert.NoError(t, err) - - total, err := testEngine.OrderBy("id desc").Count(new(CountWithTableName)) - assert.NoError(t, err) - assert.EqualValues(t, 2, total) - - total, err = testEngine.OrderBy("id desc").Count(CountWithTableName{}) - assert.NoError(t, err) - assert.EqualValues(t, 2, total) -} - -func TestCountWithSelectCols(t *testing.T) { - assert.NoError(t, PrepareEngine()) - - assertSync(t, new(CountWithTableName)) - - _, err := testEngine.Insert(&CountWithTableName{ - Name: "orderby", - }) - assert.NoError(t, err) - - _, err = testEngine.Insert(CountWithTableName{ - Name: "limit", - }) - assert.NoError(t, err) - - total, err := testEngine.Cols("id").Count(new(CountWithTableName)) - assert.NoError(t, err) - assert.EqualValues(t, 2, total) - - total, err = testEngine.Select("count(id)").Count(CountWithTableName{}) - assert.NoError(t, err) - assert.EqualValues(t, 2, total) -} diff --git a/internal/statements/query.go b/internal/statements/query.go index f1b36770..8b4cd919 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -181,11 +181,22 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa selectSQL = "count(*)" } } - sqlStr, condArgs, err := statement.genSelectSQL(selectSQL, false, false) + var subQuerySelect string + if statement.GroupByStr != "" { + subQuerySelect = statement.GroupByStr + } else { + subQuerySelect = selectSQL + } + + sqlStr, condArgs, err := statement.genSelectSQL(subQuerySelect, false, false) if err != nil { return "", nil, err } + if statement.GroupByStr != "" { + sqlStr = fmt.Sprintf("SELECT %s FROM (%s) sub", selectSQL, sqlStr) + } + return sqlStr, append(statement.joinArgs, condArgs...), nil }