diff --git a/cache.go b/caches/lruCacher.go similarity index 59% rename from cache.go rename to caches/lruCacher.go index e1ccc0d1..6f851922 100644 --- a/cache.go +++ b/caches/lruCacher.go @@ -1,107 +1,21 @@ -package xorm +//LRUCacher implements Cacher according to LRU algorithm +package caches import ( "container/list" - "errors" "fmt" - "strconv" - "strings" "sync" "time" + + "github.com/lunny/xorm/core" ) -const ( - // default cache expired time - CacheExpired = 60 * time.Minute - // not use now - CacheMaxMemory = 256 - // evey ten minutes to clear all expired nodes - CacheGcInterval = 10 * time.Minute - // each time when gc to removed max nodes - CacheGcMaxRemoved = 20 -) - -// CacheStore is a interface to store cache -type CacheStore interface { - Put(key, value interface{}) error - Get(key interface{}) (interface{}, error) - Del(key interface{}) error -} - -// MemoryStore implements CacheStore provide local machine -// memory store -type MemoryStore struct { - store map[interface{}]interface{} - mutex sync.RWMutex -} - -func NewMemoryStore() *MemoryStore { - return &MemoryStore{store: make(map[interface{}]interface{})} -} - -func (s *MemoryStore) Put(key, value interface{}) error { - s.mutex.Lock() - defer s.mutex.Unlock() - s.store[key] = value - return nil -} - -func (s *MemoryStore) Get(key interface{}) (interface{}, error) { - s.mutex.RLock() - defer s.mutex.RUnlock() - if v, ok := s.store[key]; ok { - return v, nil - } - - return nil, ErrNotExist -} - -func (s *MemoryStore) Del(key interface{}) error { - s.mutex.Lock() - defer s.mutex.Unlock() - delete(s.store, key) - return nil -} - -// Cacher is an interface to provide cache -type Cacher interface { - GetIds(tableName, sql string) interface{} - GetBean(tableName string, id int64) interface{} - PutIds(tableName, sql string, ids interface{}) - PutBean(tableName string, id int64, obj interface{}) - DelIds(tableName, sql string) - DelBean(tableName string, id int64) - ClearIds(tableName string) - ClearBeans(tableName string) -} - -type idNode struct { - tbName string - id int64 - lastVisit time.Time -} - -type sqlNode struct { - tbName string - sql string - lastVisit time.Time -} - -func newIdNode(tbName string, id int64) *idNode { - return &idNode{tbName, id, time.Now()} -} - -func newSqlNode(tbName, sql string) *sqlNode { - return &sqlNode{tbName, sql, time.Now()} -} - -// LRUCacher implements Cacher according to LRU algorithm type LRUCacher struct { idList *list.List sqlList *list.List - idIndex map[string]map[interface{}]*list.Element - sqlIndex map[string]map[interface{}]*list.Element - store CacheStore + idIndex map[string]map[string]*list.Element + sqlIndex map[string]map[string]*list.Element + store core.CacheStore Max int mutex sync.Mutex Expired time.Duration @@ -109,25 +23,17 @@ type LRUCacher struct { GcInterval time.Duration } -func newLRUCacher(store CacheStore, expired time.Duration, maxSize int, max int) *LRUCacher { +func NewLRUCacher(store core.CacheStore, expired time.Duration, maxSize int, max int) *LRUCacher { cacher := &LRUCacher{store: store, idList: list.New(), sqlList: list.New(), Expired: expired, maxSize: maxSize, - GcInterval: CacheGcInterval, Max: max, - sqlIndex: make(map[string]map[interface{}]*list.Element), - idIndex: make(map[string]map[interface{}]*list.Element), + GcInterval: core.CacheGcInterval, Max: max, + sqlIndex: make(map[string]map[string]*list.Element), + idIndex: make(map[string]map[string]*list.Element), } cacher.RunGC() return cacher } -func NewLRUCacher(store CacheStore, max int) *LRUCacher { - return newLRUCacher(store, CacheExpired, CacheMaxMemory, max) -} - -func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher { - return newLRUCacher(store, expired, 0, max) -} - //func NewLRUCacher3(store CacheStore, expired time.Duration, maxSize int) *LRUCacher { // return newLRUCacher(store, expired, maxSize, 0) //} @@ -148,7 +54,7 @@ func (m *LRUCacher) GC() { defer m.mutex.Unlock() var removedNum int for e := m.idList.Front(); e != nil; { - if removedNum <= CacheGcMaxRemoved && + if removedNum <= core.CacheGcMaxRemoved && time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { removedNum++ next := e.Next() @@ -164,7 +70,7 @@ func (m *LRUCacher) GC() { removedNum = 0 for e := m.sqlList.Front(); e != nil; { - if removedNum <= CacheGcMaxRemoved && + if removedNum <= core.CacheGcMaxRemoved && time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { removedNum++ next := e.Next() @@ -184,7 +90,7 @@ func (m *LRUCacher) GetIds(tableName, sql string) interface{} { m.mutex.Lock() defer m.mutex.Unlock() if _, ok := m.sqlIndex[tableName]; !ok { - m.sqlIndex[tableName] = make(map[interface{}]*list.Element) + m.sqlIndex[tableName] = make(map[string]*list.Element) } if v, err := m.store.Get(sql); err == nil { if el, ok := m.sqlIndex[tableName][sql]; !ok { @@ -209,11 +115,11 @@ func (m *LRUCacher) GetIds(tableName, sql string) interface{} { } // Get bean according tableName and id from cache -func (m *LRUCacher) GetBean(tableName string, id int64) interface{} { +func (m *LRUCacher) GetBean(tableName string, id string) interface{} { m.mutex.Lock() defer m.mutex.Unlock() if _, ok := m.idIndex[tableName]; !ok { - m.idIndex[tableName] = make(map[interface{}]*list.Element) + m.idIndex[tableName] = make(map[string]*list.Element) } tid := genId(tableName, id) if v, err := m.store.Get(tid); err == nil { @@ -248,7 +154,7 @@ func (m *LRUCacher) clearIds(tableName string) { m.store.Del(sql) } } - m.sqlIndex[tableName] = make(map[interface{}]*list.Element) + m.sqlIndex[tableName] = make(map[string]*list.Element) } func (m *LRUCacher) ClearIds(tableName string) { @@ -261,11 +167,11 @@ func (m *LRUCacher) clearBeans(tableName string) { if tis, ok := m.idIndex[tableName]; ok { for id, v := range tis { m.idList.Remove(v) - tid := genId(tableName, id.(int64)) + tid := genId(tableName, id) m.store.Del(tid) } } - m.idIndex[tableName] = make(map[interface{}]*list.Element) + m.idIndex[tableName] = make(map[string]*list.Element) } func (m *LRUCacher) ClearBeans(tableName string) { @@ -278,7 +184,7 @@ func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) { m.mutex.Lock() defer m.mutex.Unlock() if _, ok := m.sqlIndex[tableName]; !ok { - m.sqlIndex[tableName] = make(map[interface{}]*list.Element) + m.sqlIndex[tableName] = make(map[string]*list.Element) } if el, ok := m.sqlIndex[tableName][sql]; !ok { el = m.sqlList.PushBack(newSqlNode(tableName, sql)) @@ -294,7 +200,7 @@ func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) { } } -func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) { +func (m *LRUCacher) PutBean(tableName string, id string, obj interface{}) { m.mutex.Lock() defer m.mutex.Unlock() var el *list.Element @@ -331,7 +237,7 @@ func (m *LRUCacher) DelIds(tableName, sql string) { m.delIds(tableName, sql) } -func (m *LRUCacher) delBean(tableName string, id int64) { +func (m *LRUCacher) delBean(tableName string, id string) { tid := genId(tableName, id) if el, ok := m.idIndex[tableName][id]; ok { delete(m.idIndex[tableName], id) @@ -341,55 +247,36 @@ func (m *LRUCacher) delBean(tableName string, id int64) { m.store.Del(tid) } -func (m *LRUCacher) DelBean(tableName string, id int64) { +func (m *LRUCacher) DelBean(tableName string, id string) { m.mutex.Lock() defer m.mutex.Unlock() m.delBean(tableName, id) } -func encodeIds(ids []int64) (s string) { - s = "[" - for _, id := range ids { - s += fmt.Sprintf("%v,", id) - } - s = s[:len(s)-1] + "]" - return +type idNode struct { + tbName string + id string + lastVisit time.Time } -func decodeIds(s string) []int64 { - res := make([]int64, 0) - if len(s) >= 2 { - ss := strings.Split(s[1:len(s)-1], ",") - for _, s := range ss { - i, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return res - } - res = append(res, i) - } - } - return res -} - -func getCacheSql(m Cacher, tableName, sql string, args interface{}) ([]int64, error) { - bytes := m.GetIds(tableName, genSqlKey(sql, args)) - if bytes == nil { - return nil, errors.New("Not Exist") - } - objs := decodeIds(bytes.(string)) - return objs, nil -} - -func putCacheSql(m Cacher, ids []int64, tableName, sql string, args interface{}) error { - bytes := encodeIds(ids) - m.PutIds(tableName, genSqlKey(sql, args), bytes) - return nil +type sqlNode struct { + tbName string + sql string + lastVisit time.Time } func genSqlKey(sql string, args interface{}) string { return fmt.Sprintf("%v-%v", sql, args) } -func genId(prefix string, id int64) string { +func genId(prefix string, id string) string { return fmt.Sprintf("%v-%v", prefix, id) } + +func newIdNode(tbName string, id string) *idNode { + return &idNode{tbName, id, time.Now()} +} + +func newSqlNode(tbName, sql string) *sqlNode { + return &sqlNode{tbName, sql, time.Now()} +} diff --git a/caches/memoryStore.go b/caches/memoryStore.go new file mode 100644 index 00000000..8bbb6b99 --- /dev/null +++ b/caches/memoryStore.go @@ -0,0 +1,49 @@ +// MemoryStore implements CacheStore provide local machine +package caches + +import ( + "errors" + "sync" + + "github.com/lunny/xorm/core" +) + +var ( + ErrNotExist = errors.New("key not exist") +) + +var _ core.CacheStore = NewMemoryStore() + +// memory store +type MemoryStore struct { + store map[interface{}]interface{} + mutex sync.RWMutex +} + +func NewMemoryStore() *MemoryStore { + return &MemoryStore{store: make(map[interface{}]interface{})} +} + +func (s *MemoryStore) Put(key string, value interface{}) error { + s.mutex.Lock() + defer s.mutex.Unlock() + s.store[key] = value + return nil +} + +func (s *MemoryStore) Get(key string) (interface{}, error) { + s.mutex.RLock() + defer s.mutex.RUnlock() + if v, ok := s.store[key]; ok { + return v, nil + } + + return nil, ErrNotExist +} + +func (s *MemoryStore) Del(key string) error { + s.mutex.Lock() + defer s.mutex.Unlock() + delete(s.store, key) + return nil +} diff --git a/core/cache.go b/core/cache.go new file mode 100644 index 00000000..bedd22a7 --- /dev/null +++ b/core/cache.go @@ -0,0 +1,77 @@ +package core + +import ( + "encoding/json" + "errors" + "fmt" + "time" +) + +const ( + // default cache expired time + CacheExpired = 60 * time.Minute + // not use now + CacheMaxMemory = 256 + // evey ten minutes to clear all expired nodes + CacheGcInterval = 10 * time.Minute + // each time when gc to removed max nodes + CacheGcMaxRemoved = 20 +) + +// CacheStore is a interface to store cache +type CacheStore interface { + // key is primary key or composite primary key or unique key's value + // value is struct's pointer + // key format : -p--... + Put(key string, value interface{}) error + Get(key string) (interface{}, error) + Del(key string) error +} + +// Cacher is an interface to provide cache +// id format : u--... +type Cacher interface { + GetIds(tableName, sql string) interface{} + GetBean(tableName string, id string) interface{} + PutIds(tableName, sql string, ids interface{}) + PutBean(tableName string, id string, obj interface{}) + DelIds(tableName, sql string) + DelBean(tableName string, id string) + ClearIds(tableName string) + ClearBeans(tableName string) +} + +func encodeIds(ids []PK) (string, error) { + b, err := json.Marshal(ids) + if err != nil { + return "", err + } + return string(b), nil +} + +func decodeIds(s string) ([]PK, error) { + pks := make([]PK, 0) + err := json.Unmarshal([]byte(s), &pks) + return pks, err +} + +func GetCacheSql(m Cacher, tableName, sql string, args interface{}) ([]PK, error) { + bytes := m.GetIds(tableName, GenSqlKey(sql, args)) + if bytes == nil { + return nil, errors.New("Not Exist") + } + return decodeIds(bytes.(string)) +} + +func PutCacheSql(m Cacher, ids []PK, tableName, sql string, args interface{}) error { + bytes, err := encodeIds(ids) + if err != nil { + return err + } + m.PutIds(tableName, GenSqlKey(sql, args), bytes) + return nil +} + +func GenSqlKey(sql string, args interface{}) string { + return fmt.Sprintf("%v-%v", sql, args) +} diff --git a/core/db.go b/core/db.go index 9249c2a8..3f96b343 100644 --- a/core/db.go +++ b/core/db.go @@ -2,43 +2,138 @@ package core import ( "database/sql" + "errors" "reflect" ) type DB struct { *sql.DB + Mapper IMapper } func Open(driverName, dataSourceName string) (*DB, error) { db, err := sql.Open(driverName, dataSourceName) - return &DB{db}, err + return &DB{db, &SnakeMapper{}}, err } func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { rows, err := db.DB.Query(query, args...) - return &Rows{rows}, err + return &Rows{rows, db.Mapper}, err } type Rows struct { *sql.Rows + Mapper IMapper } -func (rs *Rows) Scan(dest ...interface{}) error { - newDest := make([]interface{}, 0) - for _, s := range dest { - vv := reflect.ValueOf(s) - switch vv.Kind() { - case reflect.Ptr: - vvv := vv.Elem() - if vvv.Kind() == reflect.Struct { - for j := 0; j < vvv.NumField(); j++ { - newDest = append(newDest, vvv.FieldByIndex([]int{j}).Addr().Interface()) - } - } else { - newDest = append(newDest, s) - } +// scan data to a struct's pointer according field index +func (rs *Rows) ScanStruct(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return errors.New("dest should be a struct's pointer") + } + + vvv := vv.Elem() + newDest := make([]interface{}, vvv.NumField()) + + for j := 0; j < vvv.NumField(); j++ { + newDest[j] = vvv.Field(j).Addr().Interface() + } + + return rs.Rows.Scan(newDest...) +} + +// scan data to a struct's pointer according field name +func (rs *Rows) ScanStruct2(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return errors.New("dest should be a struct's pointer") + } + + cols, err := rs.Columns() + if err != nil { + return err + } + + vvv := vv.Elem() + newDest := make([]interface{}, len(cols)) + + for j, name := range cols { + f := vvv.FieldByName(rs.Mapper.Table2Obj(name)) + if f.IsValid() { + newDest[j] = f.Addr().Interface() + } else { + var v interface{} + newDest[j] = &v } } return rs.Rows.Scan(newDest...) } + +// scan data to a slice's pointer, slice's length should equal to columns' number +func (rs *Rows) ScanSlice(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice { + return errors.New("dest should be a slice's pointer") + } + + vvv := vv.Elem() + cols, err := rs.Columns() + if err != nil { + return err + } + + newDest := make([]interface{}, len(cols)) + + for j := 0; j < len(cols); j++ { + if j >= vvv.Len() { + newDest[j] = reflect.New(vvv.Type().Elem()).Interface() + } else { + newDest[j] = vvv.Index(j).Addr().Interface() + } + } + + err = rs.Rows.Scan(newDest...) + if err != nil { + return err + } + + for i, _ := range cols { + vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem()) + } + return nil +} + +// scan data to a map's pointer +func (rs *Rows) ScanMap(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return errors.New("dest should be a map's pointer") + } + + cols, err := rs.Columns() + if err != nil { + return err + } + + newDest := make([]interface{}, len(cols)) + vvv := vv.Elem() + + for i, _ := range cols { + v := reflect.New(vvv.Type().Elem()) + newDest[i] = v.Interface() + } + + err = rs.Rows.Scan(newDest...) + if err != nil { + return err + } + + for i, name := range cols { + vname := reflect.ValueOf(name) + vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem()) + } + + return nil +} diff --git a/core/db_test.go b/core/db_test.go index 8894bd17..26f2e04d 100644 --- a/core/db_test.go +++ b/core/db_test.go @@ -2,7 +2,9 @@ package core import ( "fmt" + "os" "testing" + "time" _ "github.com/mattn/go-sqlite3" ) @@ -20,7 +22,8 @@ type User struct { NickName string } -func TestQuery(t *testing.T) { +func TestOriQuery(t *testing.T) { + os.Remove("./test.db") db, err := Open("sqlite3", "./test.db") if err != nil { t.Error(err) @@ -31,23 +34,197 @@ func TestQuery(t *testing.T) { t.Error(err) } - _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", - "xlw", "tester", 1.2, "lunny", "lunny xiao") + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao") + if err != nil { + t.Error(err) + } + } + + rows, err := db.Query("select * from user") if err != nil { t.Error(err) } + defer rows.Close() + + start := time.Now() + + for rows.Next() { + var Id int64 + var Name, Title, Alias, NickName string + var Age float32 + err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName) + if err != nil { + t.Error(err) + } + fmt.Println(Id, Name, Title, Age, Alias, NickName) + } + + fmt.Println("ori ------", time.Now().Sub(start), "ns") +} + +func TestStructQuery(t *testing.T) { + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + t.Error(err) + } + + _, err = db.Exec(createTableSqlite3) + if err != nil { + t.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao") + if err != nil { + t.Error(err) + } + } + + rows, err := db.Query("select * from user") + if err != nil { + t.Error(err) + } + defer rows.Close() + start := time.Now() + + for rows.Next() { + var user User + err = rows.ScanStruct(&user) + if err != nil { + t.Error(err) + } + fmt.Println(user) + } + fmt.Println("struct ------", time.Now().Sub(start)) +} + +func TestStruct2Query(t *testing.T) { + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + t.Error(err) + } + + _, err = db.Exec(createTableSqlite3) + if err != nil { + t.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao") + if err != nil { + t.Error(err) + } + } + + db.Mapper = &SnakeMapper{} + + rows, err := db.Query("select * from user") + if err != nil { + t.Error(err) + } + defer rows.Close() + start := time.Now() + + for rows.Next() { + var user User + err = rows.ScanStruct2(&user) + if err != nil { + t.Error(err) + } + fmt.Println(user) + } + fmt.Println("struct2 ------", time.Now().Sub(start)) +} + +func TestSliceQuery(t *testing.T) { + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + t.Error(err) + } + + _, err = db.Exec(createTableSqlite3) + if err != nil { + t.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao") + if err != nil { + t.Error(err) + } + } rows, err := db.Query("select * from user") if err != nil { t.Error(err) } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + t.Error(err) + } + + start := time.Now() + for rows.Next() { - var user User - err = rows.Scan(&user) + slice := make([]interface{}, len(cols)) + err = rows.ScanSlice(&slice) if err != nil { t.Error(err) } - fmt.Println(user) + fmt.Println(slice) } + + fmt.Println("slice ------", time.Now().Sub(start)) +} + +func TestMapQuery(t *testing.T) { + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + t.Error(err) + } + + _, err = db.Exec(createTableSqlite3) + if err != nil { + t.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao") + if err != nil { + t.Error(err) + } + } + + rows, err := db.Query("select * from user") + if err != nil { + t.Error(err) + } + + defer rows.Close() + + start := time.Now() + + for rows.Next() { + m := make(map[string]interface{}) + err = rows.ScanMap(&m) + if err != nil { + t.Error(err) + } + fmt.Println(m) + } + + fmt.Println("map ------", time.Now().Sub(start)) } diff --git a/core/dialect.go b/core/dialect.go index fd7a483e..ecba0844 100644 --- a/core/dialect.go +++ b/core/dialect.go @@ -107,7 +107,7 @@ func (b *Base) CreateTableSql(table *Table, tableName, storeEngine, charset stri if len(pkList) > 1 { sql += "PRIMARY KEY ( " - sql += strings.Join(pkList, ",") + sql += b.Quote(strings.Join(pkList, b.Quote(","))) sql += " ), " } diff --git a/mapper.go b/core/mapper.go similarity index 99% rename from mapper.go rename to core/mapper.go index 2e9c220a..0011dde5 100644 --- a/mapper.go +++ b/core/mapper.go @@ -1,4 +1,4 @@ -package xorm +package core import ( "strings" diff --git a/core/pk.go b/core/pk.go new file mode 100644 index 00000000..61d1371e --- /dev/null +++ b/core/pk.go @@ -0,0 +1,25 @@ +package core + +import ( + "encoding/json" +) + +type PK []interface{} + +func NewPK(pks ...interface{}) *PK { + p := PK(pks) + return &p +} + +func (p *PK) ToString() (string, error) { + bs, err := json.Marshal(*p) + if err != nil { + return "", nil + } + + return string(bs), nil +} + +func (p *PK) FromString(content string) error { + return json.Unmarshal([]byte(content), p) +} diff --git a/core/pk_test.go b/core/pk_test.go new file mode 100644 index 00000000..5245e574 --- /dev/null +++ b/core/pk_test.go @@ -0,0 +1,22 @@ +package core + +import ( + "fmt" + "testing" +) + +func TestPK(t *testing.T) { + p := NewPK(1, 3, "string") + str, err := p.ToString() + if err != nil { + t.Error(err) + } + fmt.Println(str) + + s := &PK{} + err = s.FromString(str) + if err != nil { + t.Error(err) + } + fmt.Println(s) +} diff --git a/dialects/postgres.go b/dialects/postgres.go index dbf9ca09..8a37d117 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -26,7 +26,7 @@ func (db *postgres) SqlType(c *Column) string { switch t := c.SQLType.Name; t { case TinyInt: res = SmallInt - + return res case MediumInt, Int, Integer: if c.IsAutoIncrement { return Serial diff --git a/engine.go b/engine.go index a6509828..dc3fd637 100644 --- a/engine.go +++ b/engine.go @@ -16,13 +16,11 @@ import ( "github.com/lunny/xorm/core" ) -type PK []interface{} - // Engine is the major struct of xorm, it means a database manager. // Commonly, an application only need one engine type Engine struct { - ColumnMapper IMapper - TableMapper IMapper + ColumnMapper core.IMapper + TableMapper core.IMapper TagIdentifier string DriverName string DataSourceName string @@ -37,20 +35,20 @@ type Engine struct { Pool IConnectPool Filters []core.Filter Logger io.Writer - Cacher Cacher - tableCachers map[reflect.Type]Cacher + Cacher core.Cacher + tableCachers map[reflect.Type]core.Cacher } -func (engine *Engine) SetMapper(mapper IMapper) { +func (engine *Engine) SetMapper(mapper core.IMapper) { engine.SetTableMapper(mapper) engine.SetColumnMapper(mapper) } -func (engine *Engine) SetTableMapper(mapper IMapper) { +func (engine *Engine) SetTableMapper(mapper core.IMapper) { engine.TableMapper = mapper } -func (engine *Engine) SetColumnMapper(mapper IMapper) { +func (engine *Engine) SetColumnMapper(mapper core.IMapper) { engine.ColumnMapper = mapper } @@ -100,7 +98,7 @@ func (engine *Engine) SetMaxIdleConns(conns int) { } // SetDefaltCacher set the default cacher. Xorm's default not enable cacher. -func (engine *Engine) SetDefaultCacher(cacher Cacher) { +func (engine *Engine) SetDefaultCacher(cacher core.Cacher) { engine.Cacher = cacher } @@ -119,7 +117,7 @@ func (engine *Engine) NoCascade() *Session { } // Set a table use a special cacher -func (engine *Engine) MapCacher(bean interface{}, cacher Cacher) { +func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) { t := rType(bean) engine.autoMapType(t) engine.tableCachers[t] = cacher @@ -409,7 +407,7 @@ func (engine *Engine) mapType(t reflect.Type) *core.Table { return mappingTable(t, engine.TableMapper, engine.ColumnMapper, engine.dialect, engine.TagIdentifier) } -func mappingTable(t reflect.Type, tableMapper IMapper, colMapper IMapper, dialect core.Dialect, tagId string) *core.Table { +func mappingTable(t reflect.Type, tableMapper core.IMapper, colMapper core.IMapper, dialect core.Dialect, tagId string) *core.Table { table := core.NewEmptyTable() table.Name = tableMapper.Obj2Table(t.Name()) table.Type = t @@ -517,6 +515,7 @@ func mappingTable(t reflect.Type, tableMapper IMapper, colMapper IMapper, dialec if col.Length2 == 0 { col.Length2 = col.SQLType.DefaultLength2 } + fmt.Println("======", col) if col.Name == "" { col.Name = colMapper.Obj2Table(t.Field(i).Name) } @@ -613,6 +612,24 @@ func (engine *Engine) IsTableExist(bean interface{}) (bool, error) { return has, err } +func (engine *Engine) IdOf(bean interface{}) core.PK { + table := engine.autoMap(bean) + v := reflect.Indirect(reflect.ValueOf(bean)) + pk := make([]interface{}, len(table.PrimaryKeys)) + for i, col := range table.PKColumns() { + pkField := v.FieldByName(col.FieldName) + switch pkField.Kind() { + case reflect.String: + pk[i] = pkField.String() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + pk[i] = pkField.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + pk[i] = pkField.Uint() + } + } + return core.PK(pk) +} + // create indexes func (engine *Engine) CreateIndexes(bean interface{}) error { session := engine.NewSession() @@ -627,7 +644,7 @@ func (engine *Engine) CreateUniques(bean interface{}) error { return session.CreateUniques(bean) } -func (engine *Engine) getCacher(t reflect.Type) Cacher { +func (engine *Engine) getCacher(t reflect.Type) core.Cacher { if cacher, ok := engine.tableCachers[t]; ok { return cacher } @@ -635,7 +652,7 @@ func (engine *Engine) getCacher(t reflect.Type) Cacher { } // If enabled cache, clear the cache bean -func (engine *Engine) ClearCacheBean(bean interface{}, id int64) error { +func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { t := rType(bean) if t.Kind() != reflect.Struct { return errors.New("error params") diff --git a/session.go b/session.go index a983237b..382370c7 100644 --- a/session.go +++ b/session.go @@ -588,6 +588,7 @@ func (statement *Statement) convertIdSql(sqlStr string) string { if len(sqls) != 2 { return "" } + fmt.Println("-----", col) newsql := fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()), statement.Engine.Quote(col.Name), sqls[1]) return newsql @@ -612,14 +613,14 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf cacher := session.Engine.getCacher(session.Statement.RefTable.Type) tableName := session.Statement.TableName() session.Engine.LogDebug("[xorm:cacheGet] find sql:", newsql, args) - ids, err := getCacheSql(cacher, tableName, newsql, args) + ids, err := core.GetCacheSql(cacher, tableName, newsql, args) if err != nil { resultsSlice, err := session.query(newsql, args...) if err != nil { return false, err } session.Engine.LogDebug("[xorm:cacheGet] query ids:", resultsSlice) - ids = make([]int64, 0) + ids = make([]core.PK, 0) if len(resultsSlice) > 0 { data := resultsSlice[0] var id int64 @@ -631,10 +632,10 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf return false, err } } - ids = append(ids, id) + ids = append(ids, core.PK{id}) } session.Engine.LogDebug("[xorm:cacheGet] cache ids:", newsql, ids) - err = putCacheSql(cacher, ids, tableName, newsql, args) + err = core.PutCacheSql(cacher, ids, tableName, newsql, args) if err != nil { return false, err } @@ -646,7 +647,11 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf structValue := reflect.Indirect(reflect.ValueOf(bean)) id := ids[0] session.Engine.LogDebug("[xorm:cacheGet] get bean:", tableName, id) - cacheBean := cacher.GetBean(tableName, id) + sid, err := id.ToString() + if err != nil { + return false, err + } + cacheBean := cacher.GetBean(tableName, sid) if cacheBean == nil { newSession := session.Engine.NewSession() defer newSession.Close() @@ -664,7 +669,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } session.Engine.LogDebug("[xorm:cacheGet] cache bean:", tableName, id, cacheBean) - cacher.PutBean(tableName, id, cacheBean) + cacher.PutBean(tableName, sid, cacheBean) } else { session.Engine.LogDebug("[xorm:cacheGet] cached bean:", tableName, id, cacheBean) has = true @@ -695,7 +700,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in table := session.Statement.RefTable cacher := session.Engine.getCacher(t) - ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args) + ids, err := core.GetCacheSql(cacher, session.Statement.TableName(), newsql, args) if err != nil { //session.Engine.LogError(err) resultsSlice, err := session.query(newsql, args...) @@ -709,7 +714,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in } tableName := session.Statement.TableName() - ids = make([]int64, 0) + ids = make([]core.PK, 0) if len(resultsSlice) > 0 { for _, data := range resultsSlice { //fmt.Println(data) @@ -722,11 +727,11 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in return err } } - ids = append(ids, id) + ids = append(ids, core.PK{id}) } } session.Engine.LogDebug("[xorm:cacheFind] cache ids:", ids, tableName, newsql, args) - err = putCacheSql(cacher, ids, tableName, newsql, args) + err = core.PutCacheSql(cacher, ids, tableName, newsql, args) if err != nil { return err } @@ -735,34 +740,32 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - pkFieldName := session.Statement.RefTable.PKColumns()[0].FieldName + //pkFieldName := session.Statement.RefTable.PKColumns()[0].FieldName - ididxes := make(map[int64]int) - var ides []interface{} = make([]interface{}, 0) + ididxes := make(map[string]int) + var ides []core.PK = make([]core.PK, 0) var temps []interface{} = make([]interface{}, len(ids)) tableName := session.Statement.TableName() for idx, id := range ids { - bean := cacher.GetBean(tableName, id) + sid, err := id.ToString() + if err != nil { + return err + } + bean := cacher.GetBean(tableName, sid) if bean == nil { ides = append(ides, id) - ididxes[id] = idx + ididxes[sid] = idx } else { session.Engine.LogDebug("[xorm:cacheFind] cached bean:", tableName, id, bean) - pkField := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName) - - var sid int64 - switch pkField.Type().Kind() { - case reflect.Int32, reflect.Int, reflect.Int64: - sid = pkField.Int() - case reflect.Uint, reflect.Uint32, reflect.Uint64: - sid = int64(pkField.Uint()) - default: - return ErrCacheFailed + pk := session.Engine.IdOf(bean) + xid, err := pk.ToString() + if err != nil { + return err } - if sid != id { - session.Engine.LogError("[xorm:cacheFind] error cache", id, sid, bean) + if sid != xid { + session.Engine.LogError("[xorm:cacheFind] error cache", xid, sid, bean) return ErrCacheFailed } temps[idx] = bean @@ -777,7 +780,19 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in beans := slices.Interface() //beans := reflect.New(sliceValue.Type()).Interface() //err = newSession.In("(id)", ides...).OrderBy(session.Statement.OrderStr).NoCache().Find(beans) - err = newSession.In("(id)", ides...).NoCache().Find(beans) + ff := make([][]interface{}, len(table.PrimaryKeys)) + for i, _ := range table.PrimaryKeys { + ff[i] = make([]interface{}, 0) + } + for _, ie := range ides { + for i, _ := range table.PrimaryKeys { + ff[i] = append(ff[i], ie[i]) + } + } + for i, name := range table.PrimaryKeys { + newSession.In(name, ff[i]...) + } + err = newSession.NoCache().Find(beans) if err != nil { return err } @@ -789,12 +804,16 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in rv = rv.Addr() } bean := rv.Interface() - id := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName).Int() + id := session.Engine.IdOf(bean) + sid, err := id.ToString() + if err != nil { + return err + } //bean := vs.Index(i).Addr().Interface() - temps[ididxes[id]] = bean + temps[ididxes[sid]] = bean //temps[idxes[i]] = bean session.Engine.LogDebug("[xorm:cacheFind] cache bean:", tableName, id, bean) - cacher.PutBean(tableName, id, bean) + cacher.PutBean(tableName, sid, bean) } } @@ -811,16 +830,21 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) } } else if sliceValue.Kind() == reflect.Map { - var key int64 + var key core.PK if table.PrimaryKeys[0] != "" { key = ids[j] - } else { - key = int64(j) } - if t.Kind() == reflect.Ptr { - sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(bean)) - } else { - sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(bean))) + + if len(key) == 1 { + ikey, err := strconv.ParseInt(fmt.Sprintf("%v", key[0]), 10, 64) + if err != nil { + return err + } + if t.Kind() == reflect.Ptr { + sliceValue.SetMapIndex(reflect.ValueOf(ikey), reflect.ValueOf(bean)) + } else { + sliceValue.SetMapIndex(reflect.ValueOf(ikey), reflect.Indirect(reflect.ValueOf(bean))) + } } } /*} else { @@ -2762,7 +2786,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { cacher := session.Engine.getCacher(table.Type) tableName := session.Statement.TableName() session.Engine.LogDebug("[xorm:cacheUpdate] get cache sql", newsql, args[nStart:]) - ids, err := getCacheSql(cacher, tableName, newsql, args[nStart:]) + ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) if err != nil { resultsSlice, err := session.query(newsql, args[nStart:]...) if err != nil { @@ -2770,7 +2794,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { } session.Engine.LogDebug("[xorm:cacheUpdate] find updated id", resultsSlice) - ids = make([]int64, 0) + ids = make([]core.PK, 0) if len(resultsSlice) > 0 { for _, data := range resultsSlice { var id int64 @@ -2782,7 +2806,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { return err } } - ids = append(ids, id) + ids = append(ids, core.PK{id}) } } } /*else { @@ -2791,7 +2815,11 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { }*/ for _, id := range ids { - if bean := cacher.GetBean(tableName, id); bean != nil { + sid, err := id.ToString() + if err != nil { + return err + } + if bean := cacher.GetBean(tableName, sid); bean != nil { sqls := splitNNoCase(sqlStr, "where", 2) if len(sqls) == 0 || len(sqls) > 2 { return ErrCacheFailed @@ -2834,7 +2862,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { } session.Engine.LogDebug("[xorm:cacheUpdate] update cache", tableName, id, bean) - cacher.PutBean(tableName, id, bean) + cacher.PutBean(tableName, sid, bean) } } session.Engine.LogDebug("[xorm:cacheUpdate] clear cached table sql:", tableName) @@ -3047,13 +3075,13 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { cacher := session.Engine.getCacher(session.Statement.RefTable.Type) tableName := session.Statement.TableName() - ids, err := getCacheSql(cacher, tableName, newsql, args) + ids, err := core.GetCacheSql(cacher, tableName, newsql, args) if err != nil { resultsSlice, err := session.query(newsql, args...) if err != nil { return err } - ids = make([]int64, 0) + ids = make([]core.PK, 0) if len(resultsSlice) > 0 { for _, data := range resultsSlice { var id int64 @@ -3065,7 +3093,7 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { return err } } - ids = append(ids, id) + ids = append(ids, core.PK{id}) } } } /*else { @@ -3075,7 +3103,11 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { for _, id := range ids { session.Engine.LogDebug("[xorm:cacheDelete] delete cache obj", tableName, id) - cacher.DelBean(tableName, id) + sid, err := id.ToString() + if err != nil { + return err + } + cacher.DelBean(tableName, sid) } session.Engine.LogDebug("[xorm:cacheDelete] clear cache table", tableName) cacher.ClearIds(tableName) diff --git a/statement.go b/statement.go index 8b5ff430..23cb0591 100644 --- a/statement.go +++ b/statement.go @@ -24,7 +24,7 @@ type Statement struct { Start int LimitN int WhereStr string - IdParam *PK + IdParam *core.PK Params []interface{} OrderStr string JoinStr string @@ -421,24 +421,28 @@ func (statement *Statement) TableName() string { return "" } +var ( + ptrPkType = reflect.TypeOf(&core.PK{}) + pkType = reflect.TypeOf(core.PK{}) +) + // Generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?" func (statement *Statement) Id(id interface{}) *Statement { - idValue := reflect.ValueOf(id) idType := reflect.TypeOf(idValue.Interface()) switch idType { - case reflect.TypeOf(&PK{}): - if pkPtr, ok := (id).(*PK); ok { + case ptrPkType: + if pkPtr, ok := (id).(*core.PK); ok { statement.IdParam = pkPtr } - case reflect.TypeOf(PK{}): - if pk, ok := (id).(PK); ok { + case pkType: + if pk, ok := (id).(core.PK); ok { statement.IdParam = &pk } default: // TODO treat as int primitve for now, need to handle type check - statement.IdParam = &PK{id} + statement.IdParam = &core.PK{id} // !nashtsai! REVIEW although it will be user's mistake if called Id() twice with // different value and Id should be PK's field name, however, at this stage probably @@ -789,31 +793,12 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { } func (statement *Statement) processIdParam() { - if statement.IdParam != nil { - i := 0 - columns := statement.RefTable.ColumnsSeq() - colCnt := len(columns) - for _, elem := range *(statement.IdParam) { - for ; i < colCnt; i++ { - colName := columns[i] - col := statement.RefTable.GetColumn(colName) - if col.IsPrimaryKey { - statement.And(fmt.Sprintf("%v=?", col.Name), elem) - i++ - break - } - } - } - - // !nashtsai! REVIEW what if statement.IdParam has insufficient pk item? handle it - // as empty string for now, so this will result sql exec failed instead of unexpected - // false update/delete - for ; i < colCnt; i++ { - colName := columns[i] - col := statement.RefTable.GetColumn(colName) - if col.IsPrimaryKey { - statement.And(fmt.Sprintf("%v=?", col.Name), "") + for i, col := range statement.RefTable.PKColumns() { + if i < len(*(statement.IdParam)) { + statement.And(fmt.Sprintf("%v=?", statement.Engine.Quote(col.Name)), (*(statement.IdParam))[i]) + } else { + statement.And(fmt.Sprintf("%v=?", statement.Engine.Quote(col.Name)), "") } } } diff --git a/tests/base_test.go b/tests/base_test.go index cf12066a..26d6cf9c 100644 --- a/tests/base_test.go +++ b/tests/base_test.go @@ -1,4 +1,4 @@ -package xorm +package tests import ( "errors" @@ -755,7 +755,7 @@ func orderSameMapper(engine *xorm.Engine, t *testing.T) { func joinSameMapper(engine *xorm.Engine, t *testing.T) { users := make([]Userinfo, 0) - err := engine.Join("LEFT", "`Userdetail`", "`Userinfo`.`(id)`=`Userdetail`.`(id)`").Find(&users) + err := engine.Join("LEFT", "`Userdetail`", "`Userinfo`.`(id)`=`Userdetail`.`Id`").Find(&users) if err != nil { t.Error(err) panic(err) @@ -1080,7 +1080,7 @@ func testCols(engine *xorm.Engine, t *testing.T) { func testColsSameMapper(engine *xorm.Engine, t *testing.T) { users := []Userinfo{} - err := engine.Cols("(id), Username").Find(&users) + err := engine.Cols("id, Username").Find(&users) if err != nil { t.Error(err) panic(err) @@ -1089,7 +1089,8 @@ func testColsSameMapper(engine *xorm.Engine, t *testing.T) { fmt.Println(users) tmpUsers := []tempUser{} - err = engine.Table("Userinfo").Cols("(id), Username").Find(&tmpUsers) + // TODO: should use cache + err = engine.NoCache().Table("Userinfo").Cols("id, Username").Find(&tmpUsers) if err != nil { t.Error(err) panic(err) @@ -2062,7 +2063,8 @@ func testVersion(engine *xorm.Engine, t *testing.T) { func testDistinct(engine *xorm.Engine, t *testing.T) { users := make([]Userinfo, 0) - err := engine.Distinct("departname").Find(&users) + departname := engine.TableMapper.Obj2Table("Departname") + err := engine.Distinct(departname).Find(&users) if err != nil { t.Error(err) panic(err) @@ -2079,7 +2081,7 @@ func testDistinct(engine *xorm.Engine, t *testing.T) { } users2 := make([]Depart, 0) - err = engine.Distinct("departname").Table(new(Userinfo)).Find(&users2) + err = engine.Distinct(departname).Table(new(Userinfo)).Find(&users2) if err != nil { t.Error(err) panic(err) @@ -2226,7 +2228,7 @@ func testPrefixTableName(engine *xorm.Engine, t *testing.T) { panic(err) } tempEngine.ShowSQL = true - mapper := xorm.NewPrefixMapper(xorm.SnakeMapper{}, "xlw_") + mapper := core.NewPrefixMapper(core.SnakeMapper{}, "xlw_") //tempEngine.SetMapper(mapper) tempEngine.SetTableMapper(mapper) exist, err := tempEngine.IsTableExist(&Userinfo{}) @@ -3738,7 +3740,7 @@ func testCompositeKey(engine *xorm.Engine, t *testing.T) { } var compositeKeyVal CompositeKey - has, err := engine.Id(xorm.PK{11, 22}).Get(&compositeKeyVal) + has, err := engine.Id(core.PK{11, 22}).Get(&compositeKeyVal) if err != nil { t.Error(err) } else if !has { @@ -3746,7 +3748,7 @@ func testCompositeKey(engine *xorm.Engine, t *testing.T) { } // test passing PK ptr, this test seem failed withCache - has, err = engine.Id(&xorm.PK{11, 22}).Get(&compositeKeyVal) + has, err = engine.Id(&core.PK{11, 22}).Get(&compositeKeyVal) if err != nil { t.Error(err) } else if !has { @@ -3754,14 +3756,14 @@ func testCompositeKey(engine *xorm.Engine, t *testing.T) { } compositeKeyVal = CompositeKey{UpdateStr: "test1"} - cnt, err = engine.Id(xorm.PK{11, 22}).Update(&compositeKeyVal) + cnt, err = engine.Id(core.PK{11, 22}).Update(&compositeKeyVal) if err != nil { t.Error(err) } else if cnt != 1 { t.Error(errors.New("can't update CompositeKey{11, 22}")) } - cnt, err = engine.Id(xorm.PK{11, 22}).Delete(&CompositeKey{}) + cnt, err = engine.Id(core.PK{11, 22}).Delete(&CompositeKey{}) if err != nil { t.Error(err) } else if cnt != 1 { @@ -3803,7 +3805,7 @@ func testCompositeKey2(engine *xorm.Engine, t *testing.T) { } var user User - has, err := engine.Id(xorm.PK{"11", 22}).Get(&user) + has, err := engine.Id(core.PK{"11", 22}).Get(&user) if err != nil { t.Error(err) } else if !has { @@ -3811,7 +3813,7 @@ func testCompositeKey2(engine *xorm.Engine, t *testing.T) { } // test passing PK ptr, this test seem failed withCache - has, err = engine.Id(&xorm.PK{"11", 22}).Get(&user) + has, err = engine.Id(&core.PK{"11", 22}).Get(&user) if err != nil { t.Error(err) } else if !has { @@ -3819,14 +3821,14 @@ func testCompositeKey2(engine *xorm.Engine, t *testing.T) { } user = User{NickName: "test1"} - cnt, err = engine.Id(xorm.PK{"11", 22}).Update(&user) + cnt, err = engine.Id(core.PK{"11", 22}).Update(&user) if err != nil { t.Error(err) } else if cnt != 1 { t.Error(errors.New("can't update User{11, 22}")) } - cnt, err = engine.Id(xorm.PK{"11", 22}).Delete(&User{}) + cnt, err = engine.Id(core.PK{"11", 22}).Delete(&User{}) if err != nil { t.Error(err) } else if cnt != 1 { @@ -3870,16 +3872,12 @@ func testAll(engine *xorm.Engine, t *testing.T) { } func testAll2(engine *xorm.Engine, t *testing.T) { - fmt.Println("-------------- combineTransaction --------------") - combineTransaction(engine, t) fmt.Println("-------------- table --------------") table(engine, t) fmt.Println("-------------- createMultiTables --------------") createMultiTables(engine, t) fmt.Println("-------------- tableOp --------------") tableOp(engine, t) - fmt.Println("-------------- testCols --------------") - testCols(engine, t) fmt.Println("-------------- testCharst --------------") testCharst(engine, t) fmt.Println("-------------- testStoreEngine --------------") @@ -3961,6 +3959,10 @@ func testAllSnakeMapper(engine *xorm.Engine, t *testing.T) { join(engine, t) fmt.Println("-------------- having --------------") having(engine, t) + fmt.Println("-------------- combineTransaction --------------") + combineTransaction(engine, t) + fmt.Println("-------------- testCols --------------") + testCols(engine, t) } func testAllSameMapper(engine *xorm.Engine, t *testing.T) { @@ -3976,4 +3978,8 @@ func testAllSameMapper(engine *xorm.Engine, t *testing.T) { joinSameMapper(engine, t) fmt.Println("-------------- having --------------") havingSameMapper(engine, t) + fmt.Println("-------------- combineTransaction --------------") + combineTransactionSameMapper(engine, t) + fmt.Println("-------------- testCols --------------") + testColsSameMapper(engine, t) } diff --git a/tests/benchmark_base_test.go b/tests/benchmark_base_test.go index 5ea598e5..5570fa39 100644 --- a/tests/benchmark_base_test.go +++ b/tests/benchmark_base_test.go @@ -1,4 +1,4 @@ -package xorm +package tests import ( "database/sql" diff --git a/tests/mssql_test.go b/tests/mssql_test.go index 101d07ae..e814d855 100644 --- a/tests/mssql_test.go +++ b/tests/mssql_test.go @@ -1,4 +1,4 @@ -package xorm +package tests // // +build windows @@ -9,6 +9,7 @@ import ( _ "github.com/lunny/godbc" "github.com/lunny/xorm" + "github.com/lunny/xorm/caches" ) const mssqlConnStr = "driver={SQL Server};Server=192.168.20.135;Database=xorm_test; uid=sa; pwd=1234;" @@ -40,7 +41,7 @@ func TestMssqlWithCache(t *testing.T) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) engine.ShowSQL = showTestSql engine.ShowErr = showTestSql engine.ShowWarn = showTestSql @@ -113,7 +114,7 @@ func BenchmarkMssqlCacheInsert(t *testing.B) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchInsert(engine, t) } @@ -125,7 +126,7 @@ func BenchmarkMssqlCacheFind(t *testing.B) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchFind(engine, t) } @@ -137,7 +138,7 @@ func BenchmarkMssqlCacheFindPtr(t *testing.B) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchFindPtr(engine, t) } diff --git a/tests/mymysql_test.go b/tests/mymysql_test.go index c64261dc..5646f95c 100644 --- a/tests/mymysql_test.go +++ b/tests/mymysql_test.go @@ -1,10 +1,11 @@ -package xorm +package tests import ( "database/sql" "testing" "github.com/lunny/xorm" + "github.com/lunny/xorm/caches" _ "github.com/ziutek/mymysql/godrv" ) @@ -49,7 +50,7 @@ func TestMyMysqlWithCache(t *testing.T) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) engine.ShowSQL = showTestSql engine.ShowErr = showTestSql engine.ShowWarn = showTestSql @@ -136,7 +137,7 @@ func BenchmarkMyMysqlCacheInsert(t *testing.B) { } defer engine.Close() - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchInsert(engine, t) } @@ -149,7 +150,7 @@ func BenchmarkMyMysqlCacheFind(t *testing.B) { } defer engine.Close() - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchFind(engine, t) } @@ -162,7 +163,7 @@ func BenchmarkMyMysqlCacheFindPtr(t *testing.B) { } defer engine.Close() - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchFindPtr(engine, t) } diff --git a/tests/mysql_test.go b/tests/mysql_test.go index f42be512..38eebc55 100644 --- a/tests/mysql_test.go +++ b/tests/mysql_test.go @@ -1,4 +1,4 @@ -package xorm +package tests import ( "database/sql" @@ -6,6 +6,8 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/lunny/xorm" + "github.com/lunny/xorm/caches" + "github.com/lunny/xorm/core" ) /* @@ -54,7 +56,7 @@ func TestMysqlSameMapper(t *testing.T) { engine.ShowErr = showTestSql engine.ShowWarn = showTestSql engine.ShowDebug = showTestSql - engine.SetMapper(xorm.SameMapper{}) + engine.SetMapper(core.SameMapper{}) testAll(engine, t) testAllSameMapper(engine, t) @@ -75,7 +77,7 @@ func TestMysqlWithCache(t *testing.T) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) engine.ShowSQL = showTestSql engine.ShowErr = showTestSql engine.ShowWarn = showTestSql @@ -99,8 +101,8 @@ func TestMysqlWithCacheSameMapper(t *testing.T) { t.Error(err) return } - engine.SetMapper(xorm.SameMapper{}) - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetMapper(core.SameMapper{}) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) engine.ShowSQL = showTestSql engine.ShowErr = showTestSql engine.ShowWarn = showTestSql @@ -190,7 +192,7 @@ func BenchmarkMysqlCacheInsert(t *testing.B) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchInsert(engine, t) } @@ -202,7 +204,7 @@ func BenchmarkMysqlCacheFind(t *testing.B) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchFind(engine, t) } @@ -214,7 +216,7 @@ func BenchmarkMysqlCacheFindPtr(t *testing.B) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchFindPtr(engine, t) } diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 9c2ee2d8..85e926cc 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -1,4 +1,4 @@ -package xorm +package tests import ( "database/sql" @@ -6,6 +6,8 @@ import ( _ "github.com/lib/pq" "github.com/lunny/xorm" + "github.com/lunny/xorm/caches" + "github.com/lunny/xorm/core" ) //var connStr string = "dbname=xorm_test user=lunny password=1234 sslmode=disable" @@ -59,7 +61,7 @@ func TestPostgresWithCache(t *testing.T) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) defer engine.Close() engine.ShowSQL = showTestSql engine.ShowErr = showTestSql @@ -78,7 +80,7 @@ func TestPostgresSameMapper(t *testing.T) { return } defer engine.Close() - engine.SetMapper(xorm.SameMapper{}) + engine.SetMapper(core.SameMapper{}) engine.ShowSQL = showTestSql engine.ShowErr = showTestSql engine.ShowWarn = showTestSql @@ -96,9 +98,9 @@ func TestPostgresWithCacheSameMapper(t *testing.T) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) defer engine.Close() - engine.SetMapper(xorm.SameMapper{}) + engine.SetMapper(core.SameMapper{}) engine.ShowSQL = showTestSql engine.ShowErr = showTestSql engine.ShowWarn = showTestSql @@ -168,7 +170,7 @@ func BenchmarkPostgresCacheInsert(t *testing.B) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchInsert(engine, t) } @@ -181,7 +183,7 @@ func BenchmarkPostgresCacheFind(t *testing.B) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchFind(engine, t) } @@ -194,7 +196,7 @@ func BenchmarkPostgresCacheFindPtr(t *testing.B) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchFindPtr(engine, t) } diff --git a/tests/sqlite3_test.go b/tests/sqlite3_test.go index aa10af64..1f982e29 100644 --- a/tests/sqlite3_test.go +++ b/tests/sqlite3_test.go @@ -1,4 +1,4 @@ -package xorm +package tests import ( "database/sql" @@ -6,6 +6,8 @@ import ( "testing" "github.com/lunny/xorm" + "github.com/lunny/xorm/caches" + "github.com/lunny/xorm/core" _ "github.com/mattn/go-sqlite3" ) @@ -44,7 +46,7 @@ func TestSqlite3WithCache(t *testing.T) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) engine.ShowSQL = showTestSql engine.ShowErr = showTestSql engine.ShowWarn = showTestSql @@ -62,7 +64,7 @@ func TestSqlite3SameMapper(t *testing.T) { t.Error(err) return } - engine.SetMapper(xorm.SameMapper{}) + engine.SetMapper(core.SameMapper{}) engine.ShowSQL = showTestSql engine.ShowErr = showTestSql engine.ShowWarn = showTestSql @@ -81,8 +83,8 @@ func TestSqlite3WithCacheSameMapper(t *testing.T) { t.Error(err) return } - engine.SetMapper(xorm.SameMapper{}) - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetMapper(core.SameMapper{}) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) engine.ShowSQL = showTestSql engine.ShowErr = showTestSql engine.ShowWarn = showTestSql @@ -152,7 +154,7 @@ func BenchmarkSqlite3CacheInsert(t *testing.B) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchInsert(engine, t) } @@ -164,7 +166,7 @@ func BenchmarkSqlite3CacheFind(t *testing.B) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchFind(engine, t) } @@ -176,6 +178,6 @@ func BenchmarkSqlite3CacheFindPtr(t *testing.B) { t.Error(err) return } - engine.SetDefaultCacher(xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)) + engine.SetDefaultCacher(xorm.NewLRUCacher(caches.NewMemoryStore(), 1000)) doBenchFindPtr(engine, t) } diff --git a/tests/testdata/mysql_ddl.sql b/tests/testdata/mysql_ddl.sql index db20aa33..b1fcca5b 100644 --- a/tests/testdata/mysql_ddl.sql +++ b/tests/testdata/mysql_ddl.sql @@ -1,5 +1,8 @@ ---DROP DATABASE xorm_test; ---DROP DATABASE xorm_test2; +DROP DATABASE xorm_test; +DROP DATABASE xorm_test1; +DROP DATABASE xorm_test2; +DROP DATABASE xorm_test3; CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET utf8 COLLATE utf8_general_ci; +CREATE DATABASE IF NOT EXISTS xorm_test1 CHARACTER SET utf8 COLLATE utf8_general_ci; CREATE DATABASE IF NOT EXISTS xorm_test2 CHARACTER SET utf8 COLLATE utf8_general_ci; CREATE DATABASE IF NOT EXISTS xorm_test3 CHARACTER SET utf8 COLLATE utf8_general_ci; diff --git a/xorm.go b/xorm.go index 5826db60..7ea51473 100644 --- a/xorm.go +++ b/xorm.go @@ -7,7 +7,9 @@ import ( "reflect" "runtime" "sync" + "time" + "github.com/lunny/xorm/caches" "github.com/lunny/xorm/core" _ "github.com/lunny/xorm/dialects" _ "github.com/lunny/xorm/drivers" @@ -46,9 +48,9 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { engine := &Engine{DriverName: driverName, DataSourceName: dataSourceName, dialect: dialect, - tableCachers: make(map[reflect.Type]Cacher)} + tableCachers: make(map[reflect.Type]core.Cacher)} - engine.SetMapper(SnakeMapper{}) + engine.SetMapper(core.SnakeMapper{}) engine.Filters = dialect.Filters() @@ -65,3 +67,11 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { runtime.SetFinalizer(engine, close) return engine, err } + +func NewLRUCacher(store core.CacheStore, max int) *caches.LRUCacher { + return caches.NewLRUCacher(store, core.CacheExpired, core.CacheMaxMemory, max) +} + +func NewLRUCacher2(store core.CacheStore, expired time.Duration, max int) *caches.LRUCacher { + return caches.NewLRUCacher(store, expired, 0, max) +}