diff --git a/core/db.go b/core/db.go index 7fedca73..97103039 100644 --- a/core/db.go +++ b/core/db.go @@ -8,17 +8,44 @@ import ( "sync" ) +var ( + ErrNoMapPointer = errors.New("mp should be a map's pointer") + ErrNoStructPointer = errors.New("mp should be a map's pointer") +) + +func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return "", []interface{}{}, ErrNoMapPointer + } + + args := make([]interface{}, 0) + query = re.ReplaceAllStringFunc(query, func(src string) string { + args = append(args, vv.Elem().MapIndex(reflect.ValueOf(src[1:])).Interface()) + return "?" + }) + return query, args, nil +} + +func StructToSlice(query string, st interface{}) (string, []interface{}, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return "", []interface{}{}, ErrNoStructPointer + } + + args := make([]interface{}, 0) + query = re.ReplaceAllStringFunc(query, func(src string) string { + args = append(args, vv.Elem().FieldByName(src[1:]).Interface()) + return "?" + }) + return query, args, nil +} + type DB struct { *sql.DB Mapper IMapper } -type Stmt struct { - *sql.Stmt - Mapper IMapper - names map[string]int -} - func Open(driverName, dataSourceName string) (*DB, error) { db, err := sql.Open(driverName, dataSourceName) return &DB{db, NewCacheMapper(&SnakeMapper{})}, err @@ -30,41 +57,60 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { } func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { - vv := reflect.ValueOf(mp) - if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { - return nil, errors.New("mp should be a map's pointer") + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err } - - args := make([]interface{}, 0) - query = re.ReplaceAllStringFunc(query, func(src string) string { - args = append(args, vv.Elem().MapIndex(reflect.ValueOf(src[1:])).Interface()) - return "?" - }) return db.Query(query, args...) } func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) { - vv := reflect.ValueOf(st) - if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { - return nil, errors.New("mp should be a map's pointer") + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err } - - args := make([]interface{}, 0) - query = re.ReplaceAllStringFunc(query, func(src string) string { - args = append(args, vv.Elem().FieldByName(src[1:]).Interface()) - return "?" - }) return db.Query(query, args...) } type Row struct { *sql.Row + // One of these two will be non-nil: + err error // deferred error for easy chaining Mapper IMapper } +func (row *Row) Scan(dest ...interface{}) error { + if row.err != nil { + return row.err + } + return row.Row.Scan(dest...) +} + func (db *DB) QueryRow(query string, args ...interface{}) *Row { row := db.DB.QueryRow(query, args...) - return &Row{row, db.Mapper} + return &Row{row, nil, db.Mapper} +} + +func (db *DB) QueryRowMap(query string, mp interface{}) *Row { + query, args, err := MapToSlice(query, mp) + if err != nil { + return &Row{nil, err, db.Mapper} + } + return db.QueryRow(query, args...) +} + +func (db *DB) QueryRowStruct(query string, st interface{}) *Row { + query, args, err := StructToSlice(query, st) + if err != nil { + return &Row{nil, err, db.Mapper} + } + return db.QueryRow(query, args...) +} + +type Stmt struct { + *sql.Stmt + Mapper IMapper + names map[string]int } func (db *DB) Prepare(query string) (*Stmt, error) { @@ -116,32 +162,18 @@ var ( // insert into (name) values (?) // insert into (name) values (?name) func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) { - vv := reflect.ValueOf(mp) - if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { - return nil, errors.New("mp should be a map's pointer") + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err } - - args := make([]interface{}, 0) - query = re.ReplaceAllStringFunc(query, func(src string) string { - args = append(args, vv.Elem().MapIndex(reflect.ValueOf(src[1:])).Interface()) - return "?" - }) - return db.DB.Exec(query, args...) } func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { - vv := reflect.ValueOf(st) - if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { - return nil, errors.New("mp should be a map's pointer") + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err } - - args := make([]interface{}, 0) - query = re.ReplaceAllStringFunc(query, func(src string) string { - args = append(args, vv.Elem().FieldByName(src[1:]).Interface()) - return "?" - }) - return db.DB.Exec(query, args...) } @@ -368,3 +400,98 @@ func (rs *Rows) ScanMap(dest interface{}) error { return nil }*/ + +type Tx struct { + *sql.Tx + Mapper IMapper +} + +func (db *DB) Begin() (*Tx, error) { + tx, err := db.DB.Begin() + if err != nil { + return nil, err + } + return &Tx{tx, db.Mapper}, nil +} + +func (tx *Tx) Prepare(query string) (*Stmt, error) { + names := make(map[string]int) + var i int + query = re.ReplaceAllStringFunc(query, func(src string) string { + names[src[1:]] = i + i += 1 + return "?" + }) + + stmt, err := tx.Tx.Prepare(query) + if err != nil { + return nil, err + } + return &Stmt{stmt, tx.Mapper, names}, nil +} + +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + // TODO: + return stmt +} + +func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return tx.Tx.Exec(query, args...) +} + +func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return tx.Tx.Exec(query, args...) +} + +func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { + rows, err := tx.Tx.Query(query, args...) + if err != nil { + return nil, err + } + return &Rows{rows, tx.Mapper}, nil +} + +func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return tx.Query(query, args...) +} + +func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return tx.Query(query, args...) +} + +func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { + row := tx.Tx.QueryRow(query, args...) + return &Row{row, nil, tx.Mapper} +} + +func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row { + query, args, err := MapToSlice(query, mp) + if err != nil { + return &Row{nil, err, tx.Mapper} + } + return tx.QueryRow(query, args...) +} + +func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row { + query, args, err := StructToSlice(query, st) + if err != nil { + return &Row{nil, err, tx.Mapper} + } + return tx.QueryRow(query, args...) +} diff --git a/core/db_test.go b/core/db_test.go index 8324b385..9836a589 100644 --- a/core/db_test.go +++ b/core/db_test.go @@ -5,12 +5,14 @@ import ( "fmt" "os" "testing" + "time" _ "github.com/mattn/go-sqlite3" ) var ( - createTableSqlite3 = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, `title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL);" + createTableSqlite3 = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " + + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" ) type User struct { @@ -20,6 +22,7 @@ type User struct { Age float32 Alias string NickName string + Created time.Time } func BenchmarkOriQuery(b *testing.B) { @@ -37,8 +40,8 @@ func BenchmarkOriQuery(b *testing.B) { } for i := 0; i < 50; i++ { - _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", - "xlw", "tester", 1.2, "lunny", "lunny xiao") + _, err = db.Exec("insert into user (name, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) if err != nil { b.Error(err) } @@ -56,7 +59,8 @@ func BenchmarkOriQuery(b *testing.B) { var Id int64 var Name, Title, Alias, NickName string var Age float32 - err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName) + var Created time.Time + err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created) if err != nil { b.Error(err) } @@ -81,8 +85,8 @@ func BenchmarkStructQuery(b *testing.B) { } for i := 0; i < 50; i++ { - _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", - "xlw", "tester", 1.2, "lunny", "lunny xiao") + _, err = db.Exec("insert into user (name, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) if err != nil { b.Error(err) } @@ -126,8 +130,8 @@ func BenchmarkStruct2Query(b *testing.B) { } for i := 0; i < 50; i++ { - _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", - "xlw", "tester", 1.2, "lunny", "lunny xiao") + _, err = db.Exec("insert into user (name, title, age, alias, nick_name, created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) if err != nil { b.Error(err) } @@ -172,8 +176,8 @@ func BenchmarkSliceInterfaceQuery(b *testing.B) { } for i := 0; i < 50; i++ { - _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", - "xlw", "tester", 1.2, "lunny", "lunny xiao") + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) if err != nil { b.Error(err) } @@ -207,7 +211,8 @@ func BenchmarkSliceInterfaceQuery(b *testing.B) { rows.Close() } } -func BenchmarkSliceBytesQuery(b *testing.B) { + +/*func BenchmarkSliceBytesQuery(b *testing.B) { b.StopTimer() os.Remove("./test.db") db, err := Open("sqlite3", "./test.db") @@ -222,8 +227,8 @@ func BenchmarkSliceBytesQuery(b *testing.B) { } for i := 0; i < 50; i++ { - _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", - "xlw", "tester", 1.2, "lunny", "lunny xiao") + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) if err != nil { b.Error(err) } @@ -273,8 +278,8 @@ func BenchmarkSliceStringQuery(b *testing.B) { } for i := 0; i < 50; i++ { - _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", - "xlw", "tester", 1.2, "lunny", "lunny xiao") + _, err = db.Exec("insert into user (name, title, age, alias, nick_name, created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) if err != nil { b.Error(err) } @@ -307,7 +312,7 @@ func BenchmarkSliceStringQuery(b *testing.B) { rows.Close() } -} +}*/ func BenchmarkMapInterfaceQuery(b *testing.B) { b.StopTimer() @@ -324,8 +329,8 @@ func BenchmarkMapInterfaceQuery(b *testing.B) { } for i := 0; i < 50; i++ { - _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", - "xlw", "tester", 1.2, "lunny", "lunny xiao") + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) if err != nil { b.Error(err) } @@ -355,7 +360,7 @@ func BenchmarkMapInterfaceQuery(b *testing.B) { } } -func BenchmarkMapBytesQuery(b *testing.B) { +/*func BenchmarkMapBytesQuery(b *testing.B) { b.StopTimer() os.Remove("./test.db") db, err := Open("sqlite3", "./test.db") @@ -370,8 +375,8 @@ func BenchmarkMapBytesQuery(b *testing.B) { } for i := 0; i < 50; i++ { - _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", - "xlw", "tester", 1.2, "lunny", "lunny xiao") + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) if err != nil { b.Error(err) } @@ -416,8 +421,8 @@ func BenchmarkMapStringQuery(b *testing.B) { } for i := 0; i < 50; i++ { - _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", - "xlw", "tester", 1.2, "lunny", "lunny xiao") + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) if err != nil { b.Error(err) } @@ -445,7 +450,7 @@ func BenchmarkMapStringQuery(b *testing.B) { rows.Close() } -} +}*/ func BenchmarkExec(b *testing.B) { b.StopTimer() @@ -464,8 +469,8 @@ func BenchmarkExec(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", - "xlw", "tester", 1.2, "lunny", "lunny xiao") + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) if err != nil { b.Error(err) } @@ -494,11 +499,12 @@ func BenchmarkExecMap(b *testing.B) { "age": 1.2, "alias": "lunny", "nick_name": "lunny xiao", + "created": time.Now(), } for i := 0; i < b.N; i++ { - _, err = db.ExecMap(`insert into user (name, title, age, alias, nick_name) - values (?name,?title,?age,?alias,?nick_name)`, + _, err = db.ExecMap(`insert into user (name, title, age, alias, nick_name, created) + values (?name,?title,?age,?alias,?nick_name,?created)`, &mp) if err != nil { b.Error(err) @@ -525,10 +531,11 @@ func TestExecMap(t *testing.T) { "age": 1.2, "alias": "lunny", "nick_name": "lunny xiao", + "created": time.Now(), } - _, err = db.ExecMap(`insert into user (name, title, age, alias, nick_name) - values (?name,?title,?age,?alias,?nick_name)`, + _, err = db.ExecMap(`insert into user (name, title, age, alias, nick_name,created) + values (?name,?title,?age,?alias,?nick_name,?created)`, &mp) if err != nil { t.Error(err) @@ -567,10 +574,11 @@ func TestExecStruct(t *testing.T) { Age: 1.2, Alias: "lunny", NickName: "lunny xiao", + Created: time.Now(), } - _, err = db.ExecStruct(`insert into user (name, title, age, alias, nick_name) - values (?Name,?Title,?Age,?Alias,?NickName)`, + _, err = db.ExecStruct(`insert into user (name, title, age, alias, nick_name,created) + values (?Name,?Title,?Age,?Alias,?NickName,?Created)`, &user) if err != nil { t.Error(err) @@ -612,11 +620,12 @@ func BenchmarkExecStruct(b *testing.B) { Age: 1.2, Alias: "lunny", NickName: "lunny xiao", + Created: time.Now(), } for i := 0; i < b.N; i++ { - _, err = db.ExecStruct(`insert into user (name, title, age, alias, nick_name) - values (?Name,?Title,?Age,?Alias,?NickName)`, + _, err = db.ExecStruct(`insert into user (name, title, age, alias, nick_name,created) + values (?Name,?Title,?Age,?Alias,?NickName,?Created)`, &user) if err != nil { b.Error(err)