diff --git a/core/db.go b/core/db.go index 21d276f0..7fedca73 100644 --- a/core/db.go +++ b/core/db.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "reflect" + "regexp" "sync" ) @@ -12,6 +13,12 @@ type DB struct { Mapper IMapper } +type Stmt struct { + *sql.Stmt + Mapper IMapper + names map[string]int +} + func Open(driverName, dataSourceName string) (*DB, error) { db, err := sql.Open(driverName, dataSourceName) return &DB{db, NewCacheMapper(&SnakeMapper{})}, err @@ -22,6 +29,122 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { return &Rows{rows, db.Mapper}, err } +func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, 0) + query = re.ReplaceAllStringFunc(query, func(src string) string { + args = append(args, vv.Elem().MapIndex(reflect.ValueOf(src[1:])).Interface()) + return "?" + }) + return db.Query(query, args...) +} + +func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, 0) + query = re.ReplaceAllStringFunc(query, func(src string) string { + args = append(args, vv.Elem().FieldByName(src[1:]).Interface()) + return "?" + }) + return db.Query(query, args...) +} + +type Row struct { + *sql.Row + Mapper IMapper +} + +func (db *DB) QueryRow(query string, args ...interface{}) *Row { + row := db.DB.QueryRow(query, args...) + return &Row{row, db.Mapper} +} + +func (db *DB) Prepare(query string) (*Stmt, error) { + names := make(map[string]int) + var i int + query = re.ReplaceAllStringFunc(query, func(src string) string { + names[src[1:]] = i + i += 1 + return "?" + }) + + stmt, err := db.DB.Prepare(query) + if err != nil { + return nil, err + } + return &Stmt{stmt, db.Mapper, names}, nil +} + +func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface() + } + return s.Stmt.Exec(args...) +} + +func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().FieldByName(k).Interface() + } + return s.Stmt.Exec(args...) +} + +var ( + re = regexp.MustCompile(`[?](\w+)`) +) + +// insert into (name) values (?) +// insert into (name) values (?name) +func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, 0) + query = re.ReplaceAllStringFunc(query, func(src string) string { + args = append(args, vv.Elem().MapIndex(reflect.ValueOf(src[1:])).Interface()) + return "?" + }) + + return db.DB.Exec(query, args...) +} + +func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, 0) + query = re.ReplaceAllStringFunc(query, func(src string) string { + args = append(args, vv.Elem().FieldByName(src[1:]).Interface()) + return "?" + }) + + return db.DB.Exec(query, args...) +} + type Rows struct { *sql.Rows Mapper IMapper @@ -72,6 +195,28 @@ var ( fieldCacheMutex sync.RWMutex ) +func fieldByName(v reflect.Value, name string) reflect.Value { + t := v.Type() + fieldCacheMutex.RLock() + cache, ok := fieldCache[t] + fieldCacheMutex.RUnlock() + if !ok { + cache = make(map[string]int) + for i := 0; i < v.NumField(); i++ { + cache[t.Field(i).Name] = i + } + fieldCacheMutex.Lock() + fieldCache[t] = cache + fieldCacheMutex.Unlock() + } + + if i, ok := cache[name]; ok { + return v.Field(i) + } + + return reflect.Zero(t) +} + // scan data to a struct's pointer according field name func (rs *Rows) ScanStruct2(dest interface{}) error { vv := reflect.ValueOf(dest) @@ -84,27 +229,12 @@ func (rs *Rows) ScanStruct2(dest interface{}) error { return err } - vvv := vv.Elem() - t := vvv.Type() - - fieldCacheMutex.RLock() - cache, ok := fieldCache[t] - fieldCacheMutex.RUnlock() - if !ok { - cache = make(map[string]int) - for i := 0; i < vvv.NumField(); i++ { - cache[rs.Mapper.Obj2Table(vvv.Type().Field(i).Name)] = i - } - fieldCacheMutex.Lock() - fieldCache[t] = cache - fieldCacheMutex.Unlock() - } - newDest := make([]interface{}, len(cols)) var v EmptyScanner for j, name := range cols { - if i, ok := cache[name]; ok { - newDest[j] = vvv.Field(i).Addr().Interface() + f := fieldByName(vv.Elem(), rs.Mapper.Table2Obj(name)) + if f.IsValid() { + newDest[j] = f.Addr().Interface() } else { newDest[j] = &v } @@ -113,6 +243,36 @@ func (rs *Rows) ScanStruct2(dest interface{}) error { return rs.Rows.Scan(newDest...) } +type cacheStruct struct { + value reflect.Value + idx int +} + +var ( + reflectCache = make(map[reflect.Type]*cacheStruct) + reflectCacheMutex sync.RWMutex +) + +func ReflectNew(typ reflect.Type) reflect.Value { + reflectCacheMutex.RLock() + cs, ok := reflectCache[typ] + reflectCacheMutex.RUnlock() + + const newSize = 200 + + if !ok || cs.idx+1 > newSize-1 { + cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), newSize, newSize), 0} + reflectCacheMutex.Lock() + reflectCache[typ] = cs + reflectCacheMutex.Unlock() + } else { + reflectCacheMutex.Lock() + cs.idx = cs.idx + 1 + reflectCacheMutex.Unlock() + } + return cs.value.Index(cs.idx).Addr() +} + // scan data to a slice's pointer, slice's length should equal to columns' number func (rs *Rows) ScanSlice(dest interface{}) error { vv := reflect.ValueOf(dest) @@ -164,8 +324,9 @@ func (rs *Rows) ScanMap(dest interface{}) error { vvv := vv.Elem() for i, _ := range cols { - v := reflect.New(vvv.Type().Elem()) - newDest[i] = v.Interface() + newDest[i] = ReflectNew(vvv.Type().Elem()).Interface() + //v := reflect.New(vvv.Type().Elem()) + //newDest[i] = v.Interface() } err = rs.Rows.Scan(newDest...) @@ -180,3 +341,30 @@ func (rs *Rows) ScanMap(dest interface{}) error { return nil } + +/*func (rs *Rows) ScanMap(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return errors.New("dest should be a map's pointer") + } + + cols, err := rs.Columns() + if err != nil { + return err + } + + newDest := make([]interface{}, len(cols)) + err = rs.ScanSlice(newDest) + if err != nil { + return err + } + + vvv := vv.Elem() + + for i, name := range cols { + vname := reflect.ValueOf(name) + vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem()) + } + + return nil +}*/ diff --git a/core/db_test.go b/core/db_test.go index 2d0e85c5..8324b385 100644 --- a/core/db_test.go +++ b/core/db_test.go @@ -29,6 +29,7 @@ func BenchmarkOriQuery(b *testing.B) { if err != nil { b.Error(err) } + defer db.Close() _, err = db.Exec(createTableSqlite3) if err != nil { @@ -72,6 +73,7 @@ func BenchmarkStructQuery(b *testing.B) { if err != nil { b.Error(err) } + defer db.Close() _, err = db.Exec(createTableSqlite3) if err != nil { @@ -116,6 +118,7 @@ func BenchmarkStruct2Query(b *testing.B) { if err != nil { b.Error(err) } + defer db.Close() _, err = db.Exec(createTableSqlite3) if err != nil { @@ -154,13 +157,14 @@ func BenchmarkStruct2Query(b *testing.B) { } } -func BenchmarkSliceQuery(b *testing.B) { +func BenchmarkSliceInterfaceQuery(b *testing.B) { b.StopTimer() os.Remove("./test.db") db, err := Open("sqlite3", "./test.db") if err != nil { b.Error(err) } + defer db.Close() _, err = db.Exec(createTableSqlite3) if err != nil { @@ -203,6 +207,107 @@ func BenchmarkSliceQuery(b *testing.B) { rows.Close() } } +func BenchmarkSliceBytesQuery(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSqlite3) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao") + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + cols, err := rows.Columns() + if err != nil { + b.Error(err) + } + + for rows.Next() { + slice := make([][]byte, len(cols)) + err = rows.ScanSlice(&slice) + if err != nil { + b.Error(err) + } + if string(slice[1]) != "xlw" { + fmt.Println(slice) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +} + +func BenchmarkSliceStringQuery(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSqlite3) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao") + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + cols, err := rows.Columns() + if err != nil { + b.Error(err) + } + + for rows.Next() { + slice := make([]string, len(cols)) + err = rows.ScanSlice(&slice) + if err != nil { + b.Error(err) + } + if slice[1] != "xlw" { + fmt.Println(slice) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +} func BenchmarkMapInterfaceQuery(b *testing.B) { b.StopTimer() @@ -211,6 +316,7 @@ func BenchmarkMapInterfaceQuery(b *testing.B) { if err != nil { b.Error(err) } + defer db.Close() _, err = db.Exec(createTableSqlite3) if err != nil { @@ -256,6 +362,7 @@ func BenchmarkMapBytesQuery(b *testing.B) { if err != nil { b.Error(err) } + defer db.Close() _, err = db.Exec(createTableSqlite3) if err != nil { @@ -301,6 +408,7 @@ func BenchmarkMapStringQuery(b *testing.B) { if err != nil { b.Error(err) } + defer db.Close() _, err = db.Exec(createTableSqlite3) if err != nil { @@ -338,3 +446,180 @@ func BenchmarkMapStringQuery(b *testing.B) { rows.Close() } } + +func BenchmarkExec(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSqlite3) + if err != nil { + b.Error(err) + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao") + if err != nil { + b.Error(err) + } + } +} + +func BenchmarkExecMap(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSqlite3) + if err != nil { + b.Error(err) + } + + b.StartTimer() + + mp := map[string]interface{}{ + "name": "xlw", + "title": "tester", + "age": 1.2, + "alias": "lunny", + "nick_name": "lunny xiao", + } + + for i := 0; i < b.N; i++ { + _, err = db.ExecMap(`insert into user (name, title, age, alias, nick_name) + values (?name,?title,?age,?alias,?nick_name)`, + &mp) + if err != nil { + b.Error(err) + } + } +} + +func TestExecMap(t *testing.T) { + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + t.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSqlite3) + if err != nil { + t.Error(err) + } + + mp := map[string]interface{}{ + "name": "xlw", + "title": "tester", + "age": 1.2, + "alias": "lunny", + "nick_name": "lunny xiao", + } + + _, err = db.ExecMap(`insert into user (name, title, age, alias, nick_name) + values (?name,?title,?age,?alias,?nick_name)`, + &mp) + if err != nil { + t.Error(err) + } + + rows, err := db.Query("select * from user") + if err != nil { + t.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStruct2(&user) + if err != nil { + t.Error(err) + } + fmt.Println("--", user) + } +} + +func TestExecStruct(t *testing.T) { + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + t.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSqlite3) + if err != nil { + t.Error(err) + } + + user := User{Name: "xlw", + Title: "tester", + Age: 1.2, + Alias: "lunny", + NickName: "lunny xiao", + } + + _, err = db.ExecStruct(`insert into user (name, title, age, alias, nick_name) + values (?Name,?Title,?Age,?Alias,?NickName)`, + &user) + if err != nil { + t.Error(err) + } + + rows, err := db.QueryStruct("select * from user where name = ?Name", &user) + if err != nil { + t.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStruct2(&user) + if err != nil { + t.Error(err) + } + fmt.Println("1--", user) + } +} + +func BenchmarkExecStruct(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSqlite3) + if err != nil { + b.Error(err) + } + + b.StartTimer() + + user := User{Name: "xlw", + Title: "tester", + Age: 1.2, + Alias: "lunny", + NickName: "lunny xiao", + } + + for i := 0; i < b.N; i++ { + _, err = db.ExecStruct(`insert into user (name, title, age, alias, nick_name) + values (?Name,?Title,?Age,?Alias,?NickName)`, + &user) + if err != nil { + b.Error(err) + } + } +}