improve count usage (#654)
This commit is contained in:
parent
dbc493df5e
commit
fbf37fc795
|
@ -18,10 +18,7 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
var err error
|
var err error
|
||||||
if session.Statement.RawSQL == "" {
|
if session.Statement.RawSQL == "" {
|
||||||
if len(bean) == 0 {
|
sqlStr, args, err = session.Statement.genCountSQL(bean...)
|
||||||
return 0, ErrTableNotFound
|
|
||||||
}
|
|
||||||
sqlStr, args, err = session.Statement.genCountSQL(bean[0])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
|
@ -127,6 +127,14 @@ func TestCount(t *testing.T) {
|
||||||
total, err = testEngine.Where(cond).Count(new(UserinfoCount))
|
total, err = testEngine.Where(cond).Count(new(UserinfoCount))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, 1, total)
|
assert.EqualValues(t, 1, total)
|
||||||
|
|
||||||
|
total, err = testEngine.Where(cond).Table("userinfo_count").Count()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, total)
|
||||||
|
|
||||||
|
total, err = testEngine.Table("userinfo_count").Count()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, total)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSQLCount(t *testing.T) {
|
func TestSQLCount(t *testing.T) {
|
14
statement.go
14
statement.go
|
@ -981,10 +981,16 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
|
||||||
return sqlStr, append(statement.joinArgs, condArgs...), nil
|
return sqlStr, append(statement.joinArgs, condArgs...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}, error) {
|
func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
|
||||||
statement.setRefValue(rValue(bean))
|
var condSQL string
|
||||||
|
var condArgs []interface{}
|
||||||
condSQL, condArgs, err := statement.genConds(bean)
|
var err error
|
||||||
|
if len(beans) > 0 {
|
||||||
|
statement.setRefValue(rValue(beans[0]))
|
||||||
|
condSQL, condArgs, err = statement.genConds(beans[0])
|
||||||
|
} else {
|
||||||
|
condSQL, condArgs, err = builder.ToSQL(statement.cond)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue