added QueryMap QueryStruct and etc. for Row
This commit is contained in:
parent
7da81a8908
commit
85579d38ad
108
db.go
108
db.go
|
@ -2,6 +2,7 @@ package core
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"reflect"
|
||||
"regexp"
|
||||
|
@ -29,10 +30,24 @@ func StructToSlice(query string, st interface{}) (string, []interface{}, error)
|
|||
}
|
||||
|
||||
args := make([]interface{}, 0)
|
||||
var err error
|
||||
query = re.ReplaceAllStringFunc(query, func(src string) string {
|
||||
args = append(args, vv.Elem().FieldByName(src[1:]).Interface())
|
||||
fv := vv.Elem().FieldByName(src[1:]).Interface()
|
||||
if v, ok := fv.(driver.Valuer); ok {
|
||||
var value driver.Value
|
||||
value, err = v.Value()
|
||||
if err != nil {
|
||||
return "?"
|
||||
}
|
||||
args = append(args, value)
|
||||
} else {
|
||||
args = append(args, fv)
|
||||
}
|
||||
return "?"
|
||||
})
|
||||
if err != nil {
|
||||
return "", []interface{}{}, err
|
||||
}
|
||||
return query, args, nil
|
||||
}
|
||||
|
||||
|
@ -81,28 +96,87 @@ func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
|
|||
}
|
||||
|
||||
type Row struct {
|
||||
*sql.Row
|
||||
rows *Rows
|
||||
// One of these two will be non-nil:
|
||||
err error // deferred error for easy chaining
|
||||
Mapper IMapper
|
||||
err error // deferred error for easy chaining
|
||||
}
|
||||
|
||||
func (row *Row) Columns() ([]string, error) {
|
||||
if row.err != nil {
|
||||
return nil, row.err
|
||||
}
|
||||
return row.rows.Columns()
|
||||
}
|
||||
|
||||
func (row *Row) Scan(dest ...interface{}) error {
|
||||
if row.err != nil {
|
||||
return row.err
|
||||
}
|
||||
return row.Row.Scan(dest...)
|
||||
defer row.rows.Close()
|
||||
|
||||
for _, dp := range dest {
|
||||
if _, ok := dp.(*sql.RawBytes); ok {
|
||||
return errors.New("sql: RawBytes isn't allowed on Row.Scan")
|
||||
}
|
||||
}
|
||||
|
||||
if !row.rows.Next() {
|
||||
if err := row.rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
err := row.rows.Scan(dest...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Make sure the query can be processed to completion with no errors.
|
||||
if err := row.rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (row *Row) ScanStructByName(dest interface{}) error {
|
||||
if row.err != nil {
|
||||
return row.err
|
||||
}
|
||||
return row.rows.ScanStructByName(dest)
|
||||
}
|
||||
|
||||
func (row *Row) ScanStructByIndex(dest interface{}) error {
|
||||
if row.err != nil {
|
||||
return row.err
|
||||
}
|
||||
return row.rows.ScanStructByIndex(dest)
|
||||
}
|
||||
|
||||
// scan data to a slice's pointer, slice's length should equal to columns' number
|
||||
func (row *Row) ScanSlice(dest interface{}) error {
|
||||
if row.err != nil {
|
||||
return row.err
|
||||
}
|
||||
return row.rows.ScanSlice(dest)
|
||||
}
|
||||
|
||||
// scan data to a map's pointer
|
||||
func (row *Row) ScanMap(dest interface{}) error {
|
||||
if row.err != nil {
|
||||
return row.err
|
||||
}
|
||||
return row.rows.ScanMap(dest)
|
||||
}
|
||||
|
||||
func (db *DB) QueryRow(query string, args ...interface{}) *Row {
|
||||
row := db.DB.QueryRow(query, args...)
|
||||
return &Row{row, nil, db.Mapper}
|
||||
rows, err := db.Query(query, args...)
|
||||
return &Row{rows, err}
|
||||
}
|
||||
|
||||
func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
|
||||
query, args, err := MapToSlice(query, mp)
|
||||
if err != nil {
|
||||
return &Row{nil, err, db.Mapper}
|
||||
return &Row{nil, err}
|
||||
}
|
||||
return db.QueryRow(query, args...)
|
||||
}
|
||||
|
@ -110,7 +184,7 @@ func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
|
|||
func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
|
||||
query, args, err := StructToSlice(query, st)
|
||||
if err != nil {
|
||||
return &Row{nil, err, db.Mapper}
|
||||
return &Row{nil, err}
|
||||
}
|
||||
return db.QueryRow(query, args...)
|
||||
}
|
||||
|
@ -200,14 +274,14 @@ func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
|
|||
}
|
||||
|
||||
func (s *Stmt) QueryRow(args ...interface{}) *Row {
|
||||
row := s.Stmt.QueryRow(args...)
|
||||
return &Row{row, nil, s.Mapper}
|
||||
rows, err := s.Query(args...)
|
||||
return &Row{rows, err}
|
||||
}
|
||||
|
||||
func (s *Stmt) QueryRowMap(mp interface{}) *Row {
|
||||
vv := reflect.ValueOf(mp)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
|
||||
return &Row{nil, errors.New("mp should be a map's pointer"), s.Mapper}
|
||||
return &Row{nil, errors.New("mp should be a map's pointer")}
|
||||
}
|
||||
|
||||
args := make([]interface{}, len(s.names))
|
||||
|
@ -221,7 +295,7 @@ func (s *Stmt) QueryRowMap(mp interface{}) *Row {
|
|||
func (s *Stmt) QueryRowStruct(st interface{}) *Row {
|
||||
vv := reflect.ValueOf(st)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
|
||||
return &Row{nil, errors.New("st should be a struct's pointer"), s.Mapper}
|
||||
return &Row{nil, errors.New("st should be a struct's pointer")}
|
||||
}
|
||||
|
||||
args := make([]interface{}, len(s.names))
|
||||
|
@ -553,14 +627,14 @@ func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) {
|
|||
}
|
||||
|
||||
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
|
||||
row := tx.Tx.QueryRow(query, args...)
|
||||
return &Row{row, nil, tx.Mapper}
|
||||
rows, err := tx.Query(query, args...)
|
||||
return &Row{rows, err}
|
||||
}
|
||||
|
||||
func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
|
||||
query, args, err := MapToSlice(query, mp)
|
||||
if err != nil {
|
||||
return &Row{nil, err, tx.Mapper}
|
||||
return &Row{nil, err}
|
||||
}
|
||||
return tx.QueryRow(query, args...)
|
||||
}
|
||||
|
@ -568,7 +642,7 @@ func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
|
|||
func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row {
|
||||
query, args, err := StructToSlice(query, st)
|
||||
if err != nil {
|
||||
return &Row{nil, err, tx.Mapper}
|
||||
return &Row{nil, err}
|
||||
}
|
||||
return tx.QueryRow(query, args...)
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ type User struct {
|
|||
Age float32
|
||||
Alias string
|
||||
NickName string
|
||||
Created time.Time
|
||||
Created NullTime
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -85,7 +85,7 @@ func BenchmarkOriQuery(b *testing.B) {
|
|||
var Id int64
|
||||
var Name, Title, Alias, NickName string
|
||||
var Age float32
|
||||
var Created time.Time
|
||||
var Created NullTime
|
||||
err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
|
@ -600,7 +600,7 @@ func TestExecStruct(t *testing.T) {
|
|||
Age: 1.2,
|
||||
Alias: "lunny",
|
||||
NickName: "lunny xiao",
|
||||
Created: time.Now(),
|
||||
Created: NullTime(time.Now()),
|
||||
}
|
||||
|
||||
_, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+
|
||||
|
@ -645,7 +645,7 @@ func BenchmarkExecStruct(b *testing.B) {
|
|||
Age: 1.2,
|
||||
Alias: "lunny",
|
||||
NickName: "lunny xiao",
|
||||
Created: time.Now(),
|
||||
Created: NullTime(time.Now()),
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type NullTime time.Time
|
||||
|
||||
var (
|
||||
_ driver.Valuer = NullTime{}
|
||||
)
|
||||
|
||||
func (ns *NullTime) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
return convertTime(ns, value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullTime) Value() (driver.Value, error) {
|
||||
if (time.Time)(ns).IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
return (time.Time)(ns).Format("2006-01-02 15:04:05"), nil
|
||||
}
|
||||
|
||||
func convertTime(dest *NullTime, src interface{}) error {
|
||||
// Common cases, without reflect.
|
||||
switch s := src.(type) {
|
||||
case string:
|
||||
t, err := time.Parse("2006-01-02 15:04:05", s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*dest = NullTime(t)
|
||||
return nil
|
||||
case []uint8:
|
||||
t, err := time.Parse("2006-01-02 15:04:05", string(s))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*dest = NullTime(t)
|
||||
return nil
|
||||
case nil:
|
||||
default:
|
||||
return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest)
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue