diff --git a/internal/statements/legacy_select.go b/internal/statements/legacy_select.go index 144ad96d..691d9137 100644 --- a/internal/statements/legacy_select.go +++ b/internal/statements/legacy_select.go @@ -19,7 +19,7 @@ func (statement *Statement) isUsingLegacyLimitOffset() bool { func (statement *Statement) writeMssqlLegacySelect(buf *builder.BytesWriter, columnStr string) error { return statement.writeMultiple(buf, statement.writeStrings("SELECT"), - statement.writeDistinct, + statement.writeDistinct(columnStr), statement.writeTop, statement.writeFrom, statement.writeWhereWithMssqlPagination, @@ -32,7 +32,9 @@ func (statement *Statement) writeMssqlLegacySelect(buf *builder.BytesWriter, col func (statement *Statement) writeOracleLegacySelect(buf *builder.BytesWriter, columnStr string) error { return statement.writeMultiple(buf, - statement.writeSelectColumns(columnStr), + statement.writeStrings("SELECT"), + statement.writeDistinct(columnStr), + statement.writeStrings(" ", columnStr), statement.writeFrom, statement.writeOracleLimit(columnStr), statement.writeGroupBy, diff --git a/internal/statements/query.go b/internal/statements/query.go index e817403c..bb0bf1b2 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -141,29 +141,20 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa } } - selectSQL := statement.SelectStr - if len(selectSQL) <= 0 { - if statement.IsDistinct { - selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr()) - } else if statement.ColumnStr() != "" { - selectSQL = fmt.Sprintf("count(%s)", statement.ColumnStr()) - } else { - selectSQL = "count(*)" - } + selectBuf := builder.NewWriter() + if err := statement.writeSelectCount(selectBuf); err != nil { + return "", nil, err } buf := builder.NewWriter() - if statement.GroupByStr != "" { - if _, err := fmt.Fprintf(buf, "SELECT %s FROM (", selectSQL); err != nil { - return "", nil, err - } - } - var subQuerySelect string if statement.GroupByStr != "" { + if err := statement.writeStrings("SELECT ", selectBuf.String(), " FROM (")(buf); err != nil { + return "", nil, err + } subQuerySelect = statement.GroupByStr } else { - subQuerySelect = selectSQL + subQuerySelect = selectBuf.String() } if err := statement.writeSelect(buf, subQuerySelect, true); err != nil { @@ -198,20 +189,27 @@ func (statement *Statement) writeTop(w *builder.BytesWriter) error { return err } -func (statement *Statement) writeDistinct(w *builder.BytesWriter) error { - if statement.IsDistinct && !strings.HasPrefix(statement.SelectStr, "count(") { - _, err := fmt.Fprint(w, " DISTINCT") - return err +func (statement *Statement) writeDistinct(selectStr string) func(w *builder.BytesWriter) error { + return func(w *builder.BytesWriter) error { + if statement.IsDistinct && !strings.HasPrefix(selectStr, "COUNT(") { + _, err := fmt.Fprint(w, " DISTINCT") + return err + } + return nil } - return nil } -func (statement *Statement) writeSelectColumns(columnStr string) func(w *builder.BytesWriter) error { - return statement.groupWriteFns( - statement.writeStrings("SELECT"), - statement.writeDistinct, - statement.writeStrings(" ", columnStr), - ) +func (statement *Statement) writeSelectCount(w *builder.BytesWriter) error { + if statement.SelectStr != "" { + return statement.writeStrings(statement.SelectStr)(w) + } + + if statement.IsDistinct { + return statement.writeStrings("COUNT(DISTINCT ", statement.ColumnStr(), ")")(w) + } else if statement.ColumnStr() != "" { + return statement.writeStrings("COUNT(", statement.ColumnStr(), ")")(w) + } + return statement.writeStrings("COUNT(*)")(w) } func (statement *Statement) writeWhereCond(w *builder.BytesWriter, cond builder.Cond) error { @@ -253,7 +251,9 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri } return statement.writeMultiple(buf, - statement.writeSelectColumns(columnStr), + statement.writeStrings("SELECT"), + statement.writeDistinct(columnStr), + statement.writeStrings(" ", columnStr), statement.writeFrom, statement.writeWhere, statement.writeGroupBy, diff --git a/tests/session_find_test.go b/tests/session_find_test.go index d991e6ba..1b8019f1 100644 --- a/tests/session_find_test.go +++ b/tests/session_find_test.go @@ -1254,3 +1254,20 @@ func TestFindInMaxID(t *testing.T) { err := testEngine.In("id", builder.Select("max(id)").From(testEngine.Quote(tableName))).Find(&res) assert.NoError(t, err) } + +func TestDistinctFindAndCount(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestDistinctFindAndCount struct { + Id int64 + Name string `xorm:"index"` + Age2 int + } + + assertSync(t, new(TestDistinctFindAndCount)) + + objects := make([]*TestDistinctFindAndCount, 0, 10) + total, err := testEngine.Distinct(testEngine.TableName(new(TestDistinctFindAndCount)) + ".*").FindAndCount(&objects) + assert.NoError(t, err) + assert.EqualValues(t, 0, total) +}