diff --git a/session_stats.go b/session_stats.go index bb89d4c2..df02e95a 100644 --- a/session_stats.go +++ b/session_stats.go @@ -4,7 +4,11 @@ package xorm -import "database/sql" +import ( + "database/sql" + "errors" + "reflect" +) // Count counts the records. bean's non-empty fields // are conditions. @@ -43,20 +47,26 @@ func (session *Session) Count(bean ...interface{}) (int64, error) { return 0, err } -// Sum call sum some column. bean's non-empty fields are conditions. -func (session *Session) Sum(bean interface{}, columnName string) (float64, error) { +// sum call sum some column. bean's non-empty fields are conditions. +func (session *Session) sum(res interface{}, bean interface{}, columnNames ...string) error { defer session.resetStatement() if session.isAutoClose { defer session.Close() } + v := reflect.ValueOf(res) + if v.Kind() != reflect.Ptr { + return errors.New("need a pointer to a variable") + } + + var isSlice = v.Elem().Kind() == reflect.Slice var sqlStr string var args []interface{} var err error if len(session.statement.RawSQL) == 0 { - sqlStr, args, err = session.statement.genSumSQL(bean, columnName) + sqlStr, args, err = session.statement.genSumSQL(bean, columnNames...) if err != nil { - return 0, err + return err } } else { sqlStr = session.statement.RawSQL @@ -65,120 +75,44 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error session.queryPreprocess(&sqlStr, args...) - var res float64 - if session.isAutoCommit { - err = session.DB().QueryRow(sqlStr, args...).Scan(&res) + if isSlice { + if session.isAutoCommit { + err = session.DB().QueryRow(sqlStr, args...).ScanSlice(res) + } else { + err = session.tx.QueryRow(sqlStr, args...).ScanSlice(res) + } } else { - err = session.tx.QueryRow(sqlStr, args...).Scan(&res) + if session.isAutoCommit { + err = session.DB().QueryRow(sqlStr, args...).Scan(res) + } else { + err = session.tx.QueryRow(sqlStr, args...).Scan(res) + } } if err == sql.ErrNoRows || err == nil { - return res, nil + return nil } - return 0, err + return err +} + +// Sum call sum some column. bean's non-empty fields are conditions. +func (session *Session) Sum(bean interface{}, columnName string) (res float64, err error) { + return res, session.sum(&res, bean, columnName) } // SumInt call sum some column. bean's non-empty fields are conditions. -func (session *Session) SumInt(bean interface{}, columnName string) (int64, error) { - defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } - - var sqlStr string - var args []interface{} - var err error - if len(session.statement.RawSQL) == 0 { - sqlStr, args, err = session.statement.genSumSQL(bean, columnName) - if err != nil { - return 0, err - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams - } - - session.queryPreprocess(&sqlStr, args...) - - var res int64 - if session.isAutoCommit { - err = session.DB().QueryRow(sqlStr, args...).Scan(&res) - } else { - err = session.tx.QueryRow(sqlStr, args...).Scan(&res) - } - - if err == sql.ErrNoRows || err == nil { - return res, nil - } - return 0, err +func (session *Session) SumInt(bean interface{}, columnName string) (res int64, err error) { + return res, session.sum(&res, bean, columnName) } // Sums call sum some columns. bean's non-empty fields are conditions. func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) { - defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } - - var sqlStr string - var args []interface{} - var err error - if len(session.statement.RawSQL) == 0 { - sqlStr, args, err = session.statement.genSumSQL(bean, columnNames...) - if err != nil { - return nil, err - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams - } - - session.queryPreprocess(&sqlStr, args...) - var res = make([]float64, len(columnNames), len(columnNames)) - if session.isAutoCommit { - err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res) - } else { - err = session.tx.QueryRow(sqlStr, args...).ScanSlice(&res) - } - - if err == sql.ErrNoRows || err == nil { - return res, nil - } - return nil, err + return res, session.sum(&res, bean, columnNames...) } // SumsInt sum specify columns and return as []int64 instead of []float64 func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) { - defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } - - var sqlStr string - var args []interface{} - var err error - if len(session.statement.RawSQL) == 0 { - sqlStr, args, err = session.statement.genSumSQL(bean, columnNames...) - if err != nil { - return nil, err - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams - } - - session.queryPreprocess(&sqlStr, args...) - var res = make([]int64, len(columnNames), len(columnNames)) - if session.isAutoCommit { - err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res) - } else { - err = session.tx.QueryRow(sqlStr, args...).ScanSlice(&res) - } - - if err == sql.ErrNoRows || err == nil { - return res, nil - } - return nil, err + return res, session.sum(&res, bean, columnNames...) }