From d611b575eed2e76b3d7e58f4463c96e0577224a5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 17 Nov 2013 00:52:43 +0800 Subject: [PATCH] fixed bool bug; added cache expired --- base_test.go | 42 +++++++++++ cache.go | 171 ++++++++++++++++++++++++++++++++---------- examples/cache.go | 14 ++-- examples/goroutine.go | 2 + mymysql_test.go | 22 +++++- mysql_test.go | 8 +- postgres_test.go | 20 +++++ session.go | 46 +++++++++--- sqlite3_test.go | 2 +- 9 files changed, 268 insertions(+), 59 deletions(-) diff --git a/base_test.go b/base_test.go index f2102100..4ef655b1 100644 --- a/base_test.go +++ b/base_test.go @@ -1399,6 +1399,46 @@ func testUseBool(engine *Engine, t *testing.T) { } } +func testBool(engine *Engine, t *testing.T) { + _, err := engine.UseBool().Update(&Userinfo{IsMan: true}) + if err != nil { + t.Error(err) + panic(err) + } + users := make([]Userinfo, 0) + err = engine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for _, user := range users { + if !user.IsMan { + err = errors.New("update bool or find bool error") + t.Error(err) + panic(err) + } + } + + _, err = engine.UseBool().Update(&Userinfo{IsMan: false}) + if err != nil { + t.Error(err) + panic(err) + } + users = make([]Userinfo, 0) + err = engine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for _, user := range users { + if user.IsMan { + err = errors.New("update bool or find bool error") + t.Error(err) + panic(err) + } + } +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -1489,6 +1529,8 @@ func testAll2(engine *Engine, t *testing.T) { testDistinct(engine, t) fmt.Println("-------------- testUseBool --------------") testUseBool(engine, t) + fmt.Println("-------------- testBool --------------") + testBool(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) } diff --git a/cache.go b/cache.go index e51f96c2..c16fe15b 100644 --- a/cache.go +++ b/cache.go @@ -2,7 +2,6 @@ package xorm import ( "container/list" - //"encoding/json" "errors" "fmt" "strconv" @@ -11,6 +10,15 @@ import ( "time" ) +const ( + CacheExpired = 60 * time.Minute + CacheMaxMemory = 256 + // evey ten minutes to clear all expired nodes + CacheGcInterval = 10 * time.Minute + // each time when gc to removed max nodes + CacheGcMaxRemoved = 20 +) + type CacheStore interface { Put(key, value interface{}) error Get(key interface{}) (interface{}, error) @@ -30,7 +38,6 @@ func (s *MemoryStore) Put(key, value interface{}) error { s.mutex.Lock() defer s.mutex.Unlock() s.store[key] = value - //fmt.Println(s.store) return nil } @@ -69,34 +76,99 @@ type idNode struct { } type sqlNode struct { + tbName string sql string lastVisit time.Time } -func newNode(tbName string, id int64) *idNode { +func newIdNode(tbName string, id int64) *idNode { return &idNode{tbName, id, time.Now()} } +func newSqlNode(tbName, sql string) *sqlNode { + return &sqlNode{tbName, sql, time.Now()} +} + // LRUCacher implements Cacher according to LRU algorithm type LRUCacher struct { - idList *list.List - sqlList *list.List - idIndex map[string]map[interface{}]*list.Element - sqlIndex map[string]map[interface{}]*list.Element - store CacheStore - Max int - mutex sync.Mutex - expired int + idList *list.List + sqlList *list.List + idIndex map[string]map[interface{}]*list.Element + sqlIndex map[string]map[interface{}]*list.Element + store CacheStore + Max int + mutex sync.Mutex + Expired time.Duration + maxSize int + GcInterval time.Duration } -func NewLRUCacher(store CacheStore, max int) *LRUCacher { +func newLRUCacher(store CacheStore, expired time.Duration, maxSize int, max int) *LRUCacher { cacher := &LRUCacher{store: store, idList: list.New(), - sqlList: list.New(), Max: max} - cacher.sqlIndex = make(map[string]map[interface{}]*list.Element) - cacher.idIndex = make(map[string]map[interface{}]*list.Element) + sqlList: list.New(), Expired: expired, maxSize: maxSize, + GcInterval: CacheGcInterval, Max: max, + sqlIndex: make(map[string]map[interface{}]*list.Element), + idIndex: make(map[string]map[interface{}]*list.Element), + } + cacher.RunGC() return cacher } +func NewLRUCacher(store CacheStore, max int) *LRUCacher { + return newLRUCacher(store, CacheExpired, CacheMaxMemory, max) +} + +func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher { + return newLRUCacher(store, expired, 0, max) +} + +func NewLRUCacher3(store CacheStore, expired time.Duration, maxSize int) *LRUCacher { + return newLRUCacher(store, expired, maxSize, 0) +} + +// RunGC run once every m.GcInterval +func (m *LRUCacher) RunGC() { + time.AfterFunc(m.GcInterval, func() { + m.RunGC() + m.GC() + }) +} + +// GC check ids lit and sql list to remove all element expired +func (m *LRUCacher) GC() { + m.mutex.Lock() + defer m.mutex.Unlock() + var removedNum int + for e := m.idList.Front(); e != nil; { + if removedNum <= CacheGcMaxRemoved && + time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { + removedNum++ + next := e.Next() + //fmt.Println("removing ...", e.Value) + node := e.Value.(*idNode) + m.delBean(node.tbName, node.id) + e = next + } else { + break + } + } + + for e := m.sqlList.Front(); e != nil; { + if removedNum <= CacheGcMaxRemoved && + time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { + removedNum++ + next := e.Next() + //fmt.Println("removing ...", e.Value) + node := e.Value.(*sqlNode) + m.DelIds(node.tbName, node.sql) + e = next + } else { + break + } + } +} + +// Get all bean's ids according to sql and parameter from cache func (m *LRUCacher) GetIds(tableName, sql string) interface{} { m.mutex.Lock() defer m.mutex.Unlock() @@ -105,48 +177,60 @@ func (m *LRUCacher) GetIds(tableName, sql string) interface{} { } if v, err := m.store.Get(sql); err == nil { if el, ok := m.sqlIndex[tableName][sql]; !ok { - el = m.sqlList.PushBack(sql) + el = m.sqlList.PushBack(newSqlNode(tableName, sql)) m.sqlIndex[tableName][sql] = el } else { + lastTime := el.Value.(*sqlNode).lastVisit + // if expired, remove the node and return nil + if time.Now().Sub(lastTime) > m.Expired { + m.delIds(tableName, sql) + return nil + } m.sqlList.MoveToBack(el) + el.Value.(*sqlNode).lastVisit = time.Now() } return v } else { - if el, ok := m.sqlIndex[tableName][sql]; ok { - delete(m.sqlIndex[tableName], sql) - m.sqlList.Remove(el) - } + m.delIds(tableName, sql) } return nil } +// Get bean according tableName and id from cache func (m *LRUCacher) GetBean(tableName string, id int64) interface{} { m.mutex.Lock() defer m.mutex.Unlock() if _, ok := m.idIndex[tableName]; !ok { m.idIndex[tableName] = make(map[interface{}]*list.Element) } - if v, err := m.store.Get(genId(tableName, id)); err == nil { + tid := genId(tableName, id) + if v, err := m.store.Get(tid); err == nil { if el, ok := m.idIndex[tableName][id]; ok { + lastTime := el.Value.(*idNode).lastVisit + // if expired, remove the node and return nil + if time.Now().Sub(lastTime) > m.Expired { + m.delBean(tableName, id) + //m.clearIds(tableName) + return nil + } m.idList.MoveToBack(el) + el.Value.(*idNode).lastVisit = time.Now() } else { - el = m.idList.PushBack(newNode(tableName, id)) + el = m.idList.PushBack(newIdNode(tableName, id)) m.idIndex[tableName][id] = el } return v } else { // store bean is not exist, then remove memory's index - if _, ok := m.idIndex[tableName][id]; ok { - m.delBean(tableName, id) - m.clearIds(tableName) - } + m.delBean(tableName, id) + //m.clearIds(tableName) return nil } } +// Clear all sql-ids mapping on table tableName from cache func (m *LRUCacher) clearIds(tableName string) { - //fmt.Println("clear ids") if tis, ok := m.sqlIndex[tableName]; ok { for sql, v := range tis { m.sqlList.Remove(v) @@ -163,15 +247,12 @@ func (m *LRUCacher) ClearIds(tableName string) { } func (m *LRUCacher) clearBeans(tableName string) { - //fmt.Println("clear beans") if tis, ok := m.idIndex[tableName]; ok { - //fmt.Println("before clear", len(m.idIndex[tableName])) for id, v := range tis { m.idList.Remove(v) tid := genId(tableName, id.(int64)) m.store.Del(tid) } - //fmt.Println("after clear", len(m.idIndex[tableName])) } m.idIndex[tableName] = make(map[interface{}]*list.Element) } @@ -189,15 +270,17 @@ func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) { m.sqlIndex[tableName] = make(map[interface{}]*list.Element) } if el, ok := m.sqlIndex[tableName][sql]; !ok { - el = m.sqlList.PushBack(sql) + el = m.sqlList.PushBack(newSqlNode(tableName, sql)) m.sqlIndex[tableName][sql] = el + } else { + el.Value.(*sqlNode).lastVisit = time.Now() } m.store.Put(sql, ids) - /*if m.sqlList.Len() > m.Max { + if m.sqlList.Len() > m.Max { e := m.sqlList.Front() - node := e.Value.(*idNode) - m.delBean(node.tbName, node.id) - }*/ + node := e.Value.(*sqlNode) + m.delIds(node.tbName, node.sql) + } } func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) { @@ -207,8 +290,10 @@ func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) { var ok bool if el, ok = m.idIndex[tableName][id]; !ok { - el = m.idList.PushBack(newNode(tableName, id)) + el = m.idList.PushBack(newIdNode(tableName, id)) m.idIndex[tableName][id] = el + } else { + el.Value.(*idNode).lastVisit = time.Now() } m.store.Put(genId(tableName, id), obj) @@ -219,16 +304,20 @@ func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) { } } -func (m *LRUCacher) DelIds(tableName, sql string) { - m.mutex.Lock() - defer m.mutex.Unlock() +func (m *LRUCacher) delIds(tableName, sql string) { if _, ok := m.sqlIndex[tableName]; ok { if el, ok := m.sqlIndex[tableName][sql]; ok { - m.store.Del(sql) - delete(m.sqlIndex, sql) + delete(m.sqlIndex[tableName], sql) m.sqlList.Remove(el) } } + m.store.Del(sql) +} + +func (m *LRUCacher) DelIds(tableName, sql string) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.delIds(tableName, sql) } func (m *LRUCacher) delBean(tableName string, id int64) { diff --git a/examples/cache.go b/examples/cache.go index 134386ce..4e4327cd 100644 --- a/examples/cache.go +++ b/examples/cache.go @@ -97,11 +97,13 @@ func main() { return } - user6 := new(User) - has, err = Orm.Id(1).Get(user6) - if err != nil { - fmt.Println(err) - return + for { + user6 := new(User) + has, err = Orm.Id(1).Get(user6) + if err != nil { + fmt.Println(err) + return + } + fmt.Println("user6:", has, user6) } - fmt.Println("user6:", has, user6) } diff --git a/examples/goroutine.go b/examples/goroutine.go index 23c57a76..02415d2c 100644 --- a/examples/goroutine.go +++ b/examples/goroutine.go @@ -8,6 +8,7 @@ import ( "os" //"time" //"sync/atomic" + "runtime" xorm "xorm" ) @@ -84,6 +85,7 @@ func test(engine *xorm.Engine) { } func main() { + runtime.GOMAXPROCS(2) fmt.Println("-----start sqlite go routines-----") engine, err := sqliteEngine() if err != nil { diff --git a/mymysql_test.go b/mymysql_test.go index 141b4803..e13a7ab5 100644 --- a/mymysql_test.go +++ b/mymysql_test.go @@ -13,13 +13,33 @@ utf8 COLLATE utf8_general_ci; var showTestSql bool = true func TestMyMysql(t *testing.T) { - engine, err := NewEngine("mymysql", "xorm_test2/root/") + engine, err := NewEngine("mymysql", "xorm_test/root/") defer engine.Close() if err != nil { t.Error(err) return } engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql + + testAll(engine, t) + testAll2(engine, t) +} + +func TestMyMysqlWithCache(t *testing.T) { + engine, err := NewEngine("mymysql", "xorm_test2/root/") + defer engine.Close() + if err != nil { + t.Error(err) + return + } + engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql testAll(engine, t) testAll2(engine, t) diff --git a/mysql_test.go b/mysql_test.go index 41f7e9d1..106e898e 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -18,13 +18,16 @@ func TestMysql(t *testing.T) { return } engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql testAll(engine, t) testAll2(engine, t) } func TestMysqlWithCache(t *testing.T) { - engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8") + engine, err := NewEngine("mysql", "root:@/xorm_test2?charset=utf8") defer engine.Close() if err != nil { t.Error(err) @@ -32,6 +35,9 @@ func TestMysqlWithCache(t *testing.T) { } engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql testAll(engine, t) testAll2(engine, t) diff --git a/postgres_test.go b/postgres_test.go index 97c58133..3cea129a 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -15,6 +15,26 @@ func TestPostgres(t *testing.T) { } defer engine.Close() engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql + + testAll(engine, t) + testAll2(engine, t) +} + +func TestPostgresWithCache(t *testing.T) { + engine, err := NewEngine("postgres", "dbname=xorm_test2 sslmode=disable") + if err != nil { + t.Error(err) + return + } + engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) + defer engine.Close() + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql testAll(engine, t) testAll2(engine, t) diff --git a/session.go b/session.go index 45f74abc..63e87219 100644 --- a/session.go +++ b/session.go @@ -252,6 +252,7 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b for key, data := range objMap { if _, ok := table.Columns[key]; !ok { + session.Engine.LogWarn("table %v's has not column %v.", table.Name, key) continue } col := table.Columns[key] @@ -270,6 +271,8 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b fieldValue = dataStruct.FieldByName(fieldName) } if !fieldValue.IsValid() || !fieldValue.CanSet() { + session.Engine.LogWarn("table %v's column %v is not valid or cannot set", + table.Name, key) continue } @@ -546,7 +549,10 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface } func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr interface{}, args ...interface{}) (err error) { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + if session.Statement.RefTable == nil || + session.Statement.RefTable.PrimaryKey == "" || + indexNoCase(sql, "having") != -1 || + indexNoCase(sql, "group by") != -1 { return ErrCacheFailed } @@ -1140,16 +1146,16 @@ func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err vv := reflect.ValueOf(rawValue.Interface()) var str string switch aa.Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: str = strconv.FormatInt(vv.Int(), 10) result[key] = []byte(str) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: str = strconv.FormatUint(vv.Uint(), 10) result[key] = []byte(str) case reflect.Float32, reflect.Float64: str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) result[key] = []byte(str) - case reflect.Slice: + case reflect.Array, reflect.Slice: switch aa.Elem().Kind() { case reflect.Uint8: result[key] = rawValue.Interface().([]byte) @@ -1165,10 +1171,22 @@ func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700") result[key] = []byte(str) } else { - //session.Engine.LogError("Unsupported struct type") + return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) } + case reflect.Bool: + str := strconv.FormatBool(vv.Bool()) + result[key] = []byte(str) + case reflect.Complex128, reflect.Complex64: + result[key] = []byte(fmt.Sprintf("%v", vv.Complex())) + /* TODO: unsupported types below + case reflect.Map: + case reflect.Ptr: + case reflect.Uintptr: + case reflect.UnsafePointer: + case reflect.Chan, reflect.Func, reflect.Interface: + */ default: - //session.Engine.LogError("Unsupported type") + return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) } } return result, nil @@ -1428,6 +1446,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var v interface{} key := col.Name fieldType := fieldValue.Type() + //fmt.Println("column name:", key, ", fieldType:", fieldType.String()) switch fieldType.Kind() { case reflect.Complex64, reflect.Complex128: x := reflect.New(fieldType) @@ -1468,7 +1487,9 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data case reflect.String: fieldValue.SetString(string(data)) case reflect.Bool: - v, err := strconv.ParseBool(string(data)) + d := string(data) + //fmt.Println("------", d, "-------") + v, err := strconv.ParseBool(d) if err != nil { return errors.New("arg " + key + " as bool: " + err.Error()) } @@ -1738,6 +1759,11 @@ func (statement *Statement) convertUpdateSql(sql string) (string, string) { } sqls := splitNNoCase(sql, "where", 2) if len(sqls) != 2 { + if len(sqls) == 1 { + return sqls[0], fmt.Sprintf("SELECT %v FROM %v", + statement.Engine.Quote(statement.RefTable.PrimaryKey), + statement.Engine.Quote(statement.RefTable.Name)) + } return "", "" } @@ -1832,7 +1858,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { for _, id := range ids { if bean := cacher.GetBean(tableName, id); bean != nil { sqls := splitNNoCase(sql, "where", 2) - if len(sqls) != 2 { + if len(sqls) == 0 || len(sqls) > 2 { return ErrCacheFailed } @@ -1858,6 +1884,9 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { fieldValue := col.ValueOf(bean) session.Engine.LogDebug("[xorm:cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) fieldValue.Set(reflect.ValueOf(args[idx])) + } else { + session.Engine.LogError("[xorm:cacheUpdate] ERROR: column %v is not table %v's", + colName, table.Name) } } @@ -1966,7 +1995,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if err != nil { return 0, err } - if table.Cacher != nil && session.Statement.UseCache { session.cacheUpdate(sql, args...) } diff --git a/sqlite3_test.go b/sqlite3_test.go index 0a1c33a8..fa401119 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1,7 +1,7 @@ package xorm import ( - _ "github.com/mattn/go-sqlite3" + //_ "github.com/mattn/go-sqlite3" "os" "testing" )