diff --git a/convert.go b/convert.go index 533dbe99..1aaf5dca 100644 --- a/convert.go +++ b/convert.go @@ -193,6 +193,8 @@ func asFloat64(src interface{}) (float64, error) { return float64(v.Int32), nil case *sql.NullInt64: return float64(v.Int64), nil + case *sql.NullFloat64: + return v.Float64, nil } rv := reflect.ValueOf(src) @@ -717,6 +719,8 @@ func convertAssignV(dv reflect.Value, src interface{}) error { func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { switch tp.Kind() { + case reflect.Ptr: + return asKind(vv.Elem(), tp.Elem()) case reflect.Int64: return vv.Int(), nil case reflect.Int: diff --git a/convert/time.go b/convert/time.go index 5a3e5246..283c7f83 100644 --- a/convert/time.go +++ b/convert/time.go @@ -8,11 +8,16 @@ import ( "fmt" "strconv" "time" + + "xorm.io/xorm/internal/utils" ) // String2Time converts a string to time with original location func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) { if len(s) == 19 { + if s == utils.ZeroTime0 || s == utils.ZeroTime1 { + return &time.Time{}, nil + } dt, err := time.ParseInLocation("2006-01-02 15:04:05", s, originalLocation) if err != nil { return nil, err diff --git a/dialects/dialect.go b/dialects/dialect.go index df33155d..81d1ee8d 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -118,6 +118,9 @@ func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query stri defer rows.Close() if rows.Next() { + if rows.Err() != nil { + return true, rows.Err() + } return true, nil } return false, nil diff --git a/dialects/mssql.go b/dialects/mssql.go index e708ba80..08232487 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -456,6 +456,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { + if rows.Err() != nil { + return nil, nil, rows.Err() + } var name, ctype, vdefault string var maxLen, precision, scale int var nullable, isPK, defaultIsNull, isIncrement bool @@ -524,6 +527,9 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) @@ -558,6 +564,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? indexes := make(map[string]*schemas.Index, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } var indexType int var indexName, colName, isUnique string diff --git a/dialects/mysql.go b/dialects/mysql.go index 9312c071..88c1038e 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -405,6 +405,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { + if rows.Err() != nil { + return nil, nil, rows.Err() + } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -519,6 +522,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } table := schemas.NewEmptyTable() var name, engine string var autoIncr, comment *string @@ -566,6 +572,9 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName indexes := make(map[string]*schemas.Index, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } var indexType int var indexName, colName, nonUnique string err = rows.Scan(&indexName, &nonUnique, &colName) diff --git a/dialects/oracle.go b/dialects/oracle.go index 5dd92887..9240046a 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -677,6 +677,9 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { + if rows.Err() != nil { + return nil, nil, rows.Err() + } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -772,6 +775,9 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem tables := make([]*schemas.Table, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } table := schemas.NewEmptyTable() err = rows.Scan(&table.Name) if err != nil { @@ -796,6 +802,9 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam indexes := make(map[string]*schemas.Index, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } var indexType int var indexName, colName, uniqueness string diff --git a/dialects/postgres.go b/dialects/postgres.go index 4ec780e8..e1dca631 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -810,7 +810,7 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas var version string if !rows.Next() { - return nil, errors.New("Unknow version") + return nil, errors.New("unknow version") } if err := rows.Scan(&version); err != nil { @@ -1098,6 +1098,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A colSeq := make([]string, 0) for rows.Next() { + if rows.Err() != nil { + return nil, nil, rows.Err() + } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -1192,7 +1195,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A } } if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok { - return nil, nil, fmt.Errorf("Unknown colType: %s - %s", dataType, col.SQLType.Name) + return nil, nil, fmt.Errorf("unknown colType: %s - %s", dataType, col.SQLType.Name) } col.Length = maxLen @@ -1200,13 +1203,13 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A if !col.DefaultIsEmpty { if col.SQLType.IsText() { if strings.HasSuffix(col.Default, "::character varying") { - col.Default = strings.TrimRight(col.Default, "::character varying") + col.Default = strings.TrimSuffix(col.Default, "::character varying") } else if !strings.HasPrefix(col.Default, "'") { col.Default = "'" + col.Default + "'" } } else if col.SQLType.IsTime() { if strings.HasSuffix(col.Default, "::timestamp without time zone") { - col.Default = strings.TrimRight(col.Default, "::timestamp without time zone") + col.Default = strings.TrimSuffix(col.Default, "::timestamp without time zone") } } } @@ -1234,6 +1237,9 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch tables := make([]*schemas.Table, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) @@ -1259,7 +1265,7 @@ func getIndexColName(indexdef string) []string { func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} - s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") + s := "SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1" if len(db.getSchema()) != 0 { args = append(args, db.getSchema()) s = s + " AND schemaname=$2" @@ -1271,8 +1277,11 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } var indexType int var indexName, indexdef string var colNames []string @@ -1450,6 +1459,9 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri } defer rows.Close() if rows.Next() { + if rows.Err() != nil { + return "", rows.Err() + } var defaultSchema string if err = rows.Scan(&defaultSchema); err != nil { return "", err @@ -1458,5 +1470,5 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri return strings.TrimSpace(parts[len(parts)-1]), nil } - return "", errors.New("No default schema") + return "", errors.New("no default schema") } diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 581272ad..da28d9d1 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -415,12 +415,14 @@ func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableNa defer rows.Close() var name string - for rows.Next() { + if rows.Next() { + if rows.Err() != nil { + return nil, nil, rows.Err() + } err = rows.Scan(&name) if err != nil { return nil, nil, err } - break } if name == "" { @@ -496,8 +498,11 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } var tmpSQL sql.NullString err = rows.Scan(&tmpSQL) if err != nil { diff --git a/engine.go b/engine.go index b4ef9593..35104b04 100644 --- a/engine.go +++ b/engine.go @@ -551,6 +551,9 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch sess := engine.NewSession() defer sess.Close() for rows.Next() { + if rows.Err() != nil { + return rows.Err() + } _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") if err != nil { return err diff --git a/scan.go b/scan.go index e4c0e4a1..444aa8ac 100644 --- a/scan.go +++ b/scan.go @@ -286,6 +286,9 @@ func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { return nil, err } for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } result, err := row2mapBytes(rows, types, fields) if err != nil { return nil, err diff --git a/session.go b/session.go index 5557d717..8c1d8c3b 100644 --- a/session.go +++ b/session.go @@ -364,25 +364,24 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, return } -func (session *Session) getField(dataStruct *reflect.Value, key string, table *schemas.Table, idx int) (*reflect.Value, error) { - var col *schemas.Column - if col = table.GetColumnIdx(key, idx); col == nil { - return nil, ErrFieldIsNotExist{key, table.Name} +func (session *Session) getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) { + var col = table.GetColumnIdx(colName, idx) + if col == nil { + return nil, nil, ErrFieldIsNotExist{colName, table.Name} } fieldValue, err := col.ValueOfV(dataStruct) if err != nil { - return nil, err + return nil, nil, err } if fieldValue == nil { - return nil, ErrFieldIsNotValid{key, table.Name} + return nil, nil, ErrFieldIsNotValid{colName, table.Name} } - if !fieldValue.IsValid() || !fieldValue.CanSet() { - return nil, ErrFieldIsNotValid{key, table.Name} + return nil, nil, ErrFieldIsNotValid{colName, table.Name} } - return fieldValue, nil + return col, fieldValue, nil } // Cell cell is a result of one column field @@ -392,6 +391,9 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sq table *schemas.Table, newElemFunc func([]string) reflect.Value, sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { for rows.Next() { + if rows.Err() != nil { + return rows.Err() + } var newValue = newElemFunc(fields) bean := newValue.Interface() dataStruct := newValue.Elem() @@ -435,6 +437,36 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql return scanResults, nil } +func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, scanResult interface{}) error { + bs, ok := asBytes(scanResult) + if !ok { + return fmt.Errorf("unsupported database data type: %#v", scanResult) + } + if len(bs) == 0 { + return nil + } + + if fieldType.Kind() == reflect.String { + fieldValue.SetString(string(bs)) + return nil + } + + if fieldValue.CanAddr() { + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + if err != nil { + return err + } + } else { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } + return nil +} + func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, scanResult interface{}, table *schemas.Table) error { v, ok := scanResult.(*interface{}) @@ -445,12 +477,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return nil } - rawValue := reflect.Indirect(reflect.ValueOf(scanResult)) - // if row is null then ignore - if rawValue.Interface() == nil { - return nil - } - if fieldValue.CanAddr() { if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { data, ok := asBytes(scanResult) @@ -477,40 +503,11 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return structConvert.FromDB(data) } - rawValueType := reflect.TypeOf(rawValue.Interface()) - vv := reflect.ValueOf(rawValue.Interface()) + vv := reflect.ValueOf(scanResult) fieldType := fieldValue.Type() if col.IsJSON { - var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(schemas.BytesType) { - bs = vv.Bytes() - } else { - return fmt.Errorf("unsupported database data type: %s %v", col.Name, rawValueType.Kind()) - } - - if len(bs) > 0 { - if fieldType.Kind() == reflect.String { - fieldValue.SetString(string(bs)) - return nil - } - if fieldValue.CanAddr() { - err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return err - } - } else { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - } - return nil + return session.setJSON(fieldValue, fieldType, scanResult) } switch fieldType.Kind() { @@ -529,30 +526,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } return nil case reflect.Complex64, reflect.Complex128: - // TODO: reimplement this - var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(schemas.BytesType) { - bs = vv.Bytes() - } - - if len(bs) > 0 { - if fieldValue.CanAddr() { - err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return err - } - } else { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - } - return nil + return session.setJSON(fieldValue, fieldType, scanResult) case reflect.Slice, reflect.Array: bs, ok := asBytes(scanResult) if ok && fieldType.Elem().Kind() == reflect.Uint8 { @@ -602,33 +576,11 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType)) return nil } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err := nulVal.Scan(vv.Interface()) + err := nulVal.Scan(scanResult) if err == nil { return nil } session.engine.logger.Errorf("sql.Sanner error: %v", err) - } else if col.IsJSON { - if rawValueType.Kind() == reflect.String { - x := reflect.New(fieldType) - if len([]byte(vv.String())) > 0 { - err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - return nil - } else if rawValueType.Kind() == reflect.Slice { - x := reflect.New(fieldType) - if len(vv.Bytes()) > 0 { - err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - return nil - } } else if session.statement.UseCascade { table, err := session.engine.tagParser.ParseWithCache(*fieldValue) if err != nil { @@ -639,7 +591,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return errors.New("unsupported non or composited primary key cascade") } var pk = make(schemas.PK, len(table.PrimaryKeys)) - pk[0], err = asKind(vv, rawValueType) + pk[0], err = asKind(vv, reflect.TypeOf(scanResult)) if err != nil { return err } @@ -675,9 +627,9 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b var tempMap = make(map[string]int) var pk schemas.PK - for ii, key := range fields { + for i, colName := range fields { var idx int - var lKey = strings.ToLower(key) + var lKey = strings.ToLower(colName) var ok bool if idx, ok = tempMap[lKey]; !ok { @@ -685,13 +637,9 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } else { idx = idx + 1 } - tempMap[lKey] = idx - col := table.GetColumnIdx(key, idx) - var scanResult = scanResults[ii] - - fieldValue, err := session.getField(dataStruct, key, table, idx) + col, fieldValue, err := session.getField(dataStruct, table, colName, idx) if err != nil { if _, ok := err.(ErrFieldIsNotValid); !ok { session.engine.logger.Warnf("%v", err) @@ -702,11 +650,11 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b continue } - if err := session.convertBeanField(col, fieldValue, scanResult, table); err != nil { + if err := session.convertBeanField(col, fieldValue, scanResults[i], table); err != nil { return nil, err } if col.IsPrimaryKey { - pk = append(pk, scanResult) + pk = append(pk, scanResults[i]) } } return pk, nil diff --git a/session_convert.go b/session_convert.go deleted file mode 100644 index 452801e2..00000000 --- a/session_convert.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "fmt" - "strconv" - "strings" - "time" - - "xorm.io/xorm/internal/utils" - "xorm.io/xorm/schemas" -) - -func (session *Session) str2Time(col *schemas.Column, data string) (outTime time.Time, outErr error) { - sdata := strings.TrimSpace(data) - var x time.Time - var err error - - var parseLoc = session.engine.DatabaseTZ - if col.TimeZone != nil { - parseLoc = col.TimeZone - } - - if sdata == utils.ZeroTime0 || sdata == utils.ZeroTime1 { - } else if !strings.ContainsAny(sdata, "- :") { // !nashtsai! has only found that mymysql driver is using this for time type column - // time stamp - sd, err := strconv.ParseInt(sdata, 10, 64) - if err == nil { - x = time.Unix(sd, 0) - } - } else if len(sdata) > 19 && strings.Contains(sdata, "-") { - x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc) - session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.Name, x, sdata) - if err != nil { - x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc) - } - if err != nil { - x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc) - } - } else if len(sdata) == 19 && strings.Contains(sdata, "-") { - x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc) - } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { - x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) - } else if col.SQLType.Name == schemas.Time { - if strings.Contains(sdata, " ") { - ssd := strings.Split(sdata, " ") - sdata = ssd[1] - } - - sdata = strings.TrimSpace(sdata) - if session.engine.dialect.URI().DBType == schemas.MYSQL && len(sdata) > 8 { - sdata = sdata[len(sdata)-8:] - } - - st := fmt.Sprintf("2006-01-02 %v", sdata) - x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc) - } else { - outErr = fmt.Errorf("unsupported time format %v", sdata) - return - } - if err != nil { - outErr = fmt.Errorf("unsupported time format %v: %v", sdata, err) - return - } - outTime = x.In(session.engine.TZLocation) - return -} diff --git a/session_find.go b/session_find.go index 41d68479..89e34e80 100644 --- a/session_find.go +++ b/session_find.go @@ -255,6 +255,9 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect } for rows.Next() { + if rows.Err() != nil { + return rows.Err() + } var newValue = newElemFunc(fields) bean := newValue.Interface() @@ -322,6 +325,9 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in var i int ids = make([]schemas.PK, 0) for rows.Next() { + if rows.Err() != nil { + return rows.Err() + } i++ if i > 500 { session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") diff --git a/session_get.go b/session_get.go index fa97e68e..1062bd9d 100644 --- a/session_get.go +++ b/session_get.go @@ -313,9 +313,12 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf defer rows.Close() if rows.Next() { + if rows.Err() != nil { + return true, rows.Err() + } err = rows.ScanSlice(&res) if err != nil { - return false, err + return true, err } } else { return false, ErrCacheFailed diff --git a/session_iterate.go b/session_iterate.go index 8cab8f48..dbbeb3f4 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -43,6 +43,9 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { i := 0 for rows.Next() { + if rows.Err() != nil { + return rows.Err() + } b := reflect.New(rows.beanType).Interface() err = rows.Scan(b) if err != nil { diff --git a/session_query.go b/session_query.go index d14c3908..8543ba12 100644 --- a/session_query.go +++ b/session_query.go @@ -33,6 +33,9 @@ func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string } for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } result, err := session.engine.row2mapStr(rows, types, fields) if err != nil { return nil, err @@ -54,6 +57,9 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri } for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } record, err := session.engine.row2sliceStr(rows, types, fields) if err != nil { return nil, err @@ -114,6 +120,9 @@ func (session *Session) rows2Interfaces(rows *core.Rows) (resultsSlice []map[str return nil, err } for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } result, err := session.engine.row2mapInterface(rows, types, fields) if err != nil { return nil, err diff --git a/session_update.go b/session_update.go index 78907e43..32e28ae0 100644 --- a/session_update.go +++ b/session_update.go @@ -59,6 +59,9 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = make([]schemas.PK, 0) for rows.Next() { + if rows.Err() != nil { + return rows.Err() + } var res = make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil {