diff --git a/db.go b/db.go index a7947b05..169d8553 100644 --- a/db.go +++ b/db.go @@ -2,6 +2,7 @@ package core import ( "database/sql" + "database/sql/driver" "errors" "reflect" "regexp" @@ -29,10 +30,24 @@ func StructToSlice(query string, st interface{}) (string, []interface{}, error) } args := make([]interface{}, 0) + var err error query = re.ReplaceAllStringFunc(query, func(src string) string { - args = append(args, vv.Elem().FieldByName(src[1:]).Interface()) + fv := vv.Elem().FieldByName(src[1:]).Interface() + if v, ok := fv.(driver.Valuer); ok { + var value driver.Value + value, err = v.Value() + if err != nil { + return "?" + } + args = append(args, value) + } else { + args = append(args, fv) + } return "?" }) + if err != nil { + return "", []interface{}{}, err + } return query, args, nil } @@ -81,28 +96,87 @@ func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) { } type Row struct { - *sql.Row + rows *Rows // One of these two will be non-nil: - err error // deferred error for easy chaining - Mapper IMapper + err error // deferred error for easy chaining +} + +func (row *Row) Columns() ([]string, error) { + if row.err != nil { + return nil, row.err + } + return row.rows.Columns() } func (row *Row) Scan(dest ...interface{}) error { if row.err != nil { return row.err } - return row.Row.Scan(dest...) + defer row.rows.Close() + + for _, dp := range dest { + if _, ok := dp.(*sql.RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.Scan(dest...) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + if err := row.rows.Close(); err != nil { + return err + } + + return nil +} + +func (row *Row) ScanStructByName(dest interface{}) error { + if row.err != nil { + return row.err + } + return row.rows.ScanStructByName(dest) +} + +func (row *Row) ScanStructByIndex(dest interface{}) error { + if row.err != nil { + return row.err + } + return row.rows.ScanStructByIndex(dest) +} + +// scan data to a slice's pointer, slice's length should equal to columns' number +func (row *Row) ScanSlice(dest interface{}) error { + if row.err != nil { + return row.err + } + return row.rows.ScanSlice(dest) +} + +// scan data to a map's pointer +func (row *Row) ScanMap(dest interface{}) error { + if row.err != nil { + return row.err + } + return row.rows.ScanMap(dest) } func (db *DB) QueryRow(query string, args ...interface{}) *Row { - row := db.DB.QueryRow(query, args...) - return &Row{row, nil, db.Mapper} + rows, err := db.Query(query, args...) + return &Row{rows, err} } func (db *DB) QueryRowMap(query string, mp interface{}) *Row { query, args, err := MapToSlice(query, mp) if err != nil { - return &Row{nil, err, db.Mapper} + return &Row{nil, err} } return db.QueryRow(query, args...) } @@ -110,7 +184,7 @@ func (db *DB) QueryRowMap(query string, mp interface{}) *Row { func (db *DB) QueryRowStruct(query string, st interface{}) *Row { query, args, err := StructToSlice(query, st) if err != nil { - return &Row{nil, err, db.Mapper} + return &Row{nil, err} } return db.QueryRow(query, args...) } @@ -200,14 +274,14 @@ func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) { } func (s *Stmt) QueryRow(args ...interface{}) *Row { - row := s.Stmt.QueryRow(args...) - return &Row{row, nil, s.Mapper} + rows, err := s.Query(args...) + return &Row{rows, err} } func (s *Stmt) QueryRowMap(mp interface{}) *Row { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { - return &Row{nil, errors.New("mp should be a map's pointer"), s.Mapper} + return &Row{nil, errors.New("mp should be a map's pointer")} } args := make([]interface{}, len(s.names)) @@ -221,7 +295,7 @@ func (s *Stmt) QueryRowMap(mp interface{}) *Row { func (s *Stmt) QueryRowStruct(st interface{}) *Row { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { - return &Row{nil, errors.New("st should be a struct's pointer"), s.Mapper} + return &Row{nil, errors.New("st should be a struct's pointer")} } args := make([]interface{}, len(s.names)) @@ -553,14 +627,14 @@ func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) { } func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { - row := tx.Tx.QueryRow(query, args...) - return &Row{row, nil, tx.Mapper} + rows, err := tx.Query(query, args...) + return &Row{rows, err} } func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row { query, args, err := MapToSlice(query, mp) if err != nil { - return &Row{nil, err, tx.Mapper} + return &Row{nil, err} } return tx.QueryRow(query, args...) } @@ -568,7 +642,7 @@ func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row { func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row { query, args, err := StructToSlice(query, st) if err != nil { - return &Row{nil, err, tx.Mapper} + return &Row{nil, err} } return tx.QueryRow(query, args...) } diff --git a/db_test.go b/db_test.go index 65bf2414..94c4ea4f 100644 --- a/db_test.go +++ b/db_test.go @@ -24,7 +24,7 @@ type User struct { Age float32 Alias string NickName string - Created time.Time + Created NullTime } func init() { @@ -85,7 +85,7 @@ func BenchmarkOriQuery(b *testing.B) { var Id int64 var Name, Title, Alias, NickName string var Age float32 - var Created time.Time + var Created NullTime err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created) if err != nil { b.Error(err) @@ -600,7 +600,7 @@ func TestExecStruct(t *testing.T) { Age: 1.2, Alias: "lunny", NickName: "lunny xiao", - Created: time.Now(), + Created: NullTime(time.Now()), } _, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+ @@ -645,7 +645,7 @@ func BenchmarkExecStruct(b *testing.B) { Age: 1.2, Alias: "lunny", NickName: "lunny xiao", - Created: time.Now(), + Created: NullTime(time.Now()), } for i := 0; i < b.N; i++ { diff --git a/scan.go b/scan.go new file mode 100644 index 00000000..7da338d8 --- /dev/null +++ b/scan.go @@ -0,0 +1,52 @@ +package core + +import ( + "database/sql/driver" + "fmt" + "time" +) + +type NullTime time.Time + +var ( + _ driver.Valuer = NullTime{} +) + +func (ns *NullTime) Scan(value interface{}) error { + if value == nil { + return nil + } + return convertTime(ns, value) +} + +// Value implements the driver Valuer interface. +func (ns NullTime) Value() (driver.Value, error) { + if (time.Time)(ns).IsZero() { + return nil, nil + } + return (time.Time)(ns).Format("2006-01-02 15:04:05"), nil +} + +func convertTime(dest *NullTime, src interface{}) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + t, err := time.Parse("2006-01-02 15:04:05", s) + if err != nil { + return err + } + *dest = NullTime(t) + return nil + case []uint8: + t, err := time.Parse("2006-01-02 15:04:05", string(s)) + if err != nil { + return err + } + *dest = NullTime(t) + return nil + case nil: + default: + return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest) + } + return nil +}