From fbf37fc795df623c45a0a5234906f52fdb70b943 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 25 Jul 2017 16:50:20 +0800 Subject: [PATCH] improve count usage (#654) --- session_sum.go => session_stats.go | 5 +---- session_sum_test.go => session_stats_test.go | 8 ++++++++ statement.go | 14 ++++++++++---- 3 files changed, 19 insertions(+), 8 deletions(-) rename session_sum.go => session_stats.go (97%) rename session_sum_test.go => session_stats_test.go (93%) diff --git a/session_sum.go b/session_stats.go similarity index 97% rename from session_sum.go rename to session_stats.go index 2d5ba6bd..49ef0560 100644 --- a/session_sum.go +++ b/session_stats.go @@ -18,10 +18,7 @@ func (session *Session) Count(bean ...interface{}) (int64, error) { var args []interface{} var err error if session.Statement.RawSQL == "" { - if len(bean) == 0 { - return 0, ErrTableNotFound - } - sqlStr, args, err = session.Statement.genCountSQL(bean[0]) + sqlStr, args, err = session.Statement.genCountSQL(bean...) if err != nil { return 0, err } diff --git a/session_sum_test.go b/session_stats_test.go similarity index 93% rename from session_sum_test.go rename to session_stats_test.go index 2d2ad9b2..73f30e1d 100644 --- a/session_sum_test.go +++ b/session_stats_test.go @@ -127,6 +127,14 @@ func TestCount(t *testing.T) { total, err = testEngine.Where(cond).Count(new(UserinfoCount)) assert.NoError(t, err) assert.EqualValues(t, 1, total) + + total, err = testEngine.Where(cond).Table("userinfo_count").Count() + assert.NoError(t, err) + assert.EqualValues(t, 1, total) + + total, err = testEngine.Table("userinfo_count").Count() + assert.NoError(t, err) + assert.EqualValues(t, 1, total) } func TestSQLCount(t *testing.T) { diff --git a/statement.go b/statement.go index c97dea50..b4d2d877 100644 --- a/statement.go +++ b/statement.go @@ -981,10 +981,16 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}, error) { - statement.setRefValue(rValue(bean)) - - condSQL, condArgs, err := statement.genConds(bean) +func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) { + var condSQL string + var condArgs []interface{} + var err error + if len(beans) > 0 { + statement.setRefValue(rValue(beans[0])) + condSQL, condArgs, err = statement.genConds(beans[0]) + } else { + condSQL, condArgs, err = builder.ToSQL(statement.cond) + } if err != nil { return "", nil, err }