From 1ce77128c161b651edeb9a68f3daa76d2bb1704d Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 23 Sep 2013 22:31:51 +0800 Subject: [PATCH] bug fixed --- cache.go | 184 +++++++++++++++++++++++++++++++++++++---------------- session.go | 59 ++++++++++++----- 2 files changed, 173 insertions(+), 70 deletions(-) diff --git a/cache.go b/cache.go index ea2abb62..e2111acc 100644 --- a/cache.go +++ b/cache.go @@ -54,58 +54,154 @@ func (s *MemoryStore) Del(key interface{}) error { } type Cacher interface { - Get(id interface{}) interface{} - Put(id, obj interface{}) - Del(id interface{}) + GetIds(tableName, sql string) interface{} + GetBean(tableName string, id int64) interface{} + PutIds(tableName, sql string, ids interface{}) + PutBean(tableName string, id int64, obj interface{}) + DelIds(tableName, sql string) + DelBean(tableName string, id int64) + ClearIds(tableName string) } // LRUCacher implements Cacher according to LRU algorithm type LRUCacher struct { - name string - list *list.List - index map[interface{}]*list.Element - store CacheStore - Max int - mutex sync.RWMutex + idList *list.List + sqlList *list.List + idIndex map[interface{}]*list.Element + sqlIndex map[string]map[interface{}]*list.Element + store CacheStore + Max int + mutex sync.Mutex } func NewLRUCacher(store CacheStore, max int) *LRUCacher { - return &LRUCacher{store: store, list: list.New(), - index: make(map[interface{}]*list.Element), Max: max} + cacher := &LRUCacher{store: store, idList: list.New(), + sqlList: list.New(), idIndex: make(map[interface{}]*list.Element), + Max: max} + cacher.sqlIndex = make(map[string]map[interface{}]*list.Element) + return cacher } -func (m *LRUCacher) Get(id interface{}) interface{} { - m.mutex.RLock() - defer m.mutex.RUnlock() - if v, err := m.store.Get(id); err == nil { - el := m.index[id] - m.list.MoveToBack(el) +func (m *LRUCacher) GetIds(tableName, sql string) interface{} { + m.mutex.Lock() + defer m.mutex.Unlock() + if v, err := m.store.Get(sql); err == nil { + if _, ok := m.sqlIndex[tableName]; !ok { + m.sqlIndex[tableName] = make(map[interface{}]*list.Element) + } + if el, ok := m.sqlIndex[tableName][sql]; !ok { + el = m.sqlList.PushBack(sql) + m.sqlIndex[tableName][sql] = el + } else { + m.sqlList.MoveToBack(el) + } return v } + if tel, ok := m.sqlIndex[tableName]; ok { + if el, ok := tel[sql]; ok { + delete(m.sqlIndex[tableName], sql) + m.sqlList.Remove(el) + } + } return nil } -func (m *LRUCacher) Put(id interface{}, obj interface{}) { +func (m *LRUCacher) GetBean(tableName string, id int64) interface{} { m.mutex.Lock() defer m.mutex.Unlock() - el := m.list.PushBack(id) - m.index[id] = el - m.store.Put(id, obj) - if m.list.Len() > m.Max { - e := m.list.Front() + tid := genId(tableName, id) + if v, err := m.store.Get(tid); err == nil { + if el, ok := m.idIndex[tid]; ok { + m.idList.MoveToBack(el) + } else { + el = m.idList.PushBack(tid) + m.idIndex[tid] = el + } + return v + } + if el, ok := m.idIndex[tid]; ok { + delete(m.idIndex, tid) + m.idList.Remove(el) + if ms, ok := m.sqlIndex[tableName]; ok { + for _, v := range ms { + m.sqlList.Remove(v) + } + m.sqlIndex[tableName] = make(map[interface{}]*list.Element) + } + } + return nil +} + +func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) { + m.mutex.Lock() + defer m.mutex.Unlock() + if _, ok := m.sqlIndex[tableName]; !ok { + m.sqlIndex[tableName] = make(map[interface{}]*list.Element) + } + if el, ok := m.sqlIndex[tableName][sql]; !ok { + el = m.sqlList.PushBack(sql) + m.sqlIndex[tableName][sql] = el + } + m.store.Put(sql, ids) +} + +func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) { + m.mutex.Lock() + defer m.mutex.Unlock() + var el *list.Element + var ok bool + tid := genId(tableName, id) + if el, ok = m.idIndex[tid]; !ok { + el = m.idList.PushBack(tid) + m.idIndex[tid] = el + } + + m.store.Put(tid, obj) + if m.idList.Len() > m.Max { + e := m.idList.Front() m.store.Del(e.Value) - delete(m.index, e.Value) - m.list.Remove(e) + delete(m.idIndex, e.Value) + m.idList.Remove(e) } } -func (m *LRUCacher) Del(id interface{}) { +func (m *LRUCacher) DelIds(tableName, sql string) { m.mutex.Lock() defer m.mutex.Unlock() - if el, ok := m.index[id]; ok { - m.store.Del(id) - delete(m.index, el.Value) - m.list.Remove(el) + if _, ok := m.sqlIndex[tableName]; ok { + if el, ok := m.sqlIndex[tableName][sql]; ok { + m.store.Del(sql) + delete(m.sqlIndex, sql) + m.sqlList.Remove(el) + } + } +} + +func (m *LRUCacher) DelBean(tableName string, id int64) { + m.mutex.Lock() + defer m.mutex.Unlock() + tid := genId(tableName, id) + if el, ok := m.idIndex[tid]; ok { + m.store.Del(tid) + delete(m.idIndex, tid) + m.idList.Remove(el) + if tis, ok := m.sqlIndex[tableName]; ok { + for _, v := range tis { + m.sqlList.Remove(v) + } + m.sqlIndex[tableName] = make(map[interface{}]*list.Element) + } + } +} + +func (m *LRUCacher) ClearIds(tableName string) { + m.mutex.Lock() + defer m.mutex.Unlock() + if tis, ok := m.sqlIndex[tableName]; ok { + for _, v := range tis { + m.sqlList.Remove(v) + } + m.sqlIndex[tableName] = make(map[interface{}]*list.Element) } } @@ -133,8 +229,8 @@ func decodeIds(s string) []int64 { return res } -func getCacheSql(m Cacher, sql string, args interface{}) ([]int64, error) { - bytes := m.Get(genSqlKey(sql, args)) +func getCacheSql(m Cacher, tableName, sql string, args interface{}) ([]int64, error) { + bytes := m.GetIds(tableName, genSqlKey(sql, args)) if bytes == nil { return nil, errors.New("Not Exist") } @@ -142,14 +238,9 @@ func getCacheSql(m Cacher, sql string, args interface{}) ([]int64, error) { return objs, nil } -func putCacheSql(m Cacher, ids []int64, sql string, args interface{}) error { +func putCacheSql(m Cacher, ids []int64, tableName, sql string, args interface{}) error { bytes := encodeIds(ids) - m.Put(genSqlKey(sql, args), bytes) - return nil -} - -func delCacheSql(m Cacher, sql string, args interface{}) error { - m.Del(genSqlKey(sql, args)) + m.PutIds(tableName, genSqlKey(sql, args), bytes) return nil } @@ -160,18 +251,3 @@ func genSqlKey(sql string, args interface{}) string { func genId(prefix string, id int64) string { return fmt.Sprintf("%v-%v", prefix, id) } - -func getCacheId(m Cacher, prefix string, id int64) interface{} { - return m.Get(genId(prefix, id)) -} - -func putCacheId(m Cacher, prefix string, id int64, bean interface{}) error { - m.Put(genId(prefix, id), bean) - return nil -} - -func delCacheId(m Cacher, prefix string, id int64) error { - m.Del(genId(prefix, id)) - //TODO: should delete id from select - return nil -} diff --git a/session.go b/session.go index 35fe0bae..05e83ceb 100644 --- a/session.go +++ b/session.go @@ -391,7 +391,7 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface } cacher := session.Statement.RefTable.Cacher - ids, err := getCacheSql(cacher, newsql, args) + ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args) if err != nil { //fmt.Println(err) resultsSlice, err := session.query(newsql, args...) @@ -414,7 +414,7 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface } ids = append(ids, id) } - err = putCacheSql(cacher, ids, newsql, args) + err = putCacheSql(cacher, ids, session.Statement.TableName(), newsql, args) if err != nil { //fmt.Println(err) return false, err @@ -429,7 +429,7 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface structValue := reflect.Indirect(reflect.ValueOf(bean)) id := ids[0] tableName := session.Statement.TableName() - cacheBean := getCacheId(cacher, tableName, id) + cacheBean := cacher.GetBean(tableName, id) if cacheBean == nil { //fmt.Printf("----Object Id %v no cached.\n", id) newSession := session.Engine.NewSession() @@ -440,7 +440,7 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface return has, err } //fmt.Println(bean) - putCacheId(cacher, tableName, id, cacheBean) + cacher.PutBean(tableName, id, cacheBean) } else { //fmt.Printf("-----Cached Object: %v\n", cacheBean) has = true @@ -470,7 +470,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter table := session.Statement.RefTable cacher := table.Cacher - ids, err := getCacheSql(cacher, newsql, args) + ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args) if err != nil { session.Engine.LogError(err) resultsSlice, err := session.query(newsql, args...) @@ -497,7 +497,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter ids = append(ids, id) } } - err = putCacheSql(cacher, ids, newsql, args) + err = putCacheSql(cacher, ids, session.Statement.TableName(), newsql, args) if err != nil { //fmt.Println(err) return err @@ -513,7 +513,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter var temps []interface{} = make([]interface{}, len(ids)) tableName := session.Statement.TableName() for idx, id := range ids { - bean := getCacheId(cacher, tableName, id) + bean := cacher.GetBean(tableName, id) if bean == nil { //fmt.Printf("----Object Id %v no cached.\n", id) idxes = append(idxes, idx) @@ -538,13 +538,17 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter for i := 0; i < vs.Len(); i++ { bean := vs.Index(i).Addr().Interface() temps[idxes[i]] = bean - putCacheId(cacher, tableName, ides[i].(int64), bean) + cacher.PutBean(tableName, ides[i].(int64), bean) } } for j := 0; j < len(temps); j++ { bean := temps[j] - sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) + if bean != nil { + sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) + } else { + cacher.DelBean(tableName, ides[j].(int64)) + } } return nil @@ -979,6 +983,10 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return -1, err } + if table.Cacher != nil && session.Statement.UseCache { + session.cacheInsert(session.Statement.TableName()) + } + id, err := res.LastInsertId() if err != nil { @@ -1252,6 +1260,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 0, err } + if table.Cacher != nil && session.Statement.UseCache { + session.cacheInsert(session.Statement.TableName()) + } + if table.PrimaryKey == "" { return 0, nil } @@ -1306,6 +1318,21 @@ func (statement *Statement) convertUpdateSql(sql string) (string, string) { sqls[1]) } +func (session *Session) cacheInsert(tables ...string) error { + if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + return ErrCacheFailed + } + + table := session.Statement.RefTable + cacher := table.Cacher + + for _, t := range tables { + cacher.ClearIds(t) + } + + return nil +} + func (session *Session) cacheUpdate(sql string, args ...interface{}) error { if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { return ErrCacheFailed @@ -1331,7 +1358,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { } table := session.Statement.RefTable cacher := table.Cacher - ids, err := getCacheSql(cacher, newsql, args) + ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args) if err != nil { resultsSlice, err := session.query(newsql, args[nStart:]...) if err != nil { @@ -1354,11 +1381,11 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { } } else { //fmt.Printf("-----Cached SQL: %v.\n", newsql) - delCacheSql(cacher, newsql, args) + cacher.DelIds(session.Statement.TableName(), genSqlKey(newsql, args)) } for _, id := range ids { - if bean := getCacheId(cacher, session.Statement.TableName(), id); bean != nil { + if bean := cacher.GetBean(session.Statement.TableName(), id); bean != nil { sqls := strings.SplitN(strings.ToLower(sql), "where", 2) if len(sqls) != 2 { return nil @@ -1385,7 +1412,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { } } - putCacheId(cacher, session.Statement.TableName(), id, bean) + cacher.PutBean(session.Statement.TableName(), id, bean) } } return nil @@ -1495,7 +1522,7 @@ func (session *Session) cacheDelete(sql string, args ...interface{}) error { } cacher := session.Statement.RefTable.Cacher - ids, err := getCacheSql(cacher, newsql, args) + ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args) if err != nil { resultsSlice, err := session.query(newsql, args...) if err != nil { @@ -1518,11 +1545,11 @@ func (session *Session) cacheDelete(sql string, args ...interface{}) error { } } else { //fmt.Printf("-----Cached SQL: %v.\n", newsql) - delCacheSql(cacher, newsql, args) + cacher.DelIds(session.Statement.TableName(), genSqlKey(newsql, args)) } for _, id := range ids { - delCacheId(cacher, session.Statement.TableName(), id) + cacher.DelBean(session.Statement.TableName(), id) } return nil }