diff --git a/QuickStart.md b/QuickStart.md index 00d24625..b2c6cfc5 100644 --- a/QuickStart.md +++ b/QuickStart.md @@ -232,6 +232,9 @@ fmt.Println(user.Id) * OrderBy() 按照指定的顺序进行排序 +* NoAutoTime() +如果此方法执行,则此次生成的语句中Created和Updated字段将不自动赋值为当前时间 + * In(string, …interface{}) 某字段在一些值中 @@ -310,9 +313,11 @@ affected, err := engine.Id(id).Update(&user) 删除数据`Delete`方法,参数为struct的指针并且成为查询条件。 ```Go user := new(User) -engine.Id(id).Delete(user) +affected, err := engine.Id(id).Delete(user) ``` +`Delete`的返回值第一个参数为删除的记录数,第二个参数为错误。 + ## 9.执行SQL查询 diff --git a/base_test.go b/base_test.go index 4147ddd9..800d4be7 100644 --- a/base_test.go +++ b/base_test.go @@ -778,6 +778,16 @@ func testCreatedAndUpdated(engine *Engine, t *testing.T) { t.Error(err) panic(err) } + + u.Id = 0 + u.Created = time.Now().Add(-time.Hour * 24 * 365) + u.Updated = u.Created + fmt.Println(u) + _, err = engine.NoAutoTime().Insert(u) + if err != nil { + t.Error(err) + panic(err) + } } type IndexOrUnique struct { diff --git a/cache.go b/cache.go new file mode 100644 index 00000000..8ae8e970 --- /dev/null +++ b/cache.go @@ -0,0 +1,173 @@ +package xorm + +import ( + "container/list" + //"encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "sync" +) + +type CacheStore interface { + Put(key, value interface{}) error + Get(key interface{}) (interface{}, error) + Del(key interface{}) error +} + +type MemoryStore struct { + store map[interface{}]interface{} + mutex sync.Mutex +} + +func NewMemoryStore() *MemoryStore { + return &MemoryStore{store: make(map[interface{}]interface{})} +} + +func (s *MemoryStore) Put(key, value interface{}) error { + s.mutex.Lock() + defer s.mutex.Unlock() + s.store[key] = value + //fmt.Println("after put store:", s.store) + return nil +} + +func (s *MemoryStore) Get(key interface{}) (interface{}, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + //fmt.Println("before get store:", s.store) + if v, ok := s.store[key]; ok { + return v, nil + } + + return nil, ErrNotExist +} + +func (s *MemoryStore) Del(key interface{}) error { + s.mutex.Lock() + defer s.mutex.Unlock() + //fmt.Println("before del store:", s.store) + delete(s.store, key) + //fmt.Println("after del store:", s.store) + return nil +} + +type Cacher interface { + Get(id interface{}) interface{} + Put(id, obj interface{}) + Del(id interface{}) +} + +// LRUCacher implements Cacher according to LRU algorithm +type LRUCacher struct { + name string + list *list.List + index map[interface{}]*list.Element + store CacheStore + Max int + mutex sync.RWMutex +} + +func NewLRUCacher(store CacheStore, max int) *LRUCacher { + return &LRUCacher{store: store, list: list.New(), + index: make(map[interface{}]*list.Element), Max: max} +} + +func (m *LRUCacher) Get(id interface{}) interface{} { + m.mutex.RLock() + defer m.mutex.RUnlock() + if v, err := m.store.Get(id); err == nil { + el := m.index[id] + m.list.MoveToBack(el) + return v + } + return nil +} + +func (m *LRUCacher) Put(id interface{}, obj interface{}) { + m.mutex.Lock() + defer m.mutex.Unlock() + el := m.list.PushBack(id) + m.index[id] = el + m.store.Put(id, obj) + if m.list.Len() > m.Max { + e := m.list.Front() + m.store.Del(e.Value) + delete(m.index, e.Value) + m.list.Remove(e) + } +} + +func (m *LRUCacher) Del(id interface{}) { + m.mutex.Lock() + defer m.mutex.Unlock() + if el, ok := m.index[id]; ok { + m.store.Del(id) + delete(m.index, el.Value) + m.list.Remove(el) + } +} + +func encodeIds(ids []int64) (s string) { + s = "[" + for _, id := range ids { + s += fmt.Sprintf("%v,", id) + } + s = s[:len(s)-1] + "]" + return +} + +func decodeIds(s string) []int64 { + res := make([]int64, 0) + if len(s) >= 2 { + ss := strings.Split(s[1:len(s)-1], ",") + for _, s := range ss { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return res + } + res = append(res, i) + } + } + return res +} + +func GetCacheSql(m Cacher, sql string) ([]int64, error) { + bytes := m.Get(sql) + if bytes == nil { + return nil, errors.New("Not Exist") + } + objs := decodeIds(bytes.(string)) + return objs, nil +} + +func PutCacheSql(m Cacher, sql string, ids []int64) error { + bytes := encodeIds(ids) + m.Put(sql, bytes) + return nil +} + +func DelCacheSql(m Cacher, sql string) error { + m.Del(sql) + return nil +} + +func genId(prefix string, id int64) string { + return fmt.Sprintf("%v-%v", prefix, id) +} + +func GetCacheId(m Cacher, prefix string, id int64) interface{} { + return m.Get(genId(prefix, id)) +} + +func PutCacheId(m Cacher, prefix string, id int64, bean interface{}) error { + m.Put(genId(prefix, id), bean) + return nil +} + +func DelCacheId(m Cacher, prefix string, id int64) error { + m.Del(genId(prefix, id)) + //TODO: should delete id from select + return nil +} diff --git a/engine.go b/engine.go index 9a2ea09a..a7497f09 100644 --- a/engine.go +++ b/engine.go @@ -36,9 +36,10 @@ type Engine struct { mutex *sync.Mutex ShowSQL bool Pool IConnectPool - CacheMapping bool Filters []Filter Logger io.Writer + Cacher Cacher + UseCache bool } func (engine *Engine) SupportInsertMany() bool { @@ -70,6 +71,25 @@ func (engine *Engine) SetMaxConns(conns int) { engine.Pool.SetMaxConns(conns) } +func (engine *Engine) SetDefaultCacher(cacher Cacher) { + if cacher == nil { + engine.UseCache = false + } else { + engine.UseCache = true + engine.Cacher = cacher + } +} + +func (engine *Engine) NoCache(bean interface{}) { + engine.MapCacher(bean, nil) +} + +func (engine *Engine) MapCacher(bean interface{}, cacher Cacher) { + t := Type(bean) + engine.AutoMapType(t) + engine.Tables[t].Cacher = cacher +} + func Type(bean interface{}) reflect.Type { sliceValue := reflect.Indirect(reflect.ValueOf(bean)) return reflect.TypeOf(sliceValue.Interface()) @@ -119,6 +139,12 @@ func (engine *Engine) Sql(querystring string, args ...interface{}) *Session { return session.Sql(querystring, args...) } +func (engine *Engine) NoAutoTime() *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.NoAutoTime() +} + func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { session := engine.NewSession() session.IsAutoClose = true @@ -155,11 +181,11 @@ func (engine *Engine) Cols(columns ...string) *Session { return session.Cols(columns...) } -func (engine *Engine) Trans(t string) *Session { +/*func (engine *Engine) Trans(t string) *Session { session := engine.NewSession() session.IsAutoClose = true return session.Trans(t) -} +}*/ func (engine *Engine) In(column string, args ...interface{}) *Session { session := engine.NewSession() @@ -233,8 +259,16 @@ func (engine *Engine) AutoMap(bean interface{}) *Table { return engine.AutoMapType(t) } +func (engine *Engine) newTable() *Table { + table := &Table{Indexes: map[string][]string{}, Uniques: map[string][]string{}} + table.Columns = make(map[string]*Column) + table.ColumnsSeq = make([]string, 0) + table.Cacher = engine.Cacher + return table +} + func (engine *Engine) MapType(t reflect.Type) *Table { - table := NewTable() + table := engine.newTable() table.Name = engine.Mapper.Obj2Table(t.Name()) table.Type = t diff --git a/error.go b/error.go index b275d26b..03520655 100644 --- a/error.go +++ b/error.go @@ -9,4 +9,5 @@ var ( ErrTableNotFound error = errors.New("not found table") ErrUnSupportedType error = errors.New("unsupported type error") ErrNotExist error = errors.New("not exist error") + ErrCacheFailed error = errors.New("cache failed") ) diff --git a/examples/cache.go b/examples/cache.go new file mode 100644 index 00000000..134386ce --- /dev/null +++ b/examples/cache.go @@ -0,0 +1,107 @@ +package main + +import ( + "fmt" + _ "github.com/mattn/go-sqlite3" + "os" + . "xorm" +) + +type User struct { + Id int64 + Name string +} + +func main() { + f := "cache.db" + os.Remove(f) + + Orm, err := NewEngine("sqlite3", f) + if err != nil { + fmt.Println(err) + return + } + Orm.ShowSQL = true + cacher := NewLRUCacher(NewMemoryStore(), 1000) + Orm.SetDefaultCacher(cacher) + + err = Orm.CreateTables(&User{}) + if err != nil { + fmt.Println(err) + return + } + + _, err = Orm.Insert(&User{Name: "xlw"}) + if err != nil { + fmt.Println(err) + return + } + + users := make([]User, 0) + err = Orm.Find(&users) + if err != nil { + fmt.Println(err) + return + } + + fmt.Println("users:", users) + + users2 := make([]User, 0) + + err = Orm.Find(&users2) + if err != nil { + fmt.Println(err) + return + } + + fmt.Println("users2:", users2) + + users3 := make([]User, 0) + + err = Orm.Find(&users3) + if err != nil { + fmt.Println(err) + return + } + + fmt.Println("users3:", users3) + + user4 := new(User) + has, err := Orm.Id(1).Get(user4) + if err != nil { + fmt.Println(err) + return + } + + fmt.Println("user4:", has, user4) + + user4.Name = "xiaolunwen" + _, err = Orm.Id(1).Update(user4) + if err != nil { + fmt.Println(err) + return + } + fmt.Println("user4:", user4) + + user5 := new(User) + has, err = Orm.Id(1).Get(user5) + if err != nil { + fmt.Println(err) + return + } + fmt.Println("user5:", has, user5) + + _, err = Orm.Id(1).Delete(new(User)) + if err != nil { + fmt.Println(err) + return + } + + user6 := new(User) + has, err = Orm.Id(1).Get(user6) + if err != nil { + fmt.Println(err) + return + } + fmt.Println("user6:", has, user6) +} diff --git a/examples/cachegoroutine.go b/examples/cachegoroutine.go new file mode 100644 index 00000000..af3c62b9 --- /dev/null +++ b/examples/cachegoroutine.go @@ -0,0 +1,112 @@ +package main + +import ( + //xorm "github.com/lunny/xorm" + "fmt" + _ "github.com/go-sql-driver/mysql" + _ "github.com/mattn/go-sqlite3" + "os" + //"time" + //"sync/atomic" + xorm "xorm" +) + +type User struct { + Id int64 + Name string +} + +func sqliteEngine() (*xorm.Engine, error) { + os.Remove("./test.db") + return xorm.NewEngine("sqlite3", "./goroutine.db") +} + +func mysqlEngine() (*xorm.Engine, error) { + return xorm.NewEngine("mysql", "root:@/test?charset=utf8") +} + +var u *User = &User{} + +func test(engine *xorm.Engine) { + err := engine.CreateTables(u) + if err != nil { + fmt.Println(err) + return + } + + size := 500 + queue := make(chan int, size) + + for i := 0; i < size; i++ { + go func(x int) { + //x := i + err := engine.Test() + if err != nil { + fmt.Println(err) + } else { + err = engine.Map(u) + if err != nil { + fmt.Println("Map user failed") + } else { + for j := 0; j < 10; j++ { + if x+j < 2 { + _, err = engine.Get(u) + } else if x+j < 4 { + users := make([]User, 0) + err = engine.Find(&users) + } else if x+j < 8 { + _, err = engine.Count(u) + } else if x+j < 16 { + _, err = engine.Insert(&User{Name: "xlw"}) + } else if x+j < 32 { + _, err = engine.Id(1).Delete(u) + } + if err != nil { + fmt.Println(err) + queue <- x + return + } + } + fmt.Printf("%v success!\n", x) + } + } + queue <- x + }(i) + } + + for i := 0; i < size; i++ { + <-queue + } + + //conns := atomic.LoadInt32(&xorm.ConnectionNum) + //fmt.Println("connection number:", conns) + fmt.Println("end") +} + +func main() { + fmt.Println("-----start sqlite go routines-----") + engine, err := sqliteEngine() + if err != nil { + fmt.Println(err) + return + } + engine.ShowSQL = true + cacher := xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000) + engine.SetDefaultCacher(cacher) + fmt.Println(engine) + test(engine) + fmt.Println("test end") + engine.Close() + + fmt.Println("-----start mysql go routines-----") + engine, err = mysqlEngine() + engine.ShowSQL = true + cacher = xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000) + engine.SetDefaultCacher(cacher) + if err != nil { + fmt.Println(err) + return + } + defer engine.Close() + test(engine) +} diff --git a/examples/maxconnect.go b/examples/maxconnect.go index bb979298..b930abca 100644 --- a/examples/maxconnect.go +++ b/examples/maxconnect.go @@ -95,7 +95,7 @@ func main() { engine.ShowSQL = true fmt.Println(engine) test(engine) - fmt.Println("test end") + fmt.Println("------------------------") engine.Close() engine, err = mysqlEngine() diff --git a/pool.go b/pool.go index 21bda9af..ea88af11 100644 --- a/pool.go +++ b/pool.go @@ -5,6 +5,7 @@ import ( "fmt" "sync" //"sync/atomic" + "container/list" "time" ) @@ -81,8 +82,7 @@ type SysConnectPool struct { maxConns int curConns int mutex *sync.Mutex - condMutex *sync.Mutex - cond *sync.Cond + queue *list.List } // NewSysConnectPool new a SysConnectPool. @@ -101,43 +101,61 @@ func (s *SysConnectPool) Init(engine *Engine) error { s.maxConns = -1 s.curConns = 0 s.mutex = &sync.Mutex{} - s.condMutex = &sync.Mutex{} - s.cond = sync.NewCond(s.condMutex) + s.queue = list.New() return nil } +type node struct { + mutex sync.Mutex + cond *sync.Cond +} + +func NewNode() *node { + n := &node{} + n.cond = sync.NewCond(&n.mutex) + return n +} + // RetrieveDB just return the only db -func (p *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { - if p.maxConns > 0 { - p.condMutex.Lock() - fmt.Println("before retrieve - current connections:", p.curConns, p.maxConns) - for p.curConns >= p.maxConns { - fmt.Println("waiting...", p.curConns) - p.cond.Wait() +func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { + if s.maxConns > 0 { + fmt.Println("before retrieve") + s.mutex.Lock() + for s.curConns >= s.maxConns { + fmt.Println("before waiting...", s.curConns, s.queue.Len()) + s.mutex.Unlock() + n := NewNode() + n.cond.L.Lock() + s.queue.PushBack(n) + n.cond.Wait() + n.cond.L.Unlock() + s.mutex.Lock() + fmt.Println("after waiting...", s.curConns, s.queue.Len()) } - //p.mutex.Lock() - p.curConns += 1 - p.cond.Signal() - //p.mutex.Lock() - p.condMutex.Unlock() + s.curConns += 1 + s.mutex.Unlock() + fmt.Println("after retrieve") } - return p.db, nil + return s.db, nil } // ReleaseDB do nothing -func (p *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { - if p.maxConns > 0 { - p.condMutex.Lock() - fmt.Println("before release - current connections:", p.curConns, p.maxConns) - //if p.curConns >= p.maxConns-2 { - fmt.Println("signaling...") - //p.mutex.Lock() - p.curConns -= 1 - //p.mutex.Unlock() - p.cond.Signal() - //} - p.condMutex.Unlock() +func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { + if s.maxConns > 0 { + s.mutex.Lock() + fmt.Println("before release", s.queue.Len()) + s.curConns -= 1 + if e := s.queue.Front(); e != nil { + n := e.Value.(*node) + //n.cond.L.Lock() + n.cond.Signal() + fmt.Println("signaled...") + s.queue.Remove(e) + //n.cond.L.Unlock() + } + fmt.Println("after released", s.queue.Len()) + s.mutex.Unlock() } } diff --git a/session.go b/session.go index b8c094e9..6234e14a 100644 --- a/session.go +++ b/session.go @@ -71,11 +71,16 @@ func (session *Session) Cols(columns ...string) *Session { return session } -func (session *Session) Trans(t string) *Session { - session.TransType = t +func (session *Session) NoAutoTime() *Session { + session.Statement.UseAutoTime = false return session } +/*func (session *Session) Trans(t string) *Session { + session.TransType = t + return session +}*/ + func (session *Session) Limit(limit int, start ...int) *Session { session.Statement.Limit(limit, start...) return session @@ -121,6 +126,11 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session { return session } +func (session *Session) NoCache() *Session { + session.Statement.UseCache = false + return session +} + //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (session *Session) Join(join_operator, tablename, condition string) *Session { session.Statement.Join(join_operator, tablename, condition) @@ -240,12 +250,7 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, return res, nil } -func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { - err := session.newDb() - if err != nil { - return nil, err - } - +func (session *Session) exec(sql string, args ...interface{}) (sql.Result, error) { for _, filter := range session.Engine.Filters { sql = filter.Do(sql, session) } @@ -259,6 +264,19 @@ func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error return session.Tx.Exec(sql, args...) } +func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { + err := session.newDb() + if err != nil { + return nil, err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } + + return session.exec(sql, args...) +} + // this function create a table according a bean func (session *Session) CreateTable(bean interface{}) error { session.Statement.RefTable = session.Engine.AutoMap(bean) @@ -267,17 +285,21 @@ func (session *Session) CreateTable(bean interface{}) error { if err != nil { return err } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } return session.createOneTable() } func (session *Session) createOneTable() error { sql := session.Statement.genCreateSQL() - _, err := session.Exec(sql) + _, err := session.exec(sql) if err == nil { sqls := session.Statement.genIndexSQL() for _, sql := range sqls { - _, err = session.Exec(sql) + _, err = session.exec(sql) if err != nil { return err } @@ -286,7 +308,7 @@ func (session *Session) createOneTable() error { if err == nil { sqls := session.Statement.genUniqueSQL() for _, sql := range sqls { - _, err = session.Exec(sql) + _, err = session.exec(sql) if err != nil { return err } @@ -300,7 +322,7 @@ func (session *Session) CreateAll() error { if err != nil { return err } - + defer session.Statement.Init() if session.IsAutoClose { defer session.Close() } @@ -321,6 +343,7 @@ func (session *Session) DropTable(bean interface{}) error { return err } + defer session.Statement.Init() if session.IsAutoClose { defer session.Close() } @@ -336,33 +359,227 @@ func (session *Session) DropTable(bean interface{}) error { } sql := session.Statement.genDropSQL() - _, err = session.Exec(sql) + _, err = session.exec(sql) return err } +func (statement *Statement) convertIdSql(sql string) string { + if statement.RefTable != nil { + col := statement.RefTable.PKColumn() + if col != nil { + sql = strings.ToLower(sql) + sqls := strings.SplitN(sql, "from", 2) + if len(sqls) != 2 { + return "" + } + return fmt.Sprintf("SELECT %v FROM %v", statement.Engine.Quote(col.Name), sqls[1]) + } + } + return "" +} + +func (session *Session) cacheGet(bean interface{}, sql string, args ...interface{}) (has bool, err error) { + if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + return false, ErrCacheFailed + } + for _, filter := range session.Engine.Filters { + sql = filter.Do(sql, session) + } + newsql := session.Statement.convertIdSql(sql) + if newsql == "" { + return false, ErrCacheFailed + } + + cacher := session.Statement.RefTable.Cacher + ids, err := GetCacheSql(cacher, newsql) + if err != nil { + fmt.Println(err) + resultsSlice, err := session.query(newsql, args...) + if err != nil { + return false, err + } + + ids = make([]int64, 0) + if len(resultsSlice) > 0 { + data := resultsSlice[0] + //fmt.Println(data) + var id int64 + if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + return false, errors.New("no id") + } else { + id, err = strconv.ParseInt(string(v), 10, 64) + if err != nil { + return false, err + } + } + ids = append(ids, id) + } + err = PutCacheSql(cacher, newsql, ids) + if err != nil { + fmt.Println(err) + } + } else { + fmt.Printf("-----Cached SQL: %v.\n", newsql) + } + + structValue := reflect.Indirect(reflect.ValueOf(bean)) + //fmt.Println("xxxxxxx", ids) + if len(ids) > 0 { + id := ids[0] + tableName := session.Statement.TableName() + bean = GetCacheId(cacher, tableName, id) + if bean == nil { + fmt.Printf("----Object Id %v no cached.\n", id) + newSession := session.Engine.NewSession() + defer newSession.Close() + bean = reflect.New(structValue.Type()).Interface() + has, err = newSession.Id(id).NoCache().Get(bean) + if err != nil { + return has, err + } + //fmt.Println(bean) + PutCacheId(cacher, tableName, id, bean) + } else { + fmt.Printf("-----Cached Object: %v\n", bean) + has = true + } + + structValue.Set(reflect.ValueOf(bean).Elem()) + return has, nil + } + return false, nil +} + +func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr interface{}, args ...interface{}) (err error) { + if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + return ErrCacheFailed + } + + for _, filter := range session.Engine.Filters { + sql = filter.Do(sql, session) + } + + newsql := session.Statement.convertIdSql(sql) + if newsql == "" { + return ErrCacheFailed + } + + table := session.Statement.RefTable + cacher := table.Cacher + ids, err := GetCacheSql(cacher, newsql) + if err != nil { + fmt.Println(err) + resultsSlice, err := session.query(newsql, args...) + if err != nil { + return err + } + // 查询数目太大,采用缓存将不是一个很好的方式。 + if len(resultsSlice) > 20 { + return ErrCacheFailed + } + ids = make([]int64, 0) + if len(resultsSlice) > 0 { + for _, data := range resultsSlice { + //fmt.Println(data) + var id int64 + if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + return errors.New("no id") + } else { + id, err = strconv.ParseInt(string(v), 10, 64) + if err != nil { + return err + } + } + ids = append(ids, id) + } + } + err = PutCacheSql(cacher, newsql, ids) + if err != nil { + fmt.Println(err) + } + } else { + fmt.Printf("-----Cached SQL: %v.\n", newsql) + } + + sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) + //fmt.Println("xxxxxxx", ids) + var idxes []int = make([]int, 0) + var ides []interface{} = make([]interface{}, 0) + var temps []interface{} = make([]interface{}, len(ids)) + for idx, id := range ids { + tableName := session.Statement.TableName() + bean := GetCacheId(cacher, tableName, id) + if bean == nil { + fmt.Printf("----Object Id %v no cached.\n", id) + idxes = append(idxes, idx) + ides = append(ides, id) + /*newSession := session.Engine.NewSession() + defer newSession.Close() + bean = reflect.New(t).Interface() + _, err = newSession.Id(id).In(, ...).NoCache().Get(bean) + if err != nil { + return err + } + + PutCacheId(cacher, tableName, id, bean)*/ + } else { + fmt.Printf("-----Cached Object: %v\n", bean) + temps[idx] = bean + } + + //sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(bean).Elem())) + } + + newSession := session.Engine.NewSession() + defer newSession.Close() + + beans := reflect.New(sliceValue.Type()).Interface() + err = newSession.In("(id)", ides...).OrderBy(session.Statement.OrderStr).NoCache().Find(beans) + if err != nil { + return err + } + + vs := reflect.Indirect(reflect.ValueOf(beans)) + for i := 0; i < vs.Len(); i++ { + temps[idxes[i]] = vs.Index(i).Interface() + } + + //sliceValue.SetPointer(x) + + return nil +} + +// get retrieve one record from database func (session *Session) Get(bean interface{}) (bool, error) { err := session.newDb() if err != nil { return false, err } + defer session.Statement.Init() if session.IsAutoClose { defer session.Close() } - defer session.Statement.Init() session.Statement.Limit(1) var sql string var args []interface{} + session.Statement.RefTable = session.Engine.AutoMap(bean) if session.Statement.RawSQL == "" { sql, args = session.Statement.genGetSql(bean) } else { sql = session.Statement.RawSQL args = session.Statement.RawParams - session.Engine.AutoMap(bean) } - resultsSlice, err := session.Query(sql, args...) + if session.Statement.RefTable.Cacher != nil && session.Statement.UseCache { + has, err := session.cacheGet(bean, sql, args...) + if err != ErrCacheFailed { + return has, err + } + } + + resultsSlice, err := session.query(sql, args...) if err != nil { return false, err } @@ -390,11 +607,11 @@ func (session *Session) Count(bean interface{}) (int64, error) { return 0, err } + defer session.Statement.Init() if session.IsAutoClose { defer session.Close() } - defer session.Statement.Init() var sql string var args []interface{} if session.Statement.RawSQL == "" { @@ -404,7 +621,7 @@ func (session *Session) Count(bean interface{}) (int64, error) { args = session.Statement.RawParams } - resultsSlice, err := session.Query(sql, args...) + resultsSlice, err := session.query(sql, args...) if err != nil { return 0, err } @@ -426,12 +643,11 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if err != nil { return err } - + defer session.Statement.Init() if session.IsAutoClose { defer session.Close() } - defer session.Statement.Init() sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { return errors.New("needs a pointer to a slice or a map") @@ -461,8 +677,14 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) args = session.Statement.RawParams } - resultsSlice, err := session.Query(sql, args...) + if table.Cacher != nil && session.Statement.UseCache { + err = session.cacheFind(sliceElementType, sql, rowsSlicePtr, args...) + if err != ErrCacheFailed { + return err + } + } + resultsSlice, err := session.query(sql, args...) if err != nil { return err } @@ -497,7 +719,7 @@ func (session *Session) Ping() error { if err != nil { return err } - + defer session.Statement.Init() if session.IsAutoClose { defer session.Close() } @@ -510,7 +732,7 @@ func (session *Session) DropAll() error { if err != nil { return err } - + defer session.Statement.Init() if session.IsAutoClose { defer session.Close() } @@ -519,7 +741,7 @@ func (session *Session) DropAll() error { session.Statement.Init() session.Statement.RefTable = table sql := session.Statement.genDropSQL() - _, err := session.Exec(sql) + _, err := session.exec(sql) if err != nil { return err } @@ -527,16 +749,7 @@ func (session *Session) DropAll() error { return nil } -func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { - err = session.newDb() - if err != nil { - return nil, err - } - - if session.IsAutoClose { - defer session.Close() - } - +func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { for _, filter := range session.Engine.Filters { sql = filter.Do(sql, session) } @@ -616,19 +829,28 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice return resultsSlice, nil } +func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { + err = session.newDb() + if err != nil { + return nil, err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } + + return session.query(sql, paramStr...) +} + +// insert one or more beans func (session *Session) Insert(beans ...interface{}) (int64, error) { var lastId int64 = -1 var err error = nil - isInTransaction := !session.IsAutoCommit - - if !isInTransaction { - err = session.Begin() - //defer session.Close() - if err != nil { - return 0, err - } + err = session.newDb() + if err != nil { + return 0, err } - + defer session.Statement.Init() if session.IsAutoClose { defer session.Close() } @@ -639,13 +861,6 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { if session.Engine.SupportInsertMany() { lastId, err = session.innerInsertMulti(bean) if err != nil { - if !isInTransaction { - err1 := session.Rollback() - if err1 == nil { - return lastId, err - } - err = err1 - } return lastId, err } } else { @@ -653,13 +868,6 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { for i := 0; i < size; i++ { lastId, err = session.innerInsert(sliceValue.Index(i).Interface()) if err != nil { - if !isInTransaction { - err1 := session.Rollback() - if err1 == nil { - return lastId, err - } - err = err1 - } return lastId, err } } @@ -667,20 +875,11 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { } else { lastId, err = session.innerInsert(bean) if err != nil { - if !isInTransaction { - err1 := session.Rollback() - if err1 == nil { - return lastId, err - } - err = err1 - } return lastId, err } } } - if !isInTransaction { - err = session.Commit() - } + return lastId, err } @@ -721,7 +920,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error continue } } - if col.IsCreated || col.IsUpdated { + if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { args = append(args, time.Now()) } else { arg, err := session.value2Interface(col, fieldValue) @@ -749,7 +948,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error continue } } - if col.IsCreated || col.IsUpdated { + if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { args = append(args, time.Now()) } else { arg, err := session.value2Interface(col, fieldValue) @@ -774,7 +973,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error session.Engine.QuoteStr(), strings.Join(colMultiPlaces, "),(")) - res, err := session.Exec(statement, args...) + res, err := session.exec(statement, args...) if err != nil { return -1, err } @@ -790,13 +989,13 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { err := session.newDb() - if session.IsAutoClose { - defer session.Close() - } - if err != nil { return 0, err } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } return session.innerInsertMulti(rowsSlicePtr) } @@ -1024,7 +1223,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } } - if col.IsCreated || col.IsUpdated { + if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { args = append(args, time.Now()) } else { arg, err := session.value2Interface(col, fieldValue) @@ -1047,7 +1246,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.Engine.QuoteStr(), strings.Join(colPlaces, ", ")) - res, err := session.Exec(sql, args...) + res, err := session.exec(sql, args...) if err != nil { return 0, err } @@ -1081,32 +1280,134 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { func (session *Session) InsertOne(bean interface{}) (int64, error) { err := session.newDb() - if session.IsAutoClose { - defer session.Close() - } if err != nil { return 0, err } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } return session.innerInsert(bean) } +func (statement *Statement) convertUpdateSql(sql string) (string, string) { + if statement.RefTable == nil || statement.RefTable.PrimaryKey == "" { + return "", "" + } + sqls := strings.SplitN(strings.ToLower(sql), "where", 2) + if len(sqls) != 2 { + return "", "" + } + + return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", + statement.Engine.Quote(statement.RefTable.PrimaryKey), statement.Engine.Quote(statement.TableName()), + sqls[1]) +} + +func (session *Session) cacheUpdate(sql string, args ...interface{}) error { + if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + return ErrCacheFailed + } + + for _, filter := range session.Engine.Filters { + sql = filter.Do(sql, session) + } + + oldhead, newsql := session.Statement.convertUpdateSql(sql) + if newsql == "" { + return ErrCacheFailed + } + + var nStart int + if len(args) > 0 { + if strings.Index(sql, "?") > -1 { + nStart = strings.Count(oldhead, "?") + } else { + // for pq, TODO: if any other databse? + nStart = strings.Count(oldhead, "$") + } + } + table := session.Statement.RefTable + cacher := table.Cacher + ids, err := GetCacheSql(cacher, newsql) + if err != nil { + resultsSlice, err := session.query(newsql, args[nStart:]...) + if err != nil { + return err + } + ids = make([]int64, 0) + if len(resultsSlice) > 0 { + for _, data := range resultsSlice { + var id int64 + if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + return errors.New("no id") + } else { + id, err = strconv.ParseInt(string(v), 10, 64) + if err != nil { + return err + } + } + ids = append(ids, id) + } + } + } else { + fmt.Printf("-----Cached SQL: %v.\n", newsql) + DelCacheSql(cacher, newsql) + } + + for _, id := range ids { + if bean := GetCacheId(cacher, session.Statement.TableName(), id); bean != nil { + sqls := strings.SplitN(strings.ToLower(sql), "where", 2) + if len(sqls) != 2 { + return nil + } + sqls = strings.SplitN(sqls[0], "set", 2) + if len(sqls) != 2 { + return nil + } + kvs := strings.Split(strings.TrimSpace(sqls[1]), ",") + for idx, kv := range kvs { + sps := strings.SplitN(kv, "=", 2) + sps2 := strings.Split(sps[0], ".") + colName := sps2[len(sps2)-1] + if strings.Contains(colName, "`") { + colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1)) + } else if strings.Contains(colName, session.Engine.QuoteStr()) { + colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1)) + } + //fmt.Println("find", colName) + if col, ok := table.Columns[colName]; ok { + fieldValue := col.ValueOf(bean) + //session.bytes2Value(col, fieldValue, []byte(args[idx])) + fieldValue.Set(reflect.ValueOf(args[idx])) + } + } + + PutCacheId(cacher, session.Statement.TableName(), id, bean) + } + } + return nil +} + func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { err := session.newDb() - if session.IsAutoClose { - defer session.Close() - } if err != nil { return 0, err } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } t := Type(bean) var colNames []string var args []interface{} + var table *Table if t.Kind() == reflect.Struct { - table := session.Engine.AutoMap(bean) + table = session.Engine.AutoMap(bean) session.Statement.RefTable = table colNames, args = BuildConditions(session.Engine, table, bean) if table.Updated != "" { @@ -1117,7 +1418,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.Statement.RefTable == nil { return -1, ErrTableNotFound } - table := session.Statement.RefTable + table = session.Statement.RefTable colNames = make([]string, 0) args = make([]interface{}, 0) bValue := reflect.ValueOf(bean) @@ -1163,38 +1464,84 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 strings.Join(colNames, ", "), condition) - eargs := append(append(args, st.Params...), condiArgs...) - res, err := session.Exec(sql, eargs...) + args = append(append(args, st.Params...), condiArgs...) + + res, err := session.exec(sql, args...) if err != nil { - return -1, err + return 0, err } - rows, err := res.RowsAffected() - - if err != nil { - return -1, err + if table.Cacher != nil && session.Statement.UseCache { + session.cacheUpdate(sql, args...) } - return rows, nil + + return res.RowsAffected() +} + +func (session *Session) cacheDelete(sql string, args ...interface{}) error { + if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + return ErrCacheFailed + } + + for _, filter := range session.Engine.Filters { + sql = filter.Do(sql, session) + } + + newsql := session.Statement.convertIdSql(sql) + if newsql == "" { + return ErrCacheFailed + } + + cacher := session.Statement.RefTable.Cacher + ids, err := GetCacheSql(cacher, newsql) + if err != nil { + resultsSlice, err := session.query(newsql, args...) + if err != nil { + return err + } + ids = make([]int64, 0) + if len(resultsSlice) > 0 { + for _, data := range resultsSlice { + var id int64 + if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + return errors.New("no id") + } else { + id, err = strconv.ParseInt(string(v), 10, 64) + if err != nil { + return err + } + } + ids = append(ids, id) + } + } + } else { + fmt.Printf("-----Cached SQL: %v.\n", newsql) + DelCacheSql(cacher, newsql) + } + + for _, id := range ids { + DelCacheId(cacher, session.Statement.TableName(), id) + } + return nil } func (session *Session) Delete(bean interface{}) (int64, error) { err := session.newDb() - if session.IsAutoClose { - defer session.Close() - } if err != nil { return 0, err } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } table := session.Engine.AutoMap(bean) session.Statement.RefTable = table colNames, args := BuildConditions(session.Engine, table, bean) var condition = "" - st := session.Statement - defer session.Statement.Init() - if st.WhereStr != "" { - condition = fmt.Sprintf("WHERE %v", st.WhereStr) + if session.Statement.WhereStr != "" { + condition = fmt.Sprintf("WHERE %v", session.Statement.WhereStr) if len(colNames) > 0 { condition += " and " condition += strings.Join(colNames, " and ") @@ -1203,22 +1550,22 @@ func (session *Session) Delete(bean interface{}) (int64, error) { condition = "WHERE " + strings.Join(colNames, " and ") } - statement := fmt.Sprintf("DELETE FROM %v%v%v %v", + sql := fmt.Sprintf("DELETE FROM %v%v%v %v", session.Engine.QuoteStr(), session.Statement.TableName(), session.Engine.QuoteStr(), condition) - res, err := session.Exec(statement, append(st.Params, args...)...) + args = append(session.Statement.Params, args...) - if err != nil { - return -1, err + if table.Cacher != nil && session.Statement.UseCache { + session.cacheDelete(sql, args...) } - id, err := res.RowsAffected() - + res, err := session.exec(sql, args...) if err != nil { - return -1, err + return 0, err } - return id, nil + + return res.RowsAffected() } diff --git a/statement.go b/statement.go index 2ef417e3..2343ccb6 100644 --- a/statement.go +++ b/statement.go @@ -31,6 +31,8 @@ type Statement struct { StoreEngine string Charset string BeanArgs []interface{} + UseCache bool + UseAutoTime bool } func MakeArray(elem string, count int) []string { @@ -59,6 +61,8 @@ func (statement *Statement) Init() { statement.RawSQL = "" statement.RawParams = make([]interface{}, 0) statement.BeanArgs = make([]interface{}, 0) + statement.UseCache = statement.Engine.UseCache + statement.UseAutoTime = true } func (statement *Statement) Sql(querystring string, args ...interface{}) { diff --git a/table.go b/table.go index b93dd075..f9316f24 100644 --- a/table.go +++ b/table.go @@ -232,6 +232,7 @@ type Table struct { PrimaryKey string Created string Updated string + Cacher Cacher } func (table *Table) PKColumn() *Column { @@ -243,13 +244,6 @@ func (table *Table) AddColumn(col *Column) { table.Columns[col.Name] = col } -func NewTable() *Table { - table := &Table{Indexes: map[string][]string{}, Uniques: map[string][]string{}} - table.Columns = make(map[string]*Column) - table.ColumnsSeq = make([]string, 0) - return table -} - type Conversion interface { FromDB([]byte) error ToDB() ([]byte, error) diff --git a/xorm.go b/xorm.go index 3671a2b8..84859207 100644 --- a/xorm.go +++ b/xorm.go @@ -40,6 +40,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { //engine.Pool = NewSimpleConnectPool() //engine.Pool = NewNoneConnectPool() + //engine.Cacher = NewLRUCacher() err := engine.SetPool(NewSysConnectPool()) return engine, err