diff --git a/scan.go b/scan.go index 0a9ef613..e19037a0 100644 --- a/scan.go +++ b/scan.go @@ -29,6 +29,25 @@ func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, field return result, nil } +func (engine *Engine) row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string][]byte, error) { + var scanResults = make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + var s sql.NullString + scanResults[i] = &s + } + + if err := rows.Scan(scanResults...); err != nil { + return nil, err + } + + result := make(map[string][]byte, len(fields)) + for ii, key := range fields { + s := scanResults[ii].(*sql.NullString) + result[key] = []byte(s.String) + } + return result, nil +} + func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { results := make([]string, 0, len(fields)) var scanResults = make([]interface{}, len(fields)) diff --git a/session_raw.go b/session_raw.go index 4cfe297a..d5c4520b 100644 --- a/session_raw.go +++ b/session_raw.go @@ -79,41 +79,17 @@ func value2Bytes(rawValue *reflect.Value) ([]byte, error) { return []byte(str), nil } -func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) { - result := make(map[string][]byte) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) - //if row is null then ignore - if rawValue.Interface() == nil { - result[key] = []byte{} - continue - } - - if data, err := value2Bytes(&rawValue); err == nil { - result[key] = data - } else { - return nil, err // !nashtsai! REVIEW, should return err or just error log? - } - } - return result, nil -} - -func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { +func (session *Session) rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { fields, err := rows.Columns() if err != nil { return nil, err } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } for rows.Next() { - result, err := row2map(rows, fields) + result, err := session.engine.row2mapBytes(rows, types, fields) if err != nil { return nil, err } @@ -130,7 +106,7 @@ func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[st } defer rows.Close() - return rows2maps(rows) + return session.rows2maps(rows) } func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {