From 94fd254638695349b5b57e6e271d3606f2d58696 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 11 Mar 2020 03:29:43 +0000 Subject: [PATCH] Support count with cols (#1595) Support count with cols Reviewed-on: https://gitea.com/xorm/xorm/pulls/1595 --- interface.go | 1 + internal/statements/query.go | 3 +++ session_stats_test.go | 24 ++++++++++++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/interface.go b/interface.go index 67b8d4b1..be4da707 100644 --- a/interface.go +++ b/interface.go @@ -59,6 +59,7 @@ type Interface interface { QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error) Rows(bean interface{}) (*Rows, error) SetExpr(string, interface{}) *Session + Select(string) *Session SQL(interface{}, ...interface{}) *Session Sum(bean interface{}, colName string) (float64, error) SumInt(bean interface{}, colName string) (int64, error) diff --git a/internal/statements/query.go b/internal/statements/query.go index 1568259e..ab3021bf 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -153,6 +153,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, return sqlStr, append(statement.joinArgs, condArgs...), nil } +// GenCountSQL generates the SQL for counting func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) { if statement.RawSQL != "" { return statement.GenRawSQL(), statement.RawParams, nil @@ -171,6 +172,8 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa if len(selectSQL) <= 0 { if statement.IsDistinct { selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr()) + } else if statement.ColumnStr() != "" { + selectSQL = fmt.Sprintf("count(%s)", statement.ColumnStr()) } else { selectSQL = "count(*)" } diff --git a/session_stats_test.go b/session_stats_test.go index d66a7e1f..1f11560b 100644 --- a/session_stats_test.go +++ b/session_stats_test.go @@ -274,3 +274,27 @@ func TestWithTableName(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, total) } + +func TestCountWithSelectCols(t *testing.T) { + assert.NoError(t, prepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "orderby", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "limit", + }) + assert.NoError(t, err) + + total, err := testEngine.Cols("id").Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) + + total, err = testEngine.Select("count(id)").Count(CountWithTableName{}) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) +}