Fix distinct count

This commit is contained in:
Lunny Xiao 2024-01-31 18:04:27 +08:00
parent e884f059a4
commit 1c0a25c2b7
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
3 changed files with 49 additions and 30 deletions

View File

@ -19,7 +19,7 @@ func (statement *Statement) isUsingLegacyLimitOffset() bool {
func (statement *Statement) writeMssqlLegacySelect(buf *builder.BytesWriter, columnStr string) error { func (statement *Statement) writeMssqlLegacySelect(buf *builder.BytesWriter, columnStr string) error {
return statement.writeMultiple(buf, return statement.writeMultiple(buf,
statement.writeStrings("SELECT"), statement.writeStrings("SELECT"),
statement.writeDistinct, statement.writeDistinct(columnStr),
statement.writeTop, statement.writeTop,
statement.writeFrom, statement.writeFrom,
statement.writeWhereWithMssqlPagination, statement.writeWhereWithMssqlPagination,
@ -32,7 +32,9 @@ func (statement *Statement) writeMssqlLegacySelect(buf *builder.BytesWriter, col
func (statement *Statement) writeOracleLegacySelect(buf *builder.BytesWriter, columnStr string) error { func (statement *Statement) writeOracleLegacySelect(buf *builder.BytesWriter, columnStr string) error {
return statement.writeMultiple(buf, return statement.writeMultiple(buf,
statement.writeSelectColumns(columnStr), statement.writeStrings("SELECT"),
statement.writeDistinct(columnStr),
statement.writeStrings(" ", columnStr),
statement.writeFrom, statement.writeFrom,
statement.writeOracleLimit(columnStr), statement.writeOracleLimit(columnStr),
statement.writeGroupBy, statement.writeGroupBy,

View File

@ -141,29 +141,20 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
} }
} }
selectSQL := statement.SelectStr selectBuf := builder.NewWriter()
if len(selectSQL) <= 0 { if err := statement.writeSelectCount(selectBuf); err != nil {
if statement.IsDistinct { return "", nil, err
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr())
} else if statement.ColumnStr() != "" {
selectSQL = fmt.Sprintf("count(%s)", statement.ColumnStr())
} else {
selectSQL = "count(*)"
}
} }
buf := builder.NewWriter() buf := builder.NewWriter()
if statement.GroupByStr != "" {
if _, err := fmt.Fprintf(buf, "SELECT %s FROM (", selectSQL); err != nil {
return "", nil, err
}
}
var subQuerySelect string var subQuerySelect string
if statement.GroupByStr != "" { if statement.GroupByStr != "" {
if err := statement.writeStrings("SELECT ", selectBuf.String(), " FROM (")(buf); err != nil {
return "", nil, err
}
subQuerySelect = statement.GroupByStr subQuerySelect = statement.GroupByStr
} else { } else {
subQuerySelect = selectSQL subQuerySelect = selectBuf.String()
} }
if err := statement.writeSelect(buf, subQuerySelect, true); err != nil { if err := statement.writeSelect(buf, subQuerySelect, true); err != nil {
@ -198,20 +189,27 @@ func (statement *Statement) writeTop(w *builder.BytesWriter) error {
return err return err
} }
func (statement *Statement) writeDistinct(w *builder.BytesWriter) error { func (statement *Statement) writeDistinct(selectStr string) func(w *builder.BytesWriter) error {
if statement.IsDistinct && !strings.HasPrefix(statement.SelectStr, "count(") { return func(w *builder.BytesWriter) error {
_, err := fmt.Fprint(w, " DISTINCT") if statement.IsDistinct && !strings.HasPrefix(selectStr, "COUNT(") {
return err _, err := fmt.Fprint(w, " DISTINCT")
return err
}
return nil
} }
return nil
} }
func (statement *Statement) writeSelectColumns(columnStr string) func(w *builder.BytesWriter) error { func (statement *Statement) writeSelectCount(w *builder.BytesWriter) error {
return statement.groupWriteFns( if statement.SelectStr != "" {
statement.writeStrings("SELECT"), return statement.writeStrings(statement.SelectStr)(w)
statement.writeDistinct, }
statement.writeStrings(" ", columnStr),
) if statement.IsDistinct {
return statement.writeStrings("COUNT(DISTINCT ", statement.ColumnStr(), ")")(w)
} else if statement.ColumnStr() != "" {
return statement.writeStrings("COUNT(", statement.ColumnStr(), ")")(w)
}
return statement.writeStrings("COUNT(*)")(w)
} }
func (statement *Statement) writeWhereCond(w *builder.BytesWriter, cond builder.Cond) error { func (statement *Statement) writeWhereCond(w *builder.BytesWriter, cond builder.Cond) error {
@ -253,7 +251,9 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri
} }
return statement.writeMultiple(buf, return statement.writeMultiple(buf,
statement.writeSelectColumns(columnStr), statement.writeStrings("SELECT"),
statement.writeDistinct(columnStr),
statement.writeStrings(" ", columnStr),
statement.writeFrom, statement.writeFrom,
statement.writeWhere, statement.writeWhere,
statement.writeGroupBy, statement.writeGroupBy,

View File

@ -1254,3 +1254,20 @@ func TestFindInMaxID(t *testing.T) {
err := testEngine.In("id", builder.Select("max(id)").From(testEngine.Quote(tableName))).Find(&res) err := testEngine.In("id", builder.Select("max(id)").From(testEngine.Quote(tableName))).Find(&res)
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestDistinctFindAndCount(t *testing.T) {
assert.NoError(t, PrepareEngine())
type TestDistinctFindAndCount struct {
Id int64
Name string `xorm:"index"`
Age2 int
}
assertSync(t, new(TestDistinctFindAndCount))
objects := make([]*TestDistinctFindAndCount, 0, 10)
total, err := testEngine.Distinct(testEngine.TableName(new(TestDistinctFindAndCount)) + ".*").FindAndCount(&objects)
assert.NoError(t, err)
assert.EqualValues(t, 0, total)
}