added QueryMap QueryStruct and etc. for Row

This commit is contained in:
Lunny Xiao 2016-01-08 14:33:03 +08:00
parent 7da81a8908
commit 85579d38ad
3 changed files with 147 additions and 21 deletions

106
db.go
View File

@ -2,6 +2,7 @@ package core
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"errors" "errors"
"reflect" "reflect"
"regexp" "regexp"
@ -29,10 +30,24 @@ func StructToSlice(query string, st interface{}) (string, []interface{}, error)
} }
args := make([]interface{}, 0) args := make([]interface{}, 0)
var err error
query = re.ReplaceAllStringFunc(query, func(src string) string { query = re.ReplaceAllStringFunc(query, func(src string) string {
args = append(args, vv.Elem().FieldByName(src[1:]).Interface()) fv := vv.Elem().FieldByName(src[1:]).Interface()
if v, ok := fv.(driver.Valuer); ok {
var value driver.Value
value, err = v.Value()
if err != nil {
return "?"
}
args = append(args, value)
} else {
args = append(args, fv)
}
return "?" return "?"
}) })
if err != nil {
return "", []interface{}{}, err
}
return query, args, nil return query, args, nil
} }
@ -81,28 +96,87 @@ func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
} }
type Row struct { type Row struct {
*sql.Row rows *Rows
// One of these two will be non-nil: // One of these two will be non-nil:
err error // deferred error for easy chaining err error // deferred error for easy chaining
Mapper IMapper }
func (row *Row) Columns() ([]string, error) {
if row.err != nil {
return nil, row.err
}
return row.rows.Columns()
} }
func (row *Row) Scan(dest ...interface{}) error { func (row *Row) Scan(dest ...interface{}) error {
if row.err != nil { if row.err != nil {
return row.err return row.err
} }
return row.Row.Scan(dest...) defer row.rows.Close()
for _, dp := range dest {
if _, ok := dp.(*sql.RawBytes); ok {
return errors.New("sql: RawBytes isn't allowed on Row.Scan")
}
}
if !row.rows.Next() {
if err := row.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
err := row.rows.Scan(dest...)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
if err := row.rows.Close(); err != nil {
return err
}
return nil
}
func (row *Row) ScanStructByName(dest interface{}) error {
if row.err != nil {
return row.err
}
return row.rows.ScanStructByName(dest)
}
func (row *Row) ScanStructByIndex(dest interface{}) error {
if row.err != nil {
return row.err
}
return row.rows.ScanStructByIndex(dest)
}
// scan data to a slice's pointer, slice's length should equal to columns' number
func (row *Row) ScanSlice(dest interface{}) error {
if row.err != nil {
return row.err
}
return row.rows.ScanSlice(dest)
}
// scan data to a map's pointer
func (row *Row) ScanMap(dest interface{}) error {
if row.err != nil {
return row.err
}
return row.rows.ScanMap(dest)
} }
func (db *DB) QueryRow(query string, args ...interface{}) *Row { func (db *DB) QueryRow(query string, args ...interface{}) *Row {
row := db.DB.QueryRow(query, args...) rows, err := db.Query(query, args...)
return &Row{row, nil, db.Mapper} return &Row{rows, err}
} }
func (db *DB) QueryRowMap(query string, mp interface{}) *Row { func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
query, args, err := MapToSlice(query, mp) query, args, err := MapToSlice(query, mp)
if err != nil { if err != nil {
return &Row{nil, err, db.Mapper} return &Row{nil, err}
} }
return db.QueryRow(query, args...) return db.QueryRow(query, args...)
} }
@ -110,7 +184,7 @@ func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
func (db *DB) QueryRowStruct(query string, st interface{}) *Row { func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
query, args, err := StructToSlice(query, st) query, args, err := StructToSlice(query, st)
if err != nil { if err != nil {
return &Row{nil, err, db.Mapper} return &Row{nil, err}
} }
return db.QueryRow(query, args...) return db.QueryRow(query, args...)
} }
@ -200,14 +274,14 @@ func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
} }
func (s *Stmt) QueryRow(args ...interface{}) *Row { func (s *Stmt) QueryRow(args ...interface{}) *Row {
row := s.Stmt.QueryRow(args...) rows, err := s.Query(args...)
return &Row{row, nil, s.Mapper} return &Row{rows, err}
} }
func (s *Stmt) QueryRowMap(mp interface{}) *Row { func (s *Stmt) QueryRowMap(mp interface{}) *Row {
vv := reflect.ValueOf(mp) vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return &Row{nil, errors.New("mp should be a map's pointer"), s.Mapper} return &Row{nil, errors.New("mp should be a map's pointer")}
} }
args := make([]interface{}, len(s.names)) args := make([]interface{}, len(s.names))
@ -221,7 +295,7 @@ func (s *Stmt) QueryRowMap(mp interface{}) *Row {
func (s *Stmt) QueryRowStruct(st interface{}) *Row { func (s *Stmt) QueryRowStruct(st interface{}) *Row {
vv := reflect.ValueOf(st) vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return &Row{nil, errors.New("st should be a struct's pointer"), s.Mapper} return &Row{nil, errors.New("st should be a struct's pointer")}
} }
args := make([]interface{}, len(s.names)) args := make([]interface{}, len(s.names))
@ -553,14 +627,14 @@ func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) {
} }
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
row := tx.Tx.QueryRow(query, args...) rows, err := tx.Query(query, args...)
return &Row{row, nil, tx.Mapper} return &Row{rows, err}
} }
func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row { func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
query, args, err := MapToSlice(query, mp) query, args, err := MapToSlice(query, mp)
if err != nil { if err != nil {
return &Row{nil, err, tx.Mapper} return &Row{nil, err}
} }
return tx.QueryRow(query, args...) return tx.QueryRow(query, args...)
} }
@ -568,7 +642,7 @@ func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row { func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row {
query, args, err := StructToSlice(query, st) query, args, err := StructToSlice(query, st)
if err != nil { if err != nil {
return &Row{nil, err, tx.Mapper} return &Row{nil, err}
} }
return tx.QueryRow(query, args...) return tx.QueryRow(query, args...)
} }

View File

@ -24,7 +24,7 @@ type User struct {
Age float32 Age float32
Alias string Alias string
NickName string NickName string
Created time.Time Created NullTime
} }
func init() { func init() {
@ -85,7 +85,7 @@ func BenchmarkOriQuery(b *testing.B) {
var Id int64 var Id int64
var Name, Title, Alias, NickName string var Name, Title, Alias, NickName string
var Age float32 var Age float32
var Created time.Time var Created NullTime
err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created) err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
@ -600,7 +600,7 @@ func TestExecStruct(t *testing.T) {
Age: 1.2, Age: 1.2,
Alias: "lunny", Alias: "lunny",
NickName: "lunny xiao", NickName: "lunny xiao",
Created: time.Now(), Created: NullTime(time.Now()),
} }
_, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+ _, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+
@ -645,7 +645,7 @@ func BenchmarkExecStruct(b *testing.B) {
Age: 1.2, Age: 1.2,
Alias: "lunny", Alias: "lunny",
NickName: "lunny xiao", NickName: "lunny xiao",
Created: time.Now(), Created: NullTime(time.Now()),
} }
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

52
scan.go Normal file
View File

@ -0,0 +1,52 @@
package core
import (
"database/sql/driver"
"fmt"
"time"
)
type NullTime time.Time
var (
_ driver.Valuer = NullTime{}
)
func (ns *NullTime) Scan(value interface{}) error {
if value == nil {
return nil
}
return convertTime(ns, value)
}
// Value implements the driver Valuer interface.
func (ns NullTime) Value() (driver.Value, error) {
if (time.Time)(ns).IsZero() {
return nil, nil
}
return (time.Time)(ns).Format("2006-01-02 15:04:05"), nil
}
func convertTime(dest *NullTime, src interface{}) error {
// Common cases, without reflect.
switch s := src.(type) {
case string:
t, err := time.Parse("2006-01-02 15:04:05", s)
if err != nil {
return err
}
*dest = NullTime(t)
return nil
case []uint8:
t, err := time.Parse("2006-01-02 15:04:05", string(s))
if err != nil {
return err
}
*dest = NullTime(t)
return nil
case nil:
default:
return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest)
}
return nil
}