diff --git a/dialects/dialect.go b/dialects/dialect.go index 81d1ee8d..fc11eac1 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -118,12 +118,9 @@ func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query stri defer rows.Close() if rows.Next() { - if rows.Err() != nil { - return true, rows.Err() - } return true, nil } - return false, nil + return false, rows.Err() } // IsColumnExist returns true if the column of the table exist diff --git a/dialects/mssql.go b/dialects/mssql.go index 08232487..b80543af 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -264,6 +264,9 @@ func (db *mssql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve var version, level, edition string if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } return nil, errors.New("unknow version") } @@ -456,9 +459,6 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } var name, ctype, vdefault string var maxLen, precision, scale int var nullable, isPK, defaultIsNull, isIncrement bool @@ -512,6 +512,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -527,9 +530,6 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) @@ -539,6 +539,9 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.Name = strings.Trim(name, "` ") tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -562,11 +565,8 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, colName, isUnique string @@ -604,6 +604,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } diff --git a/dialects/mysql.go b/dialects/mysql.go index f1445b01..bccaf480 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -213,7 +213,10 @@ func (db *mysql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve var version string if !rows.Next() { - return nil, errors.New("Unknow version") + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") } if err := rows.Scan(&version); err != nil { @@ -405,9 +408,6 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -506,6 +506,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -522,9 +525,6 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() var name, engine string var autoIncr, comment *string @@ -540,6 +540,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.StoreEngine = engine tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -570,11 +573,8 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, colName, nonUnique string err = rows.Scan(&indexName, &nonUnique, &colName) @@ -586,7 +586,7 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName continue } - if "YES" == nonUnique || nonUnique == "1" { + if nonUnique == "YES" || nonUnique == "1" { indexType = schemas.IndexType } else { indexType = schemas.UniqueType @@ -610,6 +610,9 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } diff --git a/dialects/oracle.go b/dialects/oracle.go index 9240046a..45223dd2 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -525,6 +525,9 @@ func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.V var version string if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } return nil, errors.New("unknow version") } @@ -677,9 +680,6 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -759,6 +759,9 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -775,9 +778,6 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() err = rows.Scan(&table.Name) if err != nil { @@ -786,6 +786,9 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -800,11 +803,8 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, colName, uniqueness string @@ -838,6 +838,9 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } diff --git a/dialects/postgres.go b/dialects/postgres.go index e1dca631..0bd8486a 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -810,6 +810,9 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas var version string if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } return nil, errors.New("unknow version") } @@ -1062,7 +1065,10 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab } defer rows.Close() - return rows.Next(), nil + if rows.Next() { + return true, nil + } + return false, rows.Err() } func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { @@ -1098,9 +1104,6 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -1216,6 +1219,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -1237,9 +1243,6 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) @@ -1249,6 +1252,9 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch table.Name = name tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -1279,9 +1285,6 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, indexdef string var colNames []string @@ -1322,6 +1325,9 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN index.IsRegular = isRegular indexes[index.Name] = index } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -1459,9 +1465,6 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri } defer rows.Close() if rows.Next() { - if rows.Err() != nil { - return "", rows.Err() - } var defaultSchema string if err = rows.Scan(&defaultSchema); err != nil { return "", err @@ -1469,6 +1472,9 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri parts := strings.Split(defaultSchema, ",") return strings.TrimSpace(parts[len(parts)-1]), nil } + if rows.Err() != nil { + return "", rows.Err() + } return "", errors.New("no default schema") } diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index da28d9d1..afcbce71 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -169,7 +169,10 @@ func (db *sqlite3) Version(ctx context.Context, queryer core.Queryer) (*schemas. var version string if !rows.Next() { - return nil, errors.New("Unknow version") + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") } if err := rows.Scan(&version); err != nil { @@ -416,14 +419,14 @@ func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableNa var name string if rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } err = rows.Scan(&name) if err != nil { return nil, nil, err } } + if rows.Err() != nil { + return nil, nil, rows.Err() + } if name == "" { return nil, nil, errors.New("no table named " + tableName) @@ -485,6 +488,9 @@ func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*sche } tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -500,9 +506,6 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var tmpSQL sql.NullString err = rows.Scan(&tmpSQL) if err != nil { @@ -547,6 +550,9 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa index.IsRegular = isRegular indexes[index.Name] = index } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } diff --git a/engine.go b/engine.go index 35104b04..20c07e13 100644 --- a/engine.go +++ b/engine.go @@ -551,9 +551,6 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch sess := engine.NewSession() defer sess.Close() for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") if err != nil { return err @@ -610,6 +607,9 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } } + if rows.Err() != nil { + return rows.Err() + } // FIXME: Hack for postgres if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { diff --git a/rows.go b/rows.go index 5e0a1ffe..8e7cc075 100644 --- a/rows.go +++ b/rows.go @@ -5,7 +5,6 @@ package xorm import ( - "database/sql" "errors" "fmt" "reflect" @@ -17,10 +16,9 @@ import ( // Rows rows wrapper a rows to type Rows struct { - session *Session - rows *core.Rows - beanType reflect.Type - lastError error + session *Session + rows *core.Rows + beanType reflect.Type } func newRows(session *Session, bean interface{}) (*Rows, error) { @@ -62,15 +60,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. // See https://gitea.com/xorm/xorm/issues/179 if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled - var colName = session.engine.Quote(col.Name) - if addedTableName { - var nm = session.statement.TableName() - if len(session.statement.TableAlias) > 0 { - nm = session.statement.TableAlias - } - colName = session.engine.Quote(nm) + "." + colName - } - autoCond = session.statement.CondDeleted(col) } } @@ -86,7 +75,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { rows.rows, err = rows.session.queryRows(sqlStr, args...) if err != nil { - rows.lastError = err rows.Close() return nil, err } @@ -96,25 +84,18 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { // Next move cursor to next record, return false if end has reached func (rows *Rows) Next() bool { - if rows.lastError == nil && rows.rows != nil { - hasNext := rows.rows.Next() - if !hasNext { - rows.lastError = sql.ErrNoRows - } - return hasNext - } - return false + return rows.rows.Next() } // Err returns the error, if any, that was encountered during iteration. Err may be called after an explicit or implicit Close. func (rows *Rows) Err() error { - return rows.lastError + return rows.rows.Err() } // Scan row record to bean properties func (rows *Rows) Scan(bean interface{}) error { - if rows.lastError != nil { - return rows.lastError + if rows.Err() != nil { + return rows.Err() } if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { @@ -158,5 +139,5 @@ func (rows *Rows) Close() error { return rows.rows.Close() } - return rows.lastError + return rows.Err() } diff --git a/scan.go b/scan.go index 444aa8ac..728d013a 100644 --- a/scan.go +++ b/scan.go @@ -286,15 +286,15 @@ func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { return nil, err } for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } result, err := row2mapBytes(rows, types, fields) if err != nil { return nil, err } resultsSlice = append(resultsSlice, result) } + if rows.Err() != nil { + return nil, rows.Err() + } return resultsSlice, nil } diff --git a/session.go b/session.go index 8c1d8c3b..62d6a770 100644 --- a/session.go +++ b/session.go @@ -391,9 +391,6 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sq table *schemas.Table, newElemFunc func([]string) reflect.Value, sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } var newValue = newElemFunc(fields) bean := newValue.Interface() dataStruct := newValue.Elem() @@ -415,7 +412,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sq bean: bean, }) } - return nil + return rows.Err() } func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql.ColumnType, bean interface{}) ([]interface{}, error) { diff --git a/session_exist.go b/session_exist.go index e52c618e..b5e4a655 100644 --- a/session_exist.go +++ b/session_exist.go @@ -25,5 +25,8 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { } defer rows.Close() - return rows.Next(), nil + if rows.Next() { + return true, nil + } + return false, rows.Err() } diff --git a/session_find.go b/session_find.go index 89e34e80..010ecd6c 100644 --- a/session_find.go +++ b/session_find.go @@ -255,9 +255,6 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect } for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } var newValue = newElemFunc(fields) bean := newValue.Interface() @@ -278,7 +275,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return err } } - return nil + return rows.Err() } func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { @@ -325,9 +322,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in var i int ids = make([]schemas.PK, 0) for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } i++ if i > 500 { session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") @@ -348,6 +342,9 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ids = append(ids, pk) } + if rows.Err() != nil { + return rows.Err() + } session.engine.logger.Debugf("[cache] cache sql: %v, %v, %v, %v, %v", ids, tableName, sqlStr, newsql, args) err = caches.PutCacheSql(cacher, ids, tableName, newsql, args) diff --git a/session_get.go b/session_get.go index 1062bd9d..08172524 100644 --- a/session_get.go +++ b/session_get.go @@ -159,10 +159,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, defer rows.Close() if !rows.Next() { - if rows.Err() != nil { - return false, rows.Err() - } - return false, nil + return false, rows.Err() } // WARN: Alougth rows return true, but we may also return error. @@ -313,14 +310,14 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf defer rows.Close() if rows.Next() { - if rows.Err() != nil { - return true, rows.Err() - } err = rows.ScanSlice(&res) if err != nil { return true, err } } else { + if rows.Err() != nil { + return false, rows.Err() + } return false, ErrCacheFailed } diff --git a/session_iterate.go b/session_iterate.go index dbbeb3f4..f6301009 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -43,9 +43,6 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { i := 0 for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } b := reflect.New(rows.beanType).Interface() err = rows.Scan(b) if err != nil { @@ -57,7 +54,7 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { } i++ } - return err + return rows.Err() } // BufferSize sets the buffersize for iterate diff --git a/session_query.go b/session_query.go index 8543ba12..a4070985 100644 --- a/session_query.go +++ b/session_query.go @@ -33,15 +33,15 @@ func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string } for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } result, err := session.engine.row2mapStr(rows, types, fields) if err != nil { return nil, err } resultsSlice = append(resultsSlice, result) } + if rows.Err() != nil { + return nil, rows.Err() + } return resultsSlice, nil } @@ -57,15 +57,15 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri } for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } record, err := session.engine.row2sliceStr(rows, types, fields) if err != nil { return nil, err } resultsSlice = append(resultsSlice, record) } + if rows.Err() != nil { + return nil, rows.Err() + } return resultsSlice, nil } @@ -120,15 +120,15 @@ func (session *Session) rows2Interfaces(rows *core.Rows) (resultsSlice []map[str return nil, err } for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } result, err := session.engine.row2mapInterface(rows, types, fields) if err != nil { return nil, err } resultsSlice = append(resultsSlice, result) } + if rows.Err() != nil { + return nil, rows.Err() + } return resultsSlice, nil } diff --git a/session_update.go b/session_update.go index 32e28ae0..4f8e6961 100644 --- a/session_update.go +++ b/session_update.go @@ -59,9 +59,6 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = make([]schemas.PK, 0) for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } var res = make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil { @@ -84,6 +81,9 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = append(ids, pk) } + if rows.Err() != nil { + return rows.Err() + } session.engine.logger.Debugf("[cache] find updated id: %v", ids) } /*else { session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args)