move tests to tests subdir & refactoring

This commit is contained in:
Lunny Xiao 2014-01-25 10:07:11 +08:00
parent 21c08ea09c
commit e77fca31ae
22 changed files with 721 additions and 328 deletions

View File

@ -1,107 +1,21 @@
package xorm //LRUCacher implements Cacher according to LRU algorithm
package caches
import ( import (
"container/list" "container/list"
"errors"
"fmt" "fmt"
"strconv"
"strings"
"sync" "sync"
"time" "time"
"github.com/lunny/xorm/core"
) )
const (
// default cache expired time
CacheExpired = 60 * time.Minute
// not use now
CacheMaxMemory = 256
// evey ten minutes to clear all expired nodes
CacheGcInterval = 10 * time.Minute
// each time when gc to removed max nodes
CacheGcMaxRemoved = 20
)
// CacheStore is a interface to store cache
type CacheStore interface {
Put(key, value interface{}) error
Get(key interface{}) (interface{}, error)
Del(key interface{}) error
}
// MemoryStore implements CacheStore provide local machine
// memory store
type MemoryStore struct {
store map[interface{}]interface{}
mutex sync.RWMutex
}
func NewMemoryStore() *MemoryStore {
return &MemoryStore{store: make(map[interface{}]interface{})}
}
func (s *MemoryStore) Put(key, value interface{}) error {
s.mutex.Lock()
defer s.mutex.Unlock()
s.store[key] = value
return nil
}
func (s *MemoryStore) Get(key interface{}) (interface{}, error) {
s.mutex.RLock()
defer s.mutex.RUnlock()
if v, ok := s.store[key]; ok {
return v, nil
}
return nil, ErrNotExist
}
func (s *MemoryStore) Del(key interface{}) error {
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.store, key)
return nil
}
// Cacher is an interface to provide cache
type Cacher interface {
GetIds(tableName, sql string) interface{}
GetBean(tableName string, id int64) interface{}
PutIds(tableName, sql string, ids interface{})
PutBean(tableName string, id int64, obj interface{})
DelIds(tableName, sql string)
DelBean(tableName string, id int64)
ClearIds(tableName string)
ClearBeans(tableName string)
}
type idNode struct {
tbName string
id int64
lastVisit time.Time
}
type sqlNode struct {
tbName string
sql string
lastVisit time.Time
}
func newIdNode(tbName string, id int64) *idNode {
return &idNode{tbName, id, time.Now()}
}
func newSqlNode(tbName, sql string) *sqlNode {
return &sqlNode{tbName, sql, time.Now()}
}
// LRUCacher implements Cacher according to LRU algorithm
type LRUCacher struct { type LRUCacher struct {
idList *list.List idList *list.List
sqlList *list.List sqlList *list.List
idIndex map[string]map[interface{}]*list.Element idIndex map[string]map[string]*list.Element
sqlIndex map[string]map[interface{}]*list.Element sqlIndex map[string]map[string]*list.Element
store CacheStore store core.CacheStore
Max int Max int
mutex sync.Mutex mutex sync.Mutex
Expired time.Duration Expired time.Duration
@ -109,25 +23,17 @@ type LRUCacher struct {
GcInterval time.Duration GcInterval time.Duration
} }
func newLRUCacher(store CacheStore, expired time.Duration, maxSize int, max int) *LRUCacher { func NewLRUCacher(store core.CacheStore, expired time.Duration, maxSize int, max int) *LRUCacher {
cacher := &LRUCacher{store: store, idList: list.New(), cacher := &LRUCacher{store: store, idList: list.New(),
sqlList: list.New(), Expired: expired, maxSize: maxSize, sqlList: list.New(), Expired: expired, maxSize: maxSize,
GcInterval: CacheGcInterval, Max: max, GcInterval: core.CacheGcInterval, Max: max,
sqlIndex: make(map[string]map[interface{}]*list.Element), sqlIndex: make(map[string]map[string]*list.Element),
idIndex: make(map[string]map[interface{}]*list.Element), idIndex: make(map[string]map[string]*list.Element),
} }
cacher.RunGC() cacher.RunGC()
return cacher return cacher
} }
func NewLRUCacher(store CacheStore, max int) *LRUCacher {
return newLRUCacher(store, CacheExpired, CacheMaxMemory, max)
}
func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher {
return newLRUCacher(store, expired, 0, max)
}
//func NewLRUCacher3(store CacheStore, expired time.Duration, maxSize int) *LRUCacher { //func NewLRUCacher3(store CacheStore, expired time.Duration, maxSize int) *LRUCacher {
// return newLRUCacher(store, expired, maxSize, 0) // return newLRUCacher(store, expired, maxSize, 0)
//} //}
@ -148,7 +54,7 @@ func (m *LRUCacher) GC() {
defer m.mutex.Unlock() defer m.mutex.Unlock()
var removedNum int var removedNum int
for e := m.idList.Front(); e != nil; { for e := m.idList.Front(); e != nil; {
if removedNum <= CacheGcMaxRemoved && if removedNum <= core.CacheGcMaxRemoved &&
time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired {
removedNum++ removedNum++
next := e.Next() next := e.Next()
@ -164,7 +70,7 @@ func (m *LRUCacher) GC() {
removedNum = 0 removedNum = 0
for e := m.sqlList.Front(); e != nil; { for e := m.sqlList.Front(); e != nil; {
if removedNum <= CacheGcMaxRemoved && if removedNum <= core.CacheGcMaxRemoved &&
time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired {
removedNum++ removedNum++
next := e.Next() next := e.Next()
@ -184,7 +90,7 @@ func (m *LRUCacher) GetIds(tableName, sql string) interface{} {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if _, ok := m.sqlIndex[tableName]; !ok { if _, ok := m.sqlIndex[tableName]; !ok {
m.sqlIndex[tableName] = make(map[interface{}]*list.Element) m.sqlIndex[tableName] = make(map[string]*list.Element)
} }
if v, err := m.store.Get(sql); err == nil { if v, err := m.store.Get(sql); err == nil {
if el, ok := m.sqlIndex[tableName][sql]; !ok { if el, ok := m.sqlIndex[tableName][sql]; !ok {
@ -209,11 +115,11 @@ func (m *LRUCacher) GetIds(tableName, sql string) interface{} {
} }
// Get bean according tableName and id from cache // Get bean according tableName and id from cache
func (m *LRUCacher) GetBean(tableName string, id int64) interface{} { func (m *LRUCacher) GetBean(tableName string, id string) interface{} {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if _, ok := m.idIndex[tableName]; !ok { if _, ok := m.idIndex[tableName]; !ok {
m.idIndex[tableName] = make(map[interface{}]*list.Element) m.idIndex[tableName] = make(map[string]*list.Element)
} }
tid := genId(tableName, id) tid := genId(tableName, id)
if v, err := m.store.Get(tid); err == nil { if v, err := m.store.Get(tid); err == nil {
@ -248,7 +154,7 @@ func (m *LRUCacher) clearIds(tableName string) {
m.store.Del(sql) m.store.Del(sql)
} }
} }
m.sqlIndex[tableName] = make(map[interface{}]*list.Element) m.sqlIndex[tableName] = make(map[string]*list.Element)
} }
func (m *LRUCacher) ClearIds(tableName string) { func (m *LRUCacher) ClearIds(tableName string) {
@ -261,11 +167,11 @@ func (m *LRUCacher) clearBeans(tableName string) {
if tis, ok := m.idIndex[tableName]; ok { if tis, ok := m.idIndex[tableName]; ok {
for id, v := range tis { for id, v := range tis {
m.idList.Remove(v) m.idList.Remove(v)
tid := genId(tableName, id.(int64)) tid := genId(tableName, id)
m.store.Del(tid) m.store.Del(tid)
} }
} }
m.idIndex[tableName] = make(map[interface{}]*list.Element) m.idIndex[tableName] = make(map[string]*list.Element)
} }
func (m *LRUCacher) ClearBeans(tableName string) { func (m *LRUCacher) ClearBeans(tableName string) {
@ -278,7 +184,7 @@ func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if _, ok := m.sqlIndex[tableName]; !ok { if _, ok := m.sqlIndex[tableName]; !ok {
m.sqlIndex[tableName] = make(map[interface{}]*list.Element) m.sqlIndex[tableName] = make(map[string]*list.Element)
} }
if el, ok := m.sqlIndex[tableName][sql]; !ok { if el, ok := m.sqlIndex[tableName][sql]; !ok {
el = m.sqlList.PushBack(newSqlNode(tableName, sql)) el = m.sqlList.PushBack(newSqlNode(tableName, sql))
@ -294,7 +200,7 @@ func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) {
} }
} }
func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) { func (m *LRUCacher) PutBean(tableName string, id string, obj interface{}) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var el *list.Element var el *list.Element
@ -331,7 +237,7 @@ func (m *LRUCacher) DelIds(tableName, sql string) {
m.delIds(tableName, sql) m.delIds(tableName, sql)
} }
func (m *LRUCacher) delBean(tableName string, id int64) { func (m *LRUCacher) delBean(tableName string, id string) {
tid := genId(tableName, id) tid := genId(tableName, id)
if el, ok := m.idIndex[tableName][id]; ok { if el, ok := m.idIndex[tableName][id]; ok {
delete(m.idIndex[tableName], id) delete(m.idIndex[tableName], id)
@ -341,55 +247,36 @@ func (m *LRUCacher) delBean(tableName string, id int64) {
m.store.Del(tid) m.store.Del(tid)
} }
func (m *LRUCacher) DelBean(tableName string, id int64) { func (m *LRUCacher) DelBean(tableName string, id string) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.delBean(tableName, id) m.delBean(tableName, id)
} }
func encodeIds(ids []int64) (s string) { type idNode struct {
s = "[" tbName string
for _, id := range ids { id string
s += fmt.Sprintf("%v,", id) lastVisit time.Time
}
s = s[:len(s)-1] + "]"
return
} }
func decodeIds(s string) []int64 { type sqlNode struct {
res := make([]int64, 0) tbName string
if len(s) >= 2 { sql string
ss := strings.Split(s[1:len(s)-1], ",") lastVisit time.Time
for _, s := range ss {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return res
}
res = append(res, i)
}
}
return res
}
func getCacheSql(m Cacher, tableName, sql string, args interface{}) ([]int64, error) {
bytes := m.GetIds(tableName, genSqlKey(sql, args))
if bytes == nil {
return nil, errors.New("Not Exist")
}
objs := decodeIds(bytes.(string))
return objs, nil
}
func putCacheSql(m Cacher, ids []int64, tableName, sql string, args interface{}) error {
bytes := encodeIds(ids)
m.PutIds(tableName, genSqlKey(sql, args), bytes)
return nil
} }
func genSqlKey(sql string, args interface{}) string { func genSqlKey(sql string, args interface{}) string {
return fmt.Sprintf("%v-%v", sql, args) return fmt.Sprintf("%v-%v", sql, args)
} }
func genId(prefix string, id int64) string { func genId(prefix string, id string) string {
return fmt.Sprintf("%v-%v", prefix, id) return fmt.Sprintf("%v-%v", prefix, id)
} }
func newIdNode(tbName string, id string) *idNode {
return &idNode{tbName, id, time.Now()}
}
func newSqlNode(tbName, sql string) *sqlNode {
return &sqlNode{tbName, sql, time.Now()}
}

49
caches/memoryStore.go Normal file
View File

@ -0,0 +1,49 @@
// MemoryStore implements CacheStore provide local machine
package caches
import (
"errors"
"sync"
"github.com/lunny/xorm/core"
)
var (
ErrNotExist = errors.New("key not exist")
)
var _ core.CacheStore = NewMemoryStore()
// memory store
type MemoryStore struct {
store map[interface{}]interface{}
mutex sync.RWMutex
}
func NewMemoryStore() *MemoryStore {
return &MemoryStore{store: make(map[interface{}]interface{})}
}
func (s *MemoryStore) Put(key string, value interface{}) error {
s.mutex.Lock()
defer s.mutex.Unlock()
s.store[key] = value
return nil
}
func (s *MemoryStore) Get(key string) (interface{}, error) {
s.mutex.RLock()
defer s.mutex.RUnlock()
if v, ok := s.store[key]; ok {
return v, nil
}
return nil, ErrNotExist
}
func (s *MemoryStore) Del(key string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.store, key)
return nil
}

77
core/cache.go Normal file
View File

@ -0,0 +1,77 @@
package core
import (
"encoding/json"
"errors"
"fmt"
"time"
)
const (
// default cache expired time
CacheExpired = 60 * time.Minute
// not use now
CacheMaxMemory = 256
// evey ten minutes to clear all expired nodes
CacheGcInterval = 10 * time.Minute
// each time when gc to removed max nodes
CacheGcMaxRemoved = 20
)
// CacheStore is a interface to store cache
type CacheStore interface {
// key is primary key or composite primary key or unique key's value
// value is struct's pointer
// key format : <tablename>-p-<pk1>-<pk2>...
Put(key string, value interface{}) error
Get(key string) (interface{}, error)
Del(key string) error
}
// Cacher is an interface to provide cache
// id format : u-<pk1>-<pk2>...
type Cacher interface {
GetIds(tableName, sql string) interface{}
GetBean(tableName string, id string) interface{}
PutIds(tableName, sql string, ids interface{})
PutBean(tableName string, id string, obj interface{})
DelIds(tableName, sql string)
DelBean(tableName string, id string)
ClearIds(tableName string)
ClearBeans(tableName string)
}
func encodeIds(ids []PK) (string, error) {
b, err := json.Marshal(ids)
if err != nil {
return "", err
}
return string(b), nil
}
func decodeIds(s string) ([]PK, error) {
pks := make([]PK, 0)
err := json.Unmarshal([]byte(s), &pks)
return pks, err
}
func GetCacheSql(m Cacher, tableName, sql string, args interface{}) ([]PK, error) {
bytes := m.GetIds(tableName, GenSqlKey(sql, args))
if bytes == nil {
return nil, errors.New("Not Exist")
}
return decodeIds(bytes.(string))
}
func PutCacheSql(m Cacher, ids []PK, tableName, sql string, args interface{}) error {
bytes, err := encodeIds(ids)
if err != nil {
return err
}
m.PutIds(tableName, GenSqlKey(sql, args), bytes)
return nil
}
func GenSqlKey(sql string, args interface{}) string {
return fmt.Sprintf("%v-%v", sql, args)
}

View File

@ -2,43 +2,138 @@ package core
import ( import (
"database/sql" "database/sql"
"errors"
"reflect" "reflect"
) )
type DB struct { type DB struct {
*sql.DB *sql.DB
Mapper IMapper
} }
func Open(driverName, dataSourceName string) (*DB, error) { func Open(driverName, dataSourceName string) (*DB, error) {
db, err := sql.Open(driverName, dataSourceName) db, err := sql.Open(driverName, dataSourceName)
return &DB{db}, err return &DB{db, &SnakeMapper{}}, err
} }
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
rows, err := db.DB.Query(query, args...) rows, err := db.DB.Query(query, args...)
return &Rows{rows}, err return &Rows{rows, db.Mapper}, err
} }
type Rows struct { type Rows struct {
*sql.Rows *sql.Rows
Mapper IMapper
}
// scan data to a struct's pointer according field index
func (rs *Rows) ScanStruct(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return errors.New("dest should be a struct's pointer")
} }
func (rs *Rows) Scan(dest ...interface{}) error {
newDest := make([]interface{}, 0)
for _, s := range dest {
vv := reflect.ValueOf(s)
switch vv.Kind() {
case reflect.Ptr:
vvv := vv.Elem() vvv := vv.Elem()
if vvv.Kind() == reflect.Struct { newDest := make([]interface{}, vvv.NumField())
for j := 0; j < vvv.NumField(); j++ { for j := 0; j < vvv.NumField(); j++ {
newDest = append(newDest, vvv.FieldByIndex([]int{j}).Addr().Interface()) newDest[j] = vvv.Field(j).Addr().Interface()
} }
return rs.Rows.Scan(newDest...)
}
// scan data to a struct's pointer according field name
func (rs *Rows) ScanStruct2(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return errors.New("dest should be a struct's pointer")
}
cols, err := rs.Columns()
if err != nil {
return err
}
vvv := vv.Elem()
newDest := make([]interface{}, len(cols))
for j, name := range cols {
f := vvv.FieldByName(rs.Mapper.Table2Obj(name))
if f.IsValid() {
newDest[j] = f.Addr().Interface()
} else { } else {
newDest = append(newDest, s) var v interface{}
} newDest[j] = &v
} }
} }
return rs.Rows.Scan(newDest...) return rs.Rows.Scan(newDest...)
} }
// scan data to a slice's pointer, slice's length should equal to columns' number
func (rs *Rows) ScanSlice(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice {
return errors.New("dest should be a slice's pointer")
}
vvv := vv.Elem()
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
for j := 0; j < len(cols); j++ {
if j >= vvv.Len() {
newDest[j] = reflect.New(vvv.Type().Elem()).Interface()
} else {
newDest[j] = vvv.Index(j).Addr().Interface()
}
}
err = rs.Rows.Scan(newDest...)
if err != nil {
return err
}
for i, _ := range cols {
vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem())
}
return nil
}
// scan data to a map's pointer
func (rs *Rows) ScanMap(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return errors.New("dest should be a map's pointer")
}
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
vvv := vv.Elem()
for i, _ := range cols {
v := reflect.New(vvv.Type().Elem())
newDest[i] = v.Interface()
}
err = rs.Rows.Scan(newDest...)
if err != nil {
return err
}
for i, name := range cols {
vname := reflect.ValueOf(name)
vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
}
return nil
}

View File

@ -2,7 +2,9 @@ package core
import ( import (
"fmt" "fmt"
"os"
"testing" "testing"
"time"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
@ -20,7 +22,8 @@ type User struct {
NickName string NickName string
} }
func TestQuery(t *testing.T) { func TestOriQuery(t *testing.T) {
os.Remove("./test.db")
db, err := Open("sqlite3", "./test.db") db, err := Open("sqlite3", "./test.db")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -31,23 +34,197 @@ func TestQuery(t *testing.T) {
t.Error(err) t.Error(err)
} }
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao") "xlw", "tester", 1.2, "lunny", "lunny xiao")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
}
rows, err := db.Query("select * from user")
if err != nil {
t.Error(err)
}
defer rows.Close()
start := time.Now()
for rows.Next() {
var Id int64
var Name, Title, Alias, NickName string
var Age float32
err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName)
if err != nil {
t.Error(err)
}
fmt.Println(Id, Name, Title, Age, Alias, NickName)
}
fmt.Println("ori ------", time.Now().Sub(start), "ns")
}
func TestStructQuery(t *testing.T) {
os.Remove("./test.db")
db, err := Open("sqlite3", "./test.db")
if err != nil {
t.Error(err)
}
_, err = db.Exec(createTableSqlite3)
if err != nil {
t.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao")
if err != nil {
t.Error(err)
}
}
rows, err := db.Query("select * from user")
if err != nil {
t.Error(err)
}
defer rows.Close()
start := time.Now()
for rows.Next() {
var user User
err = rows.ScanStruct(&user)
if err != nil {
t.Error(err)
}
fmt.Println(user)
}
fmt.Println("struct ------", time.Now().Sub(start))
}
func TestStruct2Query(t *testing.T) {
os.Remove("./test.db")
db, err := Open("sqlite3", "./test.db")
if err != nil {
t.Error(err)
}
_, err = db.Exec(createTableSqlite3)
if err != nil {
t.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao")
if err != nil {
t.Error(err)
}
}
db.Mapper = &SnakeMapper{}
rows, err := db.Query("select * from user")
if err != nil {
t.Error(err)
}
defer rows.Close()
start := time.Now()
for rows.Next() {
var user User
err = rows.ScanStruct2(&user)
if err != nil {
t.Error(err)
}
fmt.Println(user)
}
fmt.Println("struct2 ------", time.Now().Sub(start))
}
func TestSliceQuery(t *testing.T) {
os.Remove("./test.db")
db, err := Open("sqlite3", "./test.db")
if err != nil {
t.Error(err)
}
_, err = db.Exec(createTableSqlite3)
if err != nil {
t.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao")
if err != nil {
t.Error(err)
}
}
rows, err := db.Query("select * from user") rows, err := db.Query("select * from user")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
for rows.Next() { defer rows.Close()
var user User
err = rows.Scan(&user) cols, err := rows.Columns()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
fmt.Println(user)
start := time.Now()
for rows.Next() {
slice := make([]interface{}, len(cols))
err = rows.ScanSlice(&slice)
if err != nil {
t.Error(err)
}
fmt.Println(slice)
}
fmt.Println("slice ------", time.Now().Sub(start))
}
func TestMapQuery(t *testing.T) {
os.Remove("./test.db")
db, err := Open("sqlite3", "./test.db")
if err != nil {
t.Error(err)
}
_, err = db.Exec(createTableSqlite3)
if err != nil {
t.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao")
if err != nil {
t.Error(err)
} }
} }
rows, err := db.Query("select * from user")
if err != nil {
t.Error(err)
}
defer rows.Close()
start := time.Now()
for rows.Next() {
m := make(map[string]interface{})
err = rows.ScanMap(&m)
if err != nil {
t.Error(err)
}
fmt.Println(m)
}
fmt.Println("map ------", time.Now().Sub(start))
}

View File

@ -107,7 +107,7 @@ func (b *Base) CreateTableSql(table *Table, tableName, storeEngine, charset stri
if len(pkList) > 1 { if len(pkList) > 1 {
sql += "PRIMARY KEY ( " sql += "PRIMARY KEY ( "
sql += strings.Join(pkList, ",") sql += b.Quote(strings.Join(pkList, b.Quote(",")))
sql += " ), " sql += " ), "
} }

View File

@ -1,4 +1,4 @@
package xorm package core
import ( import (
"strings" "strings"

25
core/pk.go Normal file
View File

@ -0,0 +1,25 @@
package core
import (
"encoding/json"
)
type PK []interface{}
func NewPK(pks ...interface{}) *PK {
p := PK(pks)
return &p
}
func (p *PK) ToString() (string, error) {
bs, err := json.Marshal(*p)
if err != nil {
return "", nil
}
return string(bs), nil
}
func (p *PK) FromString(content string) error {
return json.Unmarshal([]byte(content), p)
}

22
core/pk_test.go Normal file
View File

@ -0,0 +1,22 @@
package core
import (
"fmt"
"testing"
)
func TestPK(t *testing.T) {
p := NewPK(1, 3, "string")
str, err := p.ToString()
if err != nil {
t.Error(err)
}
fmt.Println(str)
s := &PK{}
err = s.FromString(str)
if err != nil {
t.Error(err)
}
fmt.Println(s)
}

View File

@ -26,7 +26,7 @@ func (db *postgres) SqlType(c *Column) string {
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case TinyInt: case TinyInt:
res = SmallInt res = SmallInt
return res
case MediumInt, Int, Integer: case MediumInt, Int, Integer:
if c.IsAutoIncrement { if c.IsAutoIncrement {
return Serial return Serial

View File

@ -16,13 +16,11 @@ import (
"github.com/lunny/xorm/core" "github.com/lunny/xorm/core"
) )
type PK []interface{}
// Engine is the major struct of xorm, it means a database manager. // Engine is the major struct of xorm, it means a database manager.
// Commonly, an application only need one engine // Commonly, an application only need one engine
type Engine struct { type Engine struct {
ColumnMapper IMapper ColumnMapper core.IMapper
TableMapper IMapper TableMapper core.IMapper
TagIdentifier string TagIdentifier string
DriverName string DriverName string
DataSourceName string DataSourceName string
@ -37,20 +35,20 @@ type Engine struct {
Pool IConnectPool Pool IConnectPool
Filters []core.Filter Filters []core.Filter
Logger io.Writer Logger io.Writer
Cacher Cacher Cacher core.Cacher
tableCachers map[reflect.Type]Cacher tableCachers map[reflect.Type]core.Cacher
} }
func (engine *Engine) SetMapper(mapper IMapper) { func (engine *Engine) SetMapper(mapper core.IMapper) {
engine.SetTableMapper(mapper) engine.SetTableMapper(mapper)
engine.SetColumnMapper(mapper) engine.SetColumnMapper(mapper)
} }
func (engine *Engine) SetTableMapper(mapper IMapper) { func (engine *Engine) SetTableMapper(mapper core.IMapper) {
engine.TableMapper = mapper engine.TableMapper = mapper
} }
func (engine *Engine) SetColumnMapper(mapper IMapper) { func (engine *Engine) SetColumnMapper(mapper core.IMapper) {
engine.ColumnMapper = mapper engine.ColumnMapper = mapper
} }
@ -100,7 +98,7 @@ func (engine *Engine) SetMaxIdleConns(conns int) {
} }
// SetDefaltCacher set the default cacher. Xorm's default not enable cacher. // SetDefaltCacher set the default cacher. Xorm's default not enable cacher.
func (engine *Engine) SetDefaultCacher(cacher Cacher) { func (engine *Engine) SetDefaultCacher(cacher core.Cacher) {
engine.Cacher = cacher engine.Cacher = cacher
} }
@ -119,7 +117,7 @@ func (engine *Engine) NoCascade() *Session {
} }
// Set a table use a special cacher // Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher Cacher) { func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) {
t := rType(bean) t := rType(bean)
engine.autoMapType(t) engine.autoMapType(t)
engine.tableCachers[t] = cacher engine.tableCachers[t] = cacher
@ -409,7 +407,7 @@ func (engine *Engine) mapType(t reflect.Type) *core.Table {
return mappingTable(t, engine.TableMapper, engine.ColumnMapper, engine.dialect, engine.TagIdentifier) return mappingTable(t, engine.TableMapper, engine.ColumnMapper, engine.dialect, engine.TagIdentifier)
} }
func mappingTable(t reflect.Type, tableMapper IMapper, colMapper IMapper, dialect core.Dialect, tagId string) *core.Table { func mappingTable(t reflect.Type, tableMapper core.IMapper, colMapper core.IMapper, dialect core.Dialect, tagId string) *core.Table {
table := core.NewEmptyTable() table := core.NewEmptyTable()
table.Name = tableMapper.Obj2Table(t.Name()) table.Name = tableMapper.Obj2Table(t.Name())
table.Type = t table.Type = t
@ -517,6 +515,7 @@ func mappingTable(t reflect.Type, tableMapper IMapper, colMapper IMapper, dialec
if col.Length2 == 0 { if col.Length2 == 0 {
col.Length2 = col.SQLType.DefaultLength2 col.Length2 = col.SQLType.DefaultLength2
} }
fmt.Println("======", col)
if col.Name == "" { if col.Name == "" {
col.Name = colMapper.Obj2Table(t.Field(i).Name) col.Name = colMapper.Obj2Table(t.Field(i).Name)
} }
@ -613,6 +612,24 @@ func (engine *Engine) IsTableExist(bean interface{}) (bool, error) {
return has, err return has, err
} }
func (engine *Engine) IdOf(bean interface{}) core.PK {
table := engine.autoMap(bean)
v := reflect.Indirect(reflect.ValueOf(bean))
pk := make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
pkField := v.FieldByName(col.FieldName)
switch pkField.Kind() {
case reflect.String:
pk[i] = pkField.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
pk[i] = pkField.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
pk[i] = pkField.Uint()
}
}
return core.PK(pk)
}
// create indexes // create indexes
func (engine *Engine) CreateIndexes(bean interface{}) error { func (engine *Engine) CreateIndexes(bean interface{}) error {
session := engine.NewSession() session := engine.NewSession()
@ -627,7 +644,7 @@ func (engine *Engine) CreateUniques(bean interface{}) error {
return session.CreateUniques(bean) return session.CreateUniques(bean)
} }
func (engine *Engine) getCacher(t reflect.Type) Cacher { func (engine *Engine) getCacher(t reflect.Type) core.Cacher {
if cacher, ok := engine.tableCachers[t]; ok { if cacher, ok := engine.tableCachers[t]; ok {
return cacher return cacher
} }
@ -635,7 +652,7 @@ func (engine *Engine) getCacher(t reflect.Type) Cacher {
} }
// If enabled cache, clear the cache bean // If enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id int64) error { func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
t := rType(bean) t := rType(bean)
if t.Kind() != reflect.Struct { if t.Kind() != reflect.Struct {
return errors.New("error params") return errors.New("error params")

View File

@ -588,6 +588,7 @@ func (statement *Statement) convertIdSql(sqlStr string) string {
if len(sqls) != 2 { if len(sqls) != 2 {
return "" return ""
} }
fmt.Println("-----", col)
newsql := fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()), newsql := fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()),
statement.Engine.Quote(col.Name), sqls[1]) statement.Engine.Quote(col.Name), sqls[1])
return newsql return newsql
@ -612,14 +613,14 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
cacher := session.Engine.getCacher(session.Statement.RefTable.Type) cacher := session.Engine.getCacher(session.Statement.RefTable.Type)
tableName := session.Statement.TableName() tableName := session.Statement.TableName()
session.Engine.LogDebug("[xorm:cacheGet] find sql:", newsql, args) session.Engine.LogDebug("[xorm:cacheGet] find sql:", newsql, args)
ids, err := getCacheSql(cacher, tableName, newsql, args) ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
resultsSlice, err := session.query(newsql, args...) resultsSlice, err := session.query(newsql, args...)
if err != nil { if err != nil {
return false, err return false, err
} }
session.Engine.LogDebug("[xorm:cacheGet] query ids:", resultsSlice) session.Engine.LogDebug("[xorm:cacheGet] query ids:", resultsSlice)
ids = make([]int64, 0) ids = make([]core.PK, 0)
if len(resultsSlice) > 0 { if len(resultsSlice) > 0 {
data := resultsSlice[0] data := resultsSlice[0]
var id int64 var id int64
@ -631,10 +632,10 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return false, err return false, err
} }
} }
ids = append(ids, id) ids = append(ids, core.PK{id})
} }
session.Engine.LogDebug("[xorm:cacheGet] cache ids:", newsql, ids) session.Engine.LogDebug("[xorm:cacheGet] cache ids:", newsql, ids)
err = putCacheSql(cacher, ids, tableName, newsql, args) err = core.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -646,7 +647,11 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
structValue := reflect.Indirect(reflect.ValueOf(bean)) structValue := reflect.Indirect(reflect.ValueOf(bean))
id := ids[0] id := ids[0]
session.Engine.LogDebug("[xorm:cacheGet] get bean:", tableName, id) session.Engine.LogDebug("[xorm:cacheGet] get bean:", tableName, id)
cacheBean := cacher.GetBean(tableName, id) sid, err := id.ToString()
if err != nil {
return false, err
}
cacheBean := cacher.GetBean(tableName, sid)
if cacheBean == nil { if cacheBean == nil {
newSession := session.Engine.NewSession() newSession := session.Engine.NewSession()
defer newSession.Close() defer newSession.Close()
@ -664,7 +669,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
} }
session.Engine.LogDebug("[xorm:cacheGet] cache bean:", tableName, id, cacheBean) session.Engine.LogDebug("[xorm:cacheGet] cache bean:", tableName, id, cacheBean)
cacher.PutBean(tableName, id, cacheBean) cacher.PutBean(tableName, sid, cacheBean)
} else { } else {
session.Engine.LogDebug("[xorm:cacheGet] cached bean:", tableName, id, cacheBean) session.Engine.LogDebug("[xorm:cacheGet] cached bean:", tableName, id, cacheBean)
has = true has = true
@ -695,7 +700,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
table := session.Statement.RefTable table := session.Statement.RefTable
cacher := session.Engine.getCacher(t) cacher := session.Engine.getCacher(t)
ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args) ids, err := core.GetCacheSql(cacher, session.Statement.TableName(), newsql, args)
if err != nil { if err != nil {
//session.Engine.LogError(err) //session.Engine.LogError(err)
resultsSlice, err := session.query(newsql, args...) resultsSlice, err := session.query(newsql, args...)
@ -709,7 +714,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} }
tableName := session.Statement.TableName() tableName := session.Statement.TableName()
ids = make([]int64, 0) ids = make([]core.PK, 0)
if len(resultsSlice) > 0 { if len(resultsSlice) > 0 {
for _, data := range resultsSlice { for _, data := range resultsSlice {
//fmt.Println(data) //fmt.Println(data)
@ -722,11 +727,11 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return err return err
} }
} }
ids = append(ids, id) ids = append(ids, core.PK{id})
} }
} }
session.Engine.LogDebug("[xorm:cacheFind] cache ids:", ids, tableName, newsql, args) session.Engine.LogDebug("[xorm:cacheFind] cache ids:", ids, tableName, newsql, args)
err = putCacheSql(cacher, ids, tableName, newsql, args) err = core.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil { if err != nil {
return err return err
} }
@ -735,34 +740,32 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} }
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
pkFieldName := session.Statement.RefTable.PKColumns()[0].FieldName //pkFieldName := session.Statement.RefTable.PKColumns()[0].FieldName
ididxes := make(map[int64]int) ididxes := make(map[string]int)
var ides []interface{} = make([]interface{}, 0) var ides []core.PK = make([]core.PK, 0)
var temps []interface{} = make([]interface{}, len(ids)) var temps []interface{} = make([]interface{}, len(ids))
tableName := session.Statement.TableName() tableName := session.Statement.TableName()
for idx, id := range ids { for idx, id := range ids {
bean := cacher.GetBean(tableName, id) sid, err := id.ToString()
if err != nil {
return err
}
bean := cacher.GetBean(tableName, sid)
if bean == nil { if bean == nil {
ides = append(ides, id) ides = append(ides, id)
ididxes[id] = idx ididxes[sid] = idx
} else { } else {
session.Engine.LogDebug("[xorm:cacheFind] cached bean:", tableName, id, bean) session.Engine.LogDebug("[xorm:cacheFind] cached bean:", tableName, id, bean)
pkField := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName) pk := session.Engine.IdOf(bean)
xid, err := pk.ToString()
var sid int64 if err != nil {
switch pkField.Type().Kind() { return err
case reflect.Int32, reflect.Int, reflect.Int64:
sid = pkField.Int()
case reflect.Uint, reflect.Uint32, reflect.Uint64:
sid = int64(pkField.Uint())
default:
return ErrCacheFailed
} }
if sid != id { if sid != xid {
session.Engine.LogError("[xorm:cacheFind] error cache", id, sid, bean) session.Engine.LogError("[xorm:cacheFind] error cache", xid, sid, bean)
return ErrCacheFailed return ErrCacheFailed
} }
temps[idx] = bean temps[idx] = bean
@ -777,7 +780,19 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
beans := slices.Interface() beans := slices.Interface()
//beans := reflect.New(sliceValue.Type()).Interface() //beans := reflect.New(sliceValue.Type()).Interface()
//err = newSession.In("(id)", ides...).OrderBy(session.Statement.OrderStr).NoCache().Find(beans) //err = newSession.In("(id)", ides...).OrderBy(session.Statement.OrderStr).NoCache().Find(beans)
err = newSession.In("(id)", ides...).NoCache().Find(beans) ff := make([][]interface{}, len(table.PrimaryKeys))
for i, _ := range table.PrimaryKeys {
ff[i] = make([]interface{}, 0)
}
for _, ie := range ides {
for i, _ := range table.PrimaryKeys {
ff[i] = append(ff[i], ie[i])
}
}
for i, name := range table.PrimaryKeys {
newSession.In(name, ff[i]...)
}
err = newSession.NoCache().Find(beans)
if err != nil { if err != nil {
return err return err
} }
@ -789,12 +804,16 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
rv = rv.Addr() rv = rv.Addr()
} }
bean := rv.Interface() bean := rv.Interface()
id := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName).Int() id := session.Engine.IdOf(bean)
sid, err := id.ToString()
if err != nil {
return err
}
//bean := vs.Index(i).Addr().Interface() //bean := vs.Index(i).Addr().Interface()
temps[ididxes[id]] = bean temps[ididxes[sid]] = bean
//temps[idxes[i]] = bean //temps[idxes[i]] = bean
session.Engine.LogDebug("[xorm:cacheFind] cache bean:", tableName, id, bean) session.Engine.LogDebug("[xorm:cacheFind] cache bean:", tableName, id, bean)
cacher.PutBean(tableName, id, bean) cacher.PutBean(tableName, sid, bean)
} }
} }
@ -811,16 +830,21 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean))))
} }
} else if sliceValue.Kind() == reflect.Map { } else if sliceValue.Kind() == reflect.Map {
var key int64 var key core.PK
if table.PrimaryKeys[0] != "" { if table.PrimaryKeys[0] != "" {
key = ids[j] key = ids[j]
} else { }
key = int64(j)
if len(key) == 1 {
ikey, err := strconv.ParseInt(fmt.Sprintf("%v", key[0]), 10, 64)
if err != nil {
return err
} }
if t.Kind() == reflect.Ptr { if t.Kind() == reflect.Ptr {
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(bean)) sliceValue.SetMapIndex(reflect.ValueOf(ikey), reflect.ValueOf(bean))
} else { } else {
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(bean))) sliceValue.SetMapIndex(reflect.ValueOf(ikey), reflect.Indirect(reflect.ValueOf(bean)))
}
} }
} }
/*} else { /*} else {
@ -2762,7 +2786,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
cacher := session.Engine.getCacher(table.Type) cacher := session.Engine.getCacher(table.Type)
tableName := session.Statement.TableName() tableName := session.Statement.TableName()
session.Engine.LogDebug("[xorm:cacheUpdate] get cache sql", newsql, args[nStart:]) session.Engine.LogDebug("[xorm:cacheUpdate] get cache sql", newsql, args[nStart:])
ids, err := getCacheSql(cacher, tableName, newsql, args[nStart:]) ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
if err != nil { if err != nil {
resultsSlice, err := session.query(newsql, args[nStart:]...) resultsSlice, err := session.query(newsql, args[nStart:]...)
if err != nil { if err != nil {
@ -2770,7 +2794,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
} }
session.Engine.LogDebug("[xorm:cacheUpdate] find updated id", resultsSlice) session.Engine.LogDebug("[xorm:cacheUpdate] find updated id", resultsSlice)
ids = make([]int64, 0) ids = make([]core.PK, 0)
if len(resultsSlice) > 0 { if len(resultsSlice) > 0 {
for _, data := range resultsSlice { for _, data := range resultsSlice {
var id int64 var id int64
@ -2782,7 +2806,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
return err return err
} }
} }
ids = append(ids, id) ids = append(ids, core.PK{id})
} }
} }
} /*else { } /*else {
@ -2791,7 +2815,11 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
}*/ }*/
for _, id := range ids { for _, id := range ids {
if bean := cacher.GetBean(tableName, id); bean != nil { sid, err := id.ToString()
if err != nil {
return err
}
if bean := cacher.GetBean(tableName, sid); bean != nil {
sqls := splitNNoCase(sqlStr, "where", 2) sqls := splitNNoCase(sqlStr, "where", 2)
if len(sqls) == 0 || len(sqls) > 2 { if len(sqls) == 0 || len(sqls) > 2 {
return ErrCacheFailed return ErrCacheFailed
@ -2834,7 +2862,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
} }
session.Engine.LogDebug("[xorm:cacheUpdate] update cache", tableName, id, bean) session.Engine.LogDebug("[xorm:cacheUpdate] update cache", tableName, id, bean)
cacher.PutBean(tableName, id, bean) cacher.PutBean(tableName, sid, bean)
} }
} }
session.Engine.LogDebug("[xorm:cacheUpdate] clear cached table sql:", tableName) session.Engine.LogDebug("[xorm:cacheUpdate] clear cached table sql:", tableName)
@ -3047,13 +3075,13 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
cacher := session.Engine.getCacher(session.Statement.RefTable.Type) cacher := session.Engine.getCacher(session.Statement.RefTable.Type)
tableName := session.Statement.TableName() tableName := session.Statement.TableName()
ids, err := getCacheSql(cacher, tableName, newsql, args) ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
resultsSlice, err := session.query(newsql, args...) resultsSlice, err := session.query(newsql, args...)
if err != nil { if err != nil {
return err return err
} }
ids = make([]int64, 0) ids = make([]core.PK, 0)
if len(resultsSlice) > 0 { if len(resultsSlice) > 0 {
for _, data := range resultsSlice { for _, data := range resultsSlice {
var id int64 var id int64
@ -3065,7 +3093,7 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
return err return err
} }
} }
ids = append(ids, id) ids = append(ids, core.PK{id})
} }
} }
} /*else { } /*else {
@ -3075,7 +3103,11 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
for _, id := range ids { for _, id := range ids {
session.Engine.LogDebug("[xorm:cacheDelete] delete cache obj", tableName, id) session.Engine.LogDebug("[xorm:cacheDelete] delete cache obj", tableName, id)
cacher.DelBean(tableName, id) sid, err := id.ToString()
if err != nil {
return err
}
cacher.DelBean(tableName, sid)
} }
session.Engine.LogDebug("[xorm:cacheDelete] clear cache table", tableName) session.Engine.LogDebug("[xorm:cacheDelete] clear cache table", tableName)
cacher.ClearIds(tableName) cacher.ClearIds(tableName)

View File

@ -24,7 +24,7 @@ type Statement struct {
Start int Start int
LimitN int LimitN int
WhereStr string WhereStr string
IdParam *PK IdParam *core.PK
Params []interface{} Params []interface{}
OrderStr string OrderStr string
JoinStr string JoinStr string
@ -421,24 +421,28 @@ func (statement *Statement) TableName() string {
return "" return ""
} }
var (
ptrPkType = reflect.TypeOf(&core.PK{})
pkType = reflect.TypeOf(core.PK{})
)
// Generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?" // Generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?"
func (statement *Statement) Id(id interface{}) *Statement { func (statement *Statement) Id(id interface{}) *Statement {
idValue := reflect.ValueOf(id) idValue := reflect.ValueOf(id)
idType := reflect.TypeOf(idValue.Interface()) idType := reflect.TypeOf(idValue.Interface())
switch idType { switch idType {
case reflect.TypeOf(&PK{}): case ptrPkType:
if pkPtr, ok := (id).(*PK); ok { if pkPtr, ok := (id).(*core.PK); ok {
statement.IdParam = pkPtr statement.IdParam = pkPtr
} }
case reflect.TypeOf(PK{}): case pkType:
if pk, ok := (id).(PK); ok { if pk, ok := (id).(core.PK); ok {
statement.IdParam = &pk statement.IdParam = &pk
} }
default: default:
// TODO treat as int primitve for now, need to handle type check // TODO treat as int primitve for now, need to handle type check
statement.IdParam = &PK{id} statement.IdParam = &core.PK{id}
// !nashtsai! REVIEW although it will be user's mistake if called Id() twice with // !nashtsai! REVIEW although it will be user's mistake if called Id() twice with
// different value and Id should be PK's field name, however, at this stage probably // different value and Id should be PK's field name, however, at this stage probably
@ -789,31 +793,12 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
} }
func (statement *Statement) processIdParam() { func (statement *Statement) processIdParam() {
if statement.IdParam != nil { if statement.IdParam != nil {
i := 0 for i, col := range statement.RefTable.PKColumns() {
columns := statement.RefTable.ColumnsSeq() if i < len(*(statement.IdParam)) {
colCnt := len(columns) statement.And(fmt.Sprintf("%v=?", statement.Engine.Quote(col.Name)), (*(statement.IdParam))[i])
for _, elem := range *(statement.IdParam) { } else {
for ; i < colCnt; i++ { statement.And(fmt.Sprintf("%v=?", statement.Engine.Quote(col.Name)), "")
colName := columns[i]
col := statement.RefTable.GetColumn(colName)
if col.IsPrimaryKey {
statement.And(fmt.Sprintf("%v=?", col.Name), elem)
i++
break
}
}
}
// !nashtsai! REVIEW what if statement.IdParam has insufficient pk item? handle it
// as empty string for now, so this will result sql exec failed instead of unexpected
// false update/delete
for ; i < colCnt; i++ {
colName := columns[i]
col := statement.RefTable.GetColumn(colName)
if col.IsPrimaryKey {
statement.And(fmt.Sprintf("%v=?", col.Name), "")
} }
} }
} }

View File

@ -1,4 +1,4 @@
package xorm package tests
import ( import (
"errors" "errors"
@ -755,7 +755,7 @@ func orderSameMapper(engine *xorm.Engine, t *testing.T) {
func joinSameMapper(engine *xorm.Engine, t *testing.T) { func joinSameMapper(engine *xorm.Engine, t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
err := engine.Join("LEFT", "`Userdetail`", "`Userinfo`.`(id)`=`Userdetail`.`(id)`").Find(&users) err := engine.Join("LEFT", "`Userdetail`", "`Userinfo`.`(id)`=`Userdetail`.`Id`").Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -1080,7 +1080,7 @@ func testCols(engine *xorm.Engine, t *testing.T) {
func testColsSameMapper(engine *xorm.Engine, t *testing.T) { func testColsSameMapper(engine *xorm.Engine, t *testing.T) {
users := []Userinfo{} users := []Userinfo{}
err := engine.Cols("(id), Username").Find(&users) err := engine.Cols("id, Username").Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -1089,7 +1089,8 @@ func testColsSameMapper(engine *xorm.Engine, t *testing.T) {
fmt.Println(users) fmt.Println(users)
tmpUsers := []tempUser{} tmpUsers := []tempUser{}
err = engine.Table("Userinfo").Cols("(id), Username").Find(&tmpUsers) // TODO: should use cache
err = engine.NoCache().Table("Userinfo").Cols("id, Username").Find(&tmpUsers)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2062,7 +2063,8 @@ func testVersion(engine *xorm.Engine, t *testing.T) {
func testDistinct(engine *xorm.Engine, t *testing.T) { func testDistinct(engine *xorm.Engine, t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
err := engine.Distinct("departname").Find(&users) departname := engine.TableMapper.Obj2Table("Departname")
err := engine.Distinct(departname).Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2079,7 +2081,7 @@ func testDistinct(engine *xorm.Engine, t *testing.T) {
} }
users2 := make([]Depart, 0) users2 := make([]Depart, 0)
err = engine.Distinct("departname").Table(new(Userinfo)).Find(&users2) err = engine.Distinct(departname).Table(new(Userinfo)).Find(&users2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2226,7 +2228,7 @@ func testPrefixTableName(engine *xorm.Engine, t *testing.T) {
panic(err) panic(err)
} }
tempEngine.ShowSQL = true tempEngine.ShowSQL = true
mapper := xorm.NewPrefixMapper(xorm.SnakeMapper{}, "xlw_") mapper := core.NewPrefixMapper(core.SnakeMapper{}, "xlw_")
//tempEngine.SetMapper(mapper) //tempEngine.SetMapper(mapper)
tempEngine.SetTableMapper(mapper) tempEngine.SetTableMapper(mapper)
exist, err := tempEngine.IsTableExist(&Userinfo{}) exist, err := tempEngine.IsTableExist(&Userinfo{})
@ -3738,7 +3740,7 @@ func testCompositeKey(engine *xorm.Engine, t *testing.T) {
} }
var compositeKeyVal CompositeKey var compositeKeyVal CompositeKey
has, err := engine.Id(xorm.PK{11, 22}).Get(&compositeKeyVal) has, err := engine.Id(core.PK{11, 22}).Get(&compositeKeyVal)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -3746,7 +3748,7 @@ func testCompositeKey(engine *xorm.Engine, t *testing.T) {
} }
// test passing PK ptr, this test seem failed withCache // test passing PK ptr, this test seem failed withCache
has, err = engine.Id(&xorm.PK{11, 22}).Get(&compositeKeyVal) has, err = engine.Id(&core.PK{11, 22}).Get(&compositeKeyVal)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -3754,14 +3756,14 @@ func testCompositeKey(engine *xorm.Engine, t *testing.T) {
} }
compositeKeyVal = CompositeKey{UpdateStr: "test1"} compositeKeyVal = CompositeKey{UpdateStr: "test1"}
cnt, err = engine.Id(xorm.PK{11, 22}).Update(&compositeKeyVal) cnt, err = engine.Id(core.PK{11, 22}).Update(&compositeKeyVal)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
t.Error(errors.New("can't update CompositeKey{11, 22}")) t.Error(errors.New("can't update CompositeKey{11, 22}"))
} }
cnt, err = engine.Id(xorm.PK{11, 22}).Delete(&CompositeKey{}) cnt, err = engine.Id(core.PK{11, 22}).Delete(&CompositeKey{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
@ -3803,7 +3805,7 @@ func testCompositeKey2(engine *xorm.Engine, t *testing.T) {
} }
var user User var user User
has, err := engine.Id(xorm.PK{"11", 22}).Get(&user) has, err := engine.Id(core.PK{"11", 22}).Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -3811,7 +3813,7 @@ func testCompositeKey2(engine *xorm.Engine, t *testing.T) {
} }
// test passing PK ptr, this test seem failed withCache // test passing PK ptr, this test seem failed withCache
has, err = engine.Id(&xorm.PK{"11", 22}).Get(&user) has, err = engine.Id(&core.PK{"11", 22}).Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -3819,14 +3821,14 @@ func testCompositeKey2(engine *xorm.Engine, t *testing.T) {
} }
user = User{NickName: "test1"} user = User{NickName: "test1"}
cnt, err = engine.Id(xorm.PK{"11", 22}).Update(&user) cnt, err = engine.Id(core.PK{"11", 22}).Update(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
t.Error(errors.New("can't update User{11, 22}")) t.Error(errors.New("can't update User{11, 22}"))
} }
cnt, err = engine.Id(xorm.PK{"11", 22}).Delete(&User{}) cnt, err = engine.Id(core.PK{"11", 22}).Delete(&User{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
@ -3870,16 +3872,12 @@ func testAll(engine *xorm.Engine, t *testing.T) {
} }
func testAll2(engine *xorm.Engine, t *testing.T) { func testAll2(engine *xorm.Engine, t *testing.T) {
fmt.Println("-------------- combineTransaction --------------")
combineTransaction(engine, t)
fmt.Println("-------------- table --------------") fmt.Println("-------------- table --------------")
table(engine, t) table(engine, t)
fmt.Println("-------------- createMultiTables --------------") fmt.Println("-------------- createMultiTables --------------")
createMultiTables(engine, t) createMultiTables(engine, t)
fmt.Println("-------------- tableOp --------------") fmt.Println("-------------- tableOp --------------")
tableOp(engine, t) tableOp(engine, t)
fmt.Println("-------------- testCols --------------")
testCols(engine, t)
fmt.Println("-------------- testCharst --------------") fmt.Println("-------------- testCharst --------------")
testCharst(engine, t) testCharst(engine, t)
fmt.Println("-------------- testStoreEngine --------------") fmt.Println("-------------- testStoreEngine --------------")
@ -3961,6 +3959,10 @@ func testAllSnakeMapper(engine *xorm.Engine, t *testing.T) {
join(engine, t) join(engine, t)
fmt.Println("-------------- having --------------") fmt.Println("-------------- having --------------")
having(engine, t) having(engine, t)
fmt.Println("-------------- combineTransaction --------------")
combineTransaction(engine, t)
fmt.Println("-------------- testCols --------------")
testCols(engine, t)
} }
func testAllSameMapper(engine *xorm.Engine, t *testing.T) { func testAllSameMapper(engine *xorm.Engine, t *testing.T) {
@ -3976,4 +3978,8 @@ func testAllSameMapper(engine *xorm.Engine, t *testing.T) {
joinSameMapper(engine, t) joinSameMapper(engine, t)
fmt.Println("-------------- having --------------") fmt.Println("-------------- having --------------")
havingSameMapper(engine, t) havingSameMapper(engine, t)
fmt.Println("-------------- combineTransaction --------------")
combineTransactionSameMapper(engine, t)
fmt.Println("-------------- testCols --------------")
testColsSameMapper(engine, t)
} }

View File

@ -1,4 +1,4 @@
package xorm package tests
import ( import (
"database/sql" "database/sql"

View File

@ -1,4 +1,4 @@
package xorm package tests
// //
// +build windows // +build windows
@ -9,6 +9,7 @@ import (
_ "github.com/lunny/godbc" _ "github.com/lunny/godbc"
"github.com/lunny/xorm" "github.com/lunny/xorm"
"github.com/lunny/xorm/caches"
) )
const mssqlConnStr = "driver={SQL Server};Server=192.168.20.135;Database=xorm_test; uid=sa; pwd=1234;" const mssqlConnStr = "driver={SQL Server};Server=192.168.20.135;Database=xorm_test; uid=sa; pwd=1234;"
@ -40,7 +41,7 @@ func TestMssqlWithCache(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
@ -113,7 +114,7 @@ func BenchmarkMssqlCacheInsert(t *testing.B) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
@ -125,7 +126,7 @@ func BenchmarkMssqlCacheFind(t *testing.B) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
@ -137,7 +138,7 @@ func BenchmarkMssqlCacheFindPtr(t *testing.B) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

View File

@ -1,10 +1,11 @@
package xorm package tests
import ( import (
"database/sql" "database/sql"
"testing" "testing"
"github.com/lunny/xorm" "github.com/lunny/xorm"
"github.com/lunny/xorm/caches"
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
) )
@ -49,7 +50,7 @@ func TestMyMysqlWithCache(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
@ -136,7 +137,7 @@ func BenchmarkMyMysqlCacheInsert(t *testing.B) {
} }
defer engine.Close() defer engine.Close()
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
@ -149,7 +150,7 @@ func BenchmarkMyMysqlCacheFind(t *testing.B) {
} }
defer engine.Close() defer engine.Close()
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
@ -162,7 +163,7 @@ func BenchmarkMyMysqlCacheFindPtr(t *testing.B) {
} }
defer engine.Close() defer engine.Close()
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

View File

@ -1,4 +1,4 @@
package xorm package tests
import ( import (
"database/sql" "database/sql"
@ -6,6 +6,8 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/lunny/xorm" "github.com/lunny/xorm"
"github.com/lunny/xorm/caches"
"github.com/lunny/xorm/core"
) )
/* /*
@ -54,7 +56,7 @@ func TestMysqlSameMapper(t *testing.T) {
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
engine.SetMapper(xorm.SameMapper{}) engine.SetMapper(core.SameMapper{})
testAll(engine, t) testAll(engine, t)
testAllSameMapper(engine, t) testAllSameMapper(engine, t)
@ -75,7 +77,7 @@ func TestMysqlWithCache(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
@ -99,8 +101,8 @@ func TestMysqlWithCacheSameMapper(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
engine.SetMapper(xorm.SameMapper{}) engine.SetMapper(core.SameMapper{})
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
@ -190,7 +192,7 @@ func BenchmarkMysqlCacheInsert(t *testing.B) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
@ -202,7 +204,7 @@ func BenchmarkMysqlCacheFind(t *testing.B) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
@ -214,7 +216,7 @@ func BenchmarkMysqlCacheFindPtr(t *testing.B) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

View File

@ -1,4 +1,4 @@
package xorm package tests
import ( import (
"database/sql" "database/sql"
@ -6,6 +6,8 @@ import (
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/lunny/xorm" "github.com/lunny/xorm"
"github.com/lunny/xorm/caches"
"github.com/lunny/xorm/core"
) )
//var connStr string = "dbname=xorm_test user=lunny password=1234 sslmode=disable" //var connStr string = "dbname=xorm_test user=lunny password=1234 sslmode=disable"
@ -59,7 +61,7 @@ func TestPostgresWithCache(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
defer engine.Close() defer engine.Close()
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
@ -78,7 +80,7 @@ func TestPostgresSameMapper(t *testing.T) {
return return
} }
defer engine.Close() defer engine.Close()
engine.SetMapper(xorm.SameMapper{}) engine.SetMapper(core.SameMapper{})
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
@ -96,9 +98,9 @@ func TestPostgresWithCacheSameMapper(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
defer engine.Close() defer engine.Close()
engine.SetMapper(xorm.SameMapper{}) engine.SetMapper(core.SameMapper{})
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
@ -168,7 +170,7 @@ func BenchmarkPostgresCacheInsert(t *testing.B) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
@ -181,7 +183,7 @@ func BenchmarkPostgresCacheFind(t *testing.B) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
@ -194,7 +196,7 @@ func BenchmarkPostgresCacheFindPtr(t *testing.B) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

View File

@ -1,4 +1,4 @@
package xorm package tests
import ( import (
"database/sql" "database/sql"
@ -6,6 +6,8 @@ import (
"testing" "testing"
"github.com/lunny/xorm" "github.com/lunny/xorm"
"github.com/lunny/xorm/caches"
"github.com/lunny/xorm/core"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
@ -44,7 +46,7 @@ func TestSqlite3WithCache(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
@ -62,7 +64,7 @@ func TestSqlite3SameMapper(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
engine.SetMapper(xorm.SameMapper{}) engine.SetMapper(core.SameMapper{})
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
@ -81,8 +83,8 @@ func TestSqlite3WithCacheSameMapper(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
engine.SetMapper(xorm.SameMapper{}) engine.SetMapper(core.SameMapper{})
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
@ -152,7 +154,7 @@ func BenchmarkSqlite3CacheInsert(t *testing.B) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
@ -164,7 +166,7 @@ func BenchmarkSqlite3CacheFind(t *testing.B) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
@ -176,6 +178,6 @@ func BenchmarkSqlite3CacheFindPtr(t *testing.B) {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

View File

@ -1,5 +1,8 @@
--DROP DATABASE xorm_test; DROP DATABASE xorm_test;
--DROP DATABASE xorm_test2; DROP DATABASE xorm_test1;
DROP DATABASE xorm_test2;
DROP DATABASE xorm_test3;
CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET utf8 COLLATE utf8_general_ci; CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE DATABASE IF NOT EXISTS xorm_test1 CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE DATABASE IF NOT EXISTS xorm_test2 CHARACTER SET utf8 COLLATE utf8_general_ci; CREATE DATABASE IF NOT EXISTS xorm_test2 CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE DATABASE IF NOT EXISTS xorm_test3 CHARACTER SET utf8 COLLATE utf8_general_ci; CREATE DATABASE IF NOT EXISTS xorm_test3 CHARACTER SET utf8 COLLATE utf8_general_ci;

14
xorm.go
View File

@ -7,7 +7,9 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"sync" "sync"
"time"
"github.com/lunny/xorm/caches"
"github.com/lunny/xorm/core" "github.com/lunny/xorm/core"
_ "github.com/lunny/xorm/dialects" _ "github.com/lunny/xorm/dialects"
_ "github.com/lunny/xorm/drivers" _ "github.com/lunny/xorm/drivers"
@ -46,9 +48,9 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine := &Engine{DriverName: driverName, engine := &Engine{DriverName: driverName,
DataSourceName: dataSourceName, dialect: dialect, DataSourceName: dataSourceName, dialect: dialect,
tableCachers: make(map[reflect.Type]Cacher)} tableCachers: make(map[reflect.Type]core.Cacher)}
engine.SetMapper(SnakeMapper{}) engine.SetMapper(core.SnakeMapper{})
engine.Filters = dialect.Filters() engine.Filters = dialect.Filters()
@ -65,3 +67,11 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
runtime.SetFinalizer(engine, close) runtime.SetFinalizer(engine, close)
return engine, err return engine, err
} }
func NewLRUCacher(store core.CacheStore, max int) *caches.LRUCacher {
return caches.NewLRUCacher(store, core.CacheExpired, core.CacheMaxMemory, max)
}
func NewLRUCacher2(store core.CacheStore, expired time.Duration, max int) *caches.LRUCacher {
return caches.NewLRUCacher(store, expired, 0, max)
}