xorm/core/db.go

183 lines
3.7 KiB
Go
Raw Normal View History

2014-01-07 09:33:27 +00:00
package core
import (
"database/sql"
"errors"
2014-01-07 09:33:27 +00:00
"reflect"
"sync"
2014-01-07 09:33:27 +00:00
)
type DB struct {
*sql.DB
Mapper IMapper
2014-01-07 09:33:27 +00:00
}
func Open(driverName, dataSourceName string) (*DB, error) {
db, err := sql.Open(driverName, dataSourceName)
return &DB{db, NewCacheMapper(&SnakeMapper{})}, err
2014-01-07 09:33:27 +00:00
}
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
rows, err := db.DB.Query(query, args...)
return &Rows{rows, db.Mapper}, err
2014-01-07 09:33:27 +00:00
}
type Rows struct {
*sql.Rows
Mapper IMapper
2014-01-07 09:33:27 +00:00
}
// scan data to a struct's pointer according field index
2014-01-25 14:29:40 +00:00
func (rs *Rows) ScanStruct(dest ...interface{}) error {
if len(dest) == 0 {
return errors.New("at least one struct")
}
2014-01-25 14:29:40 +00:00
vvvs := make([]reflect.Value, len(dest))
for i, s := range dest {
vv := reflect.ValueOf(s)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return errors.New("dest should be a struct's pointer")
}
2014-01-25 14:29:40 +00:00
vvvs[i] = vv.Elem()
}
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
var i = 0
for _, vvv := range vvvs {
for j := 0; j < vvv.NumField(); j++ {
newDest[i] = vvv.Field(j).Addr().Interface()
i = i + 1
}
}
return rs.Rows.Scan(newDest...)
}
type EmptyScanner struct {
}
func (EmptyScanner) Scan(src interface{}) error {
return nil
}
var (
fieldCache = make(map[reflect.Type]map[string]int)
fieldCacheMutex sync.RWMutex
)
// scan data to a struct's pointer according field name
func (rs *Rows) ScanStruct2(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return errors.New("dest should be a struct's pointer")
}
cols, err := rs.Columns()
if err != nil {
return err
}
vvv := vv.Elem()
t := vvv.Type()
fieldCacheMutex.RLock()
cache, ok := fieldCache[t]
fieldCacheMutex.RUnlock()
if !ok {
cache = make(map[string]int)
for i := 0; i < vvv.NumField(); i++ {
2014-01-27 13:01:52 +00:00
cache[rs.Mapper.Obj2Table(vvv.Type().Field(i).Name)] = i
}
fieldCacheMutex.Lock()
fieldCache[t] = cache
fieldCacheMutex.Unlock()
}
newDest := make([]interface{}, len(cols))
var v EmptyScanner
for j, name := range cols {
2014-01-27 13:01:52 +00:00
if i, ok := cache[name]; ok {
newDest[j] = vvv.Field(i).Addr().Interface()
} else {
newDest[j] = &v
2014-01-07 09:33:27 +00:00
}
}
return rs.Rows.Scan(newDest...)
}
// scan data to a slice's pointer, slice's length should equal to columns' number
func (rs *Rows) ScanSlice(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice {
return errors.New("dest should be a slice's pointer")
}
vvv := vv.Elem()
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
for j := 0; j < len(cols); j++ {
if j >= vvv.Len() {
newDest[j] = reflect.New(vvv.Type().Elem()).Interface()
} else {
newDest[j] = vvv.Index(j).Addr().Interface()
}
}
err = rs.Rows.Scan(newDest...)
if err != nil {
return err
}
2014-01-27 13:01:52 +00:00
srcLen := vvv.Len()
for i := srcLen; i < len(cols); i++ {
vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem())
}
return nil
}
// scan data to a map's pointer
func (rs *Rows) ScanMap(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return errors.New("dest should be a map's pointer")
}
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
vvv := vv.Elem()
for i, _ := range cols {
v := reflect.New(vvv.Type().Elem())
newDest[i] = v.Interface()
}
err = rs.Rows.Scan(newDest...)
if err != nil {
return err
}
for i, name := range cols {
vname := reflect.ValueOf(name)
vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
}
return nil
}