diff --git a/VERSION b/VERSION index a9981730..d5283b40 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -xorm v0.5.4.0630 +xorm v0.5.5.0707 diff --git a/engine.go b/engine.go index 3cd1ea9d..2834935d 100644 --- a/engine.go +++ b/engine.go @@ -1536,14 +1536,34 @@ func (engine *Engine) Rows(bean interface{}) (*Rows, error) { return session.Rows(bean) } -// Count counts the records. bean's non-empty fields -// are conditions. +// Count counts the records. bean's non-empty fields are conditions. func (engine *Engine) Count(bean interface{}) (int64, error) { session := engine.NewSession() defer session.Close() return session.Count(bean) } +// Sum sum the records by some column. bean's non-empty fields are conditions. +func (engine *Engine) Sum(bean interface{}, colName string) (float64, error) { + session := engine.NewSession() + defer session.Close() + return session.Sum(bean, colName) +} + +// Sums sum the records by some columns. bean's non-empty fields are conditions. +func (engine *Engine) Sums(bean interface{}, colNames ...string) ([]float64, error) { + session := engine.NewSession() + defer session.Close() + return session.Sums(bean, colNames...) +} + +// SumsInt like Sums but return slice of int64 instead of float64. +func (engine *Engine) SumsInt(bean interface{}, colNames ...string) ([]int64, error) { + session := engine.NewSession() + defer session.Close() + return session.SumsInt(bean, colNames...) +} + // Import SQL DDL file func (engine *Engine) ImportFile(ddlPath string) ([]sql.Result, error) { file, err := os.Open(ddlPath) diff --git a/session.go b/session.go index 77fcce82..19a35b1d 100644 --- a/session.go +++ b/session.go @@ -1073,21 +1073,115 @@ func (session *Session) Count(bean interface{}) (int64, error) { args = session.Statement.RawParams } - resultsSlice, err := session.query(sqlStr, args...) + session.queryPreprocess(&sqlStr, args...) + + var err error + var total int64 + if session.IsAutoCommit { + err = session.DB().QueryRow(sqlStr, args...).Scan(&total) + } else { + err = session.Tx.QueryRow(sqlStr, args...).Scan(&total) + } if err != nil { return 0, err } - var total int64 - if len(resultsSlice) > 0 { - results := resultsSlice[0] - for _, value := range results { - total, err = strconv.ParseInt(string(value), 10, 64) - break - } + return total, nil +} + +// Sum call sum some column. bean's non-empty fields are conditions. +func (session *Session) Sum(bean interface{}, columnName string) (float64, error) { + defer session.resetStatement() + if session.IsAutoClose { + defer session.Close() } - return int64(total), err + var sqlStr string + var args []interface{} + if len(session.Statement.RawSQL) == 0 { + sqlStr, args = session.Statement.genSumSql(bean, columnName) + } else { + sqlStr = session.Statement.RawSQL + args = session.Statement.RawParams + } + + session.queryPreprocess(&sqlStr, args...) + + var err error + var res float64 + if session.IsAutoCommit { + err = session.DB().QueryRow(sqlStr, args...).Scan(&res) + } else { + err = session.Tx.QueryRow(sqlStr, args...).Scan(&res) + } + if err != nil { + return 0, err + } + + return res, nil +} + +// Sums call sum some columns. bean's non-empty fields are conditions. +func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) { + defer session.resetStatement() + if session.IsAutoClose { + defer session.Close() + } + + var sqlStr string + var args []interface{} + if len(session.Statement.RawSQL) == 0 { + sqlStr, args = session.Statement.genSumSql(bean, columnNames...) + } else { + sqlStr = session.Statement.RawSQL + args = session.Statement.RawParams + } + + session.queryPreprocess(&sqlStr, args...) + + var err error + var res = make([]float64, len(columnNames), len(columnNames)) + if session.IsAutoCommit { + err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res) + } else { + err = session.Tx.QueryRow(sqlStr, args...).ScanSlice(&res) + } + if err != nil { + return nil, err + } + + return res, nil +} + +func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) { + defer session.resetStatement() + if session.IsAutoClose { + defer session.Close() + } + + var sqlStr string + var args []interface{} + if len(session.Statement.RawSQL) == 0 { + sqlStr, args = session.Statement.genSumSql(bean, columnNames...) + } else { + sqlStr = session.Statement.RawSQL + args = session.Statement.RawParams + } + + session.queryPreprocess(&sqlStr, args...) + + var err error + var res = make([]int64, 0, len(columnNames)) + if session.IsAutoCommit { + err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res) + } else { + err = session.Tx.QueryRow(sqlStr, args...).ScanSlice(&res) + } + if err != nil { + return nil, err + } + + return res, nil } // Find retrieve records from table, condiBeans's non-empty fields diff --git a/statement.go b/statement.go index 6bf20e67..8ab1d72b 100644 --- a/statement.go +++ b/statement.go @@ -1220,6 +1220,27 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} return statement.genSelectSQL(fmt.Sprintf("count(%v)", id)), append(append(statement.joinArgs, statement.Params...), statement.BeanArgs...) } +func (statement *Statement) genSumSql(bean interface{}, columns ...string) (string, []interface{}) { + table := statement.Engine.TableInfo(bean) + statement.RefTable = table + + var addedTableName = (len(statement.JoinStr) > 0) + + if !statement.noAutoCondition { + colNames, args := statement.buildConditions(table, bean, true, true, false, true, addedTableName) + + statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.Dialect().AndStr()+" ") + statement.BeanArgs = args + } + + statement.attachInSql() + var sumStrs = make([]string, 0, len(columns)) + for _, colName := range columns { + sumStrs = append(sumStrs, fmt.Sprintf("sum(%s)", colName)) + } + return statement.genSelectSQL(strings.Join(sumStrs, ", ")), append(append(statement.joinArgs, statement.Params...), statement.BeanArgs...) +} + func (statement *Statement) genSelectSQL(columnStr string) (a string) { var distinct string if statement.IsDistinct { diff --git a/xorm.go b/xorm.go index 2115330e..b1047c79 100644 --- a/xorm.go +++ b/xorm.go @@ -17,7 +17,7 @@ import ( const ( // Version show the xorm's version - Version string = "0.5.4.0630" + Version string = "0.5.5.0707" ) func regDrvsNDialects() bool {