refactor query functions

This commit is contained in:
Lunny Xiao 2017-08-22 14:27:38 +08:00
parent d7f04c3cec
commit 267c2dbc23
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
11 changed files with 155 additions and 255 deletions

View File

@ -273,21 +273,6 @@ func (engine *Engine) logSQL(sqlStr string, sqlArgs ...interface{}) {
} }
} }
func (engine *Engine) logSQLQueryTime(sqlStr string, args []interface{}, executionBlock func() (*core.Stmt, *core.Rows, error)) (*core.Stmt, *core.Rows, error) {
if engine.showSQL && engine.showExecTime {
b4ExecTime := time.Now()
stmt, res, err := executionBlock()
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
engine.logger.Infof("[SQL] %s %v - took: %v", sqlStr, args, execDuration)
} else {
engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
return stmt, res, err
}
return executionBlock()
}
func (engine *Engine) logSQLExecutionTime(sqlStr string, args []interface{}, executionBlock func() (sql.Result, error)) (sql.Result, error) { func (engine *Engine) logSQLExecutionTime(sqlStr string, args []interface{}, executionBlock func() (sql.Result, error)) (sql.Result, error) {
if engine.showSQL && engine.showExecTime { if engine.showSQL && engine.showExecTime {
b4ExecTime := time.Now() b4ExecTime := time.Now()

33
rows.go
View File

@ -17,7 +17,6 @@ type Rows struct {
NoTypeCheck bool NoTypeCheck bool
session *Session session *Session
stmt *core.Stmt
rows *core.Rows rows *core.Rows
fields []string fields []string
beanType reflect.Type beanType reflect.Type
@ -29,8 +28,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
rows.session = session rows.session = session
rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type() rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type()
defer rows.session.resetStatement()
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
var err error var err error
@ -53,34 +50,13 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
args = rows.session.statement.RawParams args = rows.session.statement.RawParams
} }
for _, filter := range rows.session.engine.dialect.Filters() { rows.rows, err = rows.session.queryRows(sqlStr, args...)
sqlStr = filter.Do(sqlStr, session.engine.dialect, rows.session.statement.RefTable)
}
rows.session.saveLastSQL(sqlStr, args...)
if rows.session.prepareStmt {
rows.stmt, err = rows.session.DB().Prepare(sqlStr)
if err != nil { if err != nil {
rows.lastError = err rows.lastError = err
rows.Close() rows.Close()
return nil, err return nil, err
} }
rows.rows, err = rows.stmt.Query(args...)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
} else {
rows.rows, err = rows.session.DB().Query(sqlStr, args...)
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
}
rows.fields, err = rows.rows.Columns() rows.fields, err = rows.rows.Columns()
if err != nil { if err != nil {
rows.lastError = err rows.lastError = err
@ -142,17 +118,10 @@ func (rows *Rows) Close() error {
if rows.rows != nil { if rows.rows != nil {
rows.lastError = rows.rows.Close() rows.lastError = rows.rows.Close()
if rows.lastError != nil { if rows.lastError != nil {
defer rows.stmt.Close()
return rows.lastError return rows.lastError
} }
} }
if rows.stmt != nil {
rows.lastError = rows.stmt.Close()
}
} else { } else {
if rows.stmt != nil {
defer rows.stmt.Close()
}
if rows.rows != nil { if rows.rows != nil {
defer rows.rows.Close() defer rows.rows.Close()
} }

View File

@ -777,14 +777,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f
return pk, nil return pk, nil
} }
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
for _, filter := range session.engine.dialect.Filters() {
*sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable)
}
session.saveLastSQL(*sqlStr, paramStr...)
}
// saveLastSQL stores executed query information // saveLastSQL stores executed query information
func (session *Session) saveLastSQL(sql string, args ...interface{}) { func (session *Session) saveLastSQL(sql string, args ...interface{}) {
session.lastSQL = sql session.lastSQL = sql

View File

@ -31,7 +31,7 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
tableName := session.statement.TableName() tableName := session.statement.TableName()
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
resultsSlice, err := session.query(newsql, args...) resultsSlice, err := session.queryBytes(newsql, args...)
if err != nil { if err != nil {
return err return err
} }

View File

@ -10,7 +10,6 @@ import (
"reflect" "reflect"
"github.com/go-xorm/builder" "github.com/go-xorm/builder"
"github.com/go-xorm/core"
) )
// Exist returns true if the record exist otherwise return false // Exist returns true if the record exist otherwise return false
@ -69,19 +68,11 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
args = session.statement.RawParams args = session.statement.RawParams
} }
session.queryPreprocess(&sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)
var rawRows *core.Rows
if session.isAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...)
} else {
rawRows, err = session.tx.Query(sqlStr, args...)
}
if err != nil { if err != nil {
return false, err return false, err
} }
defer rows.Close()
defer rawRows.Close() return rows.Next(), nil
return rawRows.Next(), nil
} }

View File

@ -157,21 +157,13 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error {
var rawRows *core.Rows rows, err := session.queryRows(sqlStr, args...)
var err error
session.queryPreprocess(&sqlStr, args...)
if session.isAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...)
} else {
rawRows, err = session.tx.Query(sqlStr, args...)
}
if err != nil { if err != nil {
return err return err
} }
defer rawRows.Close() defer rows.Close()
fields, err := rawRows.Columns() fields, err := rows.Columns()
if err != nil { if err != nil {
return err return err
} }
@ -245,20 +237,20 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
if err != nil { if err != nil {
return err return err
} }
return session.rows2Beans(rawRows, fields, len(fields), tb, newElemFunc, containerValueSetFunc) return session.rows2Beans(rows, fields, len(fields), tb, newElemFunc, containerValueSetFunc)
} }
for rawRows.Next() { for rows.Next() {
var newValue = newElemFunc(fields) var newValue = newElemFunc(fields)
bean := newValue.Interface() bean := newValue.Interface()
switch elemType.Kind() { switch elemType.Kind() {
case reflect.Slice: case reflect.Slice:
err = rawRows.ScanSlice(bean) err = rows.ScanSlice(bean)
case reflect.Map: case reflect.Map:
err = rawRows.ScanMap(bean) err = rows.ScanMap(bean)
default: default:
err = rawRows.Scan(bean) err = rows.Scan(bean)
} }
if err != nil { if err != nil {

View File

@ -65,30 +65,21 @@ func (session *Session) Get(bean interface{}) (bool, error) {
} }
func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
session.queryPreprocess(&sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)
var rawRows *core.Rows
var err error
if session.isAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...)
} else {
rawRows, err = session.tx.Query(sqlStr, args...)
}
if err != nil { if err != nil {
return false, err return false, err
} }
defer rows.Close()
defer rawRows.Close() if !rows.Next() {
if !rawRows.Next() {
return false, nil return false, nil
} }
switch beanKind { switch beanKind {
case reflect.Struct: case reflect.Struct:
fields, err := rawRows.Columns() fields, err := rows.Columns()
if err != nil { if err != nil {
// WARN: Alougth rawRows return true, but get fields failed // WARN: Alougth rows return true, but get fields failed
return true, err return true, err
} }
dataStruct := rValue(bean) dataStruct := rValue(bean)
@ -96,19 +87,20 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlS
return false, err return false, err
} }
scanResults, err := session.row2Slice(rawRows, fields, len(fields), bean) scanResults, err := session.row2Slice(rows, fields, len(fields), bean)
if err != nil { if err != nil {
return false, err return false, err
} }
rawRows.Close() // close it before covert data
rows.Close()
_, err = session.slice2Bean(scanResults, fields, len(fields), bean, &dataStruct, session.statement.RefTable) _, err = session.slice2Bean(scanResults, fields, len(fields), bean, &dataStruct, session.statement.RefTable)
case reflect.Slice: case reflect.Slice:
err = rawRows.ScanSlice(bean) err = rows.ScanSlice(bean)
case reflect.Map: case reflect.Map:
err = rawRows.ScanMap(bean) err = rows.ScanMap(bean)
default: default:
err = rawRows.Scan(bean) err = rows.Scan(bean)
} }
return true, err return true, err

View File

@ -395,7 +395,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// for postgres, many of them didn't implement lastInsertId, so we should // for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself. // implemented it ourself.
if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 { if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
res, err := session.query("select seq_atable.currval from dual", args...) res, err := session.queryBytes("select seq_atable.currval from dual", args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -440,7 +440,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} else if session.engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 { } else if session.engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
//assert table.AutoIncrement != "" //assert table.AutoIncrement != ""
sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement) sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement)
res, err := session.query(sqlStr, args...) res, err := session.queryBytes(sqlStr, args...)
if err != nil { if err != nil {
return 0, err return 0, err

View File

@ -14,55 +14,6 @@ import (
"github.com/go-xorm/core" "github.com/go-xorm/core"
) )
func (session *Session) query(sqlStr string, paramStr ...interface{}) ([]map[string][]byte, error) {
session.queryPreprocess(&sqlStr, paramStr...)
if session.isAutoCommit {
return session.innerQuery2(sqlStr, paramStr...)
}
return session.txQuery(session.tx, sqlStr, paramStr...)
}
func (session *Session) txQuery(tx *core.Tx, sqlStr string, params ...interface{}) ([]map[string][]byte, error) {
rows, err := tx.Query(sqlStr, params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2maps(rows)
}
func (session *Session) innerQuery(sqlStr string, params ...interface{}) (*core.Stmt, *core.Rows, error) {
var callback func() (*core.Stmt, *core.Rows, error)
if session.prepareStmt {
callback = func() (*core.Stmt, *core.Rows, error) {
stmt, err := session.doPrepare(sqlStr)
if err != nil {
return nil, nil, err
}
rows, err := stmt.Query(params...)
if err != nil {
return nil, nil, err
}
return stmt, rows, nil
}
} else {
callback = func() (*core.Stmt, *core.Rows, error) {
rows, err := session.DB().Query(sqlStr, params...)
if err != nil {
return nil, nil, err
}
return nil, rows, err
}
}
stmt, rows, err := session.engine.logSQLQueryTime(sqlStr, params, callback)
if err != nil {
return nil, nil, err
}
return stmt, rows, nil
}
func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
fields, err := rows.Columns() fields, err := rows.Columns()
if err != nil { if err != nil {
@ -117,27 +68,6 @@ func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, er
return result, nil return result, nil
} }
func (session *Session) innerQuery2(sqlStr string, params ...interface{}) ([]map[string][]byte, error) {
_, rows, err := session.innerQuery(sqlStr, params...)
if rows != nil {
defer rows.Close()
}
if err != nil {
return nil, err
}
return rows2maps(rows)
}
// Query runs a raw sql and return records as []map[string][]byte
func (session *Session) Query(sqlStr string, paramStr ...interface{}) ([]map[string][]byte, error) {
defer session.resetStatement()
if session.isAutoClose {
defer session.Close()
}
return session.query(sqlStr, paramStr...)
}
func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
fields, err := rows.Columns() fields, err := rows.Columns()
if err != nil { if err != nil {
@ -234,42 +164,136 @@ func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string,
return result, nil return result, nil
} }
func txQuery2(tx *core.Tx, sqlStr string, params ...interface{}) ([]map[string]string, error) { func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
rows, err := tx.Query(sqlStr, params...) for _, filter := range session.engine.dialect.Filters() {
if err != nil { *sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable)
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
} }
func query2(db *core.DB, sqlStr string, params ...interface{}) ([]map[string]string, error) { session.lastSQL = *sqlStr
rows, err := db.Query(sqlStr, params...) session.lastSQLArgs = paramStr
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
} }
// QueryString runs a raw sql and return records as []map[string]string func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Rows, error) {
func (session *Session) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) {
defer session.resetStatement() defer session.resetStatement()
session.queryPreprocess(&sqlStr, args...)
if session.engine.showSQL {
if session.engine.showExecTime {
b4ExecTime := time.Now()
defer func() {
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration)
} else {
session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
}()
} else {
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}
if session.isAutoCommit {
if session.prepareStmt {
// don't clear stmt since session will cache them
stmt, err := session.doPrepare(sqlStr)
if err != nil {
return nil, err
}
rows, err := stmt.Query(args...)
if err != nil {
return nil, err
}
return rows, nil
}
rows, err := session.DB().Query(sqlStr, args...)
if err != nil {
return nil, err
}
return rows, nil
}
rows, err := session.tx.Query(sqlStr, args...)
if err != nil {
return nil, err
}
return rows, nil
}
func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row {
return core.NewRow(session.queryRows(sqlStr, args...))
}
func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) {
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2maps(rows)
}
// Query runs a raw sql and return records as []map[string][]byte
func (session *Session) Query(sqlStr string, args ...interface{}) ([]map[string][]byte, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
return session.queryBytes(sqlStr, args...)
}
// QueryString runs a raw sql and return records as []map[string]string
func (session *Session) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) {
if session.isAutoClose {
defer session.Close()
}
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2Strings(rows)
}
func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {
defer session.resetStatement()
session.queryPreprocess(&sqlStr, args...) session.queryPreprocess(&sqlStr, args...)
if session.isAutoCommit { if session.engine.showSQL {
return query2(session.DB(), sqlStr, args...) if session.engine.showExecTime {
b4ExecTime := time.Now()
defer func() {
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration)
} else {
session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
}()
} else {
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
} }
return txQuery2(session.tx, sqlStr, args...)
} }
// Execute sql if !session.isAutoCommit {
func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) { return session.tx.Exec(sqlStr, args...)
}
if session.prepareStmt { if session.prepareStmt {
stmt, err := session.doPrepare(sqlStr) stmt, err := session.doPrepare(sqlStr)
if err != nil { if err != nil {
@ -286,32 +310,8 @@ func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Resul
return session.DB().Exec(sqlStr, args...) return session.DB().Exec(sqlStr, args...)
} }
func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {
for _, filter := range session.engine.dialect.Filters() {
// TODO: for table name, it's no need to RefTable
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
}
session.saveLastSQL(sqlStr, args...)
return session.engine.logSQLExecutionTime(sqlStr, args, func() (sql.Result, error) {
if session.isAutoCommit {
// FIXME: oci8 can not auto commit (github.com/mattn/go-oci8)
if session.engine.dialect.DBType() == core.ORACLE {
session.Begin()
r, err := session.tx.Exec(sqlStr, args...)
session.Commit()
return r, err
}
return session.innerExec(sqlStr, args...)
}
return session.tx.Exec(sqlStr, args...)
})
}
// Exec raw sql // Exec raw sql
func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
defer session.resetStatement()
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }

View File

@ -142,7 +142,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
var needDrop = true var needDrop = true
if !session.engine.dialect.SupportDropIfExists() { if !session.engine.dialect.SupportDropIfExists() {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName) sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
results, err := session.query(sqlStr, args...) results, err := session.queryBytes(sqlStr, args...)
if err != nil { if err != nil {
return err return err
} }
@ -174,7 +174,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
func (session *Session) isTableExist(tableName string) (bool, error) { func (session *Session) isTableExist(tableName string) (bool, error) {
defer session.resetStatement() defer session.resetStatement()
sqlStr, args := session.engine.dialect.TableCheckSql(tableName) sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
results, err := session.query(sqlStr, args...) results, err := session.queryBytes(sqlStr, args...)
return len(results) > 0, err return len(results) > 0, err
} }
@ -200,8 +200,7 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) {
var total int64 var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName)) sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName))
err := session.DB().QueryRow(sqlStr).Scan(&total) err := session.queryRow(sqlStr).Scan(&total)
session.saveLastSQL(sqlStr)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil

View File

@ -13,7 +13,6 @@ import (
// Count counts the records. bean's non-empty fields // Count counts the records. bean's non-empty fields
// are conditions. // are conditions.
func (session *Session) Count(bean ...interface{}) (int64, error) { func (session *Session) Count(bean ...interface{}) (int64, error) {
defer session.resetStatement()
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
@ -31,15 +30,8 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
args = session.statement.RawParams args = session.statement.RawParams
} }
session.queryPreprocess(&sqlStr, args...)
var total int64 var total int64
if session.isAutoCommit { err = session.queryRow(sqlStr, args...).Scan(&total)
err = session.DB().QueryRow(sqlStr, args...).Scan(&total)
} else {
err = session.tx.QueryRow(sqlStr, args...).Scan(&total)
}
if err == sql.ErrNoRows || err == nil { if err == sql.ErrNoRows || err == nil {
return total, nil return total, nil
} }
@ -49,7 +41,6 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
// sum call sum some column. bean's non-empty fields are conditions. // sum call sum some column. bean's non-empty fields are conditions.
func (session *Session) sum(res interface{}, bean interface{}, columnNames ...string) error { func (session *Session) sum(res interface{}, bean interface{}, columnNames ...string) error {
defer session.resetStatement()
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
@ -73,22 +64,11 @@ func (session *Session) sum(res interface{}, bean interface{}, columnNames ...st
args = session.statement.RawParams args = session.statement.RawParams
} }
session.queryPreprocess(&sqlStr, args...)
if isSlice { if isSlice {
if session.isAutoCommit { err = session.queryRow(sqlStr, args...).ScanSlice(res)
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(res)
} else { } else {
err = session.tx.QueryRow(sqlStr, args...).ScanSlice(res) err = session.queryRow(sqlStr, args...).Scan(res)
} }
} else {
if session.isAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(res)
} else {
err = session.tx.QueryRow(sqlStr, args...).Scan(res)
}
}
if err == sql.ErrNoRows || err == nil { if err == sql.ErrNoRows || err == nil {
return nil return nil
} }