diff --git a/dialects/dialect.go b/dialects/dialect.go index 81d1ee8d..fc11eac1 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -118,12 +118,9 @@ func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query stri defer rows.Close() if rows.Next() { - if rows.Err() != nil { - return true, rows.Err() - } return true, nil } - return false, nil + return false, rows.Err() } // IsColumnExist returns true if the column of the table exist diff --git a/dialects/mssql.go b/dialects/mssql.go index 08232487..b80543af 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -264,6 +264,9 @@ func (db *mssql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve var version, level, edition string if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } return nil, errors.New("unknow version") } @@ -456,9 +459,6 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } var name, ctype, vdefault string var maxLen, precision, scale int var nullable, isPK, defaultIsNull, isIncrement bool @@ -512,6 +512,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -527,9 +530,6 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) @@ -539,6 +539,9 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.Name = strings.Trim(name, "` ") tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -562,11 +565,8 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, colName, isUnique string @@ -604,6 +604,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } diff --git a/dialects/mysql.go b/dialects/mysql.go index f1445b01..bccaf480 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -213,7 +213,10 @@ func (db *mysql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve var version string if !rows.Next() { - return nil, errors.New("Unknow version") + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") } if err := rows.Scan(&version); err != nil { @@ -405,9 +408,6 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -506,6 +506,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -522,9 +525,6 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() var name, engine string var autoIncr, comment *string @@ -540,6 +540,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.StoreEngine = engine tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -570,11 +573,8 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, colName, nonUnique string err = rows.Scan(&indexName, &nonUnique, &colName) @@ -586,7 +586,7 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName continue } - if "YES" == nonUnique || nonUnique == "1" { + if nonUnique == "YES" || nonUnique == "1" { indexType = schemas.IndexType } else { indexType = schemas.UniqueType @@ -610,6 +610,9 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } diff --git a/dialects/oracle.go b/dialects/oracle.go index 9240046a..45223dd2 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -525,6 +525,9 @@ func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.V var version string if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } return nil, errors.New("unknow version") } @@ -677,9 +680,6 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -759,6 +759,9 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -775,9 +778,6 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() err = rows.Scan(&table.Name) if err != nil { @@ -786,6 +786,9 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -800,11 +803,8 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, colName, uniqueness string @@ -838,6 +838,9 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } diff --git a/dialects/postgres.go b/dialects/postgres.go index e1dca631..0bd8486a 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -810,6 +810,9 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas var version string if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } return nil, errors.New("unknow version") } @@ -1062,7 +1065,10 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab } defer rows.Close() - return rows.Next(), nil + if rows.Next() { + return true, nil + } + return false, rows.Err() } func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { @@ -1098,9 +1104,6 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -1216,6 +1219,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -1237,9 +1243,6 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) @@ -1249,6 +1252,9 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch table.Name = name tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -1279,9 +1285,6 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, indexdef string var colNames []string @@ -1322,6 +1325,9 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN index.IsRegular = isRegular indexes[index.Name] = index } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -1459,9 +1465,6 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri } defer rows.Close() if rows.Next() { - if rows.Err() != nil { - return "", rows.Err() - } var defaultSchema string if err = rows.Scan(&defaultSchema); err != nil { return "", err @@ -1469,6 +1472,9 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri parts := strings.Split(defaultSchema, ",") return strings.TrimSpace(parts[len(parts)-1]), nil } + if rows.Err() != nil { + return "", rows.Err() + } return "", errors.New("no default schema") } diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index da28d9d1..afcbce71 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -169,7 +169,10 @@ func (db *sqlite3) Version(ctx context.Context, queryer core.Queryer) (*schemas. var version string if !rows.Next() { - return nil, errors.New("Unknow version") + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") } if err := rows.Scan(&version); err != nil { @@ -416,14 +419,14 @@ func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableNa var name string if rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } err = rows.Scan(&name) if err != nil { return nil, nil, err } } + if rows.Err() != nil { + return nil, nil, rows.Err() + } if name == "" { return nil, nil, errors.New("no table named " + tableName) @@ -485,6 +488,9 @@ func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*sche } tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -500,9 +506,6 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var tmpSQL sql.NullString err = rows.Scan(&tmpSQL) if err != nil { @@ -547,6 +550,9 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa index.IsRegular = isRegular indexes[index.Name] = index } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } diff --git a/engine.go b/engine.go index 35104b04..20c07e13 100644 --- a/engine.go +++ b/engine.go @@ -551,9 +551,6 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch sess := engine.NewSession() defer sess.Close() for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") if err != nil { return err @@ -610,6 +607,9 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } } + if rows.Err() != nil { + return rows.Err() + } // FIXME: Hack for postgres if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { diff --git a/rows.go b/rows.go index 5e0a1ffe..8e7cc075 100644 --- a/rows.go +++ b/rows.go @@ -5,7 +5,6 @@ package xorm import ( - "database/sql" "errors" "fmt" "reflect" @@ -17,10 +16,9 @@ import ( // Rows rows wrapper a rows to type Rows struct { - session *Session - rows *core.Rows - beanType reflect.Type - lastError error + session *Session + rows *core.Rows + beanType reflect.Type } func newRows(session *Session, bean interface{}) (*Rows, error) { @@ -62,15 +60,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { // !oinume! Add "