From 8a877636fdbbb0f7133b158fe5cde3588464b035 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 6 Jun 2017 14:54:59 +0800 Subject: [PATCH] add custom SQL count support (#609) * add custom SQL count support * fix tests --- doc.go | 9 ++++++++- session_sum.go | 7 +++++-- session_sum_test.go | 22 ++++++++++++++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/doc.go b/doc.go index 9944bcd8..51c3a2a8 100644 --- a/doc.go +++ b/doc.go @@ -51,11 +51,15 @@ There are 8 major ORM methods and many helpful methods to use to operate databas // INSERT INTO struct1 () values () // INSERT INTO struct2 () values (),(),() -2. Query one record from database +2. Query one record or one variable from database has, err := engine.Get(&user) // SELECT * FROM user LIMIT 1 + var id int64 + has, err := engine.Table("user").Where("name = ?", name).Get(&id) + // SELECT id FROM user WHERE name = ? LIMIT 1 + 3. Query multiple records from database var sliceOfStructs []Struct @@ -99,6 +103,9 @@ another is Rows counts, err := engine.Count(&user) // SELECT count(*) AS total FROM user + counts, err := engine.SQL("select count(*) FROM user").Count() + // select count(*) FROM user + 8. Sum records sumFloat64, err := engine.Sum(&user, "id") diff --git a/session_sum.go b/session_sum.go index e1409c7f..8b2d38c2 100644 --- a/session_sum.go +++ b/session_sum.go @@ -8,7 +8,7 @@ import "database/sql" // Count counts the records. bean's non-empty fields // are conditions. -func (session *Session) Count(bean interface{}) (int64, error) { +func (session *Session) Count(bean ...interface{}) (int64, error) { defer session.resetStatement() if session.IsAutoClose { defer session.Close() @@ -17,7 +17,10 @@ func (session *Session) Count(bean interface{}) (int64, error) { var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { - sqlStr, args = session.Statement.genCountSQL(bean) + if len(bean) == 0 { + return 0, ErrTableNotFound + } + sqlStr, args = session.Statement.genCountSQL(bean[0]) } else { sqlStr = session.Statement.RawSQL args = session.Statement.RawParams diff --git a/session_sum_test.go b/session_sum_test.go index 05190653..2d2ad9b2 100644 --- a/session_sum_test.go +++ b/session_sum_test.go @@ -128,3 +128,25 @@ func TestCount(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, total) } + +func TestSQLCount(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserinfoCount2 struct { + Id int64 + Departname string + } + + type UserinfoBooks struct { + Id int64 + Pid int64 + IsOpen bool + } + + assertSync(t, new(UserinfoCount2), new(UserinfoBooks)) + + total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2"). + Count() + assert.NoError(t, err) + assert.EqualValues(t, 0, total) +}