This commit is contained in:
Lunny Xiao 2013-12-18 11:31:32 +08:00
parent 286b8725ed
commit 59412a951c
23 changed files with 8388 additions and 8392 deletions

1
.gitignore vendored
View File

@ -25,3 +25,4 @@ vendor
*.log *.log
.vendor .vendor
temp_test.go

File diff suppressed because it is too large Load Diff

View File

@ -1,174 +1,174 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"testing" "testing"
) )
type BigStruct struct { type BigStruct struct {
Id int64 Id int64
Name string Name string
Title string Title string
Age string Age string
Alias string Alias string
NickName string NickName string
} }
func doBenchDriverInsert(db *sql.DB, b *testing.B) { func doBenchDriverInsert(db *sql.DB, b *testing.B) {
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := db.Exec(`insert into big_struct (name, title, age, alias, nick_name) _, err := db.Exec(`insert into big_struct (name, title, age, alias, nick_name)
values ('fafdasf', 'fadfa', 'afadfsaf', 'fadfafdsafd', 'fadfafdsaf')`) values ('fafdasf', 'fadfa', 'afadfsaf', 'fadfafdsafd', 'fadfafdsaf')`)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StopTimer() b.StopTimer()
} }
func doBenchDriverFind(db *sql.DB, b *testing.B) { func doBenchDriverFind(db *sql.DB, b *testing.B) {
b.StopTimer() b.StopTimer()
for i := 0; i < 50; i++ { for i := 0; i < 50; i++ {
_, err := db.Exec(`insert into big_struct (name, title, age, alias, nick_name) _, err := db.Exec(`insert into big_struct (name, title, age, alias, nick_name)
values ('fafdasf', 'fadfa', 'afadfsaf', 'fadfafdsafd', 'fadfafdsaf')`) values ('fafdasf', 'fadfa', 'afadfsaf', 'fadfafdsafd', 'fadfafdsaf')`)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StartTimer() b.StartTimer()
for i := 0; i < b.N/50; i++ { for i := 0; i < b.N/50; i++ {
rows, err := db.Query("select * from big_struct limit 50") rows, err := db.Query("select * from big_struct limit 50")
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
for rows.Next() { for rows.Next() {
s := &BigStruct{} s := &BigStruct{}
rows.Scan(&s.Id, &s.Name, &s.Title, &s.Age, &s.Alias, &s.NickName) rows.Scan(&s.Id, &s.Name, &s.Title, &s.Age, &s.Alias, &s.NickName)
} }
} }
b.StopTimer() b.StopTimer()
} }
func doBenchDriver(newdriver func() (*sql.DB, error), createTableSql, func doBenchDriver(newdriver func() (*sql.DB, error), createTableSql,
dropTableSql string, opFunc func(*sql.DB, *testing.B), t *testing.B) { dropTableSql string, opFunc func(*sql.DB, *testing.B), t *testing.B) {
db, err := newdriver() db, err := newdriver()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSql)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
opFunc(db, t) opFunc(db, t)
_, err = db.Exec(dropTableSql) _, err = db.Exec(dropTableSql)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
} }
func doBenchInsert(engine *Engine, b *testing.B) { func doBenchInsert(engine *Engine, b *testing.B) {
b.StopTimer() b.StopTimer()
bs := &BigStruct{0, "fafdasf", "fadfa", "afadfsaf", "fadfafdsafd", "fadfafdsaf"} bs := &BigStruct{0, "fafdasf", "fadfa", "afadfsaf", "fadfafdsafd", "fadfafdsaf"}
err := engine.CreateTables(bs) err := engine.CreateTables(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
bs.Id = 0 bs.Id = 0
_, err = engine.Insert(bs) _, err = engine.Insert(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StopTimer() b.StopTimer()
err = engine.DropTables(bs) err = engine.DropTables(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
func doBenchFind(engine *Engine, b *testing.B) { func doBenchFind(engine *Engine, b *testing.B) {
b.StopTimer() b.StopTimer()
bs := &BigStruct{0, "fafdasf", "fadfa", "afadfsaf", "fadfafdsafd", "fadfafdsaf"} bs := &BigStruct{0, "fafdasf", "fadfa", "afadfsaf", "fadfafdsafd", "fadfafdsaf"}
err := engine.CreateTables(bs) err := engine.CreateTables(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
bs.Id = 0 bs.Id = 0
_, err = engine.Insert(bs) _, err = engine.Insert(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StartTimer() b.StartTimer()
for i := 0; i < b.N/50; i++ { for i := 0; i < b.N/50; i++ {
bss := new([]BigStruct) bss := new([]BigStruct)
err = engine.Limit(50).Find(bss) err = engine.Limit(50).Find(bss)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StopTimer() b.StopTimer()
err = engine.DropTables(bs) err = engine.DropTables(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
func doBenchFindPtr(engine *Engine, b *testing.B) { func doBenchFindPtr(engine *Engine, b *testing.B) {
b.StopTimer() b.StopTimer()
bs := &BigStruct{0, "fafdasf", "fadfa", "afadfsaf", "fadfafdsafd", "fadfafdsaf"} bs := &BigStruct{0, "fafdasf", "fadfa", "afadfsaf", "fadfafdsafd", "fadfafdsaf"}
err := engine.CreateTables(bs) err := engine.CreateTables(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
bs.Id = 0 bs.Id = 0
_, err = engine.Insert(bs) _, err = engine.Insert(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StartTimer() b.StartTimer()
for i := 0; i < b.N/50; i++ { for i := 0; i < b.N/50; i++ {
bss := new([]*BigStruct) bss := new([]*BigStruct)
err = engine.Limit(50).Find(bss) err = engine.Limit(50).Find(bss)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StopTimer() b.StopTimer()
err = engine.DropTables(bs) err = engine.DropTables(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }

528
cache.go
View File

@ -1,131 +1,131 @@
package xorm package xorm
import ( import (
"container/list" "container/list"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
) )
const ( const (
// default cache expired time // default cache expired time
CacheExpired = 60 * time.Minute CacheExpired = 60 * time.Minute
// not use now // not use now
CacheMaxMemory = 256 CacheMaxMemory = 256
// evey ten minutes to clear all expired nodes // evey ten minutes to clear all expired nodes
CacheGcInterval = 10 * time.Minute CacheGcInterval = 10 * time.Minute
// each time when gc to removed max nodes // each time when gc to removed max nodes
CacheGcMaxRemoved = 20 CacheGcMaxRemoved = 20
) )
// CacheStore is a interface to store cache // CacheStore is a interface to store cache
type CacheStore interface { type CacheStore interface {
Put(key, value interface{}) error Put(key, value interface{}) error
Get(key interface{}) (interface{}, error) Get(key interface{}) (interface{}, error)
Del(key interface{}) error Del(key interface{}) error
} }
// MemoryStore implements CacheStore provide local machine // MemoryStore implements CacheStore provide local machine
// memory store // memory store
type MemoryStore struct { type MemoryStore struct {
store map[interface{}]interface{} store map[interface{}]interface{}
mutex sync.RWMutex mutex sync.RWMutex
} }
func NewMemoryStore() *MemoryStore { func NewMemoryStore() *MemoryStore {
return &MemoryStore{store: make(map[interface{}]interface{})} return &MemoryStore{store: make(map[interface{}]interface{})}
} }
func (s *MemoryStore) Put(key, value interface{}) error { func (s *MemoryStore) Put(key, value interface{}) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
s.store[key] = value s.store[key] = value
return nil return nil
} }
func (s *MemoryStore) Get(key interface{}) (interface{}, error) { func (s *MemoryStore) Get(key interface{}) (interface{}, error) {
s.mutex.RLock() s.mutex.RLock()
defer s.mutex.RUnlock() defer s.mutex.RUnlock()
if v, ok := s.store[key]; ok { if v, ok := s.store[key]; ok {
return v, nil return v, nil
} }
return nil, ErrNotExist return nil, ErrNotExist
} }
func (s *MemoryStore) Del(key interface{}) error { func (s *MemoryStore) Del(key interface{}) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
delete(s.store, key) delete(s.store, key)
return nil return nil
} }
// Cacher is an interface to provide cache // Cacher is an interface to provide cache
type Cacher interface { type Cacher interface {
GetIds(tableName, sql string) interface{} GetIds(tableName, sql string) interface{}
GetBean(tableName string, id int64) interface{} GetBean(tableName string, id int64) interface{}
PutIds(tableName, sql string, ids interface{}) PutIds(tableName, sql string, ids interface{})
PutBean(tableName string, id int64, obj interface{}) PutBean(tableName string, id int64, obj interface{})
DelIds(tableName, sql string) DelIds(tableName, sql string)
DelBean(tableName string, id int64) DelBean(tableName string, id int64)
ClearIds(tableName string) ClearIds(tableName string)
ClearBeans(tableName string) ClearBeans(tableName string)
} }
type idNode struct { type idNode struct {
tbName string tbName string
id int64 id int64
lastVisit time.Time lastVisit time.Time
} }
type sqlNode struct { type sqlNode struct {
tbName string tbName string
sql string sql string
lastVisit time.Time lastVisit time.Time
} }
func newIdNode(tbName string, id int64) *idNode { func newIdNode(tbName string, id int64) *idNode {
return &idNode{tbName, id, time.Now()} return &idNode{tbName, id, time.Now()}
} }
func newSqlNode(tbName, sql string) *sqlNode { func newSqlNode(tbName, sql string) *sqlNode {
return &sqlNode{tbName, sql, time.Now()} return &sqlNode{tbName, sql, time.Now()}
} }
// LRUCacher implements Cacher according to LRU algorithm // LRUCacher implements Cacher according to LRU algorithm
type LRUCacher struct { type LRUCacher struct {
idList *list.List idList *list.List
sqlList *list.List sqlList *list.List
idIndex map[string]map[interface{}]*list.Element idIndex map[string]map[interface{}]*list.Element
sqlIndex map[string]map[interface{}]*list.Element sqlIndex map[string]map[interface{}]*list.Element
store CacheStore store CacheStore
Max int Max int
mutex sync.Mutex mutex sync.Mutex
Expired time.Duration Expired time.Duration
maxSize int maxSize int
GcInterval time.Duration GcInterval time.Duration
} }
func newLRUCacher(store CacheStore, expired time.Duration, maxSize int, max int) *LRUCacher { func newLRUCacher(store CacheStore, expired time.Duration, maxSize int, max int) *LRUCacher {
cacher := &LRUCacher{store: store, idList: list.New(), cacher := &LRUCacher{store: store, idList: list.New(),
sqlList: list.New(), Expired: expired, maxSize: maxSize, sqlList: list.New(), Expired: expired, maxSize: maxSize,
GcInterval: CacheGcInterval, Max: max, GcInterval: CacheGcInterval, Max: max,
sqlIndex: make(map[string]map[interface{}]*list.Element), sqlIndex: make(map[string]map[interface{}]*list.Element),
idIndex: make(map[string]map[interface{}]*list.Element), idIndex: make(map[string]map[interface{}]*list.Element),
} }
cacher.RunGC() cacher.RunGC()
return cacher return cacher
} }
func NewLRUCacher(store CacheStore, max int) *LRUCacher { func NewLRUCacher(store CacheStore, max int) *LRUCacher {
return newLRUCacher(store, CacheExpired, CacheMaxMemory, max) return newLRUCacher(store, CacheExpired, CacheMaxMemory, max)
} }
func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher { func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher {
return newLRUCacher(store, expired, 0, max) return newLRUCacher(store, expired, 0, max)
} }
//func NewLRUCacher3(store CacheStore, expired time.Duration, maxSize int) *LRUCacher { //func NewLRUCacher3(store CacheStore, expired time.Duration, maxSize int) *LRUCacher {
@ -134,262 +134,262 @@ func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher
// RunGC run once every m.GcInterval // RunGC run once every m.GcInterval
func (m *LRUCacher) RunGC() { func (m *LRUCacher) RunGC() {
time.AfterFunc(m.GcInterval, func() { time.AfterFunc(m.GcInterval, func() {
m.RunGC() m.RunGC()
m.GC() m.GC()
}) })
} }
// GC check ids lit and sql list to remove all element expired // GC check ids lit and sql list to remove all element expired
func (m *LRUCacher) GC() { func (m *LRUCacher) GC() {
//fmt.Println("begin gc ...") //fmt.Println("begin gc ...")
//defer fmt.Println("end gc ...") //defer fmt.Println("end gc ...")
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var removedNum int var removedNum int
for e := m.idList.Front(); e != nil; { for e := m.idList.Front(); e != nil; {
if removedNum <= CacheGcMaxRemoved && if removedNum <= CacheGcMaxRemoved &&
time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired {
removedNum++ removedNum++
next := e.Next() next := e.Next()
//fmt.Println("removing ...", e.Value) //fmt.Println("removing ...", e.Value)
node := e.Value.(*idNode) node := e.Value.(*idNode)
m.delBean(node.tbName, node.id) m.delBean(node.tbName, node.id)
e = next e = next
} else { } else {
//fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.idList.Len()) //fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.idList.Len())
break break
} }
} }
removedNum = 0 removedNum = 0
for e := m.sqlList.Front(); e != nil; { for e := m.sqlList.Front(); e != nil; {
if removedNum <= CacheGcMaxRemoved && if removedNum <= CacheGcMaxRemoved &&
time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired {
removedNum++ removedNum++
next := e.Next() next := e.Next()
//fmt.Println("removing ...", e.Value) //fmt.Println("removing ...", e.Value)
node := e.Value.(*sqlNode) node := e.Value.(*sqlNode)
m.delIds(node.tbName, node.sql) m.delIds(node.tbName, node.sql)
e = next e = next
} else { } else {
//fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.sqlList.Len()) //fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.sqlList.Len())
break break
} }
} }
} }
// Get all bean's ids according to sql and parameter from cache // Get all bean's ids according to sql and parameter from cache
func (m *LRUCacher) GetIds(tableName, sql string) interface{} { func (m *LRUCacher) GetIds(tableName, sql string) interface{} {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if _, ok := m.sqlIndex[tableName]; !ok { if _, ok := m.sqlIndex[tableName]; !ok {
m.sqlIndex[tableName] = make(map[interface{}]*list.Element) m.sqlIndex[tableName] = make(map[interface{}]*list.Element)
} }
if v, err := m.store.Get(sql); err == nil { if v, err := m.store.Get(sql); err == nil {
if el, ok := m.sqlIndex[tableName][sql]; !ok { if el, ok := m.sqlIndex[tableName][sql]; !ok {
el = m.sqlList.PushBack(newSqlNode(tableName, sql)) el = m.sqlList.PushBack(newSqlNode(tableName, sql))
m.sqlIndex[tableName][sql] = el m.sqlIndex[tableName][sql] = el
} else { } else {
lastTime := el.Value.(*sqlNode).lastVisit lastTime := el.Value.(*sqlNode).lastVisit
// if expired, remove the node and return nil // if expired, remove the node and return nil
if time.Now().Sub(lastTime) > m.Expired { if time.Now().Sub(lastTime) > m.Expired {
m.delIds(tableName, sql) m.delIds(tableName, sql)
return nil return nil
} }
m.sqlList.MoveToBack(el) m.sqlList.MoveToBack(el)
el.Value.(*sqlNode).lastVisit = time.Now() el.Value.(*sqlNode).lastVisit = time.Now()
} }
return v return v
} else { } else {
m.delIds(tableName, sql) m.delIds(tableName, sql)
} }
return nil return nil
} }
// Get bean according tableName and id from cache // Get bean according tableName and id from cache
func (m *LRUCacher) GetBean(tableName string, id int64) interface{} { func (m *LRUCacher) GetBean(tableName string, id int64) interface{} {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if _, ok := m.idIndex[tableName]; !ok { if _, ok := m.idIndex[tableName]; !ok {
m.idIndex[tableName] = make(map[interface{}]*list.Element) m.idIndex[tableName] = make(map[interface{}]*list.Element)
} }
tid := genId(tableName, id) tid := genId(tableName, id)
if v, err := m.store.Get(tid); err == nil { if v, err := m.store.Get(tid); err == nil {
if el, ok := m.idIndex[tableName][id]; ok { if el, ok := m.idIndex[tableName][id]; ok {
lastTime := el.Value.(*idNode).lastVisit lastTime := el.Value.(*idNode).lastVisit
// if expired, remove the node and return nil // if expired, remove the node and return nil
if time.Now().Sub(lastTime) > m.Expired { if time.Now().Sub(lastTime) > m.Expired {
m.delBean(tableName, id) m.delBean(tableName, id)
//m.clearIds(tableName) //m.clearIds(tableName)
return nil return nil
} }
m.idList.MoveToBack(el) m.idList.MoveToBack(el)
el.Value.(*idNode).lastVisit = time.Now() el.Value.(*idNode).lastVisit = time.Now()
} else { } else {
el = m.idList.PushBack(newIdNode(tableName, id)) el = m.idList.PushBack(newIdNode(tableName, id))
m.idIndex[tableName][id] = el m.idIndex[tableName][id] = el
} }
return v return v
} else { } else {
// store bean is not exist, then remove memory's index // store bean is not exist, then remove memory's index
m.delBean(tableName, id) m.delBean(tableName, id)
//m.clearIds(tableName) //m.clearIds(tableName)
return nil return nil
} }
} }
// Clear all sql-ids mapping on table tableName from cache // Clear all sql-ids mapping on table tableName from cache
func (m *LRUCacher) clearIds(tableName string) { func (m *LRUCacher) clearIds(tableName string) {
if tis, ok := m.sqlIndex[tableName]; ok { if tis, ok := m.sqlIndex[tableName]; ok {
for sql, v := range tis { for sql, v := range tis {
m.sqlList.Remove(v) m.sqlList.Remove(v)
m.store.Del(sql) m.store.Del(sql)
} }
} }
m.sqlIndex[tableName] = make(map[interface{}]*list.Element) m.sqlIndex[tableName] = make(map[interface{}]*list.Element)
} }
func (m *LRUCacher) ClearIds(tableName string) { func (m *LRUCacher) ClearIds(tableName string) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.clearIds(tableName) m.clearIds(tableName)
} }
func (m *LRUCacher) clearBeans(tableName string) { func (m *LRUCacher) clearBeans(tableName string) {
if tis, ok := m.idIndex[tableName]; ok { if tis, ok := m.idIndex[tableName]; ok {
for id, v := range tis { for id, v := range tis {
m.idList.Remove(v) m.idList.Remove(v)
tid := genId(tableName, id.(int64)) tid := genId(tableName, id.(int64))
m.store.Del(tid) m.store.Del(tid)
} }
} }
m.idIndex[tableName] = make(map[interface{}]*list.Element) m.idIndex[tableName] = make(map[interface{}]*list.Element)
} }
func (m *LRUCacher) ClearBeans(tableName string) { func (m *LRUCacher) ClearBeans(tableName string) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.clearBeans(tableName) m.clearBeans(tableName)
} }
func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) { func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if _, ok := m.sqlIndex[tableName]; !ok { if _, ok := m.sqlIndex[tableName]; !ok {
m.sqlIndex[tableName] = make(map[interface{}]*list.Element) m.sqlIndex[tableName] = make(map[interface{}]*list.Element)
} }
if el, ok := m.sqlIndex[tableName][sql]; !ok { if el, ok := m.sqlIndex[tableName][sql]; !ok {
el = m.sqlList.PushBack(newSqlNode(tableName, sql)) el = m.sqlList.PushBack(newSqlNode(tableName, sql))
m.sqlIndex[tableName][sql] = el m.sqlIndex[tableName][sql] = el
} else { } else {
el.Value.(*sqlNode).lastVisit = time.Now() el.Value.(*sqlNode).lastVisit = time.Now()
} }
m.store.Put(sql, ids) m.store.Put(sql, ids)
if m.sqlList.Len() > m.Max { if m.sqlList.Len() > m.Max {
e := m.sqlList.Front() e := m.sqlList.Front()
node := e.Value.(*sqlNode) node := e.Value.(*sqlNode)
m.delIds(node.tbName, node.sql) m.delIds(node.tbName, node.sql)
} }
} }
func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) { func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var el *list.Element var el *list.Element
var ok bool var ok bool
if el, ok = m.idIndex[tableName][id]; !ok { if el, ok = m.idIndex[tableName][id]; !ok {
el = m.idList.PushBack(newIdNode(tableName, id)) el = m.idList.PushBack(newIdNode(tableName, id))
m.idIndex[tableName][id] = el m.idIndex[tableName][id] = el
} else { } else {
el.Value.(*idNode).lastVisit = time.Now() el.Value.(*idNode).lastVisit = time.Now()
} }
m.store.Put(genId(tableName, id), obj) m.store.Put(genId(tableName, id), obj)
if m.idList.Len() > m.Max { if m.idList.Len() > m.Max {
e := m.idList.Front() e := m.idList.Front()
node := e.Value.(*idNode) node := e.Value.(*idNode)
m.delBean(node.tbName, node.id) m.delBean(node.tbName, node.id)
} }
} }
func (m *LRUCacher) delIds(tableName, sql string) { func (m *LRUCacher) delIds(tableName, sql string) {
if _, ok := m.sqlIndex[tableName]; ok { if _, ok := m.sqlIndex[tableName]; ok {
if el, ok := m.sqlIndex[tableName][sql]; ok { if el, ok := m.sqlIndex[tableName][sql]; ok {
delete(m.sqlIndex[tableName], sql) delete(m.sqlIndex[tableName], sql)
m.sqlList.Remove(el) m.sqlList.Remove(el)
} }
} }
m.store.Del(sql) m.store.Del(sql)
} }
func (m *LRUCacher) DelIds(tableName, sql string) { func (m *LRUCacher) DelIds(tableName, sql string) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.delIds(tableName, sql) m.delIds(tableName, sql)
} }
func (m *LRUCacher) delBean(tableName string, id int64) { func (m *LRUCacher) delBean(tableName string, id int64) {
tid := genId(tableName, id) tid := genId(tableName, id)
if el, ok := m.idIndex[tableName][id]; ok { if el, ok := m.idIndex[tableName][id]; ok {
delete(m.idIndex[tableName], id) delete(m.idIndex[tableName], id)
m.idList.Remove(el) m.idList.Remove(el)
m.clearIds(tableName) m.clearIds(tableName)
} }
m.store.Del(tid) m.store.Del(tid)
} }
func (m *LRUCacher) DelBean(tableName string, id int64) { func (m *LRUCacher) DelBean(tableName string, id int64) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.delBean(tableName, id) m.delBean(tableName, id)
} }
func encodeIds(ids []int64) (s string) { func encodeIds(ids []int64) (s string) {
s = "[" s = "["
for _, id := range ids { for _, id := range ids {
s += fmt.Sprintf("%v,", id) s += fmt.Sprintf("%v,", id)
} }
s = s[:len(s)-1] + "]" s = s[:len(s)-1] + "]"
return return
} }
func decodeIds(s string) []int64 { func decodeIds(s string) []int64 {
res := make([]int64, 0) res := make([]int64, 0)
if len(s) >= 2 { if len(s) >= 2 {
ss := strings.Split(s[1:len(s)-1], ",") ss := strings.Split(s[1:len(s)-1], ",")
for _, s := range ss { for _, s := range ss {
i, err := strconv.ParseInt(s, 10, 64) i, err := strconv.ParseInt(s, 10, 64)
if err != nil { if err != nil {
return res return res
} }
res = append(res, i) res = append(res, i)
} }
} }
return res return res
} }
func getCacheSql(m Cacher, tableName, sql string, args interface{}) ([]int64, error) { func getCacheSql(m Cacher, tableName, sql string, args interface{}) ([]int64, error) {
bytes := m.GetIds(tableName, genSqlKey(sql, args)) bytes := m.GetIds(tableName, genSqlKey(sql, args))
if bytes == nil { if bytes == nil {
return nil, errors.New("Not Exist") return nil, errors.New("Not Exist")
} }
objs := decodeIds(bytes.(string)) objs := decodeIds(bytes.(string))
return objs, nil return objs, nil
} }
func putCacheSql(m Cacher, ids []int64, tableName, sql string, args interface{}) error { func putCacheSql(m Cacher, ids []int64, tableName, sql string, args interface{}) error {
bytes := encodeIds(ids) bytes := encodeIds(ids)
m.PutIds(tableName, genSqlKey(sql, args), bytes) m.PutIds(tableName, genSqlKey(sql, args), bytes)
return nil return nil
} }
func genSqlKey(sql string, args interface{}) string { func genSqlKey(sql string, args interface{}) string {
return fmt.Sprintf("%v-%v", sql, args) return fmt.Sprintf("%v-%v", sql, args)
} }
func genId(prefix string, id int64) string { func genId(prefix string, id int64) string {
return fmt.Sprintf("%v-%v", prefix, id) return fmt.Sprintf("%v-%v", prefix, id)
} }

1236
engine.go

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,15 @@
package xorm package xorm
import ( import (
"errors" "errors"
) )
var ( var (
ErrParamsType error = errors.New("Params type error") ErrParamsType error = errors.New("Params type error")
ErrTableNotFound error = errors.New("Not found table") ErrTableNotFound error = errors.New("Not found table")
ErrUnSupportedType error = errors.New("Unsupported type error") ErrUnSupportedType error = errors.New("Unsupported type error")
ErrNotExist error = errors.New("Not exist error") ErrNotExist error = errors.New("Not exist error")
ErrCacheFailed error = errors.New("Cache failed") ErrCacheFailed error = errors.New("Cache failed")
ErrNeedDeletedCond error = errors.New("Delete need at least one condition") ErrNeedDeletedCond error = errors.New("Delete need at least one condition")
ErrNotImplemented error = errors.New("Not implemented.") ErrNotImplemented error = errors.New("Not implemented.")
) )

View File

@ -1,13 +1,13 @@
package xorm package xorm
import ( import (
"fmt" "fmt"
"strings" "strings"
) )
// Filter is an interface to filter SQL // Filter is an interface to filter SQL
type Filter interface { type Filter interface {
Do(sql string, session *Session) string Do(sql string, session *Session) string
} }
// PgSeqFilter filter SQL replace ?, ? ... to $1, $2 ... // PgSeqFilter filter SQL replace ?, ? ... to $1, $2 ...
@ -15,16 +15,16 @@ type PgSeqFilter struct {
} }
func (s *PgSeqFilter) Do(sql string, session *Session) string { func (s *PgSeqFilter) Do(sql string, session *Session) string {
segs := strings.Split(sql, "?") segs := strings.Split(sql, "?")
size := len(segs) size := len(segs)
res := "" res := ""
for i, c := range segs { for i, c := range segs {
if i < size-1 { if i < size-1 {
res += c + fmt.Sprintf("$%v", i+1) res += c + fmt.Sprintf("$%v", i+1)
} }
} }
res += segs[size-1] res += segs[size-1]
return res return res
} }
// QuoteFilter filter SQL replace ` to database's own quote character // QuoteFilter filter SQL replace ` to database's own quote character
@ -32,7 +32,7 @@ type QuoteFilter struct {
} }
func (s *QuoteFilter) Do(sql string, session *Session) string { func (s *QuoteFilter) Do(sql string, session *Session) string {
return strings.Replace(sql, "`", session.Engine.QuoteStr(), -1) return strings.Replace(sql, "`", session.Engine.QuoteStr(), -1)
} }
// IdFilter filter SQL replace (id) to primary key column name // IdFilter filter SQL replace (id) to primary key column name
@ -40,10 +40,10 @@ type IdFilter struct {
} }
func (i *IdFilter) Do(sql string, session *Session) string { func (i *IdFilter) Do(sql string, session *Session) string {
if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" {
sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1)
sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1)
return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1)
} }
return sql return sql
} }

View File

@ -1,63 +1,63 @@
package xorm package xorm
import ( import (
"reflect" "reflect"
"strings" "strings"
) )
func indexNoCase(s, sep string) int { func indexNoCase(s, sep string) int {
return strings.Index(strings.ToLower(s), strings.ToLower(sep)) return strings.Index(strings.ToLower(s), strings.ToLower(sep))
} }
func splitNoCase(s, sep string) []string { func splitNoCase(s, sep string) []string {
idx := indexNoCase(s, sep) idx := indexNoCase(s, sep)
if idx < 0 { if idx < 0 {
return []string{s} return []string{s}
} }
return strings.Split(s, s[idx:idx+len(sep)]) return strings.Split(s, s[idx:idx+len(sep)])
} }
func splitNNoCase(s, sep string, n int) []string { func splitNNoCase(s, sep string, n int) []string {
idx := indexNoCase(s, sep) idx := indexNoCase(s, sep)
if idx < 0 { if idx < 0 {
return []string{s} return []string{s}
} }
return strings.SplitN(s, s[idx:idx+len(sep)], n) return strings.SplitN(s, s[idx:idx+len(sep)], n)
} }
func makeArray(elem string, count int) []string { func makeArray(elem string, count int) []string {
res := make([]string, count) res := make([]string, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
res[i] = elem res[i] = elem
} }
return res return res
} }
func rType(bean interface{}) reflect.Type { func rType(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean)) sliceValue := reflect.Indirect(reflect.ValueOf(bean))
return reflect.TypeOf(sliceValue.Interface()) return reflect.TypeOf(sliceValue.Interface())
} }
func structName(v reflect.Type) string { func structName(v reflect.Type) string {
for v.Kind() == reflect.Ptr { for v.Kind() == reflect.Ptr {
v = v.Elem() v = v.Elem()
} }
return v.Name() return v.Name()
} }
func sliceEq(left, right []string) bool { func sliceEq(left, right []string) bool {
for _, l := range left { for _, l := range left {
var find bool var find bool
for _, r := range right { for _, r := range right {
if l == r { if l == r {
find = true find = true
break break
} }
} }
if !find { if !find {
return false return false
} }
} }
return true return true
} }

View File

@ -1,13 +1,13 @@
package xorm package xorm
import ( import (
"strings" "strings"
) )
// name translation between struct, fields names and table, column names // name translation between struct, fields names and table, column names
type IMapper interface { type IMapper interface {
Obj2Table(string) string Obj2Table(string) string
Table2Obj(string) string Table2Obj(string) string
} }
// SameMapper implements IMapper and provides same name between struct and // SameMapper implements IMapper and provides same name between struct and
@ -16,11 +16,11 @@ type SameMapper struct {
} }
func (m SameMapper) Obj2Table(o string) string { func (m SameMapper) Obj2Table(o string) string {
return o return o
} }
func (m SameMapper) Table2Obj(t string) string { func (m SameMapper) Table2Obj(t string) string {
return t return t
} }
// SnakeMapper implements IMapper and provides name transaltion between // SnakeMapper implements IMapper and provides name transaltion between
@ -29,18 +29,18 @@ type SnakeMapper struct {
} }
func snakeCasedName(name string) string { func snakeCasedName(name string) string {
newstr := make([]rune, 0) newstr := make([]rune, 0)
for idx, chr := range name { for idx, chr := range name {
if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { if isUpper := 'A' <= chr && chr <= 'Z'; isUpper {
if idx > 0 { if idx > 0 {
newstr = append(newstr, '_') newstr = append(newstr, '_')
} }
chr -= ('A' - 'a') chr -= ('A' - 'a')
} }
newstr = append(newstr, chr) newstr = append(newstr, chr)
} }
return string(newstr) return string(newstr)
} }
/*func pascal2Sql(s string) (d string) { /*func pascal2Sql(s string) (d string) {
@ -63,69 +63,69 @@ func snakeCasedName(name string) string {
}*/ }*/
func (mapper SnakeMapper) Obj2Table(name string) string { func (mapper SnakeMapper) Obj2Table(name string) string {
return snakeCasedName(name) return snakeCasedName(name)
} }
func titleCasedName(name string) string { func titleCasedName(name string) string {
newstr := make([]rune, 0) newstr := make([]rune, 0)
upNextChar := true upNextChar := true
name = strings.ToLower(name) name = strings.ToLower(name)
for _, chr := range name { for _, chr := range name {
switch { switch {
case upNextChar: case upNextChar:
upNextChar = false upNextChar = false
if 'a' <= chr && chr <= 'z' { if 'a' <= chr && chr <= 'z' {
chr -= ('a' - 'A') chr -= ('a' - 'A')
} }
case chr == '_': case chr == '_':
upNextChar = true upNextChar = true
continue continue
} }
newstr = append(newstr, chr) newstr = append(newstr, chr)
} }
return string(newstr) return string(newstr)
} }
func (mapper SnakeMapper) Table2Obj(name string) string { func (mapper SnakeMapper) Table2Obj(name string) string {
return titleCasedName(name) return titleCasedName(name)
} }
// provide prefix table name support // provide prefix table name support
type PrefixMapper struct { type PrefixMapper struct {
Mapper IMapper Mapper IMapper
Prefix string Prefix string
} }
func (mapper PrefixMapper) Obj2Table(name string) string { func (mapper PrefixMapper) Obj2Table(name string) string {
return mapper.Prefix + mapper.Mapper.Obj2Table(name) return mapper.Prefix + mapper.Mapper.Obj2Table(name)
} }
func (mapper PrefixMapper) Table2Obj(name string) string { func (mapper PrefixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):])
} }
func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper { func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper {
return PrefixMapper{mapper, prefix} return PrefixMapper{mapper, prefix}
} }
// provide suffix table name support // provide suffix table name support
type SuffixMapper struct { type SuffixMapper struct {
Mapper IMapper Mapper IMapper
Suffix string Suffix string
} }
func (mapper SuffixMapper) Obj2Table(name string) string { func (mapper SuffixMapper) Obj2Table(name string) string {
return mapper.Suffix + mapper.Mapper.Obj2Table(name) return mapper.Suffix + mapper.Mapper.Obj2Table(name)
} }
func (mapper SuffixMapper) Table2Obj(name string) string { func (mapper SuffixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[len(mapper.Suffix):]) return mapper.Mapper.Table2Obj(name[len(mapper.Suffix):])
} }
func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper { func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper {
return SuffixMapper{mapper, suffix} return SuffixMapper{mapper, suffix}
} }

View File

@ -1,67 +1,67 @@
package xorm package xorm
import ( import (
"errors" "errors"
"strings" "strings"
"time" "time"
) )
type mymysql struct { type mymysql struct {
mysql mysql
} }
type mymysqlParser struct { type mymysqlParser struct {
} }
func (p *mymysqlParser) parse(driverName, dataSourceName string) (*uri, error) { func (p *mymysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
db := &uri{dbType: MYSQL} db := &uri{dbType: MYSQL}
pd := strings.SplitN(dataSourceName, "*", 2) pd := strings.SplitN(dataSourceName, "*", 2)
if len(pd) == 2 { if len(pd) == 2 {
// Parse protocol part of URI // Parse protocol part of URI
p := strings.SplitN(pd[0], ":", 2) p := strings.SplitN(pd[0], ":", 2)
if len(p) != 2 { if len(p) != 2 {
return nil, errors.New("Wrong protocol part of URI") return nil, errors.New("Wrong protocol part of URI")
} }
db.proto = p[0] db.proto = p[0]
options := strings.Split(p[1], ",") options := strings.Split(p[1], ",")
db.raddr = options[0] db.raddr = options[0]
for _, o := range options[1:] { for _, o := range options[1:] {
kv := strings.SplitN(o, "=", 2) kv := strings.SplitN(o, "=", 2)
var k, v string var k, v string
if len(kv) == 2 { if len(kv) == 2 {
k, v = kv[0], kv[1] k, v = kv[0], kv[1]
} else { } else {
k, v = o, "true" k, v = o, "true"
} }
switch k { switch k {
case "laddr": case "laddr":
db.laddr = v db.laddr = v
case "timeout": case "timeout":
to, err := time.ParseDuration(v) to, err := time.ParseDuration(v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
db.timeout = to db.timeout = to
default: default:
return nil, errors.New("Unknown option: " + k) return nil, errors.New("Unknown option: " + k)
} }
} }
// Remove protocol part // Remove protocol part
pd = pd[1:] pd = pd[1:]
} }
// Parse database part of URI // Parse database part of URI
dup := strings.SplitN(pd[0], "/", 3) dup := strings.SplitN(pd[0], "/", 3)
if len(dup) != 3 { if len(dup) != 3 {
return nil, errors.New("Wrong database part of URI") return nil, errors.New("Wrong database part of URI")
} }
db.dbName = dup[0] db.dbName = dup[0]
db.user = dup[1] db.user = dup[1]
db.passwd = dup[2] db.passwd = dup[2]
return db, nil return db, nil
} }
func (db *mymysql) Init(drivername, uri string) error { func (db *mymysql) Init(drivername, uri string) error {
return db.mysql.base.init(&mymysqlParser{}, drivername, uri) return db.mysql.base.init(&mymysqlParser{}, drivername, uri)
} }

View File

@ -1,10 +1,10 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"testing" "testing"
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
) )
/* /*
@ -15,153 +15,153 @@ utf8 COLLATE utf8_general_ci;
var showTestSql bool = true var showTestSql bool = true
func TestMyMysql(t *testing.T) { func TestMyMysql(t *testing.T) {
err := mymysqlDdlImport() err := mymysqlDdlImport()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine, err := NewEngine("mymysql", "xorm_test/root/") engine, err := NewEngine("mymysql", "xorm_test/root/")
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
testAll3(engine, t) testAll3(engine, t)
} }
func TestMyMysqlWithCache(t *testing.T) { func TestMyMysqlWithCache(t *testing.T) {
err := mymysqlDdlImport() err := mymysqlDdlImport()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine, err := NewEngine("mymysql", "xorm_test2/root/") engine, err := NewEngine("mymysql", "xorm_test2/root/")
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
} }
func newMyMysqlEngine() (*Engine, error) { func newMyMysqlEngine() (*Engine, error) {
return NewEngine("mymysql", "xorm_test2/root/") return NewEngine("mymysql", "xorm_test2/root/")
} }
func newMyMysqlDriverDB() (*sql.DB, error) { func newMyMysqlDriverDB() (*sql.DB, error) {
return sql.Open("mymysql", "xorm_test2/root/") return sql.Open("mymysql", "xorm_test2/root/")
} }
func BenchmarkMyMysqlDriverInsert(t *testing.B) { func BenchmarkMyMysqlDriverInsert(t *testing.B) {
doBenchDriver(newMyMysqlDriverDB, createTableMySql, dropTableMySql, doBenchDriver(newMyMysqlDriverDB, createTableMySql, dropTableMySql,
doBenchDriverInsert, t) doBenchDriverInsert, t)
} }
func BenchmarkMyMysqlDriverFind(t *testing.B) { func BenchmarkMyMysqlDriverFind(t *testing.B) {
doBenchDriver(newMyMysqlDriverDB, createTableMySql, dropTableMySql, doBenchDriver(newMyMysqlDriverDB, createTableMySql, dropTableMySql,
doBenchDriverFind, t) doBenchDriverFind, t)
} }
func mymysqlDdlImport() error { func mymysqlDdlImport() error {
engine, err := NewEngine("mymysql", "/root/") engine, err := NewEngine("mymysql", "/root/")
if err != nil { if err != nil {
return err return err
} }
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
sqlResults, _ := engine.Import("tests/mysql_ddl.sql") sqlResults, _ := engine.Import("tests/mysql_ddl.sql")
engine.LogDebug("sql results: %v", sqlResults) engine.LogDebug("sql results: %v", sqlResults)
engine.Close() engine.Close()
return nil return nil
} }
func BenchmarkMyMysqlNoCacheInsert(t *testing.B) { func BenchmarkMyMysqlNoCacheInsert(t *testing.B) {
engine, err := newMyMysqlEngine() engine, err := newMyMysqlEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkMyMysqlNoCacheFind(t *testing.B) { func BenchmarkMyMysqlNoCacheFind(t *testing.B) {
engine, err := newMyMysqlEngine() engine, err := newMyMysqlEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkMyMysqlNoCacheFindPtr(t *testing.B) { func BenchmarkMyMysqlNoCacheFindPtr(t *testing.B) {
engine, err := newMyMysqlEngine() engine, err := newMyMysqlEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }
func BenchmarkMyMysqlCacheInsert(t *testing.B) { func BenchmarkMyMysqlCacheInsert(t *testing.B) {
engine, err := newMyMysqlEngine() engine, err := newMyMysqlEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkMyMysqlCacheFind(t *testing.B) { func BenchmarkMyMysqlCacheFind(t *testing.B) {
engine, err := newMyMysqlEngine() engine, err := newMyMysqlEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkMyMysqlCacheFindPtr(t *testing.B) { func BenchmarkMyMysqlCacheFindPtr(t *testing.B) {
engine, err := newMyMysqlEngine() engine, err := newMyMysqlEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

500
mysql.go
View File

@ -1,323 +1,323 @@
package xorm package xorm
import ( import (
"crypto/tls" "crypto/tls"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
) )
type uri struct { type uri struct {
dbType string dbType string
proto string proto string
host string host string
port string port string
dbName string dbName string
user string user string
passwd string passwd string
charset string charset string
laddr string laddr string
raddr string raddr string
timeout time.Duration timeout time.Duration
} }
type parser interface { type parser interface {
parse(driverName, dataSourceName string) (*uri, error) parse(driverName, dataSourceName string) (*uri, error)
} }
type mysqlParser struct { type mysqlParser struct {
} }
func (p *mysqlParser) parse(driverName, dataSourceName string) (*uri, error) { func (p *mysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
//cfg.params = make(map[string]string) //cfg.params = make(map[string]string)
dsnPattern := regexp.MustCompile( dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@] `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]] `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname `\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN] `(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dataSourceName) matches := dsnPattern.FindStringSubmatch(dataSourceName)
//tlsConfigRegister := make(map[string]*tls.Config) //tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames() names := dsnPattern.SubexpNames()
uri := &uri{dbType: MYSQL} uri := &uri{dbType: MYSQL}
for i, match := range matches { for i, match := range matches {
switch names[i] { switch names[i] {
case "dbname": case "dbname":
uri.dbName = match uri.dbName = match
} }
} }
return uri, nil return uri, nil
} }
type base struct { type base struct {
parser parser parser parser
driverName string driverName string
dataSourceName string dataSourceName string
*uri *uri
} }
func (b *base) init(parser parser, drivername, dataSourceName string) (err error) { func (b *base) init(parser parser, drivername, dataSourceName string) (err error) {
b.parser = parser b.parser = parser
b.driverName, b.dataSourceName = drivername, dataSourceName b.driverName, b.dataSourceName = drivername, dataSourceName
b.uri, err = b.parser.parse(b.driverName, b.dataSourceName) b.uri, err = b.parser.parse(b.driverName, b.dataSourceName)
return return
} }
type mysql struct { type mysql struct {
base base
net string net string
addr string addr string
params map[string]string params map[string]string
loc *time.Location loc *time.Location
timeout time.Duration timeout time.Duration
tls *tls.Config tls *tls.Config
allowAllFiles bool allowAllFiles bool
allowOldPasswords bool allowOldPasswords bool
clientFoundRows bool clientFoundRows bool
} }
func (db *mysql) Init(drivername, uri string) error { func (db *mysql) Init(drivername, uri string) error {
return db.base.init(&mysqlParser{}, drivername, uri) return db.base.init(&mysqlParser{}, drivername, uri)
} }
func (db *mysql) SqlType(c *Column) string { func (db *mysql) SqlType(c *Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case Bool: case Bool:
res = TinyInt res = TinyInt
case Serial: case Serial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = Int res = Int
case BigSerial: case BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = BigInt res = BigInt
case Bytea: case Bytea:
res = Blob res = Blob
case TimeStampz: case TimeStampz:
res = Char res = Char
c.Length = 64 c.Length = 64
default: default:
res = t res = t
} }
var hasLen1 bool = (c.Length > 0) var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0) var hasLen2 bool = (c.Length2 > 0)
if hasLen1 { if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.Itoa(c.Length) + ")"
} else if hasLen2 { } else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
} }
return res return res
} }
func (db *mysql) SupportInsertMany() bool { func (db *mysql) SupportInsertMany() bool {
return true return true
} }
func (db *mysql) QuoteStr() string { func (db *mysql) QuoteStr() string {
return "`" return "`"
} }
func (db *mysql) SupportEngine() bool { func (db *mysql) SupportEngine() bool {
return true return true
} }
func (db *mysql) AutoIncrStr() string { func (db *mysql) AutoIncrStr() string {
return "AUTO_INCREMENT" return "AUTO_INCREMENT"
} }
func (db *mysql) SupportCharset() bool { func (db *mysql) SupportCharset() bool {
return true return true
} }
func (db *mysql) IndexOnTable() bool { func (db *mysql) IndexOnTable() bool {
return true return true
} }
func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{db.dbName, tableName, idxName} args := []interface{}{db.dbName, tableName, idxName}
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?" sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?"
return sql, args return sql, args
} }
func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{db.dbName, tableName, colName} args := []interface{}{db.dbName, tableName, colName}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
return sql, args return sql, args
} }
func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{db.dbName, tableName} args := []interface{}{db.dbName, tableName}
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
return sql, args return sql, args
} }
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{db.dbName, tableName} args := []interface{}{db.dbName, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, record := range res { for _, record := range res {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "COLUMN_NAME": case "COLUMN_NAME":
col.Name = strings.Trim(string(content), "` ") col.Name = strings.Trim(string(content), "` ")
case "IS_NULLABLE": case "IS_NULLABLE":
if "YES" == string(content) { if "YES" == string(content) {
col.Nullable = true col.Nullable = true
} }
case "COLUMN_DEFAULT": case "COLUMN_DEFAULT":
// add '' // add ''
col.Default = string(content) col.Default = string(content)
case "COLUMN_TYPE": case "COLUMN_TYPE":
cts := strings.Split(string(content), "(") cts := strings.Split(string(content), "(")
var len1, len2 int var len1, len2 int
if len(cts) == 2 { if len(cts) == 2 {
idx := strings.Index(cts[1], ")") idx := strings.Index(cts[1], ")")
lens := strings.Split(cts[1][0:idx], ",") lens := strings.Split(cts[1][0:idx], ",")
len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if len(lens) == 2 { if len(lens) == 2 {
len2, err = strconv.Atoi(lens[1]) len2, err = strconv.Atoi(lens[1])
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} }
} }
colName := cts[0] colName := cts[0]
colType := strings.ToUpper(colName) colType := strings.ToUpper(colName)
col.Length = len1 col.Length = len1
col.Length2 = len2 col.Length2 = len2
if _, ok := sqlTypes[colType]; ok { if _, ok := sqlTypes[colType]; ok {
col.SQLType = SQLType{colType, len1, len2} col.SQLType = SQLType{colType, len1, len2}
} else { } else {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
} }
case "COLUMN_KEY": case "COLUMN_KEY":
key := string(content) key := string(content)
if key == "PRI" { if key == "PRI" {
col.IsPrimaryKey = true col.IsPrimaryKey = true
} }
if key == "UNI" { if key == "UNI" {
//col.is //col.is
} }
case "EXTRA": case "EXTRA":
extra := string(content) extra := string(content)
if extra == "auto_increment" { if extra == "auto_increment" {
col.IsAutoIncrement = true col.IsAutoIncrement = true
} }
} }
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *mysql) GetTables() ([]*Table, error) { func (db *mysql) GetTables() ([]*Table, error) {
args := []interface{}{db.dbName} args := []interface{}{db.dbName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?" s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for _, record := range res {
table := new(Table) table := new(Table)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "TABLE_NAME": case "TABLE_NAME":
table.Name = strings.Trim(string(content), "` ") table.Name = strings.Trim(string(content), "` ")
case "ENGINE": case "ENGINE":
} }
} }
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
} }
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{db.dbName, tableName} args := []interface{}{db.dbName, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for _, record := range res {
var indexType int var indexType int
var indexName, colName string var indexName, colName string
for name, content := range record { for name, content := range record {
switch name { switch name {
case "NON_UNIQUE": case "NON_UNIQUE":
if "YES" == string(content) || string(content) == "1" { if "YES" == string(content) || string(content) == "1" {
indexType = IndexType indexType = IndexType
} else { } else {
indexType = UniqueType indexType = UniqueType
} }
case "INDEX_NAME": case "INDEX_NAME":
indexName = string(content) indexName = string(content)
case "COLUMN_NAME": case "COLUMN_NAME":
colName = strings.Trim(string(content), "` ") colName = strings.Trim(string(content), "` ")
} }
} }
if indexName == "PRIMARY" { if indexName == "PRIMARY" {
continue continue
} }
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)] indexName = indexName[5+len(tableName) : len(indexName)]
} }
var index *Index var index *Index
var ok bool var ok bool
if index, ok = indexes[indexName]; !ok { if index, ok = indexes[indexName]; !ok {
index = new(Index) index = new(Index)
index.Type = indexType index.Type = indexType
index.Name = indexName index.Name = indexName
indexes[indexName] = index indexes[indexName] = index
} }
index.AddColumn(colName) index.AddColumn(colName)
} }
return indexes, nil return indexes, nil
} }

View File

@ -1,10 +1,10 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"testing" "testing"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
) )
/* /*
@ -15,155 +15,155 @@ utf8 COLLATE utf8_general_ci;
var mysqlShowTestSql bool = true var mysqlShowTestSql bool = true
func TestMysql(t *testing.T) { func TestMysql(t *testing.T) {
err := mysqlDdlImport() err := mysqlDdlImport()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8") engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8")
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.ShowSQL = mysqlShowTestSql engine.ShowSQL = mysqlShowTestSql
engine.ShowErr = mysqlShowTestSql engine.ShowErr = mysqlShowTestSql
engine.ShowWarn = mysqlShowTestSql engine.ShowWarn = mysqlShowTestSql
engine.ShowDebug = mysqlShowTestSql engine.ShowDebug = mysqlShowTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
testAll3(engine, t) testAll3(engine, t)
} }
func TestMysqlWithCache(t *testing.T) { func TestMysqlWithCache(t *testing.T) {
err := mysqlDdlImport() err := mysqlDdlImport()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8") engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8")
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
engine.ShowSQL = mysqlShowTestSql engine.ShowSQL = mysqlShowTestSql
engine.ShowErr = mysqlShowTestSql engine.ShowErr = mysqlShowTestSql
engine.ShowWarn = mysqlShowTestSql engine.ShowWarn = mysqlShowTestSql
engine.ShowDebug = mysqlShowTestSql engine.ShowDebug = mysqlShowTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
} }
func newMysqlEngine() (*Engine, error) { func newMysqlEngine() (*Engine, error) {
return NewEngine("mysql", "root:@/xorm_test?charset=utf8") return NewEngine("mysql", "root:@/xorm_test?charset=utf8")
} }
func mysqlDdlImport() error { func mysqlDdlImport() error {
engine, err := NewEngine("mysql", "root:@/?charset=utf8") engine, err := NewEngine("mysql", "root:@/?charset=utf8")
if err != nil { if err != nil {
return err return err
} }
engine.ShowSQL = mysqlShowTestSql engine.ShowSQL = mysqlShowTestSql
engine.ShowErr = mysqlShowTestSql engine.ShowErr = mysqlShowTestSql
engine.ShowWarn = mysqlShowTestSql engine.ShowWarn = mysqlShowTestSql
engine.ShowDebug = mysqlShowTestSql engine.ShowDebug = mysqlShowTestSql
sqlResults, _ := engine.Import("tests/mysql_ddl.sql") sqlResults, _ := engine.Import("tests/mysql_ddl.sql")
engine.LogDebug("sql results: %v", sqlResults) engine.LogDebug("sql results: %v", sqlResults)
engine.Close() engine.Close()
return nil return nil
} }
func newMysqlDriverDB() (*sql.DB, error) { func newMysqlDriverDB() (*sql.DB, error) {
return sql.Open("mysql", "root:@/xorm_test?charset=utf8") return sql.Open("mysql", "root:@/xorm_test?charset=utf8")
} }
const ( const (
createTableMySql = "CREATE TABLE IF NOT EXISTS `big_struct` (`id` BIGINT PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` VARCHAR(255) NULL, `title` VARCHAR(255) NULL, `age` VARCHAR(255) NULL, `alias` VARCHAR(255) NULL, `nick_name` VARCHAR(255) NULL);" createTableMySql = "CREATE TABLE IF NOT EXISTS `big_struct` (`id` BIGINT PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` VARCHAR(255) NULL, `title` VARCHAR(255) NULL, `age` VARCHAR(255) NULL, `alias` VARCHAR(255) NULL, `nick_name` VARCHAR(255) NULL);"
dropTableMySql = "DROP TABLE IF EXISTS `big_struct`;" dropTableMySql = "DROP TABLE IF EXISTS `big_struct`;"
) )
func BenchmarkMysqlDriverInsert(t *testing.B) { func BenchmarkMysqlDriverInsert(t *testing.B) {
doBenchDriver(newMysqlDriverDB, createTableMySql, dropTableMySql, doBenchDriver(newMysqlDriverDB, createTableMySql, dropTableMySql,
doBenchDriverInsert, t) doBenchDriverInsert, t)
} }
func BenchmarkMysqlDriverFind(t *testing.B) { func BenchmarkMysqlDriverFind(t *testing.B) {
doBenchDriver(newMysqlDriverDB, createTableMySql, dropTableMySql, doBenchDriver(newMysqlDriverDB, createTableMySql, dropTableMySql,
doBenchDriverFind, t) doBenchDriverFind, t)
} }
func BenchmarkMysqlNoCacheInsert(t *testing.B) { func BenchmarkMysqlNoCacheInsert(t *testing.B) {
engine, err := newMysqlEngine() engine, err := newMysqlEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkMysqlNoCacheFind(t *testing.B) { func BenchmarkMysqlNoCacheFind(t *testing.B) {
engine, err := newMysqlEngine() engine, err := newMysqlEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkMysqlNoCacheFindPtr(t *testing.B) { func BenchmarkMysqlNoCacheFindPtr(t *testing.B) {
engine, err := newMysqlEngine() engine, err := newMysqlEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }
func BenchmarkMysqlCacheInsert(t *testing.B) { func BenchmarkMysqlCacheInsert(t *testing.B) {
engine, err := newMysqlEngine() engine, err := newMysqlEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkMysqlCacheFind(t *testing.B) { func BenchmarkMysqlCacheFind(t *testing.B) {
engine, err := newMysqlEngine() engine, err := newMysqlEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkMysqlCacheFindPtr(t *testing.B) { func BenchmarkMysqlCacheFindPtr(t *testing.B) {
engine, err := newMysqlEngine() engine, err := newMysqlEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

290
pool.go
View File

@ -1,13 +1,13 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
//"fmt" //"fmt"
"sync" "sync"
//"sync/atomic" //"sync/atomic"
"container/list" "container/list"
"reflect" "reflect"
"time" "time"
) )
// Interface IConnecPool is a connection pool interface, all implements should implement // Interface IConnecPool is a connection pool interface, all implements should implement
@ -17,14 +17,14 @@ import (
// ReleaseDB for releasing a db connection; // ReleaseDB for releasing a db connection;
// Close for invoking when engine.Close // Close for invoking when engine.Close
type IConnectPool interface { type IConnectPool interface {
Init(engine *Engine) error Init(engine *Engine) error
RetrieveDB(engine *Engine) (*sql.DB, error) RetrieveDB(engine *Engine) (*sql.DB, error)
ReleaseDB(engine *Engine, db *sql.DB) ReleaseDB(engine *Engine, db *sql.DB)
Close(engine *Engine) error Close(engine *Engine) error
SetMaxIdleConns(conns int) SetMaxIdleConns(conns int)
MaxIdleConns() int MaxIdleConns() int
SetMaxConns(conns int) SetMaxConns(conns int)
MaxConns() int MaxConns() int
} }
// Struct NoneConnectPool is a implement for IConnectPool. It provides directly invoke driver's // Struct NoneConnectPool is a implement for IConnectPool. It provides directly invoke driver's
@ -34,35 +34,35 @@ type NoneConnectPool struct {
// NewNoneConnectPool new a NoneConnectPool. // NewNoneConnectPool new a NoneConnectPool.
func NewNoneConnectPool() IConnectPool { func NewNoneConnectPool() IConnectPool {
return &NoneConnectPool{} return &NoneConnectPool{}
} }
// Init do nothing // Init do nothing
func (p *NoneConnectPool) Init(engine *Engine) error { func (p *NoneConnectPool) Init(engine *Engine) error {
return nil return nil
} }
// RetrieveDB directly open a connection // RetrieveDB directly open a connection
func (p *NoneConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { func (p *NoneConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) {
db, err = engine.OpenDB() db, err = engine.OpenDB()
return return
} }
// ReleaseDB directly close a connection // ReleaseDB directly close a connection
func (p *NoneConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { func (p *NoneConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
db.Close() db.Close()
} }
// Close do nothing // Close do nothing
func (p *NoneConnectPool) Close(engine *Engine) error { func (p *NoneConnectPool) Close(engine *Engine) error {
return nil return nil
} }
func (p *NoneConnectPool) SetMaxIdleConns(conns int) { func (p *NoneConnectPool) SetMaxIdleConns(conns int) {
} }
func (p *NoneConnectPool) MaxIdleConns() int { func (p *NoneConnectPool) MaxIdleConns() int {
return 0 return 0
} }
// not implemented // not implemented
@ -71,133 +71,133 @@ func (p *NoneConnectPool) SetMaxConns(conns int) {
// not implemented // not implemented
func (p *NoneConnectPool) MaxConns() int { func (p *NoneConnectPool) MaxConns() int {
return -1 return -1
} }
// Struct SysConnectPool is a simple wrapper for using system default connection pool. // Struct SysConnectPool is a simple wrapper for using system default connection pool.
// About the system connection pool, you can review the code database/sql/sql.go // About the system connection pool, you can review the code database/sql/sql.go
// It's currently default Pool implments. // It's currently default Pool implments.
type SysConnectPool struct { type SysConnectPool struct {
db *sql.DB db *sql.DB
maxIdleConns int maxIdleConns int
maxConns int maxConns int
curConns int curConns int
mutex *sync.Mutex mutex *sync.Mutex
queue *list.List queue *list.List
} }
// NewSysConnectPool new a SysConnectPool. // NewSysConnectPool new a SysConnectPool.
func NewSysConnectPool() IConnectPool { func NewSysConnectPool() IConnectPool {
return &SysConnectPool{} return &SysConnectPool{}
} }
// Init create a db immediately and keep it util engine closed. // Init create a db immediately and keep it util engine closed.
func (s *SysConnectPool) Init(engine *Engine) error { func (s *SysConnectPool) Init(engine *Engine) error {
db, err := engine.OpenDB() db, err := engine.OpenDB()
if err != nil { if err != nil {
return err return err
} }
s.db = db s.db = db
s.maxIdleConns = 2 s.maxIdleConns = 2
s.maxConns = -1 s.maxConns = -1
s.curConns = 0 s.curConns = 0
s.mutex = &sync.Mutex{} s.mutex = &sync.Mutex{}
s.queue = list.New() s.queue = list.New()
return nil return nil
} }
type node struct { type node struct {
mutex sync.Mutex mutex sync.Mutex
cond *sync.Cond cond *sync.Cond
} }
func newCondNode() *node { func newCondNode() *node {
n := &node{} n := &node{}
n.cond = sync.NewCond(&n.mutex) n.cond = sync.NewCond(&n.mutex)
return n return n
} }
// RetrieveDB just return the only db // RetrieveDB just return the only db
func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) {
/*if s.maxConns > 0 { /*if s.maxConns > 0 {
fmt.Println("before retrieve") fmt.Println("before retrieve")
s.mutex.Lock() s.mutex.Lock()
for s.curConns >= s.maxConns { for s.curConns >= s.maxConns {
fmt.Println("before waiting...", s.curConns, s.queue.Len()) fmt.Println("before waiting...", s.curConns, s.queue.Len())
s.mutex.Unlock() s.mutex.Unlock()
n := NewNode() n := NewNode()
n.cond.L.Lock() n.cond.L.Lock()
s.queue.PushBack(n) s.queue.PushBack(n)
n.cond.Wait() n.cond.Wait()
n.cond.L.Unlock() n.cond.L.Unlock()
s.mutex.Lock() s.mutex.Lock()
fmt.Println("after waiting...", s.curConns, s.queue.Len()) fmt.Println("after waiting...", s.curConns, s.queue.Len())
} }
s.curConns += 1 s.curConns += 1
s.mutex.Unlock() s.mutex.Unlock()
fmt.Println("after retrieve") fmt.Println("after retrieve")
}*/ }*/
return s.db, nil return s.db, nil
} }
// ReleaseDB do nothing // ReleaseDB do nothing
func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
/*if s.maxConns > 0 { /*if s.maxConns > 0 {
s.mutex.Lock() s.mutex.Lock()
fmt.Println("before release", s.queue.Len()) fmt.Println("before release", s.queue.Len())
s.curConns -= 1 s.curConns -= 1
if e := s.queue.Front(); e != nil { if e := s.queue.Front(); e != nil {
n := e.Value.(*node) n := e.Value.(*node)
//n.cond.L.Lock() //n.cond.L.Lock()
n.cond.Signal() n.cond.Signal()
fmt.Println("signaled...") fmt.Println("signaled...")
s.queue.Remove(e) s.queue.Remove(e)
//n.cond.L.Unlock() //n.cond.L.Unlock()
} }
fmt.Println("after released", s.queue.Len()) fmt.Println("after released", s.queue.Len())
s.mutex.Unlock() s.mutex.Unlock()
}*/ }*/
} }
// Close closed the only db // Close closed the only db
func (p *SysConnectPool) Close(engine *Engine) error { func (p *SysConnectPool) Close(engine *Engine) error {
return p.db.Close() return p.db.Close()
} }
func (p *SysConnectPool) SetMaxIdleConns(conns int) { func (p *SysConnectPool) SetMaxIdleConns(conns int) {
p.db.SetMaxIdleConns(conns) p.db.SetMaxIdleConns(conns)
p.maxIdleConns = conns p.maxIdleConns = conns
} }
func (p *SysConnectPool) MaxIdleConns() int { func (p *SysConnectPool) MaxIdleConns() int {
return p.maxIdleConns return p.maxIdleConns
} }
// not implemented // not implemented
func (p *SysConnectPool) SetMaxConns(conns int) { func (p *SysConnectPool) SetMaxConns(conns int) {
p.maxConns = conns p.maxConns = conns
// if support SetMaxOpenConns, go 1.2+, then set // if support SetMaxOpenConns, go 1.2+, then set
if reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").IsValid() { if reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").IsValid() {
reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").Call([]reflect.Value{reflect.ValueOf(conns)}) reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").Call([]reflect.Value{reflect.ValueOf(conns)})
} }
//p.db.SetMaxOpenConns(conns) //p.db.SetMaxOpenConns(conns)
} }
// not implemented // not implemented
func (p *SysConnectPool) MaxConns() int { func (p *SysConnectPool) MaxConns() int {
return p.maxConns return p.maxConns
} }
// NewSimpleConnectPool new a SimpleConnectPool // NewSimpleConnectPool new a SimpleConnectPool
func NewSimpleConnectPool() IConnectPool { func NewSimpleConnectPool() IConnectPool {
return &SimpleConnectPool{releasedConnects: make([]*sql.DB, 10), return &SimpleConnectPool{releasedConnects: make([]*sql.DB, 10),
usingConnects: map[*sql.DB]time.Time{}, usingConnects: map[*sql.DB]time.Time{},
cur: -1, cur: -1,
maxWaitTimeOut: 14400, maxWaitTimeOut: 14400,
maxIdleConns: 10, maxIdleConns: 10,
mutex: &sync.Mutex{}, mutex: &sync.Mutex{},
} }
} }
// Struct SimpleConnectPool is a simple implementation for IConnectPool. // Struct SimpleConnectPool is a simple implementation for IConnectPool.
@ -205,75 +205,75 @@ func NewSimpleConnectPool() IConnectPool {
// Opening or Closing a database connection must be enter a lock. // Opening or Closing a database connection must be enter a lock.
// This implements will be improved in furture. // This implements will be improved in furture.
type SimpleConnectPool struct { type SimpleConnectPool struct {
releasedConnects []*sql.DB releasedConnects []*sql.DB
cur int cur int
usingConnects map[*sql.DB]time.Time usingConnects map[*sql.DB]time.Time
maxWaitTimeOut int maxWaitTimeOut int
mutex *sync.Mutex mutex *sync.Mutex
maxIdleConns int maxIdleConns int
} }
func (s *SimpleConnectPool) Init(engine *Engine) error { func (s *SimpleConnectPool) Init(engine *Engine) error {
return nil return nil
} }
// RetrieveDB get a connection from connection pool // RetrieveDB get a connection from connection pool
func (p *SimpleConnectPool) RetrieveDB(engine *Engine) (*sql.DB, error) { func (p *SimpleConnectPool) RetrieveDB(engine *Engine) (*sql.DB, error) {
p.mutex.Lock() p.mutex.Lock()
defer p.mutex.Unlock() defer p.mutex.Unlock()
var db *sql.DB = nil var db *sql.DB = nil
var err error = nil var err error = nil
//fmt.Printf("%x, rbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) //fmt.Printf("%x, rbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects))
if p.cur < 0 { if p.cur < 0 {
db, err = engine.OpenDB() db, err = engine.OpenDB()
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.usingConnects[db] = time.Now() p.usingConnects[db] = time.Now()
} else { } else {
db = p.releasedConnects[p.cur] db = p.releasedConnects[p.cur]
p.usingConnects[db] = time.Now() p.usingConnects[db] = time.Now()
p.releasedConnects[p.cur] = nil p.releasedConnects[p.cur] = nil
p.cur = p.cur - 1 p.cur = p.cur - 1
} }
//fmt.Printf("%x, rend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) //fmt.Printf("%x, rend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects))
return db, nil return db, nil
} }
// ReleaseDB release a db from connection pool // ReleaseDB release a db from connection pool
func (p *SimpleConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { func (p *SimpleConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
p.mutex.Lock() p.mutex.Lock()
defer p.mutex.Unlock() defer p.mutex.Unlock()
//fmt.Printf("%x, lbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) //fmt.Printf("%x, lbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects))
if p.cur >= p.maxIdleConns-1 { if p.cur >= p.maxIdleConns-1 {
db.Close() db.Close()
} else { } else {
p.cur = p.cur + 1 p.cur = p.cur + 1
p.releasedConnects[p.cur] = db p.releasedConnects[p.cur] = db
} }
delete(p.usingConnects, db) delete(p.usingConnects, db)
//fmt.Printf("%x, lend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) //fmt.Printf("%x, lend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects))
} }
// Close release all db // Close release all db
func (p *SimpleConnectPool) Close(engine *Engine) error { func (p *SimpleConnectPool) Close(engine *Engine) error {
p.mutex.Lock() p.mutex.Lock()
defer p.mutex.Unlock() defer p.mutex.Unlock()
for len(p.releasedConnects) > 0 { for len(p.releasedConnects) > 0 {
p.releasedConnects[0].Close() p.releasedConnects[0].Close()
p.releasedConnects = p.releasedConnects[1:] p.releasedConnects = p.releasedConnects[1:]
} }
return nil return nil
} }
func (p *SimpleConnectPool) SetMaxIdleConns(conns int) { func (p *SimpleConnectPool) SetMaxIdleConns(conns int) {
p.maxIdleConns = conns p.maxIdleConns = conns
} }
func (p *SimpleConnectPool) MaxIdleConns() int { func (p *SimpleConnectPool) MaxIdleConns() int {
return p.maxIdleConns return p.maxIdleConns
} }
// not implemented // not implemented
@ -282,5 +282,5 @@ func (p *SimpleConnectPool) SetMaxConns(conns int) {
// not implemented // not implemented
func (p *SimpleConnectPool) MaxConns() int { func (p *SimpleConnectPool) MaxConns() int {
return -1 return -1
} }

View File

@ -1,305 +1,305 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
) )
type postgres struct { type postgres struct {
base base
} }
type values map[string]string type values map[string]string
func (vs values) Set(k, v string) { func (vs values) Set(k, v string) {
vs[k] = v vs[k] = v
} }
func (vs values) Get(k string) (v string) { func (vs values) Get(k string) (v string) {
return vs[k] return vs[k]
} }
func errorf(s string, args ...interface{}) { func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
} }
func parseOpts(name string, o values) { func parseOpts(name string, o values) {
if len(name) == 0 { if len(name) == 0 {
return return
} }
name = strings.TrimSpace(name) name = strings.TrimSpace(name)
ps := strings.Split(name, " ") ps := strings.Split(name, " ")
for _, p := range ps { for _, p := range ps {
kv := strings.Split(p, "=") kv := strings.Split(p, "=")
if len(kv) < 2 { if len(kv) < 2 {
errorf("invalid option: %q", p) errorf("invalid option: %q", p)
} }
o.Set(kv[0], kv[1]) o.Set(kv[0], kv[1])
} }
} }
type postgresParser struct { type postgresParser struct {
} }
func (p *postgresParser) parse(driverName, dataSourceName string) (*uri, error) { func (p *postgresParser) parse(driverName, dataSourceName string) (*uri, error) {
db := &uri{dbType: POSTGRES} db := &uri{dbType: POSTGRES}
o := make(values) o := make(values)
parseOpts(dataSourceName, o) parseOpts(dataSourceName, o)
db.dbName = o.Get("dbname") db.dbName = o.Get("dbname")
if db.dbName == "" { if db.dbName == "" {
return nil, errors.New("dbname is empty") return nil, errors.New("dbname is empty")
} }
return db, nil return db, nil
} }
func (db *postgres) Init(drivername, uri string) error { func (db *postgres) Init(drivername, uri string) error {
return db.base.init(&postgresParser{}, drivername, uri) return db.base.init(&postgresParser{}, drivername, uri)
} }
func (db *postgres) SqlType(c *Column) string { func (db *postgres) SqlType(c *Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case TinyInt: case TinyInt:
res = SmallInt res = SmallInt
case MediumInt, Int, Integer: case MediumInt, Int, Integer:
return Integer return Integer
case Serial, BigSerial: case Serial, BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.Nullable = false c.Nullable = false
res = t res = t
case Binary, VarBinary: case Binary, VarBinary:
return Bytea return Bytea
case DateTime: case DateTime:
res = TimeStamp res = TimeStamp
case TimeStampz: case TimeStampz:
return "timestamp with time zone" return "timestamp with time zone"
case Float: case Float:
res = Real res = Real
case TinyText, MediumText, LongText: case TinyText, MediumText, LongText:
res = Text res = Text
case Blob, TinyBlob, MediumBlob, LongBlob: case Blob, TinyBlob, MediumBlob, LongBlob:
return Bytea return Bytea
case Double: case Double:
return "DOUBLE PRECISION" return "DOUBLE PRECISION"
default: default:
if c.IsAutoIncrement { if c.IsAutoIncrement {
return Serial return Serial
} }
res = t res = t
} }
var hasLen1 bool = (c.Length > 0) var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0) var hasLen2 bool = (c.Length2 > 0)
if hasLen1 { if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.Itoa(c.Length) + ")"
} else if hasLen2 { } else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
} }
return res return res
} }
func (db *postgres) SupportInsertMany() bool { func (db *postgres) SupportInsertMany() bool {
return true return true
} }
func (db *postgres) QuoteStr() string { func (db *postgres) QuoteStr() string {
return "\"" return "\""
} }
func (db *postgres) AutoIncrStr() string { func (db *postgres) AutoIncrStr() string {
return "" return ""
} }
func (db *postgres) SupportEngine() bool { func (db *postgres) SupportEngine() bool {
return false return false
} }
func (db *postgres) SupportCharset() bool { func (db *postgres) SupportCharset() bool {
return false return false
} }
func (db *postgres) IndexOnTable() bool { func (db *postgres) IndexOnTable() bool {
return false return false
} }
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{tableName, idxName} args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes ` + return `SELECT indexname FROM pg_indexes ` +
`WHERE tablename = ? AND indexname = ?`, args `WHERE tablename = ? AND indexname = ?`, args
} }
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
} }
func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName, colName} args := []interface{}{tableName, colName}
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
" AND column_name = ?", args " AND column_name = ?", args
} }
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, record := range res { for _, record := range res {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "column_name": case "column_name":
col.Name = strings.Trim(string(content), `" `) col.Name = strings.Trim(string(content), `" `)
case "column_default": case "column_default":
if strings.HasPrefix(string(content), "nextval") { if strings.HasPrefix(string(content), "nextval") {
col.IsPrimaryKey = true col.IsPrimaryKey = true
} else { } else {
col.Default = string(content) col.Default = string(content)
} }
case "is_nullable": case "is_nullable":
if string(content) == "YES" { if string(content) == "YES" {
col.Nullable = true col.Nullable = true
} else { } else {
col.Nullable = false col.Nullable = false
} }
case "data_type": case "data_type":
ct := string(content) ct := string(content)
switch ct { switch ct {
case "character varying", "character": case "character varying", "character":
col.SQLType = SQLType{Varchar, 0, 0} col.SQLType = SQLType{Varchar, 0, 0}
case "timestamp without time zone": case "timestamp without time zone":
col.SQLType = SQLType{DateTime, 0, 0} col.SQLType = SQLType{DateTime, 0, 0}
case "timestamp with time zone": case "timestamp with time zone":
col.SQLType = SQLType{TimeStampz, 0, 0} col.SQLType = SQLType{TimeStampz, 0, 0}
case "double precision": case "double precision":
col.SQLType = SQLType{Double, 0, 0} col.SQLType = SQLType{Double, 0, 0}
case "boolean": case "boolean":
col.SQLType = SQLType{Bool, 0, 0} col.SQLType = SQLType{Bool, 0, 0}
case "time without time zone": case "time without time zone":
col.SQLType = SQLType{Time, 0, 0} col.SQLType = SQLType{Time, 0, 0}
default: default:
col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} col.SQLType = SQLType{strings.ToUpper(ct), 0, 0}
} }
if _, ok := sqlTypes[col.SQLType.Name]; !ok { if _, ok := sqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct))
} }
case "character_maximum_length": case "character_maximum_length":
i, err := strconv.Atoi(string(content)) i, err := strconv.Atoi(string(content))
if err != nil { if err != nil {
return nil, nil, errors.New("retrieve length error") return nil, nil, errors.New("retrieve length error")
} }
col.Length = i col.Length = i
case "numeric_precision": case "numeric_precision":
case "numeric_precision_radix": case "numeric_precision_radix":
} }
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *postgres) GetTables() ([]*Table, error) { func (db *postgres) GetTables() ([]*Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT tablename FROM pg_tables where schemaname = 'public'" s := "SELECT tablename FROM pg_tables where schemaname = 'public'"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for _, record := range res {
table := new(Table) table := new(Table)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "tablename": case "tablename":
table.Name = string(content) table.Name = string(content)
} }
} }
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
} }
func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for _, record := range res {
var indexType int var indexType int
var indexName string var indexName string
var colNames []string var colNames []string
for name, content := range record { for name, content := range record {
switch name { switch name {
case "indexname": case "indexname":
indexName = strings.Trim(string(content), `" `) indexName = strings.Trim(string(content), `" `)
case "indexdef": case "indexdef":
c := string(content) c := string(content)
if strings.HasPrefix(c, "CREATE UNIQUE INDEX") { if strings.HasPrefix(c, "CREATE UNIQUE INDEX") {
indexType = UniqueType indexType = UniqueType
} else { } else {
indexType = IndexType indexType = IndexType
} }
cs := strings.Split(c, "(") cs := strings.Split(c, "(")
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
} }
} }
if strings.HasSuffix(indexName, "_pkey") { if strings.HasSuffix(indexName, "_pkey") {
continue continue
} }
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
newIdxName := indexName[5+len(tableName) : len(indexName)] newIdxName := indexName[5+len(tableName) : len(indexName)]
if newIdxName != "" { if newIdxName != "" {
indexName = newIdxName indexName = newIdxName
} }
} }
index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
for _, colName := range colNames { for _, colName := range colNames {
index.Cols = append(index.Cols, strings.Trim(colName, `" `)) index.Cols = append(index.Cols, strings.Trim(colName, `" `))
} }
indexes[index.Name] = index indexes[index.Name] = index
} }
return indexes, nil return indexes, nil
} }

View File

@ -1,52 +1,52 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"testing" "testing"
_ "github.com/lib/pq" _ "github.com/lib/pq"
) )
func newPostgresEngine() (*Engine, error) { func newPostgresEngine() (*Engine, error) {
return NewEngine("postgres", "dbname=xorm_test sslmode=disable") return NewEngine("postgres", "dbname=xorm_test sslmode=disable")
} }
func newPostgresDriverDB() (*sql.DB, error) { func newPostgresDriverDB() (*sql.DB, error) {
return sql.Open("postgres", "dbname=xorm_test sslmode=disable") return sql.Open("postgres", "dbname=xorm_test sslmode=disable")
} }
func TestPostgres(t *testing.T) { func TestPostgres(t *testing.T) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
testAll3(engine, t) testAll3(engine, t)
} }
func TestPostgresWithCache(t *testing.T) { func TestPostgresWithCache(t *testing.T) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
defer engine.Close() defer engine.Close()
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
} }
/* /*
@ -147,91 +147,91 @@ func TestPostgres2(t *testing.T) {
}*/ }*/
const ( const (
createTablePostgres = `CREATE TABLE IF NOT EXISTS "big_struct" ("id" SERIAL PRIMARY KEY NOT NULL, "name" VARCHAR(255) NULL, "title" VARCHAR(255) NULL, "age" VARCHAR(255) NULL, "alias" VARCHAR(255) NULL, "nick_name" VARCHAR(255) NULL);` createTablePostgres = `CREATE TABLE IF NOT EXISTS "big_struct" ("id" SERIAL PRIMARY KEY NOT NULL, "name" VARCHAR(255) NULL, "title" VARCHAR(255) NULL, "age" VARCHAR(255) NULL, "alias" VARCHAR(255) NULL, "nick_name" VARCHAR(255) NULL);`
dropTablePostgres = `DROP TABLE IF EXISTS "big_struct";` dropTablePostgres = `DROP TABLE IF EXISTS "big_struct";`
) )
func BenchmarkPostgresDriverInsert(t *testing.B) { func BenchmarkPostgresDriverInsert(t *testing.B) {
doBenchDriver(newPostgresDriverDB, createTablePostgres, dropTablePostgres, doBenchDriver(newPostgresDriverDB, createTablePostgres, dropTablePostgres,
doBenchDriverInsert, t) doBenchDriverInsert, t)
} }
func BenchmarkPostgresDriverFind(t *testing.B) { func BenchmarkPostgresDriverFind(t *testing.B) {
doBenchDriver(newPostgresDriverDB, createTablePostgres, dropTablePostgres, doBenchDriver(newPostgresDriverDB, createTablePostgres, dropTablePostgres,
doBenchDriverFind, t) doBenchDriverFind, t)
} }
func BenchmarkPostgresNoCacheInsert(t *testing.B) { func BenchmarkPostgresNoCacheInsert(t *testing.B) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkPostgresNoCacheFind(t *testing.B) { func BenchmarkPostgresNoCacheFind(t *testing.B) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkPostgresNoCacheFindPtr(t *testing.B) { func BenchmarkPostgresNoCacheFindPtr(t *testing.B) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }
func BenchmarkPostgresCacheInsert(t *testing.B) { func BenchmarkPostgresCacheInsert(t *testing.B) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkPostgresCacheFind(t *testing.B) { func BenchmarkPostgresCacheFind(t *testing.B) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkPostgresCacheFindPtr(t *testing.B) { func BenchmarkPostgresCacheFindPtr(t *testing.B) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

View File

@ -2,17 +2,17 @@ package xorm
// Executed before an object is initially persisted to the database // Executed before an object is initially persisted to the database
type BeforeInsertProcessor interface { type BeforeInsertProcessor interface {
BeforeInsert() BeforeInsert()
} }
// Executed before an object is updated // Executed before an object is updated
type BeforeUpdateProcessor interface { type BeforeUpdateProcessor interface {
BeforeUpdate() BeforeUpdate()
} }
// Executed before an object is deleted // Executed before an object is deleted
type BeforeDeleteProcessor interface { type BeforeDeleteProcessor interface {
BeforeDelete() BeforeDelete()
} }
// !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations // !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations
@ -24,16 +24,15 @@ type BeforeDeleteProcessor interface {
// Executed after an object is persisted to the database // Executed after an object is persisted to the database
type AfterInsertProcessor interface { type AfterInsertProcessor interface {
AfterInsert() AfterInsert()
} }
// Executed after an object has been updated // Executed after an object has been updated
type AfterUpdateProcessor interface { type AfterUpdateProcessor interface {
AfterUpdate() AfterUpdate()
} }
// Executed after an object has been deleted // Executed after an object has been deleted
type AfterDeleteProcessor interface { type AfterDeleteProcessor interface {
AfterDelete() AfterDelete()
} }

4523
session.go

File diff suppressed because it is too large Load Diff

View File

@ -1,229 +1,229 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"strings" "strings"
) )
type sqlite3 struct { type sqlite3 struct {
base base
} }
type sqlite3Parser struct { type sqlite3Parser struct {
} }
func (p *sqlite3Parser) parse(driverName, dataSourceName string) (*uri, error) { func (p *sqlite3Parser) parse(driverName, dataSourceName string) (*uri, error) {
return &uri{dbType: SQLITE, dbName: dataSourceName}, nil return &uri{dbType: SQLITE, dbName: dataSourceName}, nil
} }
func (db *sqlite3) Init(drivername, dataSourceName string) error { func (db *sqlite3) Init(drivername, dataSourceName string) error {
return db.base.init(&sqlite3Parser{}, drivername, dataSourceName) return db.base.init(&sqlite3Parser{}, drivername, dataSourceName)
} }
func (db *sqlite3) SqlType(c *Column) string { func (db *sqlite3) SqlType(c *Column) string {
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case Date, DateTime, TimeStamp, Time: case Date, DateTime, TimeStamp, Time:
return Numeric return Numeric
case TimeStampz: case TimeStampz:
return Text return Text
case Char, Varchar, TinyText, Text, MediumText, LongText: case Char, Varchar, TinyText, Text, MediumText, LongText:
return Text return Text
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool: case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool:
return Integer return Integer
case Float, Double, Real: case Float, Double, Real:
return Real return Real
case Decimal, Numeric: case Decimal, Numeric:
return Numeric return Numeric
case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary: case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary:
return Blob return Blob
case Serial, BigSerial: case Serial, BigSerial:
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.Nullable = false c.Nullable = false
return Integer return Integer
default: default:
return t return t
} }
} }
func (db *sqlite3) SupportInsertMany() bool { func (db *sqlite3) SupportInsertMany() bool {
return true return true
} }
func (db *sqlite3) QuoteStr() string { func (db *sqlite3) QuoteStr() string {
return "`" return "`"
} }
func (db *sqlite3) AutoIncrStr() string { func (db *sqlite3) AutoIncrStr() string {
return "AUTOINCREMENT" return "AUTOINCREMENT"
} }
func (db *sqlite3) SupportEngine() bool { func (db *sqlite3) SupportEngine() bool {
return false return false
} }
func (db *sqlite3) SupportCharset() bool { func (db *sqlite3) SupportCharset() bool {
return false return false
} }
func (db *sqlite3) IndexOnTable() bool { func (db *sqlite3) IndexOnTable() bool {
return false return false
} }
func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{idxName} args := []interface{}{idxName}
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
} }
func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) { func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
} }
func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))"
return sql, args return sql, args
} }
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
var sql string var sql string
for _, record := range res { for _, record := range res {
for name, content := range record { for name, content := range record {
if name == "sql" { if name == "sql" {
sql = string(content) sql = string(content)
} }
} }
} }
nStart := strings.Index(sql, "(") nStart := strings.Index(sql, "(")
nEnd := strings.Index(sql, ")") nEnd := strings.Index(sql, ")")
colCreates := strings.Split(sql[nStart+1:nEnd], ",") colCreates := strings.Split(sql[nStart+1:nEnd], ",")
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, colStr := range colCreates { for _, colStr := range colCreates {
fields := strings.Fields(strings.TrimSpace(colStr)) fields := strings.Fields(strings.TrimSpace(colStr))
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
col.Nullable = true col.Nullable = true
for idx, field := range fields { for idx, field := range fields {
if idx == 0 { if idx == 0 {
col.Name = strings.Trim(field, "`[] ") col.Name = strings.Trim(field, "`[] ")
continue continue
} else if idx == 1 { } else if idx == 1 {
col.SQLType = SQLType{field, 0, 0} col.SQLType = SQLType{field, 0, 0}
} }
switch field { switch field {
case "PRIMARY": case "PRIMARY":
col.IsPrimaryKey = true col.IsPrimaryKey = true
case "AUTOINCREMENT": case "AUTOINCREMENT":
col.IsAutoIncrement = true col.IsAutoIncrement = true
case "NULL": case "NULL":
if fields[idx-1] == "NOT" { if fields[idx-1] == "NOT" {
col.Nullable = false col.Nullable = false
} else { } else {
col.Nullable = true col.Nullable = true
} }
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *sqlite3) GetTables() ([]*Table, error) { func (db *sqlite3) GetTables() ([]*Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT name FROM sqlite_master WHERE type='table'" s := "SELECT name FROM sqlite_master WHERE type='table'"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for _, record := range res {
table := new(Table) table := new(Table)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "name": case "name":
table.Name = string(content) table.Name = string(content)
} }
} }
if table.Name == "sqlite_sequence" { if table.Name == "sqlite_sequence" {
continue continue
} }
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
} }
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for _, record := range res {
var sql string var sql string
index := new(Index) index := new(Index)
for name, content := range record { for name, content := range record {
if name == "sql" { if name == "sql" {
sql = string(content) sql = string(content)
} }
} }
nNStart := strings.Index(sql, "INDEX") nNStart := strings.Index(sql, "INDEX")
nNEnd := strings.Index(sql, "ON") nNEnd := strings.Index(sql, "ON")
indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []")
//fmt.Println(indexName) //fmt.Println(indexName)
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
index.Name = indexName[5+len(tableName) : len(indexName)] index.Name = indexName[5+len(tableName) : len(indexName)]
} else { } else {
index.Name = indexName index.Name = indexName
} }
if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") {
index.Type = UniqueType index.Type = UniqueType
} else { } else {
index.Type = IndexType index.Type = IndexType
} }
nStart := strings.Index(sql, "(") nStart := strings.Index(sql, "(")
nEnd := strings.Index(sql, ")") nEnd := strings.Index(sql, ")")
colIndexes := strings.Split(sql[nStart+1:nEnd], ",") colIndexes := strings.Split(sql[nStart+1:nEnd], ",")
index.Cols = make([]string, 0) index.Cols = make([]string, 0)
for _, col := range colIndexes { for _, col := range colIndexes {
index.Cols = append(index.Cols, strings.Trim(col, "` []")) index.Cols = append(index.Cols, strings.Trim(col, "` []"))
} }
indexes[index.Name] = index indexes[index.Name] = index
} }
return indexes, nil return indexes, nil
} }

View File

@ -1,140 +1,140 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"os" "os"
"testing" "testing"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
func newSqlite3Engine() (*Engine, error) { func newSqlite3Engine() (*Engine, error) {
os.Remove("./test.db") os.Remove("./test.db")
return NewEngine("sqlite3", "./test.db") return NewEngine("sqlite3", "./test.db")
} }
func newSqlite3DriverDB() (*sql.DB, error) { func newSqlite3DriverDB() (*sql.DB, error) {
os.Remove("./test.db") os.Remove("./test.db")
return sql.Open("sqlite3", "./test.db") return sql.Open("sqlite3", "./test.db")
} }
func TestSqlite3(t *testing.T) { func TestSqlite3(t *testing.T) {
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
testAll3(engine, t) testAll3(engine, t)
} }
func TestSqlite3WithCache(t *testing.T) { func TestSqlite3WithCache(t *testing.T) {
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
} }
const ( const (
createTableSqlite3 = "CREATE TABLE IF NOT EXISTS `big_struct` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, `title` TEXT NULL, `age` TEXT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL);" createTableSqlite3 = "CREATE TABLE IF NOT EXISTS `big_struct` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, `title` TEXT NULL, `age` TEXT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL);"
dropTableSqlite3 = "DROP TABLE IF EXISTS `big_struct`;" dropTableSqlite3 = "DROP TABLE IF EXISTS `big_struct`;"
) )
func BenchmarkSqlite3DriverInsert(t *testing.B) { func BenchmarkSqlite3DriverInsert(t *testing.B) {
doBenchDriver(newSqlite3DriverDB, createTableSqlite3, dropTableSqlite3, doBenchDriver(newSqlite3DriverDB, createTableSqlite3, dropTableSqlite3,
doBenchDriverInsert, t) doBenchDriverInsert, t)
} }
func BenchmarkSqlite3DriverFind(t *testing.B) { func BenchmarkSqlite3DriverFind(t *testing.B) {
doBenchDriver(newSqlite3DriverDB, createTableSqlite3, dropTableSqlite3, doBenchDriver(newSqlite3DriverDB, createTableSqlite3, dropTableSqlite3,
doBenchDriverFind, t) doBenchDriverFind, t)
} }
func BenchmarkSqlite3NoCacheInsert(t *testing.B) { func BenchmarkSqlite3NoCacheInsert(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkSqlite3NoCacheFind(t *testing.B) { func BenchmarkSqlite3NoCacheFind(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkSqlite3NoCacheFindPtr(t *testing.B) { func BenchmarkSqlite3NoCacheFindPtr(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }
func BenchmarkSqlite3CacheInsert(t *testing.B) { func BenchmarkSqlite3CacheInsert(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkSqlite3CacheFind(t *testing.B) { func BenchmarkSqlite3CacheFind(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkSqlite3CacheFindPtr(t *testing.B) { func BenchmarkSqlite3CacheFindPtr(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

File diff suppressed because it is too large Load Diff

600
table.go
View File

@ -1,335 +1,335 @@
package xorm package xorm
import ( import (
"reflect" "reflect"
"sort" "sort"
"strings" "strings"
"time" "time"
) )
// xorm SQL types // xorm SQL types
type SQLType struct { type SQLType struct {
Name string Name string
DefaultLength int DefaultLength int
DefaultLength2 int DefaultLength2 int
} }
func (s *SQLType) IsText() bool { func (s *SQLType) IsText() bool {
return s.Name == Char || s.Name == Varchar || s.Name == TinyText || return s.Name == Char || s.Name == Varchar || s.Name == TinyText ||
s.Name == Text || s.Name == MediumText || s.Name == LongText s.Name == Text || s.Name == MediumText || s.Name == LongText
} }
func (s *SQLType) IsBlob() bool { func (s *SQLType) IsBlob() bool {
return (s.Name == TinyBlob) || (s.Name == Blob) || return (s.Name == TinyBlob) || (s.Name == Blob) ||
s.Name == MediumBlob || s.Name == LongBlob || s.Name == MediumBlob || s.Name == LongBlob ||
s.Name == Binary || s.Name == VarBinary || s.Name == Bytea s.Name == Binary || s.Name == VarBinary || s.Name == Bytea
} }
const () const ()
var ( var (
Bit = "BIT" Bit = "BIT"
TinyInt = "TINYINT" TinyInt = "TINYINT"
SmallInt = "SMALLINT" SmallInt = "SMALLINT"
MediumInt = "MEDIUMINT" MediumInt = "MEDIUMINT"
Int = "INT" Int = "INT"
Integer = "INTEGER" Integer = "INTEGER"
BigInt = "BIGINT" BigInt = "BIGINT"
Char = "CHAR" Char = "CHAR"
Varchar = "VARCHAR" Varchar = "VARCHAR"
TinyText = "TINYTEXT" TinyText = "TINYTEXT"
Text = "TEXT" Text = "TEXT"
MediumText = "MEDIUMTEXT" MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT" LongText = "LONGTEXT"
Date = "DATE" Date = "DATE"
DateTime = "DATETIME" DateTime = "DATETIME"
Time = "TIME" Time = "TIME"
TimeStamp = "TIMESTAMP" TimeStamp = "TIMESTAMP"
TimeStampz = "TIMESTAMPZ" TimeStampz = "TIMESTAMPZ"
Decimal = "DECIMAL" Decimal = "DECIMAL"
Numeric = "NUMERIC" Numeric = "NUMERIC"
Real = "REAL" Real = "REAL"
Float = "FLOAT" Float = "FLOAT"
Double = "DOUBLE" Double = "DOUBLE"
Binary = "BINARY" Binary = "BINARY"
VarBinary = "VARBINARY" VarBinary = "VARBINARY"
TinyBlob = "TINYBLOB" TinyBlob = "TINYBLOB"
Blob = "BLOB" Blob = "BLOB"
MediumBlob = "MEDIUMBLOB" MediumBlob = "MEDIUMBLOB"
LongBlob = "LONGBLOB" LongBlob = "LONGBLOB"
Bytea = "BYTEA" Bytea = "BYTEA"
Bool = "BOOL" Bool = "BOOL"
Serial = "SERIAL" Serial = "SERIAL"
BigSerial = "BIGSERIAL" BigSerial = "BIGSERIAL"
sqlTypes = map[string]bool{ sqlTypes = map[string]bool{
Bit: true, Bit: true,
TinyInt: true, TinyInt: true,
SmallInt: true, SmallInt: true,
MediumInt: true, MediumInt: true,
Int: true, Int: true,
Integer: true, Integer: true,
BigInt: true, BigInt: true,
Char: true, Char: true,
Varchar: true, Varchar: true,
TinyText: true, TinyText: true,
Text: true, Text: true,
MediumText: true, MediumText: true,
LongText: true, LongText: true,
Date: true, Date: true,
DateTime: true, DateTime: true,
Time: true, Time: true,
TimeStamp: true, TimeStamp: true,
TimeStampz: true, TimeStampz: true,
Decimal: true, Decimal: true,
Numeric: true, Numeric: true,
Binary: true, Binary: true,
VarBinary: true, VarBinary: true,
Real: true, Real: true,
Float: true, Float: true,
Double: true, Double: true,
TinyBlob: true, TinyBlob: true,
Blob: true, Blob: true,
MediumBlob: true, MediumBlob: true,
LongBlob: true, LongBlob: true,
Bytea: true, Bytea: true,
Bool: true, Bool: true,
Serial: true, Serial: true,
BigSerial: true, BigSerial: true,
} }
intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"} intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"}
uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"}
) )
var b byte var b byte
var tm time.Time var tm time.Time
func Type2SQLType(t reflect.Type) (st SQLType) { func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k { switch k := t.Kind(); k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
st = SQLType{Int, 0, 0} st = SQLType{Int, 0, 0}
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
st = SQLType{BigInt, 0, 0} st = SQLType{BigInt, 0, 0}
case reflect.Float32: case reflect.Float32:
st = SQLType{Float, 0, 0} st = SQLType{Float, 0, 0}
case reflect.Float64: case reflect.Float64:
st = SQLType{Double, 0, 0} st = SQLType{Double, 0, 0}
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
st = SQLType{Varchar, 64, 0} st = SQLType{Varchar, 64, 0}
case reflect.Array, reflect.Slice, reflect.Map: case reflect.Array, reflect.Slice, reflect.Map:
if t.Elem() == reflect.TypeOf(b) { if t.Elem() == reflect.TypeOf(b) {
st = SQLType{Blob, 0, 0} st = SQLType{Blob, 0, 0}
} else { } else {
st = SQLType{Text, 0, 0} st = SQLType{Text, 0, 0}
} }
case reflect.Bool: case reflect.Bool:
st = SQLType{Bool, 0, 0} st = SQLType{Bool, 0, 0}
case reflect.String: case reflect.String:
st = SQLType{Varchar, 255, 0} st = SQLType{Varchar, 255, 0}
case reflect.Struct: case reflect.Struct:
if t == reflect.TypeOf(tm) { if t == reflect.TypeOf(tm) {
st = SQLType{DateTime, 0, 0} st = SQLType{DateTime, 0, 0}
} else { } else {
st = SQLType{Text, 0, 0} st = SQLType{Text, 0, 0}
} }
case reflect.Ptr: case reflect.Ptr:
st, _ = ptrType2SQLType(t) st, _ = ptrType2SQLType(t)
default: default:
st = SQLType{Text, 0, 0} st = SQLType{Text, 0, 0}
} }
return return
} }
func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) { func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) {
has = true has = true
switch t { switch t {
case reflect.TypeOf(&c_EMPTY_STRING): case reflect.TypeOf(&c_EMPTY_STRING):
st = SQLType{Varchar, 255, 0} st = SQLType{Varchar, 255, 0}
return return
case reflect.TypeOf(&c_BOOL_DEFAULT): case reflect.TypeOf(&c_BOOL_DEFAULT):
st = SQLType{Bool, 0, 0} st = SQLType{Bool, 0, 0}
case reflect.TypeOf(&c_COMPLEX64_DEFAULT), reflect.TypeOf(&c_COMPLEX128_DEFAULT): case reflect.TypeOf(&c_COMPLEX64_DEFAULT), reflect.TypeOf(&c_COMPLEX128_DEFAULT):
st = SQLType{Varchar, 64, 0} st = SQLType{Varchar, 64, 0}
case reflect.TypeOf(&c_FLOAT32_DEFAULT): case reflect.TypeOf(&c_FLOAT32_DEFAULT):
st = SQLType{Float, 0, 0} st = SQLType{Float, 0, 0}
case reflect.TypeOf(&c_FLOAT64_DEFAULT): case reflect.TypeOf(&c_FLOAT64_DEFAULT):
st = SQLType{Double, 0, 0} st = SQLType{Double, 0, 0}
case reflect.TypeOf(&c_INT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT): case reflect.TypeOf(&c_INT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT):
st = SQLType{BigInt, 0, 0} st = SQLType{BigInt, 0, 0}
case reflect.TypeOf(&c_TIME_DEFAULT): case reflect.TypeOf(&c_TIME_DEFAULT):
st = SQLType{DateTime, 0, 0} st = SQLType{DateTime, 0, 0}
case reflect.TypeOf(&c_INT_DEFAULT), reflect.TypeOf(&c_INT32_DEFAULT), reflect.TypeOf(&c_INT8_DEFAULT), reflect.TypeOf(&c_INT16_DEFAULT), reflect.TypeOf(&c_UINT_DEFAULT), reflect.TypeOf(&c_UINT32_DEFAULT), reflect.TypeOf(&c_UINT8_DEFAULT), reflect.TypeOf(&c_UINT16_DEFAULT): case reflect.TypeOf(&c_INT_DEFAULT), reflect.TypeOf(&c_INT32_DEFAULT), reflect.TypeOf(&c_INT8_DEFAULT), reflect.TypeOf(&c_INT16_DEFAULT), reflect.TypeOf(&c_UINT_DEFAULT), reflect.TypeOf(&c_UINT32_DEFAULT), reflect.TypeOf(&c_UINT8_DEFAULT), reflect.TypeOf(&c_UINT16_DEFAULT):
st = SQLType{Int, 0, 0} st = SQLType{Int, 0, 0}
default: default:
has = false has = false
} }
return return
} }
// default sql type change to go types // default sql type change to go types
func SQLType2Type(st SQLType) reflect.Type { func SQLType2Type(st SQLType) reflect.Type {
name := strings.ToUpper(st.Name) name := strings.ToUpper(st.Name)
switch name { switch name {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial:
return reflect.TypeOf(1) return reflect.TypeOf(1)
case BigInt, BigSerial: case BigInt, BigSerial:
return reflect.TypeOf(int64(1)) return reflect.TypeOf(int64(1))
case Float, Real: case Float, Real:
return reflect.TypeOf(float32(1)) return reflect.TypeOf(float32(1))
case Double: case Double:
return reflect.TypeOf(float64(1)) return reflect.TypeOf(float64(1))
case Char, Varchar, TinyText, Text, MediumText, LongText: case Char, Varchar, TinyText, Text, MediumText, LongText:
return reflect.TypeOf("") return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary: case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary:
return reflect.TypeOf([]byte{}) return reflect.TypeOf([]byte{})
case Bool: case Bool:
return reflect.TypeOf(true) return reflect.TypeOf(true)
case DateTime, Date, Time, TimeStamp, TimeStampz: case DateTime, Date, Time, TimeStamp, TimeStampz:
return reflect.TypeOf(tm) return reflect.TypeOf(tm)
case Decimal, Numeric: case Decimal, Numeric:
return reflect.TypeOf("") return reflect.TypeOf("")
default: default:
return reflect.TypeOf("") return reflect.TypeOf("")
} }
} }
const ( const (
IndexType = iota + 1 IndexType = iota + 1
UniqueType UniqueType
) )
// database index // database index
type Index struct { type Index struct {
Name string Name string
Type int Type int
Cols []string Cols []string
} }
// add columns which will be composite index // add columns which will be composite index
func (index *Index) AddColumn(cols ...string) { func (index *Index) AddColumn(cols ...string) {
for _, col := range cols { for _, col := range cols {
index.Cols = append(index.Cols, col) index.Cols = append(index.Cols, col)
} }
} }
// new an index // new an index
func NewIndex(name string, indexType int) *Index { func NewIndex(name string, indexType int) *Index {
return &Index{name, indexType, make([]string, 0)} return &Index{name, indexType, make([]string, 0)}
} }
const ( const (
TWOSIDES = iota + 1 TWOSIDES = iota + 1
ONLYTODB ONLYTODB
ONLYFROMDB ONLYFROMDB
) )
// database column // database column
type Column struct { type Column struct {
Name string Name string
FieldName string FieldName string
SQLType SQLType SQLType SQLType
Length int Length int
Length2 int Length2 int
Nullable bool Nullable bool
Default string Default string
Indexes map[string]bool Indexes map[string]bool
IsPrimaryKey bool IsPrimaryKey bool
IsAutoIncrement bool IsAutoIncrement bool
MapType int MapType int
IsCreated bool IsCreated bool
IsUpdated bool IsUpdated bool
IsCascade bool IsCascade bool
IsVersion bool IsVersion bool
} }
// generate column description string according dialect // generate column description string according dialect
func (col *Column) String(d dialect) string { func (col *Column) String(d dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
sql += d.SqlType(col) + " " sql += d.SqlType(col) + " "
if col.IsPrimaryKey { if col.IsPrimaryKey {
sql += "PRIMARY KEY " sql += "PRIMARY KEY "
if col.IsAutoIncrement { if col.IsAutoIncrement {
sql += d.AutoIncrStr() + " " sql += d.AutoIncrStr() + " "
} }
} }
if col.Nullable { if col.Nullable {
sql += "NULL " sql += "NULL "
} else { } else {
sql += "NOT NULL " sql += "NOT NULL "
} }
if col.Default != "" { if col.Default != "" {
sql += "DEFAULT " + col.Default + " " sql += "DEFAULT " + col.Default + " "
} }
return sql return sql
} }
func (col *Column) stringNoPk(d dialect) string { func (col *Column) stringNoPk(d dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
sql += d.SqlType(col) + " " sql += d.SqlType(col) + " "
if col.Nullable { if col.Nullable {
sql += "NULL " sql += "NULL "
} else { } else {
sql += "NOT NULL " sql += "NOT NULL "
} }
if col.Default != "" { if col.Default != "" {
sql += "DEFAULT " + col.Default + " " sql += "DEFAULT " + col.Default + " "
} }
return sql return sql
} }
// return col's filed of struct's value // return col's filed of struct's value
func (col *Column) ValueOf(bean interface{}) reflect.Value { func (col *Column) ValueOf(bean interface{}) reflect.Value {
var fieldValue reflect.Value var fieldValue reflect.Value
if strings.Contains(col.FieldName, ".") { if strings.Contains(col.FieldName, ".") {
fields := strings.Split(col.FieldName, ".") fields := strings.Split(col.FieldName, ".")
if len(fields) > 2 { if len(fields) > 2 {
return reflect.ValueOf(nil) return reflect.ValueOf(nil)
} }
fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(fields[0]) fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(fields[0])
fieldValue = fieldValue.FieldByName(fields[1]) fieldValue = fieldValue.FieldByName(fields[1])
} else { } else {
fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
} }
return fieldValue return fieldValue
} }
// database table // database table
type Table struct { type Table struct {
Name string Name string
Type reflect.Type Type reflect.Type
ColumnsSeq []string ColumnsSeq []string
Columns map[string]*Column Columns map[string]*Column
Indexes map[string]*Index Indexes map[string]*Index
PrimaryKey string PrimaryKey string
Created map[string]bool Created map[string]bool
Updated string Updated string
Version string Version string
Cacher Cacher Cacher Cacher
} }
/* /*
@ -344,90 +344,90 @@ func NewTable(name string, t reflect.Type) *Table {
// if has primary key, return column // if has primary key, return column
func (table *Table) PKColumn() *Column { func (table *Table) PKColumn() *Column {
return table.Columns[table.PrimaryKey] return table.Columns[table.PrimaryKey]
} }
func (table *Table) VersionColumn() *Column { func (table *Table) VersionColumn() *Column {
return table.Columns[table.Version] return table.Columns[table.Version]
} }
// add a column to table // add a column to table
func (table *Table) AddColumn(col *Column) { func (table *Table) AddColumn(col *Column) {
table.ColumnsSeq = append(table.ColumnsSeq, col.Name) table.ColumnsSeq = append(table.ColumnsSeq, col.Name)
table.Columns[col.Name] = col table.Columns[col.Name] = col
if col.IsPrimaryKey { if col.IsPrimaryKey {
table.PrimaryKey = col.Name table.PrimaryKey = col.Name
} }
if col.IsCreated { if col.IsCreated {
table.Created[col.Name] = true table.Created[col.Name] = true
} }
if col.IsUpdated { if col.IsUpdated {
table.Updated = col.Name table.Updated = col.Name
} }
if col.IsVersion { if col.IsVersion {
table.Version = col.Name table.Version = col.Name
} }
} }
// add an index or an unique to table // add an index or an unique to table
func (table *Table) AddIndex(index *Index) { func (table *Table) AddIndex(index *Index) {
table.Indexes[index.Name] = index table.Indexes[index.Name] = index
} }
func (table *Table) genCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) { func (table *Table) genCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {
colNames := make([]string, 0) colNames := make([]string, 0)
args := make([]interface{}, 0) args := make([]interface{}, 0)
for _, col := range table.Columns { for _, col := range table.Columns {
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if _, ok := session.Statement.columnMap[col.Name]; !ok { if _, ok := session.Statement.columnMap[col.Name]; !ok {
continue continue
} }
} }
if col.MapType == ONLYFROMDB { if col.MapType == ONLYFROMDB {
continue continue
} }
fieldValue := col.ValueOf(bean) fieldValue := col.ValueOf(bean)
if col.IsAutoIncrement && fieldValue.Int() == 0 { if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue continue
} }
if session.Statement.ColumnStr != "" { if session.Statement.ColumnStr != "" {
if _, ok := session.Statement.columnMap[col.Name]; !ok { if _, ok := session.Statement.columnMap[col.Name]; !ok {
continue continue
} }
} }
if session.Statement.OmitStr != "" { if session.Statement.OmitStr != "" {
if _, ok := session.Statement.columnMap[col.Name]; ok { if _, ok := session.Statement.columnMap[col.Name]; ok {
continue continue
} }
} }
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
args = append(args, time.Now()) args = append(args, time.Now())
} else if col.IsVersion && session.Statement.checkVersion { } else if col.IsVersion && session.Statement.checkVersion {
args = append(args, 1) args = append(args, 1)
} else { } else {
arg, err := session.value2Interface(col, fieldValue) arg, err := session.value2Interface(col, fieldValue)
if err != nil { if err != nil {
return colNames, args, err return colNames, args, err
} }
args = append(args, arg) args = append(args, arg)
} }
if includeQuote { if includeQuote {
colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?") colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?")
} else { } else {
colNames = append(colNames, col.Name) colNames = append(colNames, col.Name)
} }
} }
return colNames, args, nil return colNames, args, nil
} }
// Conversion is an interface. A type implements Conversion will according // Conversion is an interface. A type implements Conversion will according
// the custom method to fill into database and retrieve from database. // the custom method to fill into database and retrieve from database.
type Conversion interface { type Conversion interface {
FromDB([]byte) error FromDB([]byte) error
ToDB() ([]byte, error) ToDB() ([]byte, error)
} }

78
xorm.go
View File

@ -1,58 +1,58 @@
package xorm package xorm
import ( import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
"sync" "sync"
) )
const ( const (
version string = "0.2.3" version string = "0.2.3"
) )
func close(engine *Engine) { func close(engine *Engine) {
engine.Close() engine.Close()
} }
// new a db manager according to the parameter. Currently support four // new a db manager according to the parameter. Currently support four
// drivers // drivers
func NewEngine(driverName string, dataSourceName string) (*Engine, error) { func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine := &Engine{DriverName: driverName, engine := &Engine{DriverName: driverName,
DataSourceName: dataSourceName, Filters: make([]Filter, 0)} DataSourceName: dataSourceName, Filters: make([]Filter, 0)}
engine.SetMapper(SnakeMapper{}) engine.SetMapper(SnakeMapper{})
if driverName == SQLITE { if driverName == SQLITE {
engine.dialect = &sqlite3{} engine.dialect = &sqlite3{}
} else if driverName == MYSQL { } else if driverName == MYSQL {
engine.dialect = &mysql{} engine.dialect = &mysql{}
} else if driverName == POSTGRES { } else if driverName == POSTGRES {
engine.dialect = &postgres{} engine.dialect = &postgres{}
engine.Filters = append(engine.Filters, &PgSeqFilter{}) engine.Filters = append(engine.Filters, &PgSeqFilter{})
engine.Filters = append(engine.Filters, &QuoteFilter{}) engine.Filters = append(engine.Filters, &QuoteFilter{})
} else if driverName == MYMYSQL { } else if driverName == MYMYSQL {
engine.dialect = &mymysql{} engine.dialect = &mymysql{}
} else { } else {
return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName)) return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName))
} }
err := engine.dialect.Init(driverName, dataSourceName) err := engine.dialect.Init(driverName, dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
engine.Tables = make(map[reflect.Type]*Table) engine.Tables = make(map[reflect.Type]*Table)
engine.mutex = &sync.Mutex{} engine.mutex = &sync.Mutex{}
engine.TagIdentifier = "xorm" engine.TagIdentifier = "xorm"
engine.Filters = append(engine.Filters, &IdFilter{}) engine.Filters = append(engine.Filters, &IdFilter{})
engine.Logger = os.Stdout engine.Logger = os.Stdout
//engine.Pool = NewSimpleConnectPool() //engine.Pool = NewSimpleConnectPool()
//engine.Pool = NewNoneConnectPool() //engine.Pool = NewNoneConnectPool()
//engine.Cacher = NewLRUCacher() //engine.Cacher = NewLRUCacher()
err = engine.SetPool(NewSysConnectPool()) err = engine.SetPool(NewSysConnectPool())
runtime.SetFinalizer(engine, close) runtime.SetFinalizer(engine, close)
return engine, err return engine, err
} }