improve count usage (#654)

This commit is contained in:
Lunny Xiao 2017-07-25 16:50:20 +08:00 committed by GitHub
parent dbc493df5e
commit fbf37fc795
3 changed files with 19 additions and 8 deletions

View File

@ -18,10 +18,7 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
var args []interface{} var args []interface{}
var err error var err error
if session.Statement.RawSQL == "" { if session.Statement.RawSQL == "" {
if len(bean) == 0 { sqlStr, args, err = session.Statement.genCountSQL(bean...)
return 0, ErrTableNotFound
}
sqlStr, args, err = session.Statement.genCountSQL(bean[0])
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -127,6 +127,14 @@ func TestCount(t *testing.T) {
total, err = testEngine.Where(cond).Count(new(UserinfoCount)) total, err = testEngine.Where(cond).Count(new(UserinfoCount))
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, total) assert.EqualValues(t, 1, total)
total, err = testEngine.Where(cond).Table("userinfo_count").Count()
assert.NoError(t, err)
assert.EqualValues(t, 1, total)
total, err = testEngine.Table("userinfo_count").Count()
assert.NoError(t, err)
assert.EqualValues(t, 1, total)
} }
func TestSQLCount(t *testing.T) { func TestSQLCount(t *testing.T) {

View File

@ -981,10 +981,16 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
return sqlStr, append(statement.joinArgs, condArgs...), nil return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}, error) { func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
statement.setRefValue(rValue(bean)) var condSQL string
var condArgs []interface{}
condSQL, condArgs, err := statement.genConds(bean) var err error
if len(beans) > 0 {
statement.setRefValue(rValue(beans[0]))
condSQL, condArgs, err = statement.genConds(beans[0])
} else {
condSQL, condArgs, err = builder.ToSQL(statement.cond)
}
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }