refactor query functions

This commit is contained in:
Lunny Xiao 2017-08-22 14:27:38 +08:00
parent d7f04c3cec
commit 267c2dbc23
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
11 changed files with 155 additions and 255 deletions

View File

@ -273,21 +273,6 @@ func (engine *Engine) logSQL(sqlStr string, sqlArgs ...interface{}) {
}
}
func (engine *Engine) logSQLQueryTime(sqlStr string, args []interface{}, executionBlock func() (*core.Stmt, *core.Rows, error)) (*core.Stmt, *core.Rows, error) {
if engine.showSQL && engine.showExecTime {
b4ExecTime := time.Now()
stmt, res, err := executionBlock()
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
engine.logger.Infof("[SQL] %s %v - took: %v", sqlStr, args, execDuration)
} else {
engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
return stmt, res, err
}
return executionBlock()
}
func (engine *Engine) logSQLExecutionTime(sqlStr string, args []interface{}, executionBlock func() (sql.Result, error)) (sql.Result, error) {
if engine.showSQL && engine.showExecTime {
b4ExecTime := time.Now()

41
rows.go
View File

@ -17,7 +17,6 @@ type Rows struct {
NoTypeCheck bool
session *Session
stmt *core.Stmt
rows *core.Rows
fields []string
beanType reflect.Type
@ -29,8 +28,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
rows.session = session
rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type()
defer rows.session.resetStatement()
var sqlStr string
var args []interface{}
var err error
@ -53,32 +50,11 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
args = rows.session.statement.RawParams
}
for _, filter := range rows.session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, rows.session.statement.RefTable)
}
rows.session.saveLastSQL(sqlStr, args...)
if rows.session.prepareStmt {
rows.stmt, err = rows.session.DB().Prepare(sqlStr)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
rows.rows, err = rows.stmt.Query(args...)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
} else {
rows.rows, err = rows.session.DB().Query(sqlStr, args...)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
rows.rows, err = rows.session.queryRows(sqlStr, args...)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
rows.fields, err = rows.rows.Columns()
@ -142,17 +118,10 @@ func (rows *Rows) Close() error {
if rows.rows != nil {
rows.lastError = rows.rows.Close()
if rows.lastError != nil {
defer rows.stmt.Close()
return rows.lastError
}
}
if rows.stmt != nil {
rows.lastError = rows.stmt.Close()
}
} else {
if rows.stmt != nil {
defer rows.stmt.Close()
}
if rows.rows != nil {
defer rows.rows.Close()
}

View File

@ -777,14 +777,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
return pk, nil
}
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
for _, filter := range session.engine.dialect.Filters() {
*sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable)
}
session.saveLastSQL(*sqlStr, paramStr...)
}
// saveLastSQL stores executed query information
func (session *Session) saveLastSQL(sql string, args ...interface{}) {
session.lastSQL = sql

View File

@ -31,7 +31,7 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
tableName := session.statement.TableName()
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
resultsSlice, err := session.query(newsql, args...)
resultsSlice, err := session.queryBytes(newsql, args...)
if err != nil {
return err
}

View File

@ -10,7 +10,6 @@ import (
"reflect"
"github.com/go-xorm/builder"
"github.com/go-xorm/core"
)
// Exist returns true if the record exist otherwise return false
@ -69,19 +68,11 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
args = session.statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var rawRows *core.Rows
if session.isAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...)
} else {
rawRows, err = session.tx.Query(sqlStr, args...)
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return false, err
}
defer rows.Close()
defer rawRows.Close()
return rawRows.Next(), nil
return rows.Next(), nil
}

View File

@ -157,21 +157,13 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
}
func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error {
var rawRows *core.Rows
var err error
session.queryPreprocess(&sqlStr, args...)
if session.isAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...)
} else {
rawRows, err = session.tx.Query(sqlStr, args...)
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return err
}
defer rawRows.Close()
defer rows.Close()
fields, err := rawRows.Columns()
fields, err := rows.Columns()
if err != nil {
return err
}
@ -245,20 +237,20 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
if err != nil {
return err
}
return session.rows2Beans(rawRows, fields, len(fields), tb, newElemFunc, containerValueSetFunc)
return session.rows2Beans(rows, fields, len(fields), tb, newElemFunc, containerValueSetFunc)
}
for rawRows.Next() {
for rows.Next() {
var newValue = newElemFunc(fields)
bean := newValue.Interface()
switch elemType.Kind() {
case reflect.Slice:
err = rawRows.ScanSlice(bean)
err = rows.ScanSlice(bean)
case reflect.Map:
err = rawRows.ScanMap(bean)
err = rows.ScanMap(bean)
default:
err = rawRows.Scan(bean)
err = rows.Scan(bean)
}
if err != nil {

View File

@ -65,30 +65,21 @@ func (session *Session) Get(bean interface{}) (bool, error) {
}
func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
session.queryPreprocess(&sqlStr, args...)
var rawRows *core.Rows
var err error
if session.isAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...)
} else {
rawRows, err = session.tx.Query(sqlStr, args...)
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return false, err
}
defer rows.Close()
defer rawRows.Close()
if !rawRows.Next() {
if !rows.Next() {
return false, nil
}
switch beanKind {
case reflect.Struct:
fields, err := rawRows.Columns()
fields, err := rows.Columns()
if err != nil {
// WARN: Alougth rawRows return true, but get fields failed
// WARN: Alougth rows return true, but get fields failed
return true, err
}
dataStruct := rValue(bean)
@ -96,19 +87,20 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlS
return false, err
}
scanResults, err := session.row2Slice(rawRows, fields, len(fields), bean)
scanResults, err := session.row2Slice(rows, fields, len(fields), bean)
if err != nil {
return false, err
}
rawRows.Close()
// close it before covert data
rows.Close()
_, err = session.slice2Bean(scanResults, fields, len(fields), bean, &dataStruct, session.statement.RefTable)
case reflect.Slice:
err = rawRows.ScanSlice(bean)
err = rows.ScanSlice(bean)
case reflect.Map:
err = rawRows.ScanMap(bean)
err = rows.ScanMap(bean)
default:
err = rawRows.Scan(bean)
err = rows.Scan(bean)
}
return true, err

View File

@ -395,7 +395,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself.
if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
res, err := session.query("select seq_atable.currval from dual", args...)
res, err := session.queryBytes("select seq_atable.currval from dual", args...)
if err != nil {
return 0, err
}
@ -440,7 +440,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} else if session.engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
//assert table.AutoIncrement != ""
sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement)
res, err := session.query(sqlStr, args...)
res, err := session.queryBytes(sqlStr, args...)
if err != nil {
return 0, err

View File

@ -14,55 +14,6 @@ import (
"github.com/go-xorm/core"
)
func (session *Session) query(sqlStr string, paramStr ...interface{}) ([]map[string][]byte, error) {
session.queryPreprocess(&sqlStr, paramStr...)
if session.isAutoCommit {
return session.innerQuery2(sqlStr, paramStr...)
}
return session.txQuery(session.tx, sqlStr, paramStr...)
}
func (session *Session) txQuery(tx *core.Tx, sqlStr string, params ...interface{}) ([]map[string][]byte, error) {
rows, err := tx.Query(sqlStr, params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2maps(rows)
}
func (session *Session) innerQuery(sqlStr string, params ...interface{}) (*core.Stmt, *core.Rows, error) {
var callback func() (*core.Stmt, *core.Rows, error)
if session.prepareStmt {
callback = func() (*core.Stmt, *core.Rows, error) {
stmt, err := session.doPrepare(sqlStr)
if err != nil {
return nil, nil, err
}
rows, err := stmt.Query(params...)
if err != nil {
return nil, nil, err
}
return stmt, rows, nil
}
} else {
callback = func() (*core.Stmt, *core.Rows, error) {
rows, err := session.DB().Query(sqlStr, params...)
if err != nil {
return nil, nil, err
}
return nil, rows, err
}
}
stmt, rows, err := session.engine.logSQLQueryTime(sqlStr, params, callback)
if err != nil {
return nil, nil, err
}
return stmt, rows, nil
}
func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
fields, err := rows.Columns()
if err != nil {
@ -117,27 +68,6 @@ func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, er
return result, nil
}
func (session *Session) innerQuery2(sqlStr string, params ...interface{}) ([]map[string][]byte, error) {
_, rows, err := session.innerQuery(sqlStr, params...)
if rows != nil {
defer rows.Close()
}
if err != nil {
return nil, err
}
return rows2maps(rows)
}
// Query runs a raw sql and return records as []map[string][]byte
func (session *Session) Query(sqlStr string, paramStr ...interface{}) ([]map[string][]byte, error) {
defer session.resetStatement()
if session.isAutoClose {
defer session.Close()
}
return session.query(sqlStr, paramStr...)
}
func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
fields, err := rows.Columns()
if err != nil {
@ -234,42 +164,136 @@ func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string,
return result, nil
}
func txQuery2(tx *core.Tx, sqlStr string, params ...interface{}) ([]map[string]string, error) {
rows, err := tx.Query(sqlStr, params...)
if err != nil {
return nil, err
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
for _, filter := range session.engine.dialect.Filters() {
*sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable)
}
defer rows.Close()
return rows2Strings(rows)
session.lastSQL = *sqlStr
session.lastSQLArgs = paramStr
}
func query2(db *core.DB, sqlStr string, params ...interface{}) ([]map[string]string, error) {
rows, err := db.Query(sqlStr, params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
}
// QueryString runs a raw sql and return records as []map[string]string
func (session *Session) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) {
func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Rows, error) {
defer session.resetStatement()
session.queryPreprocess(&sqlStr, args...)
if session.engine.showSQL {
if session.engine.showExecTime {
b4ExecTime := time.Now()
defer func() {
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration)
} else {
session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
}()
} else {
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}
if session.isAutoCommit {
if session.prepareStmt {
// don't clear stmt since session will cache them
stmt, err := session.doPrepare(sqlStr)
if err != nil {
return nil, err
}
rows, err := stmt.Query(args...)
if err != nil {
return nil, err
}
return rows, nil
}
rows, err := session.DB().Query(sqlStr, args...)
if err != nil {
return nil, err
}
return rows, nil
}
rows, err := session.tx.Query(sqlStr, args...)
if err != nil {
return nil, err
}
return rows, nil
}
func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row {
return core.NewRow(session.queryRows(sqlStr, args...))
}
func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) {
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2maps(rows)
}
// Query runs a raw sql and return records as []map[string][]byte
func (session *Session) Query(sqlStr string, args ...interface{}) ([]map[string][]byte, error) {
if session.isAutoClose {
defer session.Close()
}
session.queryPreprocess(&sqlStr, args...)
if session.isAutoCommit {
return query2(session.DB(), sqlStr, args...)
}
return txQuery2(session.tx, sqlStr, args...)
return session.queryBytes(sqlStr, args...)
}
// Execute sql
func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) {
// QueryString runs a raw sql and return records as []map[string]string
func (session *Session) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) {
if session.isAutoClose {
defer session.Close()
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
}
func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {
defer session.resetStatement()
session.queryPreprocess(&sqlStr, args...)
if session.engine.showSQL {
if session.engine.showExecTime {
b4ExecTime := time.Now()
defer func() {
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration)
} else {
session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
}()
} else {
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}
if !session.isAutoCommit {
return session.tx.Exec(sqlStr, args...)
}
if session.prepareStmt {
stmt, err := session.doPrepare(sqlStr)
if err != nil {
@ -286,32 +310,8 @@ func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Resul
return session.DB().Exec(sqlStr, args...)
}
func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {
for _, filter := range session.engine.dialect.Filters() {
// TODO: for table name, it's no need to RefTable
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
}
session.saveLastSQL(sqlStr, args...)
return session.engine.logSQLExecutionTime(sqlStr, args, func() (sql.Result, error) {
if session.isAutoCommit {
// FIXME: oci8 can not auto commit (github.com/mattn/go-oci8)
if session.engine.dialect.DBType() == core.ORACLE {
session.Begin()
r, err := session.tx.Exec(sqlStr, args...)
session.Commit()
return r, err
}
return session.innerExec(sqlStr, args...)
}
return session.tx.Exec(sqlStr, args...)
})
}
// Exec raw sql
func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
defer session.resetStatement()
if session.isAutoClose {
defer session.Close()
}

View File

@ -142,7 +142,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
var needDrop = true
if !session.engine.dialect.SupportDropIfExists() {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
results, err := session.query(sqlStr, args...)
results, err := session.queryBytes(sqlStr, args...)
if err != nil {
return err
}
@ -174,7 +174,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
func (session *Session) isTableExist(tableName string) (bool, error) {
defer session.resetStatement()
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
results, err := session.query(sqlStr, args...)
results, err := session.queryBytes(sqlStr, args...)
return len(results) > 0, err
}
@ -200,8 +200,7 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) {
var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName))
err := session.DB().QueryRow(sqlStr).Scan(&total)
session.saveLastSQL(sqlStr)
err := session.queryRow(sqlStr).Scan(&total)
if err != nil {
if err == sql.ErrNoRows {
err = nil

View File

@ -13,7 +13,6 @@ import (
// Count counts the records. bean's non-empty fields
// are conditions.
func (session *Session) Count(bean ...interface{}) (int64, error) {
defer session.resetStatement()
if session.isAutoClose {
defer session.Close()
}
@ -31,15 +30,8 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
args = session.statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
var total int64
if session.isAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(&total)
} else {
err = session.tx.QueryRow(sqlStr, args...).Scan(&total)
}
err = session.queryRow(sqlStr, args...).Scan(&total)
if err == sql.ErrNoRows || err == nil {
return total, nil
}
@ -49,7 +41,6 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
// sum call sum some column. bean's non-empty fields are conditions.
func (session *Session) sum(res interface{}, bean interface{}, columnNames ...string) error {
defer session.resetStatement()
if session.isAutoClose {
defer session.Close()
}
@ -73,22 +64,11 @@ func (session *Session) sum(res interface{}, bean interface{}, columnNames ...st
args = session.statement.RawParams
}
session.queryPreprocess(&sqlStr, args...)
if isSlice {
if session.isAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(res)
} else {
err = session.tx.QueryRow(sqlStr, args...).ScanSlice(res)
}
err = session.queryRow(sqlStr, args...).ScanSlice(res)
} else {
if session.isAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(res)
} else {
err = session.tx.QueryRow(sqlStr, args...).Scan(res)
}
err = session.queryRow(sqlStr, args...).Scan(res)
}
if err == sql.ErrNoRows || err == nil {
return nil
}