replace tab to 4 spaces for all codes

This commit is contained in:
Lunny Xiao 2013-12-09 10:29:23 +08:00
parent e84e14f972
commit c70b4ad8d3
39 changed files with 9759 additions and 9759 deletions

File diff suppressed because it is too large Load Diff

View File

@ -1,180 +1,180 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"testing" "testing"
) )
type BigStruct struct { type BigStruct struct {
Id int64 Id int64
Name string Name string
Title string Title string
Age string Age string
Alias string Alias string
NickName string NickName string
} }
func doBenchDriverInsert(engine *Engine, db *sql.DB, b *testing.B) { func doBenchDriverInsert(engine *Engine, db *sql.DB, b *testing.B) {
b.StopTimer() b.StopTimer()
err := engine.CreateTables(&BigStruct{}) err := engine.CreateTables(&BigStruct{})
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
doBenchDriverInsertS(db, b) doBenchDriverInsertS(db, b)
err = engine.DropTables(&BigStruct{}) err = engine.DropTables(&BigStruct{})
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
func doBenchDriverInsertS(db *sql.DB, b *testing.B) { func doBenchDriverInsertS(db *sql.DB, b *testing.B) {
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := db.Exec(`insert into big_struct (name, title, age, alias, nick_name) _, err := db.Exec(`insert into big_struct (name, title, age, alias, nick_name)
values ('fafdasf', 'fadfa', 'afadfsaf', 'fadfafdsafd', 'fadfafdsaf')`) values ('fafdasf', 'fadfa', 'afadfsaf', 'fadfafdsafd', 'fadfafdsaf')`)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StopTimer() b.StopTimer()
} }
func doBenchDriverFind(engine *Engine, db *sql.DB, b *testing.B) { func doBenchDriverFind(engine *Engine, db *sql.DB, b *testing.B) {
b.StopTimer() b.StopTimer()
err := engine.CreateTables(&BigStruct{}) err := engine.CreateTables(&BigStruct{})
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
doBenchDriverFindS(db, b) doBenchDriverFindS(db, b)
err = engine.DropTables(&BigStruct{}) err = engine.DropTables(&BigStruct{})
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
func doBenchDriverFindS(db *sql.DB, b *testing.B) { func doBenchDriverFindS(db *sql.DB, b *testing.B) {
b.StopTimer() b.StopTimer()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
_, err := db.Exec(`insert into big_struct (name, title, age, alias, nick_name) _, err := db.Exec(`insert into big_struct (name, title, age, alias, nick_name)
values ('fafdasf', 'fadfa', 'afadfsaf', 'fadfafdsafd', 'fadfafdsaf')`) values ('fafdasf', 'fadfa', 'afadfsaf', 'fadfafdsafd', 'fadfafdsaf')`)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := db.Query("select * from big_struct limit 50") _, err := db.Query("select * from big_struct limit 50")
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StopTimer() b.StopTimer()
} }
func doBenchInsert(engine *Engine, b *testing.B) { func doBenchInsert(engine *Engine, b *testing.B) {
b.StopTimer() b.StopTimer()
bs := &BigStruct{0, "fafdasf", "fadfa", "afadfsaf", "fadfafdsafd", "fadfafdsaf"} bs := &BigStruct{0, "fafdasf", "fadfa", "afadfsaf", "fadfafdsafd", "fadfafdsaf"}
err := engine.CreateTables(bs) err := engine.CreateTables(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
bs.Id = 0 bs.Id = 0
_, err = engine.Insert(bs) _, err = engine.Insert(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StopTimer() b.StopTimer()
err = engine.DropTables(bs) err = engine.DropTables(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
func doBenchFind(engine *Engine, b *testing.B) { func doBenchFind(engine *Engine, b *testing.B) {
b.StopTimer() b.StopTimer()
bs := &BigStruct{0, "fafdasf", "fadfa", "afadfsaf", "fadfafdsafd", "fadfafdsaf"} bs := &BigStruct{0, "fafdasf", "fadfa", "afadfsaf", "fadfafdsafd", "fadfafdsaf"}
err := engine.CreateTables(bs) err := engine.CreateTables(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
bs.Id = 0 bs.Id = 0
_, err = engine.Insert(bs) _, err = engine.Insert(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
bss := new([]BigStruct) bss := new([]BigStruct)
err = engine.Limit(50).Find(bss) err = engine.Limit(50).Find(bss)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StopTimer() b.StopTimer()
err = engine.DropTables(bs) err = engine.DropTables(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
func doBenchFindPtr(engine *Engine, b *testing.B) { func doBenchFindPtr(engine *Engine, b *testing.B) {
b.StopTimer() b.StopTimer()
bs := &BigStruct{0, "fafdasf", "fadfa", "afadfsaf", "fadfafdsafd", "fadfafdsaf"} bs := &BigStruct{0, "fafdasf", "fadfa", "afadfsaf", "fadfafdsafd", "fadfafdsaf"}
err := engine.CreateTables(bs) err := engine.CreateTables(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
bs.Id = 0 bs.Id = 0
_, err = engine.Insert(bs) _, err = engine.Insert(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
bss := new([]*BigStruct) bss := new([]*BigStruct)
err = engine.Limit(50).Find(bss) err = engine.Limit(50).Find(bss)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }
b.StopTimer() b.StopTimer()
err = engine.DropTables(bs) err = engine.DropTables(bs)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
return return
} }
} }

530
cache.go
View File

@ -1,395 +1,395 @@
package xorm package xorm
import ( import (
"container/list" "container/list"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
) )
const ( const (
// default cache expired time // default cache expired time
CacheExpired = 60 * time.Minute CacheExpired = 60 * time.Minute
// not use now // not use now
CacheMaxMemory = 256 CacheMaxMemory = 256
// evey ten minutes to clear all expired nodes // evey ten minutes to clear all expired nodes
CacheGcInterval = 10 * time.Minute CacheGcInterval = 10 * time.Minute
// each time when gc to removed max nodes // each time when gc to removed max nodes
CacheGcMaxRemoved = 20 CacheGcMaxRemoved = 20
) )
// CacheStore is a interface to store cache // CacheStore is a interface to store cache
type CacheStore interface { type CacheStore interface {
Put(key, value interface{}) error Put(key, value interface{}) error
Get(key interface{}) (interface{}, error) Get(key interface{}) (interface{}, error)
Del(key interface{}) error Del(key interface{}) error
} }
// MemoryStore implements CacheStore provide local machine // MemoryStore implements CacheStore provide local machine
// memory store // memory store
type MemoryStore struct { type MemoryStore struct {
store map[interface{}]interface{} store map[interface{}]interface{}
mutex sync.RWMutex mutex sync.RWMutex
} }
func NewMemoryStore() *MemoryStore { func NewMemoryStore() *MemoryStore {
return &MemoryStore{store: make(map[interface{}]interface{})} return &MemoryStore{store: make(map[interface{}]interface{})}
} }
func (s *MemoryStore) Put(key, value interface{}) error { func (s *MemoryStore) Put(key, value interface{}) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
s.store[key] = value s.store[key] = value
return nil return nil
} }
func (s *MemoryStore) Get(key interface{}) (interface{}, error) { func (s *MemoryStore) Get(key interface{}) (interface{}, error) {
s.mutex.RLock() s.mutex.RLock()
defer s.mutex.RUnlock() defer s.mutex.RUnlock()
if v, ok := s.store[key]; ok { if v, ok := s.store[key]; ok {
return v, nil return v, nil
} }
return nil, ErrNotExist return nil, ErrNotExist
} }
func (s *MemoryStore) Del(key interface{}) error { func (s *MemoryStore) Del(key interface{}) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
delete(s.store, key) delete(s.store, key)
return nil return nil
} }
// Cacher is an interface to provide cache // Cacher is an interface to provide cache
type Cacher interface { type Cacher interface {
GetIds(tableName, sql string) interface{} GetIds(tableName, sql string) interface{}
GetBean(tableName string, id int64) interface{} GetBean(tableName string, id int64) interface{}
PutIds(tableName, sql string, ids interface{}) PutIds(tableName, sql string, ids interface{})
PutBean(tableName string, id int64, obj interface{}) PutBean(tableName string, id int64, obj interface{})
DelIds(tableName, sql string) DelIds(tableName, sql string)
DelBean(tableName string, id int64) DelBean(tableName string, id int64)
ClearIds(tableName string) ClearIds(tableName string)
ClearBeans(tableName string) ClearBeans(tableName string)
} }
type idNode struct { type idNode struct {
tbName string tbName string
id int64 id int64
lastVisit time.Time lastVisit time.Time
} }
type sqlNode struct { type sqlNode struct {
tbName string tbName string
sql string sql string
lastVisit time.Time lastVisit time.Time
} }
func newIdNode(tbName string, id int64) *idNode { func newIdNode(tbName string, id int64) *idNode {
return &idNode{tbName, id, time.Now()} return &idNode{tbName, id, time.Now()}
} }
func newSqlNode(tbName, sql string) *sqlNode { func newSqlNode(tbName, sql string) *sqlNode {
return &sqlNode{tbName, sql, time.Now()} return &sqlNode{tbName, sql, time.Now()}
} }
// LRUCacher implements Cacher according to LRU algorithm // LRUCacher implements Cacher according to LRU algorithm
type LRUCacher struct { type LRUCacher struct {
idList *list.List idList *list.List
sqlList *list.List sqlList *list.List
idIndex map[string]map[interface{}]*list.Element idIndex map[string]map[interface{}]*list.Element
sqlIndex map[string]map[interface{}]*list.Element sqlIndex map[string]map[interface{}]*list.Element
store CacheStore store CacheStore
Max int Max int
mutex sync.Mutex mutex sync.Mutex
Expired time.Duration Expired time.Duration
maxSize int maxSize int
GcInterval time.Duration GcInterval time.Duration
} }
func newLRUCacher(store CacheStore, expired time.Duration, maxSize int, max int) *LRUCacher { func newLRUCacher(store CacheStore, expired time.Duration, maxSize int, max int) *LRUCacher {
cacher := &LRUCacher{store: store, idList: list.New(), cacher := &LRUCacher{store: store, idList: list.New(),
sqlList: list.New(), Expired: expired, maxSize: maxSize, sqlList: list.New(), Expired: expired, maxSize: maxSize,
GcInterval: CacheGcInterval, Max: max, GcInterval: CacheGcInterval, Max: max,
sqlIndex: make(map[string]map[interface{}]*list.Element), sqlIndex: make(map[string]map[interface{}]*list.Element),
idIndex: make(map[string]map[interface{}]*list.Element), idIndex: make(map[string]map[interface{}]*list.Element),
} }
cacher.RunGC() cacher.RunGC()
return cacher return cacher
} }
func NewLRUCacher(store CacheStore, max int) *LRUCacher { func NewLRUCacher(store CacheStore, max int) *LRUCacher {
return newLRUCacher(store, CacheExpired, CacheMaxMemory, max) return newLRUCacher(store, CacheExpired, CacheMaxMemory, max)
} }
func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher { func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher {
return newLRUCacher(store, expired, 0, max) return newLRUCacher(store, expired, 0, max)
} }
//func NewLRUCacher3(store CacheStore, expired time.Duration, maxSize int) *LRUCacher { //func NewLRUCacher3(store CacheStore, expired time.Duration, maxSize int) *LRUCacher {
// return newLRUCacher(store, expired, maxSize, 0) // return newLRUCacher(store, expired, maxSize, 0)
//} //}
// RunGC run once every m.GcInterval // RunGC run once every m.GcInterval
func (m *LRUCacher) RunGC() { func (m *LRUCacher) RunGC() {
time.AfterFunc(m.GcInterval, func() { time.AfterFunc(m.GcInterval, func() {
m.RunGC() m.RunGC()
m.GC() m.GC()
}) })
} }
// GC check ids lit and sql list to remove all element expired // GC check ids lit and sql list to remove all element expired
func (m *LRUCacher) GC() { func (m *LRUCacher) GC() {
//fmt.Println("begin gc ...") //fmt.Println("begin gc ...")
//defer fmt.Println("end gc ...") //defer fmt.Println("end gc ...")
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var removedNum int var removedNum int
for e := m.idList.Front(); e != nil; { for e := m.idList.Front(); e != nil; {
if removedNum <= CacheGcMaxRemoved && if removedNum <= CacheGcMaxRemoved &&
time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired {
removedNum++ removedNum++
next := e.Next() next := e.Next()
//fmt.Println("removing ...", e.Value) //fmt.Println("removing ...", e.Value)
node := e.Value.(*idNode) node := e.Value.(*idNode)
m.delBean(node.tbName, node.id) m.delBean(node.tbName, node.id)
e = next e = next
} else { } else {
//fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.idList.Len()) //fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.idList.Len())
break break
} }
} }
removedNum = 0 removedNum = 0
for e := m.sqlList.Front(); e != nil; { for e := m.sqlList.Front(); e != nil; {
if removedNum <= CacheGcMaxRemoved && if removedNum <= CacheGcMaxRemoved &&
time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired {
removedNum++ removedNum++
next := e.Next() next := e.Next()
//fmt.Println("removing ...", e.Value) //fmt.Println("removing ...", e.Value)
node := e.Value.(*sqlNode) node := e.Value.(*sqlNode)
m.delIds(node.tbName, node.sql) m.delIds(node.tbName, node.sql)
e = next e = next
} else { } else {
//fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.sqlList.Len()) //fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.sqlList.Len())
break break
} }
} }
} }
// Get all bean's ids according to sql and parameter from cache // Get all bean's ids according to sql and parameter from cache
func (m *LRUCacher) GetIds(tableName, sql string) interface{} { func (m *LRUCacher) GetIds(tableName, sql string) interface{} {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if _, ok := m.sqlIndex[tableName]; !ok { if _, ok := m.sqlIndex[tableName]; !ok {
m.sqlIndex[tableName] = make(map[interface{}]*list.Element) m.sqlIndex[tableName] = make(map[interface{}]*list.Element)
} }
if v, err := m.store.Get(sql); err == nil { if v, err := m.store.Get(sql); err == nil {
if el, ok := m.sqlIndex[tableName][sql]; !ok { if el, ok := m.sqlIndex[tableName][sql]; !ok {
el = m.sqlList.PushBack(newSqlNode(tableName, sql)) el = m.sqlList.PushBack(newSqlNode(tableName, sql))
m.sqlIndex[tableName][sql] = el m.sqlIndex[tableName][sql] = el
} else { } else {
lastTime := el.Value.(*sqlNode).lastVisit lastTime := el.Value.(*sqlNode).lastVisit
// if expired, remove the node and return nil // if expired, remove the node and return nil
if time.Now().Sub(lastTime) > m.Expired { if time.Now().Sub(lastTime) > m.Expired {
m.delIds(tableName, sql) m.delIds(tableName, sql)
return nil return nil
} }
m.sqlList.MoveToBack(el) m.sqlList.MoveToBack(el)
el.Value.(*sqlNode).lastVisit = time.Now() el.Value.(*sqlNode).lastVisit = time.Now()
} }
return v return v
} else { } else {
m.delIds(tableName, sql) m.delIds(tableName, sql)
} }
return nil return nil
} }
// Get bean according tableName and id from cache // Get bean according tableName and id from cache
func (m *LRUCacher) GetBean(tableName string, id int64) interface{} { func (m *LRUCacher) GetBean(tableName string, id int64) interface{} {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if _, ok := m.idIndex[tableName]; !ok { if _, ok := m.idIndex[tableName]; !ok {
m.idIndex[tableName] = make(map[interface{}]*list.Element) m.idIndex[tableName] = make(map[interface{}]*list.Element)
} }
tid := genId(tableName, id) tid := genId(tableName, id)
if v, err := m.store.Get(tid); err == nil { if v, err := m.store.Get(tid); err == nil {
if el, ok := m.idIndex[tableName][id]; ok { if el, ok := m.idIndex[tableName][id]; ok {
lastTime := el.Value.(*idNode).lastVisit lastTime := el.Value.(*idNode).lastVisit
// if expired, remove the node and return nil // if expired, remove the node and return nil
if time.Now().Sub(lastTime) > m.Expired { if time.Now().Sub(lastTime) > m.Expired {
m.delBean(tableName, id) m.delBean(tableName, id)
//m.clearIds(tableName) //m.clearIds(tableName)
return nil return nil
} }
m.idList.MoveToBack(el) m.idList.MoveToBack(el)
el.Value.(*idNode).lastVisit = time.Now() el.Value.(*idNode).lastVisit = time.Now()
} else { } else {
el = m.idList.PushBack(newIdNode(tableName, id)) el = m.idList.PushBack(newIdNode(tableName, id))
m.idIndex[tableName][id] = el m.idIndex[tableName][id] = el
} }
return v return v
} else { } else {
// store bean is not exist, then remove memory's index // store bean is not exist, then remove memory's index
m.delBean(tableName, id) m.delBean(tableName, id)
//m.clearIds(tableName) //m.clearIds(tableName)
return nil return nil
} }
} }
// Clear all sql-ids mapping on table tableName from cache // Clear all sql-ids mapping on table tableName from cache
func (m *LRUCacher) clearIds(tableName string) { func (m *LRUCacher) clearIds(tableName string) {
if tis, ok := m.sqlIndex[tableName]; ok { if tis, ok := m.sqlIndex[tableName]; ok {
for sql, v := range tis { for sql, v := range tis {
m.sqlList.Remove(v) m.sqlList.Remove(v)
m.store.Del(sql) m.store.Del(sql)
} }
} }
m.sqlIndex[tableName] = make(map[interface{}]*list.Element) m.sqlIndex[tableName] = make(map[interface{}]*list.Element)
} }
func (m *LRUCacher) ClearIds(tableName string) { func (m *LRUCacher) ClearIds(tableName string) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.clearIds(tableName) m.clearIds(tableName)
} }
func (m *LRUCacher) clearBeans(tableName string) { func (m *LRUCacher) clearBeans(tableName string) {
if tis, ok := m.idIndex[tableName]; ok { if tis, ok := m.idIndex[tableName]; ok {
for id, v := range tis { for id, v := range tis {
m.idList.Remove(v) m.idList.Remove(v)
tid := genId(tableName, id.(int64)) tid := genId(tableName, id.(int64))
m.store.Del(tid) m.store.Del(tid)
} }
} }
m.idIndex[tableName] = make(map[interface{}]*list.Element) m.idIndex[tableName] = make(map[interface{}]*list.Element)
} }
func (m *LRUCacher) ClearBeans(tableName string) { func (m *LRUCacher) ClearBeans(tableName string) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.clearBeans(tableName) m.clearBeans(tableName)
} }
func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) { func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if _, ok := m.sqlIndex[tableName]; !ok { if _, ok := m.sqlIndex[tableName]; !ok {
m.sqlIndex[tableName] = make(map[interface{}]*list.Element) m.sqlIndex[tableName] = make(map[interface{}]*list.Element)
} }
if el, ok := m.sqlIndex[tableName][sql]; !ok { if el, ok := m.sqlIndex[tableName][sql]; !ok {
el = m.sqlList.PushBack(newSqlNode(tableName, sql)) el = m.sqlList.PushBack(newSqlNode(tableName, sql))
m.sqlIndex[tableName][sql] = el m.sqlIndex[tableName][sql] = el
} else { } else {
el.Value.(*sqlNode).lastVisit = time.Now() el.Value.(*sqlNode).lastVisit = time.Now()
} }
m.store.Put(sql, ids) m.store.Put(sql, ids)
if m.sqlList.Len() > m.Max { if m.sqlList.Len() > m.Max {
e := m.sqlList.Front() e := m.sqlList.Front()
node := e.Value.(*sqlNode) node := e.Value.(*sqlNode)
m.delIds(node.tbName, node.sql) m.delIds(node.tbName, node.sql)
} }
} }
func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) { func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var el *list.Element var el *list.Element
var ok bool var ok bool
if el, ok = m.idIndex[tableName][id]; !ok { if el, ok = m.idIndex[tableName][id]; !ok {
el = m.idList.PushBack(newIdNode(tableName, id)) el = m.idList.PushBack(newIdNode(tableName, id))
m.idIndex[tableName][id] = el m.idIndex[tableName][id] = el
} else { } else {
el.Value.(*idNode).lastVisit = time.Now() el.Value.(*idNode).lastVisit = time.Now()
} }
m.store.Put(genId(tableName, id), obj) m.store.Put(genId(tableName, id), obj)
if m.idList.Len() > m.Max { if m.idList.Len() > m.Max {
e := m.idList.Front() e := m.idList.Front()
node := e.Value.(*idNode) node := e.Value.(*idNode)
m.delBean(node.tbName, node.id) m.delBean(node.tbName, node.id)
} }
} }
func (m *LRUCacher) delIds(tableName, sql string) { func (m *LRUCacher) delIds(tableName, sql string) {
if _, ok := m.sqlIndex[tableName]; ok { if _, ok := m.sqlIndex[tableName]; ok {
if el, ok := m.sqlIndex[tableName][sql]; ok { if el, ok := m.sqlIndex[tableName][sql]; ok {
delete(m.sqlIndex[tableName], sql) delete(m.sqlIndex[tableName], sql)
m.sqlList.Remove(el) m.sqlList.Remove(el)
} }
} }
m.store.Del(sql) m.store.Del(sql)
} }
func (m *LRUCacher) DelIds(tableName, sql string) { func (m *LRUCacher) DelIds(tableName, sql string) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.delIds(tableName, sql) m.delIds(tableName, sql)
} }
func (m *LRUCacher) delBean(tableName string, id int64) { func (m *LRUCacher) delBean(tableName string, id int64) {
tid := genId(tableName, id) tid := genId(tableName, id)
if el, ok := m.idIndex[tableName][id]; ok { if el, ok := m.idIndex[tableName][id]; ok {
delete(m.idIndex[tableName], id) delete(m.idIndex[tableName], id)
m.idList.Remove(el) m.idList.Remove(el)
m.clearIds(tableName) m.clearIds(tableName)
} }
m.store.Del(tid) m.store.Del(tid)
} }
func (m *LRUCacher) DelBean(tableName string, id int64) { func (m *LRUCacher) DelBean(tableName string, id int64) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.delBean(tableName, id) m.delBean(tableName, id)
} }
func encodeIds(ids []int64) (s string) { func encodeIds(ids []int64) (s string) {
s = "[" s = "["
for _, id := range ids { for _, id := range ids {
s += fmt.Sprintf("%v,", id) s += fmt.Sprintf("%v,", id)
} }
s = s[:len(s)-1] + "]" s = s[:len(s)-1] + "]"
return return
} }
func decodeIds(s string) []int64 { func decodeIds(s string) []int64 {
res := make([]int64, 0) res := make([]int64, 0)
if len(s) >= 2 { if len(s) >= 2 {
ss := strings.Split(s[1:len(s)-1], ",") ss := strings.Split(s[1:len(s)-1], ",")
for _, s := range ss { for _, s := range ss {
i, err := strconv.ParseInt(s, 10, 64) i, err := strconv.ParseInt(s, 10, 64)
if err != nil { if err != nil {
return res return res
} }
res = append(res, i) res = append(res, i)
} }
} }
return res return res
} }
func getCacheSql(m Cacher, tableName, sql string, args interface{}) ([]int64, error) { func getCacheSql(m Cacher, tableName, sql string, args interface{}) ([]int64, error) {
bytes := m.GetIds(tableName, genSqlKey(sql, args)) bytes := m.GetIds(tableName, genSqlKey(sql, args))
if bytes == nil { if bytes == nil {
return nil, errors.New("Not Exist") return nil, errors.New("Not Exist")
} }
objs := decodeIds(bytes.(string)) objs := decodeIds(bytes.(string))
return objs, nil return objs, nil
} }
func putCacheSql(m Cacher, ids []int64, tableName, sql string, args interface{}) error { func putCacheSql(m Cacher, ids []int64, tableName, sql string, args interface{}) error {
bytes := encodeIds(ids) bytes := encodeIds(ids)
m.PutIds(tableName, genSqlKey(sql, args), bytes) m.PutIds(tableName, genSqlKey(sql, args), bytes)
return nil return nil
} }
func genSqlKey(sql string, args interface{}) string { func genSqlKey(sql string, args interface{}) string {
return fmt.Sprintf("%v-%v", sql, args) return fmt.Sprintf("%v-%v", sql, args)
} }
func genId(prefix string, id int64) string { func genId(prefix string, id int64) string {
return fmt.Sprintf("%v-%v", prefix, id) return fmt.Sprintf("%v-%v", prefix, id)
} }

102
doc.go
View File

@ -9,13 +9,13 @@ Installation
Make sure you have installed Go 1.1+ and then: Make sure you have installed Go 1.1+ and then:
go get github.com/lunny/xorm go get github.com/lunny/xorm
Create Engine Create Engine
Firstly, we should new an engine for a database Firstly, we should new an engine for a database
engine, err := xorm.NewEngine(driverName, dataSourceName) engine, err := xorm.NewEngine(driverName, dataSourceName)
Method NewEngine's parameters is the same as sql.Open. It depends Method NewEngine's parameters is the same as sql.Open. It depends
drivers' implementation. drivers' implementation.
@ -27,11 +27,11 @@ Xorm also support raw sql execution:
1. query sql, the returned results is []map[string][]byte 1. query sql, the returned results is []map[string][]byte
results, err := engine.Query("select * from user") results, err := engine.Query("select * from user")
2. exec sql, the returned results 2. exec sql, the returned results
affected, err := engine.Exec("update user set .... where ...") affected, err := engine.Exec("update user set .... where ...")
ORM Methods ORM Methods
@ -39,46 +39,46 @@ There are 7 major ORM methods and many helpful methods to use to operate databas
1. Insert one or multipe records to database 1. Insert one or multipe records to database
affected, err := engine.Insert(&struct) affected, err := engine.Insert(&struct)
// INSERT INTO struct () values () // INSERT INTO struct () values ()
affected, err := engine.Insert(&struct1, &struct2) affected, err := engine.Insert(&struct1, &struct2)
// INSERT INTO struct1 () values () // INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values () // INSERT INTO struct2 () values ()
affected, err := engine.Insert(&sliceOfStruct) affected, err := engine.Insert(&sliceOfStruct)
// INSERT INTO struct () values (),(),() // INSERT INTO struct () values (),(),()
affected, err := engine.Insert(&struct1, &sliceOfStruct2) affected, err := engine.Insert(&struct1, &sliceOfStruct2)
// INSERT INTO struct1 () values () // INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values (),(),() // INSERT INTO struct2 () values (),(),()
2. Query one record from database 2. Query one record from database
has, err := engine.Get(&user) has, err := engine.Get(&user)
// SELECT * FROM user LIMIT 1 // SELECT * FROM user LIMIT 1
3. Query multiple records from database 3. Query multiple records from database
err := engine.Find(...) err := engine.Find(...)
// SELECT * FROM user // SELECT * FROM user
4. Query multiple records and record by record handle 4. Query multiple records and record by record handle
err := engine.Iterate(...) err := engine.Iterate(...)
// SELECT * FROM user // SELECT * FROM user
5. Update one or more records 5. Update one or more records
affected, err := engine.Update(&user) affected, err := engine.Update(&user)
// UPDATE user SET // UPDATE user SET
6. Delete one or more records 6. Delete one or more records
affected, err := engine.Delete(&user) affected, err := engine.Delete(&user)
// DELETE FROM user Where ... // DELETE FROM user Where ...
7. Count records 7. Count records
counts, err := engine.Count(&user) counts, err := engine.Count(&user)
// SELECT count(*) AS total FROM user // SELECT count(*) AS total FROM user
Conditions Conditions
@ -86,49 +86,49 @@ The above 7 methods could use with condition methods.
1. Id, In 1. Id, In
engine.Id(1).Get(&user) engine.Id(1).Get(&user)
// SELECT * FROM user WHERE id = 1 // SELECT * FROM user WHERE id = 1
engine.In("id", 1, 2, 3).Find(&users) engine.In("id", 1, 2, 3).Find(&users)
// SELECT * FROM user WHERE id IN (1, 2, 3) // SELECT * FROM user WHERE id IN (1, 2, 3)
2. Where, And, Or 2. Where, And, Or
engine.Where().And().Or().Find() engine.Where().And().Or().Find()
// SELECT * FROM user WHERE (.. AND ..) OR ... // SELECT * FROM user WHERE (.. AND ..) OR ...
3. OrderBy, Asc, Desc 3. OrderBy, Asc, Desc
engine.Asc().Desc().Find() engine.Asc().Desc().Find()
// SELECT * FROM user ORDER BY .. ASC, .. DESC // SELECT * FROM user ORDER BY .. ASC, .. DESC
engine.OrderBy().Find() engine.OrderBy().Find()
// SELECT * FROM user ORDER BY .. // SELECT * FROM user ORDER BY ..
4. Limit, Top 4. Limit, Top
engine.Limit().Find() engine.Limit().Find()
// SELECT * FROM user LIMIT .. OFFSET .. // SELECT * FROM user LIMIT .. OFFSET ..
engine.Top().Find() engine.Top().Find()
// SELECT * FROM user LIMIT .. // SELECT * FROM user LIMIT ..
5. Sql 5. Sql
engine.Sql("select * from user").Find() engine.Sql("select * from user").Find()
6. Cols, Omit, Distinct 6. Cols, Omit, Distinct
engine.Cols("col1, col2").Find() engine.Cols("col1, col2").Find()
// SELECT col1, col2 FROM user // SELECT col1, col2 FROM user
engine.Omit("col1").Find() engine.Omit("col1").Find()
// SELECT col2, col3 FROM user // SELECT col2, col3 FROM user
engine.Distinct("col1").Find() engine.Distinct("col1").Find()
// SELECT DISTINCT col1 FROM user // SELECT DISTINCT col1 FROM user
7. Join, GroupBy, Having 7. Join, GroupBy, Having
engine.GroupBy("name").Having("name='xlw'").Find() engine.GroupBy("name").Having("name='xlw'").Find()
//SELECT * FROM user GROUP BY name HAVING name='xlw' //SELECT * FROM user GROUP BY name HAVING name='xlw'
engine.Join("LEFT", "userdetail", "user.id=userdetail.id").Find() engine.Join("LEFT", "userdetail", "user.id=userdetail.id").Find()
//SELECT * FROM user LEFT JOIN userdetail ON user.id=userdetail.id //SELECT * FROM user LEFT JOIN userdetail ON user.id=userdetail.id
More usage, please visit https://github.com/lunny/xorm/blob/master/docs/QuickStartEn.md More usage, please visit https://github.com/lunny/xorm/blob/master/docs/QuickStartEn.md
*/ */

1250
engine.go

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,15 @@
package xorm package xorm
import ( import (
"errors" "errors"
) )
var ( var (
ErrParamsType error = errors.New("Params type error") ErrParamsType error = errors.New("Params type error")
ErrTableNotFound error = errors.New("Not found table") ErrTableNotFound error = errors.New("Not found table")
ErrUnSupportedType error = errors.New("Unsupported type error") ErrUnSupportedType error = errors.New("Unsupported type error")
ErrNotExist error = errors.New("Not exist error") ErrNotExist error = errors.New("Not exist error")
ErrCacheFailed error = errors.New("Cache failed") ErrCacheFailed error = errors.New("Cache failed")
ErrNeedDeletedCond error = errors.New("Delete need at least one condition") ErrNeedDeletedCond error = errors.New("Delete need at least one condition")
ErrNotImplemented error = errors.New("Not implemented.") ErrNotImplemented error = errors.New("Not implemented.")
) )

View File

@ -1,109 +1,109 @@
package main package main
import ( import (
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
. "xorm" . "xorm"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
func main() { func main() {
f := "cache.db" f := "cache.db"
os.Remove(f) os.Remove(f)
Orm, err := NewEngine("sqlite3", f) Orm, err := NewEngine("sqlite3", f)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
Orm.ShowSQL = true Orm.ShowSQL = true
cacher := NewLRUCacher(NewMemoryStore(), 1000) cacher := NewLRUCacher(NewMemoryStore(), 1000)
Orm.SetDefaultCacher(cacher) Orm.SetDefaultCacher(cacher)
err = Orm.CreateTables(&User{}) err = Orm.CreateTables(&User{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
_, err = Orm.Insert(&User{Name: "xlw"}) _, err = Orm.Insert(&User{Name: "xlw"})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
users := make([]User, 0) users := make([]User, 0)
err = Orm.Find(&users) err = Orm.Find(&users)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("users:", users) fmt.Println("users:", users)
users2 := make([]User, 0) users2 := make([]User, 0)
err = Orm.Find(&users2) err = Orm.Find(&users2)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("users2:", users2) fmt.Println("users2:", users2)
users3 := make([]User, 0) users3 := make([]User, 0)
err = Orm.Find(&users3) err = Orm.Find(&users3)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("users3:", users3) fmt.Println("users3:", users3)
user4 := new(User) user4 := new(User)
has, err := Orm.Id(1).Get(user4) has, err := Orm.Id(1).Get(user4)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("user4:", has, user4) fmt.Println("user4:", has, user4)
user4.Name = "xiaolunwen" user4.Name = "xiaolunwen"
_, err = Orm.Id(1).Update(user4) _, err = Orm.Id(1).Update(user4)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("user4:", user4) fmt.Println("user4:", user4)
user5 := new(User) user5 := new(User)
has, err = Orm.Id(1).Get(user5) has, err = Orm.Id(1).Get(user5)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("user5:", has, user5) fmt.Println("user5:", has, user5)
_, err = Orm.Id(1).Delete(new(User)) _, err = Orm.Id(1).Delete(new(User))
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
for { for {
user6 := new(User) user6 := new(User)
has, err = Orm.Id(1).Get(user6) has, err = Orm.Id(1).Get(user6)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("user6:", has, user6) fmt.Println("user6:", has, user6)
} }
} }

View File

@ -1,113 +1,113 @@
package main package main
import ( import (
//xorm "github.com/lunny/xorm" //xorm "github.com/lunny/xorm"
"fmt" "fmt"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
//"time" //"time"
//"sync/atomic" //"sync/atomic"
xorm "xorm" xorm "xorm"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
func sqliteEngine() (*xorm.Engine, error) { func sqliteEngine() (*xorm.Engine, error) {
os.Remove("./test.db") os.Remove("./test.db")
return xorm.NewEngine("sqlite3", "./goroutine.db") return xorm.NewEngine("sqlite3", "./goroutine.db")
} }
func mysqlEngine() (*xorm.Engine, error) { func mysqlEngine() (*xorm.Engine, error) {
return xorm.NewEngine("mysql", "root:@/test?charset=utf8") return xorm.NewEngine("mysql", "root:@/test?charset=utf8")
} }
var u *User = &User{} var u *User = &User{}
func test(engine *xorm.Engine) { func test(engine *xorm.Engine) {
err := engine.CreateTables(u) err := engine.CreateTables(u)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
size := 500 size := 500
queue := make(chan int, size) queue := make(chan int, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
go func(x int) { go func(x int) {
//x := i //x := i
err := engine.Test() err := engine.Test()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
err = engine.Map(u) err = engine.Map(u)
if err != nil { if err != nil {
fmt.Println("Map user failed") fmt.Println("Map user failed")
} else { } else {
for j := 0; j < 10; j++ { for j := 0; j < 10; j++ {
if x+j < 2 { if x+j < 2 {
_, err = engine.Get(u) _, err = engine.Get(u)
} else if x+j < 4 { } else if x+j < 4 {
users := make([]User, 0) users := make([]User, 0)
err = engine.Find(&users) err = engine.Find(&users)
} else if x+j < 8 { } else if x+j < 8 {
_, err = engine.Count(u) _, err = engine.Count(u)
} else if x+j < 16 { } else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"}) _, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 { } else if x+j < 32 {
//_, err = engine.Id(1).Delete(u) //_, err = engine.Id(1).Delete(u)
_, err = engine.Delete(u) _, err = engine.Delete(u)
} }
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
queue <- x queue <- x
return return
} }
} }
fmt.Printf("%v success!\n", x) fmt.Printf("%v success!\n", x)
} }
} }
queue <- x queue <- x
}(i) }(i)
} }
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
<-queue <-queue
} }
//conns := atomic.LoadInt32(&xorm.ConnectionNum) //conns := atomic.LoadInt32(&xorm.ConnectionNum)
//fmt.Println("connection number:", conns) //fmt.Println("connection number:", conns)
fmt.Println("end") fmt.Println("end")
} }
func main() { func main() {
fmt.Println("-----start sqlite go routines-----") fmt.Println("-----start sqlite go routines-----")
engine, err := sqliteEngine() engine, err := sqliteEngine()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
engine.ShowSQL = true engine.ShowSQL = true
cacher := xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000) cacher := xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)
engine.SetDefaultCacher(cacher) engine.SetDefaultCacher(cacher)
fmt.Println(engine) fmt.Println(engine)
test(engine) test(engine)
fmt.Println("test end") fmt.Println("test end")
engine.Close() engine.Close()
fmt.Println("-----start mysql go routines-----") fmt.Println("-----start mysql go routines-----")
engine, err = mysqlEngine() engine, err = mysqlEngine()
engine.ShowSQL = true engine.ShowSQL = true
cacher = xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000) cacher = xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)
engine.SetDefaultCacher(cacher) engine.SetDefaultCacher(cacher)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
defer engine.Close() defer engine.Close()
test(engine) test(engine)
} }

View File

@ -1,76 +1,76 @@
package main package main
import ( import (
"errors" "errors"
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
. "xorm" . "xorm"
) )
type Status struct { type Status struct {
Name string Name string
Color string Color string
} }
var ( var (
Registed Status = Status{"Registed", "white"} Registed Status = Status{"Registed", "white"}
Approved Status = Status{"Approved", "green"} Approved Status = Status{"Approved", "green"}
Removed Status = Status{"Removed", "red"} Removed Status = Status{"Removed", "red"}
Statuses map[string]Status = map[string]Status{ Statuses map[string]Status = map[string]Status{
Registed.Name: Registed, Registed.Name: Registed,
Approved.Name: Approved, Approved.Name: Approved,
Removed.Name: Removed, Removed.Name: Removed,
} }
) )
func (s *Status) FromDB(bytes []byte) error { func (s *Status) FromDB(bytes []byte) error {
if r, ok := Statuses[string(bytes)]; ok { if r, ok := Statuses[string(bytes)]; ok {
*s = r *s = r
return nil return nil
} else { } else {
return errors.New("no this data") return errors.New("no this data")
} }
} }
func (s *Status) ToDB() ([]byte, error) { func (s *Status) ToDB() ([]byte, error) {
return []byte(s.Name), nil return []byte(s.Name), nil
} }
type User struct { type User struct {
Id int64 Id int64
Name string Name string
Status Status `xorm:"varchar(40)"` Status Status `xorm:"varchar(40)"`
} }
func main() { func main() {
f := "conversion.db" f := "conversion.db"
os.Remove(f) os.Remove(f)
Orm, err := NewEngine("sqlite3", f) Orm, err := NewEngine("sqlite3", f)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
Orm.ShowSQL = true Orm.ShowSQL = true
err = Orm.CreateTables(&User{}) err = Orm.CreateTables(&User{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
_, err = Orm.Insert(&User{1, "xlw", Registed}) _, err = Orm.Insert(&User{1, "xlw", Registed})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
users := make([]User, 0) users := make([]User, 0)
err = Orm.Find(&users) err = Orm.Find(&users)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println(users) fmt.Println(users)
} }

View File

@ -1,66 +1,66 @@
package main package main
import ( import (
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
. "xorm" . "xorm"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
type LoginInfo struct { type LoginInfo struct {
Id int64 Id int64
IP string IP string
UserId int64 UserId int64
} }
type LoginInfo1 struct { type LoginInfo1 struct {
LoginInfo `xorm:"extends"` LoginInfo `xorm:"extends"`
UserName string UserName string
} }
func main() { func main() {
f := "derive.db" f := "derive.db"
os.Remove(f) os.Remove(f)
Orm, err := NewEngine("sqlite3", f) Orm, err := NewEngine("sqlite3", f)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
defer Orm.Close() defer Orm.Close()
Orm.ShowSQL = true Orm.ShowSQL = true
err = Orm.CreateTables(&User{}, &LoginInfo{}) err = Orm.CreateTables(&User{}, &LoginInfo{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
_, err = Orm.Insert(&User{1, "xlw"}, &LoginInfo{1, "127.0.0.1", 1}) _, err = Orm.Insert(&User{1, "xlw"}, &LoginInfo{1, "127.0.0.1", 1})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
info := LoginInfo{} info := LoginInfo{}
_, err = Orm.Id(1).Get(&info) _, err = Orm.Id(1).Get(&info)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println(info) fmt.Println(info)
infos := make([]LoginInfo1, 0) infos := make([]LoginInfo1, 0)
err = Orm.Sql(`select *, (select name from user where id = login_info.user_id) as user_name from err = Orm.Sql(`select *, (select name from user where id = login_info.user_id) as user_name from
login_info limit 10`).Find(&infos) login_info limit 10`).Find(&infos)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println(infos) fmt.Println(infos)
} }

View File

@ -1,109 +1,109 @@
package main package main
import ( import (
//xorm "github.com/lunny/xorm" //xorm "github.com/lunny/xorm"
"fmt" "fmt"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
//"time" //"time"
//"sync/atomic" //"sync/atomic"
"runtime" "runtime"
xorm "xorm" xorm "xorm"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
func sqliteEngine() (*xorm.Engine, error) { func sqliteEngine() (*xorm.Engine, error) {
os.Remove("./test.db") os.Remove("./test.db")
return xorm.NewEngine("sqlite3", "./goroutine.db") return xorm.NewEngine("sqlite3", "./goroutine.db")
} }
func mysqlEngine() (*xorm.Engine, error) { func mysqlEngine() (*xorm.Engine, error) {
return xorm.NewEngine("mysql", "root:@/test?charset=utf8") return xorm.NewEngine("mysql", "root:@/test?charset=utf8")
} }
var u *User = &User{} var u *User = &User{}
func test(engine *xorm.Engine) { func test(engine *xorm.Engine) {
err := engine.CreateTables(u) err := engine.CreateTables(u)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
size := 500 size := 500
queue := make(chan int, size) queue := make(chan int, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
go func(x int) { go func(x int) {
//x := i //x := i
err := engine.Test() err := engine.Test()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
err = engine.Map(u) err = engine.Map(u)
if err != nil { if err != nil {
fmt.Println("Map user failed") fmt.Println("Map user failed")
} else { } else {
for j := 0; j < 10; j++ { for j := 0; j < 10; j++ {
if x+j < 2 { if x+j < 2 {
_, err = engine.Get(u) _, err = engine.Get(u)
} else if x+j < 4 { } else if x+j < 4 {
users := make([]User, 0) users := make([]User, 0)
err = engine.Find(&users) err = engine.Find(&users)
} else if x+j < 8 { } else if x+j < 8 {
_, err = engine.Count(u) _, err = engine.Count(u)
} else if x+j < 16 { } else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"}) _, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 { } else if x+j < 32 {
_, err = engine.Id(1).Delete(u) _, err = engine.Id(1).Delete(u)
} }
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
queue <- x queue <- x
return return
} }
} }
fmt.Printf("%v success!\n", x) fmt.Printf("%v success!\n", x)
} }
} }
queue <- x queue <- x
}(i) }(i)
} }
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
<-queue <-queue
} }
//conns := atomic.LoadInt32(&xorm.ConnectionNum) //conns := atomic.LoadInt32(&xorm.ConnectionNum)
//fmt.Println("connection number:", conns) //fmt.Println("connection number:", conns)
fmt.Println("end") fmt.Println("end")
} }
func main() { func main() {
runtime.GOMAXPROCS(2) runtime.GOMAXPROCS(2)
fmt.Println("-----start sqlite go routines-----") fmt.Println("-----start sqlite go routines-----")
engine, err := sqliteEngine() engine, err := sqliteEngine()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
engine.ShowSQL = true engine.ShowSQL = true
fmt.Println(engine) fmt.Println(engine)
test(engine) test(engine)
fmt.Println("test end") fmt.Println("test end")
engine.Close() engine.Close()
fmt.Println("-----start mysql go routines-----") fmt.Println("-----start mysql go routines-----")
engine, err = mysqlEngine() engine, err = mysqlEngine()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
defer engine.Close() defer engine.Close()
test(engine) test(engine)
} }

View File

@ -1,108 +1,108 @@
package main package main
import ( import (
"fmt" "fmt"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
xorm "github.com/lunny/xorm" xorm "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
//"time" //"time"
//"sync/atomic" //"sync/atomic"
"runtime" "runtime"
//xorm "xorm" //xorm "xorm"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
func sqliteEngine() (*xorm.Engine, error) { func sqliteEngine() (*xorm.Engine, error) {
os.Remove("./test.db") os.Remove("./test.db")
return xorm.NewEngine("sqlite3", "./goroutine.db") return xorm.NewEngine("sqlite3", "./goroutine.db")
} }
func mysqlEngine() (*xorm.Engine, error) { func mysqlEngine() (*xorm.Engine, error) {
return xorm.NewEngine("mysql", "root:@/test?charset=utf8") return xorm.NewEngine("mysql", "root:@/test?charset=utf8")
} }
var u *User = &User{} var u *User = &User{}
func test(engine *xorm.Engine) { func test(engine *xorm.Engine) {
err := engine.CreateTables(u) err := engine.CreateTables(u)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
engine.ShowSQL = true engine.ShowSQL = true
engine.Pool.SetMaxConns(5) engine.Pool.SetMaxConns(5)
size := 1000 size := 1000
queue := make(chan int, size) queue := make(chan int, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
go func(x int) { go func(x int) {
//x := i //x := i
err := engine.Test() err := engine.Test()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
err = engine.Map(u) err = engine.Map(u)
if err != nil { if err != nil {
fmt.Println("Map user failed") fmt.Println("Map user failed")
} else { } else {
for j := 0; j < 10; j++ { for j := 0; j < 10; j++ {
if x+j < 2 { if x+j < 2 {
_, err = engine.Get(u) _, err = engine.Get(u)
} else if x+j < 4 { } else if x+j < 4 {
users := make([]User, 0) users := make([]User, 0)
err = engine.Find(&users) err = engine.Find(&users)
} else if x+j < 8 { } else if x+j < 8 {
_, err = engine.Count(u) _, err = engine.Count(u)
} else if x+j < 16 { } else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"}) _, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 { } else if x+j < 32 {
_, err = engine.Id(1).Delete(u) _, err = engine.Id(1).Delete(u)
} }
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
queue <- x queue <- x
return return
} }
} }
fmt.Printf("%v success!\n", x) fmt.Printf("%v success!\n", x)
} }
} }
queue <- x queue <- x
}(i) }(i)
} }
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
<-queue <-queue
} }
fmt.Println("end") fmt.Println("end")
} }
func main() { func main() {
runtime.GOMAXPROCS(2) runtime.GOMAXPROCS(2)
fmt.Println("create engine") fmt.Println("create engine")
engine, err := sqliteEngine() engine, err := sqliteEngine()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
engine.ShowSQL = true engine.ShowSQL = true
fmt.Println(engine) fmt.Println(engine)
test(engine) test(engine)
fmt.Println("------------------------") fmt.Println("------------------------")
engine.Close() engine.Close()
engine, err = mysqlEngine() engine, err = mysqlEngine()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
defer engine.Close() defer engine.Close()
test(engine) test(engine)
} }

View File

@ -1,45 +1,45 @@
package main package main
import ( import (
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
. "xorm" . "xorm"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
func main() { func main() {
f := "pool.db" f := "pool.db"
os.Remove(f) os.Remove(f)
Orm, err := NewEngine("sqlite3", f) Orm, err := NewEngine("sqlite3", f)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
err = Orm.SetPool(NewSimpleConnectPool()) err = Orm.SetPool(NewSimpleConnectPool())
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
Orm.ShowSQL = true Orm.ShowSQL = true
err = Orm.CreateTables(&User{}) err = Orm.CreateTables(&User{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
_, err = Orm.Get(&User{}) _, err = Orm.Get(&User{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
break break
} }
} }
} }

View File

@ -1,54 +1,54 @@
package main package main
import ( import (
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
. "xorm" . "xorm"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
type LoginInfo struct { type LoginInfo struct {
Id int64 Id int64
IP string IP string
UserId int64 UserId int64
// timestamp should be updated by database, so only allow get from db // timestamp should be updated by database, so only allow get from db
TimeStamp string `xorm:"<-"` TimeStamp string `xorm:"<-"`
// assume // assume
Nonuse int `xorm:"->"` Nonuse int `xorm:"->"`
} }
func main() { func main() {
f := "singleMapping.db" f := "singleMapping.db"
os.Remove(f) os.Remove(f)
Orm, err := NewEngine("sqlite3", f) Orm, err := NewEngine("sqlite3", f)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
Orm.ShowSQL = true Orm.ShowSQL = true
err = Orm.CreateTables(&User{}, &LoginInfo{}) err = Orm.CreateTables(&User{}, &LoginInfo{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
_, err = Orm.Insert(&User{1, "xlw"}, &LoginInfo{1, "127.0.0.1", 1, "", 23}) _, err = Orm.Insert(&User{1, "xlw"}, &LoginInfo{1, "127.0.0.1", 1, "", 23})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
info := LoginInfo{} info := LoginInfo{}
_, err = Orm.Id(1).Get(&info) _, err = Orm.Id(1).Get(&info)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println(info) fmt.Println(info)
} }

View File

@ -1,92 +1,92 @@
package main package main
import ( import (
"fmt" "fmt"
_ "github.com/bylevel/pq" _ "github.com/bylevel/pq"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"xorm" "xorm"
) )
type SyncUser2 struct { type SyncUser2 struct {
Id int64 Id int64
Name string `xorm:"unique"` Name string `xorm:"unique"`
Age int `xorm:"index"` Age int `xorm:"index"`
Title string Title string
Address string Address string
Genre string Genre string
Area string Area string
Date int Date int
} }
type SyncLoginInfo2 struct { type SyncLoginInfo2 struct {
Id int64 Id int64
IP string `xorm:"index"` IP string `xorm:"index"`
UserId int64 UserId int64
AddedCol int AddedCol int
// timestamp should be updated by database, so only allow get from db // timestamp should be updated by database, so only allow get from db
TimeStamp string TimeStamp string
// assume // assume
Nonuse int `xorm:"unique"` Nonuse int `xorm:"unique"`
Newa string `xorm:"index"` Newa string `xorm:"index"`
} }
func sync(engine *xorm.Engine) error { func sync(engine *xorm.Engine) error {
return engine.Sync(&SyncLoginInfo2{}, &SyncUser2{}) return engine.Sync(&SyncLoginInfo2{}, &SyncUser2{})
} }
func sqliteEngine() (*xorm.Engine, error) { func sqliteEngine() (*xorm.Engine, error) {
f := "sync.db" f := "sync.db"
//os.Remove(f) //os.Remove(f)
return xorm.NewEngine("sqlite3", f) return xorm.NewEngine("sqlite3", f)
} }
func mysqlEngine() (*xorm.Engine, error) { func mysqlEngine() (*xorm.Engine, error) {
return xorm.NewEngine("mysql", "root:@/test?charset=utf8") return xorm.NewEngine("mysql", "root:@/test?charset=utf8")
} }
func postgresEngine() (*xorm.Engine, error) { func postgresEngine() (*xorm.Engine, error) {
return xorm.NewEngine("postgres", "dbname=xorm_test sslmode=disable") return xorm.NewEngine("postgres", "dbname=xorm_test sslmode=disable")
} }
type engineFunc func() (*xorm.Engine, error) type engineFunc func() (*xorm.Engine, error)
func main() { func main() {
//engines := []engineFunc{sqliteEngine, mysqlEngine, postgresEngine} //engines := []engineFunc{sqliteEngine, mysqlEngine, postgresEngine}
//engines := []engineFunc{sqliteEngine} //engines := []engineFunc{sqliteEngine}
//engines := []engineFunc{mysqlEngine} //engines := []engineFunc{mysqlEngine}
engines := []engineFunc{postgresEngine} engines := []engineFunc{postgresEngine}
for _, enginefunc := range engines { for _, enginefunc := range engines {
Orm, err := enginefunc() Orm, err := enginefunc()
fmt.Println("--------", Orm.DriverName, "----------") fmt.Println("--------", Orm.DriverName, "----------")
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
Orm.ShowSQL = true Orm.ShowSQL = true
err = sync(Orm) err = sync(Orm)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
_, err = Orm.Where("id > 0").Delete(&SyncUser2{}) _, err = Orm.Where("id > 0").Delete(&SyncUser2{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
user := &SyncUser2{ user := &SyncUser2{
Name: "testsdf", Name: "testsdf",
Age: 15, Age: 15,
Title: "newsfds", Title: "newsfds",
Address: "fasfdsafdsaf", Address: "fasfdsafdsaf",
Genre: "fsafd", Genre: "fsafd",
Area: "fafdsafd", Area: "fafdsafd",
Date: 1000, Date: 1000,
} }
_, err = Orm.Insert(user) _, err = Orm.Insert(user)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
} }
} }

View File

@ -1,13 +1,13 @@
package xorm package xorm
import ( import (
"fmt" "fmt"
"strings" "strings"
) )
// Filter is an interface to filter SQL // Filter is an interface to filter SQL
type Filter interface { type Filter interface {
Do(sql string, session *Session) string Do(sql string, session *Session) string
} }
// PgSeqFilter filter SQL replace ?, ? ... to $1, $2 ... // PgSeqFilter filter SQL replace ?, ? ... to $1, $2 ...
@ -15,16 +15,16 @@ type PgSeqFilter struct {
} }
func (s *PgSeqFilter) Do(sql string, session *Session) string { func (s *PgSeqFilter) Do(sql string, session *Session) string {
segs := strings.Split(sql, "?") segs := strings.Split(sql, "?")
size := len(segs) size := len(segs)
res := "" res := ""
for i, c := range segs { for i, c := range segs {
if i < size-1 { if i < size-1 {
res += c + fmt.Sprintf("$%v", i+1) res += c + fmt.Sprintf("$%v", i+1)
} }
} }
res += segs[size-1] res += segs[size-1]
return res return res
} }
// QuoteFilter filter SQL replace ` to database's own quote character // QuoteFilter filter SQL replace ` to database's own quote character
@ -32,7 +32,7 @@ type QuoteFilter struct {
} }
func (s *QuoteFilter) Do(sql string, session *Session) string { func (s *QuoteFilter) Do(sql string, session *Session) string {
return strings.Replace(sql, "`", session.Engine.QuoteStr(), -1) return strings.Replace(sql, "`", session.Engine.QuoteStr(), -1)
} }
// IdFilter filter SQL replace (id) to primary key column name // IdFilter filter SQL replace (id) to primary key column name
@ -40,10 +40,10 @@ type IdFilter struct {
} }
func (i *IdFilter) Do(sql string, session *Session) string { func (i *IdFilter) Do(sql string, session *Session) string {
if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" {
sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1)
sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1)
return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1)
} }
return sql return sql
} }

View File

@ -1,63 +1,63 @@
package xorm package xorm
import ( import (
"reflect" "reflect"
"strings" "strings"
) )
func indexNoCase(s, sep string) int { func indexNoCase(s, sep string) int {
return strings.Index(strings.ToLower(s), strings.ToLower(sep)) return strings.Index(strings.ToLower(s), strings.ToLower(sep))
} }
func splitNoCase(s, sep string) []string { func splitNoCase(s, sep string) []string {
idx := indexNoCase(s, sep) idx := indexNoCase(s, sep)
if idx < 0 { if idx < 0 {
return []string{s} return []string{s}
} }
return strings.Split(s, s[idx:idx+len(sep)]) return strings.Split(s, s[idx:idx+len(sep)])
} }
func splitNNoCase(s, sep string, n int) []string { func splitNNoCase(s, sep string, n int) []string {
idx := indexNoCase(s, sep) idx := indexNoCase(s, sep)
if idx < 0 { if idx < 0 {
return []string{s} return []string{s}
} }
return strings.SplitN(s, s[idx:idx+len(sep)], n) return strings.SplitN(s, s[idx:idx+len(sep)], n)
} }
func makeArray(elem string, count int) []string { func makeArray(elem string, count int) []string {
res := make([]string, count) res := make([]string, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
res[i] = elem res[i] = elem
} }
return res return res
} }
func rType(bean interface{}) reflect.Type { func rType(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean)) sliceValue := reflect.Indirect(reflect.ValueOf(bean))
return reflect.TypeOf(sliceValue.Interface()) return reflect.TypeOf(sliceValue.Interface())
} }
func structName(v reflect.Type) string { func structName(v reflect.Type) string {
for v.Kind() == reflect.Ptr { for v.Kind() == reflect.Ptr {
v = v.Elem() v = v.Elem()
} }
return v.Name() return v.Name()
} }
func sliceEq(left, right []string) bool { func sliceEq(left, right []string) bool {
for _, l := range left { for _, l := range left {
var find bool var find bool
for _, r := range right { for _, r := range right {
if l == r { if l == r {
find = true find = true
break break
} }
} }
if !find { if !find {
return false return false
} }
} }
return true return true
} }

118
mapper.go
View File

@ -7,8 +7,8 @@ import (
// name translation between struct, fields names and table, column names // name translation between struct, fields names and table, column names
type IMapper interface { type IMapper interface {
Obj2Table(string) string Obj2Table(string) string
Table2Obj(string) string Table2Obj(string) string
} }
// SameMapper implements IMapper and provides same name between struct and // SameMapper implements IMapper and provides same name between struct and
@ -17,11 +17,11 @@ type SameMapper struct {
} }
func (m SameMapper) Obj2Table(o string) string { func (m SameMapper) Obj2Table(o string) string {
return o return o
} }
func (m SameMapper) Table2Obj(t string) string { func (m SameMapper) Table2Obj(t string) string {
return t return t
} }
// SnakeMapper implements IMapper and provides name transaltion between // SnakeMapper implements IMapper and provides name transaltion between
@ -30,101 +30,101 @@ type SnakeMapper struct {
} }
func snakeCasedName(name string) string { func snakeCasedName(name string) string {
newstr := make([]rune, 0) newstr := make([]rune, 0)
for idx, chr := range name { for idx, chr := range name {
if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { if isUpper := 'A' <= chr && chr <= 'Z'; isUpper {
if idx > 0 { if idx > 0 {
newstr = append(newstr, '_') newstr = append(newstr, '_')
} }
chr -= ('A' - 'a') chr -= ('A' - 'a')
} }
newstr = append(newstr, chr) newstr = append(newstr, chr)
} }
return string(newstr) return string(newstr)
} }
/*func pascal2Sql(s string) (d string) { /*func pascal2Sql(s string) (d string) {
d = "" d = ""
lastIdx := 0 lastIdx := 0
for i := 0; i < len(s); i++ { for i := 0; i < len(s); i++ {
if s[i] >= 'A' && s[i] <= 'Z' { if s[i] >= 'A' && s[i] <= 'Z' {
if lastIdx < i { if lastIdx < i {
d += s[lastIdx+1 : i] d += s[lastIdx+1 : i]
} }
if i != 0 { if i != 0 {
d += "_" d += "_"
} }
d += string(s[i] + 32) d += string(s[i] + 32)
lastIdx = i lastIdx = i
} }
} }
d += s[lastIdx+1:] d += s[lastIdx+1:]
return return
}*/ }*/
func (mapper SnakeMapper) Obj2Table(name string) string { func (mapper SnakeMapper) Obj2Table(name string) string {
return snakeCasedName(name) return snakeCasedName(name)
} }
func titleCasedName(name string) string { func titleCasedName(name string) string {
newstr := make([]rune, 0) newstr := make([]rune, 0)
upNextChar := true upNextChar := true
for _, chr := range name { for _, chr := range name {
switch { switch {
case upNextChar: case upNextChar:
upNextChar = false upNextChar = false
if 'a' <= chr && chr <= 'z' { if 'a' <= chr && chr <= 'z' {
chr -= ('a' - 'A') chr -= ('a' - 'A')
} }
case chr == '_': case chr == '_':
upNextChar = true upNextChar = true
continue continue
} }
newstr = append(newstr, chr) newstr = append(newstr, chr)
} }
return string(newstr) return string(newstr)
} }
func (mapper SnakeMapper) Table2Obj(name string) string { func (mapper SnakeMapper) Table2Obj(name string) string {
return titleCasedName(name) return titleCasedName(name)
} }
// provide prefix table name support // provide prefix table name support
type PrefixMapper struct { type PrefixMapper struct {
Mapper IMapper Mapper IMapper
Prefix string Prefix string
} }
func (mapper PrefixMapper) Obj2Table(name string) string { func (mapper PrefixMapper) Obj2Table(name string) string {
return mapper.Prefix + mapper.Mapper.Obj2Table(name) return mapper.Prefix + mapper.Mapper.Obj2Table(name)
} }
func (mapper PrefixMapper) Table2Obj(name string) string { func (mapper PrefixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):])
} }
func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper { func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper {
return PrefixMapper{mapper, prefix} return PrefixMapper{mapper, prefix}
} }
// provide suffix table name support // provide suffix table name support
type SuffixMapper struct { type SuffixMapper struct {
Mapper IMapper Mapper IMapper
Suffix string Suffix string
} }
func (mapper SuffixMapper) Obj2Table(name string) string { func (mapper SuffixMapper) Obj2Table(name string) string {
return mapper.Suffix + mapper.Mapper.Obj2Table(name) return mapper.Suffix + mapper.Mapper.Obj2Table(name)
} }
func (mapper SuffixMapper) Table2Obj(name string) string { func (mapper SuffixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[len(mapper.Suffix):]) return mapper.Mapper.Table2Obj(name[len(mapper.Suffix):])
} }
func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper { func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper {
return SuffixMapper{mapper, suffix} return SuffixMapper{mapper, suffix}
} }

View File

@ -1,66 +1,66 @@
package xorm package xorm
import ( import (
"errors" "errors"
"strings" "strings"
"time" "time"
) )
type mymysql struct { type mymysql struct {
mysql mysql
proto string proto string
raddr string raddr string
laddr string laddr string
timeout time.Duration timeout time.Duration
db string db string
user string user string
passwd string passwd string
} }
func (db *mymysql) Init(drivername, uri string) error { func (db *mymysql) Init(drivername, uri string) error {
db.mysql.base.init(drivername, uri) db.mysql.base.init(drivername, uri)
pd := strings.SplitN(uri, "*", 2) pd := strings.SplitN(uri, "*", 2)
if len(pd) == 2 { if len(pd) == 2 {
// Parse protocol part of URI // Parse protocol part of URI
p := strings.SplitN(pd[0], ":", 2) p := strings.SplitN(pd[0], ":", 2)
if len(p) != 2 { if len(p) != 2 {
return errors.New("Wrong protocol part of URI") return errors.New("Wrong protocol part of URI")
} }
db.proto = p[0] db.proto = p[0]
options := strings.Split(p[1], ",") options := strings.Split(p[1], ",")
db.raddr = options[0] db.raddr = options[0]
for _, o := range options[1:] { for _, o := range options[1:] {
kv := strings.SplitN(o, "=", 2) kv := strings.SplitN(o, "=", 2)
var k, v string var k, v string
if len(kv) == 2 { if len(kv) == 2 {
k, v = kv[0], kv[1] k, v = kv[0], kv[1]
} else { } else {
k, v = o, "true" k, v = o, "true"
} }
switch k { switch k {
case "laddr": case "laddr":
db.laddr = v db.laddr = v
case "timeout": case "timeout":
to, err := time.ParseDuration(v) to, err := time.ParseDuration(v)
if err != nil { if err != nil {
return err return err
} }
db.timeout = to db.timeout = to
default: default:
return errors.New("Unknown option: " + k) return errors.New("Unknown option: " + k)
} }
} }
// Remove protocol part // Remove protocol part
pd = pd[1:] pd = pd[1:]
} }
// Parse database part of URI // Parse database part of URI
dup := strings.SplitN(pd[0], "/", 3) dup := strings.SplitN(pd[0], "/", 3)
if len(dup) != 3 { if len(dup) != 3 {
return errors.New("Wrong database part of URI") return errors.New("Wrong database part of URI")
} }
db.dbname = dup[0] db.dbname = dup[0]
db.user = dup[1] db.user = dup[1]
db.passwd = dup[2] db.passwd = dup[2]
return nil return nil
} }

View File

@ -1,8 +1,8 @@
package xorm package xorm
import ( import (
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
"testing" "testing"
) )
/* /*
@ -13,145 +13,145 @@ utf8 COLLATE utf8_general_ci;
var showTestSql bool = true var showTestSql bool = true
func TestMyMysql(t *testing.T) { func TestMyMysql(t *testing.T) {
err := mymysqlDdlImport() err := mymysqlDdlImport()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine, err := NewEngine("mymysql", "xorm_test/root/") engine, err := NewEngine("mymysql", "xorm_test/root/")
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
sqlResults, _ := engine.Import("tests/mysql_ddl.sql") sqlResults, _ := engine.Import("tests/mysql_ddl.sql")
engine.LogDebug("sql results: %v", sqlResults) engine.LogDebug("sql results: %v", sqlResults)
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
testAll3(engine, t) testAll3(engine, t)
} }
func TestMyMysqlWithCache(t *testing.T) { func TestMyMysqlWithCache(t *testing.T) {
err := mymysqlDdlImport() err := mymysqlDdlImport()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine, err := NewEngine("mymysql", "xorm_test2/root/") engine, err := NewEngine("mymysql", "xorm_test2/root/")
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
sqlResults, _ := engine.Import("tests/mysql_ddl.sql") sqlResults, _ := engine.Import("tests/mysql_ddl.sql")
engine.LogDebug("sql results: %v", sqlResults) engine.LogDebug("sql results: %v", sqlResults)
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
} }
func newMyMysqlEngine() (*Engine, error) { func newMyMysqlEngine() (*Engine, error) {
return NewEngine("mymysql", "xorm_test2/root/") return NewEngine("mymysql", "xorm_test2/root/")
} }
func mymysqlDdlImport() error { func mymysqlDdlImport() error {
engine, err := NewEngine("mymysql", "/root/") engine, err := NewEngine("mymysql", "/root/")
if err != nil { if err != nil {
return err return err
} }
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
sqlResults, _ := engine.Import("tests/mysql_ddl.sql") sqlResults, _ := engine.Import("tests/mysql_ddl.sql")
engine.LogDebug("sql results: %v", sqlResults) engine.LogDebug("sql results: %v", sqlResults)
engine.Close() engine.Close()
return nil return nil
} }
func BenchmarkMyMysqlNoCacheInsert(t *testing.B) { func BenchmarkMyMysqlNoCacheInsert(t *testing.B) {
engine, err := newMyMysqlEngine() engine, err := newMyMysqlEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkMyMysqlNoCacheFind(t *testing.B) { func BenchmarkMyMysqlNoCacheFind(t *testing.B) {
engine, err := newMyMysqlEngine() engine, err := newMyMysqlEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkMyMysqlNoCacheFindPtr(t *testing.B) { func BenchmarkMyMysqlNoCacheFindPtr(t *testing.B) {
engine, err := newMyMysqlEngine() engine, err := newMyMysqlEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }
func BenchmarkMyMysqlCacheInsert(t *testing.B) { func BenchmarkMyMysqlCacheInsert(t *testing.B) {
engine, err := newMyMysqlEngine() engine, err := newMyMysqlEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkMyMysqlCacheFind(t *testing.B) { func BenchmarkMyMysqlCacheFind(t *testing.B) {
engine, err := newMyMysqlEngine() engine, err := newMyMysqlEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkMyMysqlCacheFindPtr(t *testing.B) { func BenchmarkMyMysqlCacheFindPtr(t *testing.B) {
engine, err := newMyMysqlEngine() engine, err := newMyMysqlEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

488
mysql.go
View File

@ -1,311 +1,311 @@
package xorm package xorm
import ( import (
"crypto/tls" "crypto/tls"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
) )
type base struct { type base struct {
drivername string drivername string
dataSourceName string dataSourceName string
} }
func (b *base) init(drivername, dataSourceName string) { func (b *base) init(drivername, dataSourceName string) {
b.drivername, b.dataSourceName = drivername, dataSourceName b.drivername, b.dataSourceName = drivername, dataSourceName
} }
type mysql struct { type mysql struct {
base base
user string user string
passwd string passwd string
net string net string
addr string addr string
dbname string dbname string
params map[string]string params map[string]string
loc *time.Location loc *time.Location
timeout time.Duration timeout time.Duration
tls *tls.Config tls *tls.Config
allowAllFiles bool allowAllFiles bool
allowOldPasswords bool allowOldPasswords bool
clientFoundRows bool clientFoundRows bool
} }
/*func readBool(input string) (value bool, valid bool) { /*func readBool(input string) (value bool, valid bool) {
switch input { switch input {
case "1", "true", "TRUE", "True": case "1", "true", "TRUE", "True":
return true, true return true, true
case "0", "false", "FALSE", "False": case "0", "false", "FALSE", "False":
return false, true return false, true
} }
// Not a valid bool value // Not a valid bool value
return return
}*/ }*/
func (cfg *mysql) parseDSN(dsn string) (err error) { func (cfg *mysql) parseDSN(dsn string) (err error) {
//cfg.params = make(map[string]string) //cfg.params = make(map[string]string)
dsnPattern := regexp.MustCompile( dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@] `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]] `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname `\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN] `(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dsn) matches := dsnPattern.FindStringSubmatch(dsn)
//tlsConfigRegister := make(map[string]*tls.Config) //tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames() names := dsnPattern.SubexpNames()
for i, match := range matches { for i, match := range matches {
switch names[i] { switch names[i] {
case "dbname": case "dbname":
cfg.dbname = match cfg.dbname = match
} }
} }
return return
} }
func (db *mysql) Init(drivername, uri string) error { func (db *mysql) Init(drivername, uri string) error {
db.base.init(drivername, uri) db.base.init(drivername, uri)
return db.parseDSN(uri) return db.parseDSN(uri)
} }
func (db *mysql) SqlType(c *Column) string { func (db *mysql) SqlType(c *Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case Bool: case Bool:
res = TinyInt res = TinyInt
case Serial: case Serial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = Int res = Int
case BigSerial: case BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = BigInt res = BigInt
case Bytea: case Bytea:
res = Blob res = Blob
case TimeStampz: case TimeStampz:
res = Char res = Char
c.Length = 64 c.Length = 64
default: default:
res = t res = t
} }
var hasLen1 bool = (c.Length > 0) var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0) var hasLen2 bool = (c.Length2 > 0)
if hasLen1 { if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.Itoa(c.Length) + ")"
} else if hasLen2 { } else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
} }
return res return res
} }
func (db *mysql) SupportInsertMany() bool { func (db *mysql) SupportInsertMany() bool {
return true return true
} }
func (db *mysql) QuoteStr() string { func (db *mysql) QuoteStr() string {
return "`" return "`"
} }
func (db *mysql) SupportEngine() bool { func (db *mysql) SupportEngine() bool {
return true return true
} }
func (db *mysql) AutoIncrStr() string { func (db *mysql) AutoIncrStr() string {
return "AUTO_INCREMENT" return "AUTO_INCREMENT"
} }
func (db *mysql) SupportCharset() bool { func (db *mysql) SupportCharset() bool {
return true return true
} }
func (db *mysql) IndexOnTable() bool { func (db *mysql) IndexOnTable() bool {
return true return true
} }
func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{db.dbname, tableName, idxName} args := []interface{}{db.dbname, tableName, idxName}
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?" sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?"
return sql, args return sql, args
} }
func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{db.dbname, tableName, colName} args := []interface{}{db.dbname, tableName, colName}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
return sql, args return sql, args
} }
func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{db.dbname, tableName} args := []interface{}{db.dbname, tableName}
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
return sql, args return sql, args
} }
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{db.dbname, tableName} args := []interface{}{db.dbname, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, record := range res { for _, record := range res {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "COLUMN_NAME": case "COLUMN_NAME":
col.Name = strings.Trim(string(content), "` ") col.Name = strings.Trim(string(content), "` ")
case "IS_NULLABLE": case "IS_NULLABLE":
if "YES" == string(content) { if "YES" == string(content) {
col.Nullable = true col.Nullable = true
} }
case "COLUMN_DEFAULT": case "COLUMN_DEFAULT":
// add '' // add ''
col.Default = string(content) col.Default = string(content)
case "COLUMN_TYPE": case "COLUMN_TYPE":
cts := strings.Split(string(content), "(") cts := strings.Split(string(content), "(")
var len1, len2 int var len1, len2 int
if len(cts) == 2 { if len(cts) == 2 {
idx := strings.Index(cts[1], ")") idx := strings.Index(cts[1], ")")
lens := strings.Split(cts[1][0:idx], ",") lens := strings.Split(cts[1][0:idx], ",")
len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if len(lens) == 2 { if len(lens) == 2 {
len2, err = strconv.Atoi(lens[1]) len2, err = strconv.Atoi(lens[1])
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} }
} }
colName := cts[0] colName := cts[0]
colType := strings.ToUpper(colName) colType := strings.ToUpper(colName)
col.Length = len1 col.Length = len1
col.Length2 = len2 col.Length2 = len2
if _, ok := sqlTypes[colType]; ok { if _, ok := sqlTypes[colType]; ok {
col.SQLType = SQLType{colType, len1, len2} col.SQLType = SQLType{colType, len1, len2}
} else { } else {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
} }
case "COLUMN_KEY": case "COLUMN_KEY":
key := string(content) key := string(content)
if key == "PRI" { if key == "PRI" {
col.IsPrimaryKey = true col.IsPrimaryKey = true
} }
if key == "UNI" { if key == "UNI" {
//col.is //col.is
} }
case "EXTRA": case "EXTRA":
extra := string(content) extra := string(content)
if extra == "auto_increment" { if extra == "auto_increment" {
col.IsAutoIncrement = true col.IsAutoIncrement = true
} }
} }
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *mysql) GetTables() ([]*Table, error) { func (db *mysql) GetTables() ([]*Table, error) {
args := []interface{}{db.dbname} args := []interface{}{db.dbname}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?" s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for _, record := range res {
table := new(Table) table := new(Table)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "TABLE_NAME": case "TABLE_NAME":
table.Name = strings.Trim(string(content), "` ") table.Name = strings.Trim(string(content), "` ")
case "ENGINE": case "ENGINE":
} }
} }
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
} }
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{db.dbname, tableName} args := []interface{}{db.dbname, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for _, record := range res {
var indexType int var indexType int
var indexName, colName string var indexName, colName string
for name, content := range record { for name, content := range record {
switch name { switch name {
case "NON_UNIQUE": case "NON_UNIQUE":
if "YES" == string(content) || string(content) == "1" { if "YES" == string(content) || string(content) == "1" {
indexType = IndexType indexType = IndexType
} else { } else {
indexType = UniqueType indexType = UniqueType
} }
case "INDEX_NAME": case "INDEX_NAME":
indexName = string(content) indexName = string(content)
case "COLUMN_NAME": case "COLUMN_NAME":
colName = strings.Trim(string(content), "` ") colName = strings.Trim(string(content), "` ")
} }
} }
if indexName == "PRIMARY" { if indexName == "PRIMARY" {
continue continue
} }
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)] indexName = indexName[5+len(tableName) : len(indexName)]
} }
var index *Index var index *Index
var ok bool var ok bool
if index, ok = indexes[indexName]; !ok { if index, ok = indexes[indexName]; !ok {
index = new(Index) index = new(Index)
index.Type = indexType index.Type = indexType
index.Name = indexName index.Name = indexName
indexes[indexName] = index indexes[indexName] = index
} }
index.AddColumn(colName) index.AddColumn(colName)
} }
return indexes, nil return indexes, nil
} }

View File

@ -1,8 +1,8 @@
package xorm package xorm
import ( import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"testing" "testing"
) )
/* /*
@ -13,136 +13,136 @@ utf8 COLLATE utf8_general_ci;
var mysqlShowTestSql bool = true var mysqlShowTestSql bool = true
func TestMysql(t *testing.T) { func TestMysql(t *testing.T) {
err := mysqlDdlImport() err := mysqlDdlImport()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8") engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8")
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.ShowSQL = mysqlShowTestSql engine.ShowSQL = mysqlShowTestSql
engine.ShowErr = mysqlShowTestSql engine.ShowErr = mysqlShowTestSql
engine.ShowWarn = mysqlShowTestSql engine.ShowWarn = mysqlShowTestSql
engine.ShowDebug = mysqlShowTestSql engine.ShowDebug = mysqlShowTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
testAll3(engine, t) testAll3(engine, t)
} }
func TestMysqlWithCache(t *testing.T) { func TestMysqlWithCache(t *testing.T) {
err := mysqlDdlImport() err := mysqlDdlImport()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8") engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8")
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
engine.ShowSQL = mysqlShowTestSql engine.ShowSQL = mysqlShowTestSql
engine.ShowErr = mysqlShowTestSql engine.ShowErr = mysqlShowTestSql
engine.ShowWarn = mysqlShowTestSql engine.ShowWarn = mysqlShowTestSql
engine.ShowDebug = mysqlShowTestSql engine.ShowDebug = mysqlShowTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
} }
func newMysqlEngine() (*Engine, error) { func newMysqlEngine() (*Engine, error) {
return NewEngine("mysql", "root:@/xorm_test?charset=utf8") return NewEngine("mysql", "root:@/xorm_test?charset=utf8")
} }
func mysqlDdlImport() error { func mysqlDdlImport() error {
engine, err := NewEngine("mysql", "root:@/?charset=utf8") engine, err := NewEngine("mysql", "root:@/?charset=utf8")
if err != nil { if err != nil {
return err return err
} }
engine.ShowSQL = mysqlShowTestSql engine.ShowSQL = mysqlShowTestSql
engine.ShowErr = mysqlShowTestSql engine.ShowErr = mysqlShowTestSql
engine.ShowWarn = mysqlShowTestSql engine.ShowWarn = mysqlShowTestSql
engine.ShowDebug = mysqlShowTestSql engine.ShowDebug = mysqlShowTestSql
sqlResults, _ := engine.Import("tests/mysql_ddl.sql") sqlResults, _ := engine.Import("tests/mysql_ddl.sql")
engine.LogDebug("sql results: %v", sqlResults) engine.LogDebug("sql results: %v", sqlResults)
engine.Close() engine.Close()
return nil return nil
} }
func BenchmarkMysqlNoCacheInsert(t *testing.B) { func BenchmarkMysqlNoCacheInsert(t *testing.B) {
engine, err := newMysqlEngine() engine, err := newMysqlEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkMysqlNoCacheFind(t *testing.B) { func BenchmarkMysqlNoCacheFind(t *testing.B) {
engine, err := newMysqlEngine() engine, err := newMysqlEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkMysqlNoCacheFindPtr(t *testing.B) { func BenchmarkMysqlNoCacheFindPtr(t *testing.B) {
engine, err := newMysqlEngine() engine, err := newMysqlEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }
func BenchmarkMysqlCacheInsert(t *testing.B) { func BenchmarkMysqlCacheInsert(t *testing.B) {
engine, err := newMysqlEngine() engine, err := newMysqlEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkMysqlCacheFind(t *testing.B) { func BenchmarkMysqlCacheFind(t *testing.B) {
engine, err := newMysqlEngine() engine, err := newMysqlEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkMysqlCacheFindPtr(t *testing.B) { func BenchmarkMysqlCacheFindPtr(t *testing.B) {
engine, err := newMysqlEngine() engine, err := newMysqlEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

290
pool.go
View File

@ -1,13 +1,13 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
//"fmt" //"fmt"
"sync" "sync"
//"sync/atomic" //"sync/atomic"
"container/list" "container/list"
"reflect" "reflect"
"time" "time"
) )
// Interface IConnecPool is a connection pool interface, all implements should implement // Interface IConnecPool is a connection pool interface, all implements should implement
@ -17,14 +17,14 @@ import (
// ReleaseDB for releasing a db connection; // ReleaseDB for releasing a db connection;
// Close for invoking when engine.Close // Close for invoking when engine.Close
type IConnectPool interface { type IConnectPool interface {
Init(engine *Engine) error Init(engine *Engine) error
RetrieveDB(engine *Engine) (*sql.DB, error) RetrieveDB(engine *Engine) (*sql.DB, error)
ReleaseDB(engine *Engine, db *sql.DB) ReleaseDB(engine *Engine, db *sql.DB)
Close(engine *Engine) error Close(engine *Engine) error
SetMaxIdleConns(conns int) SetMaxIdleConns(conns int)
MaxIdleConns() int MaxIdleConns() int
SetMaxConns(conns int) SetMaxConns(conns int)
MaxConns() int MaxConns() int
} }
// Struct NoneConnectPool is a implement for IConnectPool. It provides directly invoke driver's // Struct NoneConnectPool is a implement for IConnectPool. It provides directly invoke driver's
@ -34,35 +34,35 @@ type NoneConnectPool struct {
// NewNoneConnectPool new a NoneConnectPool. // NewNoneConnectPool new a NoneConnectPool.
func NewNoneConnectPool() IConnectPool { func NewNoneConnectPool() IConnectPool {
return &NoneConnectPool{} return &NoneConnectPool{}
} }
// Init do nothing // Init do nothing
func (p *NoneConnectPool) Init(engine *Engine) error { func (p *NoneConnectPool) Init(engine *Engine) error {
return nil return nil
} }
// RetrieveDB directly open a connection // RetrieveDB directly open a connection
func (p *NoneConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { func (p *NoneConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) {
db, err = engine.OpenDB() db, err = engine.OpenDB()
return return
} }
// ReleaseDB directly close a connection // ReleaseDB directly close a connection
func (p *NoneConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { func (p *NoneConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
db.Close() db.Close()
} }
// Close do nothing // Close do nothing
func (p *NoneConnectPool) Close(engine *Engine) error { func (p *NoneConnectPool) Close(engine *Engine) error {
return nil return nil
} }
func (p *NoneConnectPool) SetMaxIdleConns(conns int) { func (p *NoneConnectPool) SetMaxIdleConns(conns int) {
} }
func (p *NoneConnectPool) MaxIdleConns() int { func (p *NoneConnectPool) MaxIdleConns() int {
return 0 return 0
} }
// not implemented // not implemented
@ -71,133 +71,133 @@ func (p *NoneConnectPool) SetMaxConns(conns int) {
// not implemented // not implemented
func (p *NoneConnectPool) MaxConns() int { func (p *NoneConnectPool) MaxConns() int {
return -1 return -1
} }
// Struct SysConnectPool is a simple wrapper for using system default connection pool. // Struct SysConnectPool is a simple wrapper for using system default connection pool.
// About the system connection pool, you can review the code database/sql/sql.go // About the system connection pool, you can review the code database/sql/sql.go
// It's currently default Pool implments. // It's currently default Pool implments.
type SysConnectPool struct { type SysConnectPool struct {
db *sql.DB db *sql.DB
maxIdleConns int maxIdleConns int
maxConns int maxConns int
curConns int curConns int
mutex *sync.Mutex mutex *sync.Mutex
queue *list.List queue *list.List
} }
// NewSysConnectPool new a SysConnectPool. // NewSysConnectPool new a SysConnectPool.
func NewSysConnectPool() IConnectPool { func NewSysConnectPool() IConnectPool {
return &SysConnectPool{} return &SysConnectPool{}
} }
// Init create a db immediately and keep it util engine closed. // Init create a db immediately and keep it util engine closed.
func (s *SysConnectPool) Init(engine *Engine) error { func (s *SysConnectPool) Init(engine *Engine) error {
db, err := engine.OpenDB() db, err := engine.OpenDB()
if err != nil { if err != nil {
return err return err
} }
s.db = db s.db = db
s.maxIdleConns = 2 s.maxIdleConns = 2
s.maxConns = -1 s.maxConns = -1
s.curConns = 0 s.curConns = 0
s.mutex = &sync.Mutex{} s.mutex = &sync.Mutex{}
s.queue = list.New() s.queue = list.New()
return nil return nil
} }
type node struct { type node struct {
mutex sync.Mutex mutex sync.Mutex
cond *sync.Cond cond *sync.Cond
} }
func newCondNode() *node { func newCondNode() *node {
n := &node{} n := &node{}
n.cond = sync.NewCond(&n.mutex) n.cond = sync.NewCond(&n.mutex)
return n return n
} }
// RetrieveDB just return the only db // RetrieveDB just return the only db
func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) {
/*if s.maxConns > 0 { /*if s.maxConns > 0 {
fmt.Println("before retrieve") fmt.Println("before retrieve")
s.mutex.Lock() s.mutex.Lock()
for s.curConns >= s.maxConns { for s.curConns >= s.maxConns {
fmt.Println("before waiting...", s.curConns, s.queue.Len()) fmt.Println("before waiting...", s.curConns, s.queue.Len())
s.mutex.Unlock() s.mutex.Unlock()
n := NewNode() n := NewNode()
n.cond.L.Lock() n.cond.L.Lock()
s.queue.PushBack(n) s.queue.PushBack(n)
n.cond.Wait() n.cond.Wait()
n.cond.L.Unlock() n.cond.L.Unlock()
s.mutex.Lock() s.mutex.Lock()
fmt.Println("after waiting...", s.curConns, s.queue.Len()) fmt.Println("after waiting...", s.curConns, s.queue.Len())
} }
s.curConns += 1 s.curConns += 1
s.mutex.Unlock() s.mutex.Unlock()
fmt.Println("after retrieve") fmt.Println("after retrieve")
}*/ }*/
return s.db, nil return s.db, nil
} }
// ReleaseDB do nothing // ReleaseDB do nothing
func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
/*if s.maxConns > 0 { /*if s.maxConns > 0 {
s.mutex.Lock() s.mutex.Lock()
fmt.Println("before release", s.queue.Len()) fmt.Println("before release", s.queue.Len())
s.curConns -= 1 s.curConns -= 1
if e := s.queue.Front(); e != nil { if e := s.queue.Front(); e != nil {
n := e.Value.(*node) n := e.Value.(*node)
//n.cond.L.Lock() //n.cond.L.Lock()
n.cond.Signal() n.cond.Signal()
fmt.Println("signaled...") fmt.Println("signaled...")
s.queue.Remove(e) s.queue.Remove(e)
//n.cond.L.Unlock() //n.cond.L.Unlock()
} }
fmt.Println("after released", s.queue.Len()) fmt.Println("after released", s.queue.Len())
s.mutex.Unlock() s.mutex.Unlock()
}*/ }*/
} }
// Close closed the only db // Close closed the only db
func (p *SysConnectPool) Close(engine *Engine) error { func (p *SysConnectPool) Close(engine *Engine) error {
return p.db.Close() return p.db.Close()
} }
func (p *SysConnectPool) SetMaxIdleConns(conns int) { func (p *SysConnectPool) SetMaxIdleConns(conns int) {
p.db.SetMaxIdleConns(conns) p.db.SetMaxIdleConns(conns)
p.maxIdleConns = conns p.maxIdleConns = conns
} }
func (p *SysConnectPool) MaxIdleConns() int { func (p *SysConnectPool) MaxIdleConns() int {
return p.maxIdleConns return p.maxIdleConns
} }
// not implemented // not implemented
func (p *SysConnectPool) SetMaxConns(conns int) { func (p *SysConnectPool) SetMaxConns(conns int) {
p.maxConns = conns p.maxConns = conns
// if support SetMaxOpenConns, go 1.2+, then set // if support SetMaxOpenConns, go 1.2+, then set
if reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").IsValid() { if reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").IsValid() {
reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").Call([]reflect.Value{reflect.ValueOf(conns)}) reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").Call([]reflect.Value{reflect.ValueOf(conns)})
} }
//p.db.SetMaxOpenConns(conns) //p.db.SetMaxOpenConns(conns)
} }
// not implemented // not implemented
func (p *SysConnectPool) MaxConns() int { func (p *SysConnectPool) MaxConns() int {
return p.maxConns return p.maxConns
} }
// NewSimpleConnectPool new a SimpleConnectPool // NewSimpleConnectPool new a SimpleConnectPool
func NewSimpleConnectPool() IConnectPool { func NewSimpleConnectPool() IConnectPool {
return &SimpleConnectPool{releasedConnects: make([]*sql.DB, 10), return &SimpleConnectPool{releasedConnects: make([]*sql.DB, 10),
usingConnects: map[*sql.DB]time.Time{}, usingConnects: map[*sql.DB]time.Time{},
cur: -1, cur: -1,
maxWaitTimeOut: 14400, maxWaitTimeOut: 14400,
maxIdleConns: 10, maxIdleConns: 10,
mutex: &sync.Mutex{}, mutex: &sync.Mutex{},
} }
} }
// Struct SimpleConnectPool is a simple implementation for IConnectPool. // Struct SimpleConnectPool is a simple implementation for IConnectPool.
@ -205,75 +205,75 @@ func NewSimpleConnectPool() IConnectPool {
// Opening or Closing a database connection must be enter a lock. // Opening or Closing a database connection must be enter a lock.
// This implements will be improved in furture. // This implements will be improved in furture.
type SimpleConnectPool struct { type SimpleConnectPool struct {
releasedConnects []*sql.DB releasedConnects []*sql.DB
cur int cur int
usingConnects map[*sql.DB]time.Time usingConnects map[*sql.DB]time.Time
maxWaitTimeOut int maxWaitTimeOut int
mutex *sync.Mutex mutex *sync.Mutex
maxIdleConns int maxIdleConns int
} }
func (s *SimpleConnectPool) Init(engine *Engine) error { func (s *SimpleConnectPool) Init(engine *Engine) error {
return nil return nil
} }
// RetrieveDB get a connection from connection pool // RetrieveDB get a connection from connection pool
func (p *SimpleConnectPool) RetrieveDB(engine *Engine) (*sql.DB, error) { func (p *SimpleConnectPool) RetrieveDB(engine *Engine) (*sql.DB, error) {
p.mutex.Lock() p.mutex.Lock()
defer p.mutex.Unlock() defer p.mutex.Unlock()
var db *sql.DB = nil var db *sql.DB = nil
var err error = nil var err error = nil
//fmt.Printf("%x, rbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) //fmt.Printf("%x, rbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects))
if p.cur < 0 { if p.cur < 0 {
db, err = engine.OpenDB() db, err = engine.OpenDB()
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.usingConnects[db] = time.Now() p.usingConnects[db] = time.Now()
} else { } else {
db = p.releasedConnects[p.cur] db = p.releasedConnects[p.cur]
p.usingConnects[db] = time.Now() p.usingConnects[db] = time.Now()
p.releasedConnects[p.cur] = nil p.releasedConnects[p.cur] = nil
p.cur = p.cur - 1 p.cur = p.cur - 1
} }
//fmt.Printf("%x, rend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) //fmt.Printf("%x, rend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects))
return db, nil return db, nil
} }
// ReleaseDB release a db from connection pool // ReleaseDB release a db from connection pool
func (p *SimpleConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { func (p *SimpleConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
p.mutex.Lock() p.mutex.Lock()
defer p.mutex.Unlock() defer p.mutex.Unlock()
//fmt.Printf("%x, lbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) //fmt.Printf("%x, lbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects))
if p.cur >= p.maxIdleConns-1 { if p.cur >= p.maxIdleConns-1 {
db.Close() db.Close()
} else { } else {
p.cur = p.cur + 1 p.cur = p.cur + 1
p.releasedConnects[p.cur] = db p.releasedConnects[p.cur] = db
} }
delete(p.usingConnects, db) delete(p.usingConnects, db)
//fmt.Printf("%x, lend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) //fmt.Printf("%x, lend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects))
} }
// Close release all db // Close release all db
func (p *SimpleConnectPool) Close(engine *Engine) error { func (p *SimpleConnectPool) Close(engine *Engine) error {
p.mutex.Lock() p.mutex.Lock()
defer p.mutex.Unlock() defer p.mutex.Unlock()
for len(p.releasedConnects) > 0 { for len(p.releasedConnects) > 0 {
p.releasedConnects[0].Close() p.releasedConnects[0].Close()
p.releasedConnects = p.releasedConnects[1:] p.releasedConnects = p.releasedConnects[1:]
} }
return nil return nil
} }
func (p *SimpleConnectPool) SetMaxIdleConns(conns int) { func (p *SimpleConnectPool) SetMaxIdleConns(conns int) {
p.maxIdleConns = conns p.maxIdleConns = conns
} }
func (p *SimpleConnectPool) MaxIdleConns() int { func (p *SimpleConnectPool) MaxIdleConns() int {
return p.maxIdleConns return p.maxIdleConns
} }
// not implemented // not implemented
@ -282,5 +282,5 @@ func (p *SimpleConnectPool) SetMaxConns(conns int) {
// not implemented // not implemented
func (p *SimpleConnectPool) MaxConns() int { func (p *SimpleConnectPool) MaxConns() int {
return -1 return -1
} }

View File

@ -1,300 +1,300 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
) )
type postgres struct { type postgres struct {
base base
dbname string dbname string
} }
type values map[string]string type values map[string]string
func (vs values) Set(k, v string) { func (vs values) Set(k, v string) {
vs[k] = v vs[k] = v
} }
func (vs values) Get(k string) (v string) { func (vs values) Get(k string) (v string) {
return vs[k] return vs[k]
} }
func errorf(s string, args ...interface{}) { func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
} }
func parseOpts(name string, o values) { func parseOpts(name string, o values) {
if len(name) == 0 { if len(name) == 0 {
return return
} }
name = strings.TrimSpace(name) name = strings.TrimSpace(name)
ps := strings.Split(name, " ") ps := strings.Split(name, " ")
for _, p := range ps { for _, p := range ps {
kv := strings.Split(p, "=") kv := strings.Split(p, "=")
if len(kv) < 2 { if len(kv) < 2 {
errorf("invalid option: %q", p) errorf("invalid option: %q", p)
} }
o.Set(kv[0], kv[1]) o.Set(kv[0], kv[1])
} }
} }
func (db *postgres) Init(drivername, uri string) error { func (db *postgres) Init(drivername, uri string) error {
db.base.init(drivername, uri) db.base.init(drivername, uri)
o := make(values) o := make(values)
parseOpts(uri, o) parseOpts(uri, o)
db.dbname = o.Get("dbname") db.dbname = o.Get("dbname")
if db.dbname == "" { if db.dbname == "" {
return errors.New("dbname is empty") return errors.New("dbname is empty")
} }
return nil return nil
} }
func (db *postgres) SqlType(c *Column) string { func (db *postgres) SqlType(c *Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case TinyInt: case TinyInt:
res = SmallInt res = SmallInt
case MediumInt, Int, Integer: case MediumInt, Int, Integer:
return Integer return Integer
case Serial, BigSerial: case Serial, BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.Nullable = false c.Nullable = false
res = t res = t
case Binary, VarBinary: case Binary, VarBinary:
return Bytea return Bytea
case DateTime: case DateTime:
res = TimeStamp res = TimeStamp
case TimeStampz: case TimeStampz:
return "timestamp with time zone" return "timestamp with time zone"
case Float: case Float:
res = Real res = Real
case TinyText, MediumText, LongText: case TinyText, MediumText, LongText:
res = Text res = Text
case Blob, TinyBlob, MediumBlob, LongBlob: case Blob, TinyBlob, MediumBlob, LongBlob:
return Bytea return Bytea
case Double: case Double:
return "DOUBLE PRECISION" return "DOUBLE PRECISION"
default: default:
if c.IsAutoIncrement { if c.IsAutoIncrement {
return Serial return Serial
} }
res = t res = t
} }
var hasLen1 bool = (c.Length > 0) var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0) var hasLen2 bool = (c.Length2 > 0)
if hasLen1 { if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.Itoa(c.Length) + ")"
} else if hasLen2 { } else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
} }
return res return res
} }
func (db *postgres) SupportInsertMany() bool { func (db *postgres) SupportInsertMany() bool {
return true return true
} }
func (db *postgres) QuoteStr() string { func (db *postgres) QuoteStr() string {
return "\"" return "\""
} }
func (db *postgres) AutoIncrStr() string { func (db *postgres) AutoIncrStr() string {
return "" return ""
} }
func (db *postgres) SupportEngine() bool { func (db *postgres) SupportEngine() bool {
return false return false
} }
func (db *postgres) SupportCharset() bool { func (db *postgres) SupportCharset() bool {
return false return false
} }
func (db *postgres) IndexOnTable() bool { func (db *postgres) IndexOnTable() bool {
return false return false
} }
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{tableName, idxName} args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes ` + return `SELECT indexname FROM pg_indexes ` +
`WHERE tablename = ? AND indexname = ?`, args `WHERE tablename = ? AND indexname = ?`, args
} }
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
} }
func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName, colName} args := []interface{}{tableName, colName}
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
" AND column_name = ?", args " AND column_name = ?", args
} }
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, record := range res { for _, record := range res {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "column_name": case "column_name":
col.Name = strings.Trim(string(content), `" `) col.Name = strings.Trim(string(content), `" `)
case "column_default": case "column_default":
if strings.HasPrefix(string(content), "nextval") { if strings.HasPrefix(string(content), "nextval") {
col.IsPrimaryKey = true col.IsPrimaryKey = true
} else { } else {
col.Default = string(content) col.Default = string(content)
} }
case "is_nullable": case "is_nullable":
if string(content) == "YES" { if string(content) == "YES" {
col.Nullable = true col.Nullable = true
} else { } else {
col.Nullable = false col.Nullable = false
} }
case "data_type": case "data_type":
ct := string(content) ct := string(content)
switch ct { switch ct {
case "character varying", "character": case "character varying", "character":
col.SQLType = SQLType{Varchar, 0, 0} col.SQLType = SQLType{Varchar, 0, 0}
case "timestamp without time zone": case "timestamp without time zone":
col.SQLType = SQLType{DateTime, 0, 0} col.SQLType = SQLType{DateTime, 0, 0}
case "timestamp with time zone": case "timestamp with time zone":
col.SQLType = SQLType{TimeStampz, 0, 0} col.SQLType = SQLType{TimeStampz, 0, 0}
case "double precision": case "double precision":
col.SQLType = SQLType{Double, 0, 0} col.SQLType = SQLType{Double, 0, 0}
case "boolean": case "boolean":
col.SQLType = SQLType{Bool, 0, 0} col.SQLType = SQLType{Bool, 0, 0}
case "time without time zone": case "time without time zone":
col.SQLType = SQLType{Time, 0, 0} col.SQLType = SQLType{Time, 0, 0}
default: default:
col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} col.SQLType = SQLType{strings.ToUpper(ct), 0, 0}
} }
if _, ok := sqlTypes[col.SQLType.Name]; !ok { if _, ok := sqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct))
} }
case "character_maximum_length": case "character_maximum_length":
i, err := strconv.Atoi(string(content)) i, err := strconv.Atoi(string(content))
if err != nil { if err != nil {
return nil, nil, errors.New("retrieve length error") return nil, nil, errors.New("retrieve length error")
} }
col.Length = i col.Length = i
case "numeric_precision": case "numeric_precision":
case "numeric_precision_radix": case "numeric_precision_radix":
} }
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *postgres) GetTables() ([]*Table, error) { func (db *postgres) GetTables() ([]*Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT tablename FROM pg_tables where schemaname = 'public'" s := "SELECT tablename FROM pg_tables where schemaname = 'public'"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for _, record := range res {
table := new(Table) table := new(Table)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "tablename": case "tablename":
table.Name = string(content) table.Name = string(content)
} }
} }
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
} }
func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for _, record := range res {
var indexType int var indexType int
var indexName string var indexName string
var colNames []string var colNames []string
for name, content := range record { for name, content := range record {
switch name { switch name {
case "indexname": case "indexname":
indexName = strings.Trim(string(content), `" `) indexName = strings.Trim(string(content), `" `)
case "indexdef": case "indexdef":
c := string(content) c := string(content)
if strings.HasPrefix(c, "CREATE UNIQUE INDEX") { if strings.HasPrefix(c, "CREATE UNIQUE INDEX") {
indexType = UniqueType indexType = UniqueType
} else { } else {
indexType = IndexType indexType = IndexType
} }
cs := strings.Split(c, "(") cs := strings.Split(c, "(")
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
} }
} }
if strings.HasSuffix(indexName, "_pkey") { if strings.HasSuffix(indexName, "_pkey") {
continue continue
} }
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
newIdxName := indexName[5+len(tableName) : len(indexName)] newIdxName := indexName[5+len(tableName) : len(indexName)]
if newIdxName != "" { if newIdxName != "" {
indexName = newIdxName indexName = newIdxName
} }
} }
index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
for _, colName := range colNames { for _, colName := range colNames {
index.Cols = append(index.Cols, strings.Trim(colName, `" `)) index.Cols = append(index.Cols, strings.Trim(colName, `" `))
} }
indexes[index.Name] = index indexes[index.Name] = index
} }
return indexes, nil return indexes, nil
} }

View File

@ -1,259 +1,259 @@
package xorm package xorm
import ( import (
//"fmt" //"fmt"
//_ "github.com/bylevel/pq" //_ "github.com/bylevel/pq"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"testing" "testing"
) )
func newPostgresEngine() (*Engine, error) { func newPostgresEngine() (*Engine, error) {
return NewEngine("postgres", "dbname=xorm_test sslmode=disable") return NewEngine("postgres", "dbname=xorm_test sslmode=disable")
} }
func TestPostgres(t *testing.T) { func TestPostgres(t *testing.T) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
testAll3(engine, t) testAll3(engine, t)
} }
func TestPostgresWithCache(t *testing.T) { func TestPostgresWithCache(t *testing.T) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
defer engine.Close() defer engine.Close()
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
} }
/* /*
func TestPostgres2(t *testing.T) { func TestPostgres2(t *testing.T) {
engine, err := NewEngine("postgres", "dbname=xorm_test sslmode=disable") engine, err := NewEngine("postgres", "dbname=xorm_test sslmode=disable")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.Mapper = SameMapper{} engine.Mapper = SameMapper{}
fmt.Println("-------------- directCreateTable --------------") fmt.Println("-------------- directCreateTable --------------")
directCreateTable(engine, t) directCreateTable(engine, t)
fmt.Println("-------------- mapper --------------") fmt.Println("-------------- mapper --------------")
mapper(engine, t) mapper(engine, t)
fmt.Println("-------------- insert --------------") fmt.Println("-------------- insert --------------")
insert(engine, t) insert(engine, t)
fmt.Println("-------------- querySameMapper --------------") fmt.Println("-------------- querySameMapper --------------")
querySameMapper(engine, t) querySameMapper(engine, t)
fmt.Println("-------------- execSameMapper --------------") fmt.Println("-------------- execSameMapper --------------")
execSameMapper(engine, t) execSameMapper(engine, t)
fmt.Println("-------------- insertAutoIncr --------------") fmt.Println("-------------- insertAutoIncr --------------")
insertAutoIncr(engine, t) insertAutoIncr(engine, t)
fmt.Println("-------------- insertMulti --------------") fmt.Println("-------------- insertMulti --------------")
insertMulti(engine, t) insertMulti(engine, t)
fmt.Println("-------------- insertTwoTable --------------") fmt.Println("-------------- insertTwoTable --------------")
insertTwoTable(engine, t) insertTwoTable(engine, t)
fmt.Println("-------------- updateSameMapper --------------") fmt.Println("-------------- updateSameMapper --------------")
updateSameMapper(engine, t) updateSameMapper(engine, t)
fmt.Println("-------------- testdelete --------------") fmt.Println("-------------- testdelete --------------")
testdelete(engine, t) testdelete(engine, t)
fmt.Println("-------------- get --------------") fmt.Println("-------------- get --------------")
get(engine, t) get(engine, t)
fmt.Println("-------------- cascadeGet --------------") fmt.Println("-------------- cascadeGet --------------")
cascadeGet(engine, t) cascadeGet(engine, t)
fmt.Println("-------------- find --------------") fmt.Println("-------------- find --------------")
find(engine, t) find(engine, t)
fmt.Println("-------------- find2 --------------") fmt.Println("-------------- find2 --------------")
find2(engine, t) find2(engine, t)
fmt.Println("-------------- findMap --------------") fmt.Println("-------------- findMap --------------")
findMap(engine, t) findMap(engine, t)
fmt.Println("-------------- findMap2 --------------") fmt.Println("-------------- findMap2 --------------")
findMap2(engine, t) findMap2(engine, t)
fmt.Println("-------------- count --------------") fmt.Println("-------------- count --------------")
count(engine, t) count(engine, t)
fmt.Println("-------------- where --------------") fmt.Println("-------------- where --------------")
where(engine, t) where(engine, t)
fmt.Println("-------------- in --------------") fmt.Println("-------------- in --------------")
in(engine, t) in(engine, t)
fmt.Println("-------------- limit --------------") fmt.Println("-------------- limit --------------")
limit(engine, t) limit(engine, t)
fmt.Println("-------------- orderSameMapper --------------") fmt.Println("-------------- orderSameMapper --------------")
orderSameMapper(engine, t) orderSameMapper(engine, t)
fmt.Println("-------------- joinSameMapper --------------") fmt.Println("-------------- joinSameMapper --------------")
joinSameMapper(engine, t) joinSameMapper(engine, t)
fmt.Println("-------------- havingSameMapper --------------") fmt.Println("-------------- havingSameMapper --------------")
havingSameMapper(engine, t) havingSameMapper(engine, t)
fmt.Println("-------------- combineTransactionSameMapper --------------") fmt.Println("-------------- combineTransactionSameMapper --------------")
combineTransactionSameMapper(engine, t) combineTransactionSameMapper(engine, t)
fmt.Println("-------------- table --------------") fmt.Println("-------------- table --------------")
table(engine, t) table(engine, t)
fmt.Println("-------------- createMultiTables --------------") fmt.Println("-------------- createMultiTables --------------")
createMultiTables(engine, t) createMultiTables(engine, t)
fmt.Println("-------------- tableOp --------------") fmt.Println("-------------- tableOp --------------")
tableOp(engine, t) tableOp(engine, t)
fmt.Println("-------------- testColsSameMapper --------------") fmt.Println("-------------- testColsSameMapper --------------")
testColsSameMapper(engine, t) testColsSameMapper(engine, t)
fmt.Println("-------------- testCharst --------------") fmt.Println("-------------- testCharst --------------")
testCharst(engine, t) testCharst(engine, t)
fmt.Println("-------------- testStoreEngine --------------") fmt.Println("-------------- testStoreEngine --------------")
testStoreEngine(engine, t) testStoreEngine(engine, t)
fmt.Println("-------------- testExtends --------------") fmt.Println("-------------- testExtends --------------")
testExtends(engine, t) testExtends(engine, t)
fmt.Println("-------------- testColTypes --------------") fmt.Println("-------------- testColTypes --------------")
testColTypes(engine, t) testColTypes(engine, t)
fmt.Println("-------------- testCustomType --------------") fmt.Println("-------------- testCustomType --------------")
testCustomType(engine, t) testCustomType(engine, t)
fmt.Println("-------------- testCreatedAndUpdated --------------") fmt.Println("-------------- testCreatedAndUpdated --------------")
testCreatedAndUpdated(engine, t) testCreatedAndUpdated(engine, t)
fmt.Println("-------------- testIndexAndUnique --------------") fmt.Println("-------------- testIndexAndUnique --------------")
testIndexAndUnique(engine, t) testIndexAndUnique(engine, t)
fmt.Println("-------------- testMetaInfo --------------") fmt.Println("-------------- testMetaInfo --------------")
testMetaInfo(engine, t) testMetaInfo(engine, t)
fmt.Println("-------------- testIterate --------------") fmt.Println("-------------- testIterate --------------")
testIterate(engine, t) testIterate(engine, t)
fmt.Println("-------------- testStrangeName --------------") fmt.Println("-------------- testStrangeName --------------")
testStrangeName(engine, t) testStrangeName(engine, t)
fmt.Println("-------------- testVersion --------------") fmt.Println("-------------- testVersion --------------")
testVersion(engine, t) testVersion(engine, t)
fmt.Println("-------------- testDistinct --------------") fmt.Println("-------------- testDistinct --------------")
testDistinct(engine, t) testDistinct(engine, t)
fmt.Println("-------------- testUseBool --------------") fmt.Println("-------------- testUseBool --------------")
testUseBool(engine, t) testUseBool(engine, t)
fmt.Println("-------------- transaction --------------") fmt.Println("-------------- transaction --------------")
transaction(engine, t) transaction(engine, t)
}*/ }*/
/* /*
func BenchmarkPostgresDriverInsert(t *testing.B) { func BenchmarkPostgresDriverInsert(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
db, err := engine.OpenDB() db, err := engine.OpenDB()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer db.Close() defer db.Close()
doBenchDriverInsert(engine, db, t) doBenchDriverInsert(engine, db, t)
} }
func BenchmarkPostgresDriverFind(t *testing.B) { func BenchmarkPostgresDriverFind(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
db, err := engine.OpenDB() db, err := engine.OpenDB()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//defer db.Close() //defer db.Close()
doBenchDriverFind(engine, db, t) doBenchDriverFind(engine, db, t)
}*/ }*/
func BenchmarkPostgresNoCacheInsert(t *testing.B) { func BenchmarkPostgresNoCacheInsert(t *testing.B) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkPostgresNoCacheFind(t *testing.B) { func BenchmarkPostgresNoCacheFind(t *testing.B) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkPostgresNoCacheFindPtr(t *testing.B) { func BenchmarkPostgresNoCacheFindPtr(t *testing.B) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }
func BenchmarkPostgresCacheInsert(t *testing.B) { func BenchmarkPostgresCacheInsert(t *testing.B) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkPostgresCacheFind(t *testing.B) { func BenchmarkPostgresCacheFind(t *testing.B) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkPostgresCacheFindPtr(t *testing.B) { func BenchmarkPostgresCacheFindPtr(t *testing.B) {
engine, err := newPostgresEngine() engine, err := newPostgresEngine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

View File

@ -2,38 +2,38 @@ package xorm
// Executed before an object is initially persisted to the database // Executed before an object is initially persisted to the database
type BeforeInsertProcessor interface { type BeforeInsertProcessor interface {
BeforeInsert() BeforeInsert()
} }
// Executed before an object is updated // Executed before an object is updated
type BeforeUpdateProcessor interface { type BeforeUpdateProcessor interface {
BeforeUpdate() BeforeUpdate()
} }
// Executed before an object is deleted // Executed before an object is deleted
type BeforeDeleteProcessor interface { type BeforeDeleteProcessor interface {
BeforeDelete() BeforeDelete()
} }
// !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations // !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations
//// Executed before an object is validated //// Executed before an object is validated
//type BeforeValidateProcessor interface { //type BeforeValidateProcessor interface {
// BeforeValidate() // BeforeValidate()
//} //}
// -- // --
// Executed after an object is persisted to the database // Executed after an object is persisted to the database
type AfterInsertProcessor interface { type AfterInsertProcessor interface {
AfterInsert() AfterInsert()
} }
// Executed after an object has been updated // Executed after an object has been updated
type AfterUpdateProcessor interface { type AfterUpdateProcessor interface {
AfterUpdate() AfterUpdate()
} }
// Executed after an object has been deleted // Executed after an object has been deleted
type AfterDeleteProcessor interface { type AfterDeleteProcessor interface {
AfterDelete() AfterDelete()
} }

4430
session.go

File diff suppressed because it is too large Load Diff

View File

@ -1,223 +1,223 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"strings" "strings"
) )
type sqlite3 struct { type sqlite3 struct {
base base
} }
func (db *sqlite3) Init(drivername, dataSourceName string) error { func (db *sqlite3) Init(drivername, dataSourceName string) error {
db.base.init(drivername, dataSourceName) db.base.init(drivername, dataSourceName)
return nil return nil
} }
func (db *sqlite3) SqlType(c *Column) string { func (db *sqlite3) SqlType(c *Column) string {
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case Date, DateTime, TimeStamp, Time: case Date, DateTime, TimeStamp, Time:
return Numeric return Numeric
case TimeStampz: case TimeStampz:
return Text return Text
case Char, Varchar, TinyText, Text, MediumText, LongText: case Char, Varchar, TinyText, Text, MediumText, LongText:
return Text return Text
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool: case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool:
return Integer return Integer
case Float, Double, Real: case Float, Double, Real:
return Real return Real
case Decimal, Numeric: case Decimal, Numeric:
return Numeric return Numeric
case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary: case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary:
return Blob return Blob
case Serial, BigSerial: case Serial, BigSerial:
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.Nullable = false c.Nullable = false
return Integer return Integer
default: default:
return t return t
} }
} }
func (db *sqlite3) SupportInsertMany() bool { func (db *sqlite3) SupportInsertMany() bool {
return true return true
} }
func (db *sqlite3) QuoteStr() string { func (db *sqlite3) QuoteStr() string {
return "`" return "`"
} }
func (db *sqlite3) AutoIncrStr() string { func (db *sqlite3) AutoIncrStr() string {
return "AUTOINCREMENT" return "AUTOINCREMENT"
} }
func (db *sqlite3) SupportEngine() bool { func (db *sqlite3) SupportEngine() bool {
return false return false
} }
func (db *sqlite3) SupportCharset() bool { func (db *sqlite3) SupportCharset() bool {
return false return false
} }
func (db *sqlite3) IndexOnTable() bool { func (db *sqlite3) IndexOnTable() bool {
return false return false
} }
func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{idxName} args := []interface{}{idxName}
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
} }
func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) { func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
} }
func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))"
return sql, args return sql, args
} }
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
var sql string var sql string
for _, record := range res { for _, record := range res {
for name, content := range record { for name, content := range record {
if name == "sql" { if name == "sql" {
sql = string(content) sql = string(content)
} }
} }
} }
nStart := strings.Index(sql, "(") nStart := strings.Index(sql, "(")
nEnd := strings.Index(sql, ")") nEnd := strings.Index(sql, ")")
colCreates := strings.Split(sql[nStart+1:nEnd], ",") colCreates := strings.Split(sql[nStart+1:nEnd], ",")
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, colStr := range colCreates { for _, colStr := range colCreates {
fields := strings.Fields(strings.TrimSpace(colStr)) fields := strings.Fields(strings.TrimSpace(colStr))
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
col.Nullable = true col.Nullable = true
for idx, field := range fields { for idx, field := range fields {
if idx == 0 { if idx == 0 {
col.Name = strings.Trim(field, "`[] ") col.Name = strings.Trim(field, "`[] ")
continue continue
} else if idx == 1 { } else if idx == 1 {
col.SQLType = SQLType{field, 0, 0} col.SQLType = SQLType{field, 0, 0}
} }
switch field { switch field {
case "PRIMARY": case "PRIMARY":
col.IsPrimaryKey = true col.IsPrimaryKey = true
case "AUTOINCREMENT": case "AUTOINCREMENT":
col.IsAutoIncrement = true col.IsAutoIncrement = true
case "NULL": case "NULL":
if fields[idx-1] == "NOT" { if fields[idx-1] == "NOT" {
col.Nullable = false col.Nullable = false
} else { } else {
col.Nullable = true col.Nullable = true
} }
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *sqlite3) GetTables() ([]*Table, error) { func (db *sqlite3) GetTables() ([]*Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT name FROM sqlite_master WHERE type='table'" s := "SELECT name FROM sqlite_master WHERE type='table'"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for _, record := range res {
table := new(Table) table := new(Table)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "name": case "name":
table.Name = string(content) table.Name = string(content)
} }
} }
if table.Name == "sqlite_sequence" { if table.Name == "sqlite_sequence" {
continue continue
} }
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
} }
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for _, record := range res {
var sql string var sql string
index := new(Index) index := new(Index)
for name, content := range record { for name, content := range record {
if name == "sql" { if name == "sql" {
sql = string(content) sql = string(content)
} }
} }
nNStart := strings.Index(sql, "INDEX") nNStart := strings.Index(sql, "INDEX")
nNEnd := strings.Index(sql, "ON") nNEnd := strings.Index(sql, "ON")
indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []")
//fmt.Println(indexName) //fmt.Println(indexName)
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
index.Name = indexName[5+len(tableName) : len(indexName)] index.Name = indexName[5+len(tableName) : len(indexName)]
} else { } else {
index.Name = indexName index.Name = indexName
} }
if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") {
index.Type = UniqueType index.Type = UniqueType
} else { } else {
index.Type = IndexType index.Type = IndexType
} }
nStart := strings.Index(sql, "(") nStart := strings.Index(sql, "(")
nEnd := strings.Index(sql, ")") nEnd := strings.Index(sql, ")")
colIndexes := strings.Split(sql[nStart+1:nEnd], ",") colIndexes := strings.Split(sql[nStart+1:nEnd], ",")
index.Cols = make([]string, 0) index.Cols = make([]string, 0)
for _, col := range colIndexes { for _, col := range colIndexes {
index.Cols = append(index.Cols, strings.Trim(col, "` []")) index.Cols = append(index.Cols, strings.Trim(col, "` []"))
} }
indexes[index.Name] = index indexes[index.Name] = index
} }
return indexes, nil return indexes, nil
} }

View File

@ -1,199 +1,199 @@
package xorm package xorm
import ( import (
//"database/sql" //"database/sql"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
"testing" "testing"
) )
func newSqlite3Engine() (*Engine, error) { func newSqlite3Engine() (*Engine, error) {
os.Remove("./test.db") os.Remove("./test.db")
return NewEngine("sqlite3", "./test.db") return NewEngine("sqlite3", "./test.db")
} }
func TestSqlite3(t *testing.T) { func TestSqlite3(t *testing.T) {
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
testAll3(engine, t) testAll3(engine, t)
} }
func TestSqlite3WithCache(t *testing.T) { func TestSqlite3WithCache(t *testing.T) {
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
engine.ShowSQL = showTestSql engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql engine.ShowDebug = showTestSql
testAll(engine, t) testAll(engine, t)
testAll2(engine, t) testAll2(engine, t)
} }
/*func BenchmarkSqlite3DriverInsert(t *testing.B) { /*func BenchmarkSqlite3DriverInsert(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
err = engine.CreateTables(&BigStruct{}) err = engine.CreateTables(&BigStruct{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.Close() engine.Close()
db, err := sql.Open("sqlite3", "./test.db") db, err := sql.Open("sqlite3", "./test.db")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
doBenchDriverInsertS(db, t) doBenchDriverInsertS(db, t)
db.Close() db.Close()
engine, err = newSqlite3Engine() engine, err = newSqlite3Engine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
err = engine.DropTables(&BigStruct{}) err = engine.DropTables(&BigStruct{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
} }
func BenchmarkSqlite3DriverFind(t *testing.B) { func BenchmarkSqlite3DriverFind(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
err = engine.CreateTables(&BigStruct{}) err = engine.CreateTables(&BigStruct{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.Close() engine.Close()
db, err := sql.Open("sqlite3", "./test.db") db, err := sql.Open("sqlite3", "./test.db")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer db.Close() defer db.Close()
doBenchDriverFindS(db, t) doBenchDriverFindS(db, t)
db.Close() db.Close()
engine, err = newSqlite3Engine() engine, err = newSqlite3Engine()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
err = engine.DropTables(&BigStruct{}) err = engine.DropTables(&BigStruct{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close() defer engine.Close()
}*/ }*/
func BenchmarkSqlite3NoCacheInsert(t *testing.B) { func BenchmarkSqlite3NoCacheInsert(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkSqlite3NoCacheFind(t *testing.B) { func BenchmarkSqlite3NoCacheFind(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkSqlite3NoCacheFindPtr(t *testing.B) { func BenchmarkSqlite3NoCacheFindPtr(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
//engine.ShowSQL = true //engine.ShowSQL = true
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }
func BenchmarkSqlite3CacheInsert(t *testing.B) { func BenchmarkSqlite3CacheInsert(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchInsert(engine, t) doBenchInsert(engine, t)
} }
func BenchmarkSqlite3CacheFind(t *testing.B) { func BenchmarkSqlite3CacheFind(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFind(engine, t) doBenchFind(engine, t)
} }
func BenchmarkSqlite3CacheFindPtr(t *testing.B) { func BenchmarkSqlite3CacheFindPtr(t *testing.B) {
t.StopTimer() t.StopTimer()
engine, err := newSqlite3Engine() engine, err := newSqlite3Engine()
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
doBenchFindPtr(engine, t) doBenchFindPtr(engine, t)
} }

File diff suppressed because it is too large Load Diff

590
table.go
View File

@ -1,416 +1,416 @@
package xorm package xorm
import ( import (
"reflect" "reflect"
"sort" "sort"
"strings" "strings"
"time" "time"
) )
// xorm SQL types // xorm SQL types
type SQLType struct { type SQLType struct {
Name string Name string
DefaultLength int DefaultLength int
DefaultLength2 int DefaultLength2 int
} }
func (s *SQLType) IsText() bool { func (s *SQLType) IsText() bool {
return s.Name == Char || s.Name == Varchar || s.Name == TinyText || return s.Name == Char || s.Name == Varchar || s.Name == TinyText ||
s.Name == Text || s.Name == MediumText || s.Name == LongText s.Name == Text || s.Name == MediumText || s.Name == LongText
} }
func (s *SQLType) IsBlob() bool { func (s *SQLType) IsBlob() bool {
return (s.Name == TinyBlob) || (s.Name == Blob) || return (s.Name == TinyBlob) || (s.Name == Blob) ||
s.Name == MediumBlob || s.Name == LongBlob || s.Name == MediumBlob || s.Name == LongBlob ||
s.Name == Binary || s.Name == VarBinary || s.Name == Bytea s.Name == Binary || s.Name == VarBinary || s.Name == Bytea
} }
const () const ()
var ( var (
Bit = "BIT" Bit = "BIT"
TinyInt = "TINYINT" TinyInt = "TINYINT"
SmallInt = "SMALLINT" SmallInt = "SMALLINT"
MediumInt = "MEDIUMINT" MediumInt = "MEDIUMINT"
Int = "INT" Int = "INT"
Integer = "INTEGER" Integer = "INTEGER"
BigInt = "BIGINT" BigInt = "BIGINT"
Char = "CHAR" Char = "CHAR"
Varchar = "VARCHAR" Varchar = "VARCHAR"
TinyText = "TINYTEXT" TinyText = "TINYTEXT"
Text = "TEXT" Text = "TEXT"
MediumText = "MEDIUMTEXT" MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT" LongText = "LONGTEXT"
Binary = "BINARY" Binary = "BINARY"
VarBinary = "VARBINARY" VarBinary = "VARBINARY"
Date = "DATE" Date = "DATE"
DateTime = "DATETIME" DateTime = "DATETIME"
Time = "TIME" Time = "TIME"
TimeStamp = "TIMESTAMP" TimeStamp = "TIMESTAMP"
TimeStampz = "TIMESTAMPZ" TimeStampz = "TIMESTAMPZ"
Decimal = "DECIMAL" Decimal = "DECIMAL"
Numeric = "NUMERIC" Numeric = "NUMERIC"
Real = "REAL" Real = "REAL"
Float = "FLOAT" Float = "FLOAT"
Double = "DOUBLE" Double = "DOUBLE"
TinyBlob = "TINYBLOB" TinyBlob = "TINYBLOB"
Blob = "BLOB" Blob = "BLOB"
MediumBlob = "MEDIUMBLOB" MediumBlob = "MEDIUMBLOB"
LongBlob = "LONGBLOB" LongBlob = "LONGBLOB"
Bytea = "BYTEA" Bytea = "BYTEA"
Bool = "BOOL" Bool = "BOOL"
Serial = "SERIAL" Serial = "SERIAL"
BigSerial = "BIGSERIAL" BigSerial = "BIGSERIAL"
sqlTypes = map[string]bool{ sqlTypes = map[string]bool{
Bit: true, Bit: true,
TinyInt: true, TinyInt: true,
SmallInt: true, SmallInt: true,
MediumInt: true, MediumInt: true,
Int: true, Int: true,
Integer: true, Integer: true,
BigInt: true, BigInt: true,
Char: true, Char: true,
Varchar: true, Varchar: true,
TinyText: true, TinyText: true,
Text: true, Text: true,
MediumText: true, MediumText: true,
LongText: true, LongText: true,
Binary: true, Binary: true,
VarBinary: true, VarBinary: true,
Date: true, Date: true,
DateTime: true, DateTime: true,
Time: true, Time: true,
TimeStamp: true, TimeStamp: true,
TimeStampz: true, TimeStampz: true,
Decimal: true, Decimal: true,
Numeric: true, Numeric: true,
Real: true, Real: true,
Float: true, Float: true,
Double: true, Double: true,
TinyBlob: true, TinyBlob: true,
Blob: true, Blob: true,
MediumBlob: true, MediumBlob: true,
LongBlob: true, LongBlob: true,
Bytea: true, Bytea: true,
Bool: true, Bool: true,
Serial: true, Serial: true,
BigSerial: true, BigSerial: true,
} }
intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"} intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"}
uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"}
) )
var b byte var b byte
var tm time.Time var tm time.Time
func Type2SQLType(t reflect.Type) (st SQLType) { func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k { switch k := t.Kind(); k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
st = SQLType{Int, 0, 0} st = SQLType{Int, 0, 0}
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
st = SQLType{BigInt, 0, 0} st = SQLType{BigInt, 0, 0}
case reflect.Float32: case reflect.Float32:
st = SQLType{Float, 0, 0} st = SQLType{Float, 0, 0}
case reflect.Float64: case reflect.Float64:
st = SQLType{Double, 0, 0} st = SQLType{Double, 0, 0}
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
st = SQLType{Varchar, 64, 0} st = SQLType{Varchar, 64, 0}
case reflect.Array, reflect.Slice, reflect.Map: case reflect.Array, reflect.Slice, reflect.Map:
if t.Elem() == reflect.TypeOf(b) { if t.Elem() == reflect.TypeOf(b) {
st = SQLType{Blob, 0, 0} st = SQLType{Blob, 0, 0}
} else { } else {
st = SQLType{Text, 0, 0} st = SQLType{Text, 0, 0}
} }
case reflect.Bool: case reflect.Bool:
st = SQLType{Bool, 0, 0} st = SQLType{Bool, 0, 0}
case reflect.String: case reflect.String:
st = SQLType{Varchar, 255, 0} st = SQLType{Varchar, 255, 0}
case reflect.Struct: case reflect.Struct:
if t == reflect.TypeOf(tm) { if t == reflect.TypeOf(tm) {
st = SQLType{DateTime, 0, 0} st = SQLType{DateTime, 0, 0}
} else { } else {
st = SQLType{Text, 0, 0} st = SQLType{Text, 0, 0}
} }
case reflect.Ptr: case reflect.Ptr:
st, _ = ptrType2SQLType(t) st, _ = ptrType2SQLType(t)
default: default:
st = SQLType{Text, 0, 0} st = SQLType{Text, 0, 0}
} }
return return
} }
func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) { func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) {
typeStr := t.String() typeStr := t.String()
has = true has = true
switch typeStr { switch typeStr {
case "*string": case "*string":
st = SQLType{Varchar, 255, 0} st = SQLType{Varchar, 255, 0}
case "*bool": case "*bool":
st = SQLType{Bool, 0, 0} st = SQLType{Bool, 0, 0}
case "*complex64", "*complex128": case "*complex64", "*complex128":
st = SQLType{Varchar, 64, 0} st = SQLType{Varchar, 64, 0}
case "*float32": case "*float32":
st = SQLType{Float, 0, 0} st = SQLType{Float, 0, 0}
case "*float64": case "*float64":
st = SQLType{Double, 0, 0} st = SQLType{Double, 0, 0}
case "*int64", "*uint64": case "*int64", "*uint64":
st = SQLType{BigInt, 0, 0} st = SQLType{BigInt, 0, 0}
case "*time.Time": case "*time.Time":
st = SQLType{DateTime, 0, 0} st = SQLType{DateTime, 0, 0}
case "*int", "*int16", "*int32", "*int8", "*uint", "*uint16", "*uint32", "*uint8": case "*int", "*int16", "*int32", "*int8", "*uint", "*uint16", "*uint32", "*uint8":
st = SQLType{Int, 0, 0} st = SQLType{Int, 0, 0}
default: default:
has = false has = false
} }
return return
} }
// default sql type change to go types // default sql type change to go types
func SQLType2Type(st SQLType) reflect.Type { func SQLType2Type(st SQLType) reflect.Type {
name := strings.ToUpper(st.Name) name := strings.ToUpper(st.Name)
switch name { switch name {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial:
return reflect.TypeOf(1) return reflect.TypeOf(1)
case BigInt, BigSerial: case BigInt, BigSerial:
return reflect.TypeOf(int64(1)) return reflect.TypeOf(int64(1))
case Float, Real: case Float, Real:
return reflect.TypeOf(float32(1)) return reflect.TypeOf(float32(1))
case Double: case Double:
return reflect.TypeOf(float64(1)) return reflect.TypeOf(float64(1))
case Char, Varchar, TinyText, Text, MediumText, LongText: case Char, Varchar, TinyText, Text, MediumText, LongText:
return reflect.TypeOf("") return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary: case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary:
return reflect.TypeOf([]byte{}) return reflect.TypeOf([]byte{})
case Bool: case Bool:
return reflect.TypeOf(true) return reflect.TypeOf(true)
case DateTime, Date, Time, TimeStamp, TimeStampz: case DateTime, Date, Time, TimeStamp, TimeStampz:
return reflect.TypeOf(tm) return reflect.TypeOf(tm)
case Decimal, Numeric: case Decimal, Numeric:
return reflect.TypeOf("") return reflect.TypeOf("")
default: default:
return reflect.TypeOf("") return reflect.TypeOf("")
} }
} }
const ( const (
IndexType = iota + 1 IndexType = iota + 1
UniqueType UniqueType
) )
// database index // database index
type Index struct { type Index struct {
Name string Name string
Type int Type int
Cols []string Cols []string
} }
// add columns which will be composite index // add columns which will be composite index
func (index *Index) AddColumn(cols ...string) { func (index *Index) AddColumn(cols ...string) {
for _, col := range cols { for _, col := range cols {
index.Cols = append(index.Cols, col) index.Cols = append(index.Cols, col)
} }
} }
// new an index // new an index
func NewIndex(name string, indexType int) *Index { func NewIndex(name string, indexType int) *Index {
return &Index{name, indexType, make([]string, 0)} return &Index{name, indexType, make([]string, 0)}
} }
const ( const (
TWOSIDES = iota + 1 TWOSIDES = iota + 1
ONLYTODB ONLYTODB
ONLYFROMDB ONLYFROMDB
) )
// database column // database column
type Column struct { type Column struct {
Name string Name string
FieldName string FieldName string
SQLType SQLType SQLType SQLType
Length int Length int
Length2 int Length2 int
Nullable bool Nullable bool
Default string Default string
Indexes map[string]bool Indexes map[string]bool
IsPrimaryKey bool IsPrimaryKey bool
IsAutoIncrement bool IsAutoIncrement bool
MapType int MapType int
IsCreated bool IsCreated bool
IsUpdated bool IsUpdated bool
IsCascade bool IsCascade bool
IsVersion bool IsVersion bool
} }
// generate column description string according dialect // generate column description string according dialect
func (col *Column) String(d dialect) string { func (col *Column) String(d dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
sql += d.SqlType(col) + " " sql += d.SqlType(col) + " "
if col.IsPrimaryKey { if col.IsPrimaryKey {
sql += "PRIMARY KEY " sql += "PRIMARY KEY "
} }
if col.IsAutoIncrement { if col.IsAutoIncrement {
sql += d.AutoIncrStr() + " " sql += d.AutoIncrStr() + " "
} }
if col.Nullable { if col.Nullable {
sql += "NULL " sql += "NULL "
} else { } else {
sql += "NOT NULL " sql += "NOT NULL "
} }
if col.Default != "" { if col.Default != "" {
sql += "DEFAULT " + col.Default + " " sql += "DEFAULT " + col.Default + " "
} }
return sql return sql
} }
// return col's filed of struct's value // return col's filed of struct's value
func (col *Column) ValueOf(bean interface{}) reflect.Value { func (col *Column) ValueOf(bean interface{}) reflect.Value {
var fieldValue reflect.Value var fieldValue reflect.Value
if strings.Contains(col.FieldName, ".") { if strings.Contains(col.FieldName, ".") {
fields := strings.Split(col.FieldName, ".") fields := strings.Split(col.FieldName, ".")
if len(fields) > 2 { if len(fields) > 2 {
return reflect.ValueOf(nil) return reflect.ValueOf(nil)
} }
fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(fields[0]) fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(fields[0])
fieldValue = fieldValue.FieldByName(fields[1]) fieldValue = fieldValue.FieldByName(fields[1])
} else { } else {
fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
} }
return fieldValue return fieldValue
} }
// database table // database table
type Table struct { type Table struct {
Name string Name string
Type reflect.Type Type reflect.Type
ColumnsSeq []string ColumnsSeq []string
Columns map[string]*Column Columns map[string]*Column
Indexes map[string]*Index Indexes map[string]*Index
PrimaryKey string PrimaryKey string
Created map[string]bool Created map[string]bool
Updated string Updated string
Version string Version string
Cacher Cacher Cacher Cacher
} }
/* /*
func NewTable(name string, t reflect.Type) *Table { func NewTable(name string, t reflect.Type) *Table {
return &Table{Name: name, Type: t, return &Table{Name: name, Type: t,
ColumnsSeq: make([]string, 0), ColumnsSeq: make([]string, 0),
Columns: make(map[string]*Column), Columns: make(map[string]*Column),
Indexes: make(map[string]*Index), Indexes: make(map[string]*Index),
Created: make(map[string]bool), Created: make(map[string]bool),
} }
}*/ }*/
// if has primary key, return column // if has primary key, return column
func (table *Table) PKColumn() *Column { func (table *Table) PKColumn() *Column {
return table.Columns[table.PrimaryKey] return table.Columns[table.PrimaryKey]
} }
func (table *Table) VersionColumn() *Column { func (table *Table) VersionColumn() *Column {
return table.Columns[table.Version] return table.Columns[table.Version]
} }
// add a column to table // add a column to table
func (table *Table) AddColumn(col *Column) { func (table *Table) AddColumn(col *Column) {
table.ColumnsSeq = append(table.ColumnsSeq, col.Name) table.ColumnsSeq = append(table.ColumnsSeq, col.Name)
table.Columns[col.Name] = col table.Columns[col.Name] = col
if col.IsPrimaryKey { if col.IsPrimaryKey {
table.PrimaryKey = col.Name table.PrimaryKey = col.Name
} }
if col.IsCreated { if col.IsCreated {
table.Created[col.Name] = true table.Created[col.Name] = true
} }
if col.IsUpdated { if col.IsUpdated {
table.Updated = col.Name table.Updated = col.Name
} }
if col.IsVersion { if col.IsVersion {
table.Version = col.Name table.Version = col.Name
} }
} }
// add an index or an unique to table // add an index or an unique to table
func (table *Table) AddIndex(index *Index) { func (table *Table) AddIndex(index *Index) {
table.Indexes[index.Name] = index table.Indexes[index.Name] = index
} }
func (table *Table) genCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) { func (table *Table) genCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {
colNames := make([]string, 0) colNames := make([]string, 0)
args := make([]interface{}, 0) args := make([]interface{}, 0)
for _, col := range table.Columns { for _, col := range table.Columns {
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if _, ok := session.Statement.columnMap[col.Name]; !ok { if _, ok := session.Statement.columnMap[col.Name]; !ok {
continue continue
} }
} }
if col.MapType == ONLYFROMDB { if col.MapType == ONLYFROMDB {
continue continue
} }
fieldValue := col.ValueOf(bean) fieldValue := col.ValueOf(bean)
if col.IsAutoIncrement && fieldValue.Int() == 0 { if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue continue
} }
if session.Statement.ColumnStr != "" { if session.Statement.ColumnStr != "" {
if _, ok := session.Statement.columnMap[col.Name]; !ok { if _, ok := session.Statement.columnMap[col.Name]; !ok {
continue continue
} }
} }
if session.Statement.OmitStr != "" { if session.Statement.OmitStr != "" {
if _, ok := session.Statement.columnMap[col.Name]; ok { if _, ok := session.Statement.columnMap[col.Name]; ok {
continue continue
} }
} }
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
args = append(args, time.Now()) args = append(args, time.Now())
} else if col.IsVersion && session.Statement.checkVersion { } else if col.IsVersion && session.Statement.checkVersion {
args = append(args, 1) args = append(args, 1)
} else { } else {
arg, err := session.value2Interface(col, fieldValue) arg, err := session.value2Interface(col, fieldValue)
if err != nil { if err != nil {
return colNames, args, err return colNames, args, err
} }
args = append(args, arg) args = append(args, arg)
} }
if includeQuote { if includeQuote {
colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?") colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?")
} else { } else {
colNames = append(colNames, col.Name) colNames = append(colNames, col.Name)
} }
} }
return colNames, args, nil return colNames, args, nil
} }
// Conversion is an interface. A type implements Conversion will according // Conversion is an interface. A type implements Conversion will according
// the custom method to fill into database and retrieve from database. // the custom method to fill into database and retrieve from database.
type Conversion interface { type Conversion interface {
FromDB([]byte) error FromDB([]byte) error
ToDB() ([]byte, error) ToDB() ([]byte, error)
} }

78
xorm.go
View File

@ -1,58 +1,58 @@
package xorm package xorm
import ( import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
"sync" "sync"
) )
const ( const (
version string = "0.2.2" version string = "0.2.2"
) )
func close(engine *Engine) { func close(engine *Engine) {
engine.Close() engine.Close()
} }
// new a db manager according to the parameter. Currently support four // new a db manager according to the parameter. Currently support four
// drivers // drivers
func NewEngine(driverName string, dataSourceName string) (*Engine, error) { func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine := &Engine{DriverName: driverName, engine := &Engine{DriverName: driverName,
DataSourceName: dataSourceName, Filters: make([]Filter, 0)} DataSourceName: dataSourceName, Filters: make([]Filter, 0)}
engine.SetMapper(SnakeMapper{}) engine.SetMapper(SnakeMapper{})
if driverName == SQLITE { if driverName == SQLITE {
engine.dialect = &sqlite3{} engine.dialect = &sqlite3{}
} else if driverName == MYSQL { } else if driverName == MYSQL {
engine.dialect = &mysql{} engine.dialect = &mysql{}
} else if driverName == POSTGRES { } else if driverName == POSTGRES {
engine.dialect = &postgres{} engine.dialect = &postgres{}
engine.Filters = append(engine.Filters, &PgSeqFilter{}) engine.Filters = append(engine.Filters, &PgSeqFilter{})
engine.Filters = append(engine.Filters, &QuoteFilter{}) engine.Filters = append(engine.Filters, &QuoteFilter{})
} else if driverName == MYMYSQL { } else if driverName == MYMYSQL {
engine.dialect = &mymysql{} engine.dialect = &mymysql{}
} else { } else {
return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName)) return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName))
} }
err := engine.dialect.Init(driverName, dataSourceName) err := engine.dialect.Init(driverName, dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
engine.Tables = make(map[reflect.Type]*Table) engine.Tables = make(map[reflect.Type]*Table)
engine.mutex = &sync.Mutex{} engine.mutex = &sync.Mutex{}
engine.TagIdentifier = "xorm" engine.TagIdentifier = "xorm"
engine.Filters = append(engine.Filters, &IdFilter{}) engine.Filters = append(engine.Filters, &IdFilter{})
engine.Logger = os.Stdout engine.Logger = os.Stdout
//engine.Pool = NewSimpleConnectPool() //engine.Pool = NewSimpleConnectPool()
//engine.Pool = NewNoneConnectPool() //engine.Pool = NewNoneConnectPool()
//engine.Cacher = NewLRUCacher() //engine.Cacher = NewLRUCacher()
err = engine.SetPool(NewSysConnectPool()) err = engine.SetPool(NewSysConnectPool())
runtime.SetFinalizer(engine, close) runtime.SetFinalizer(engine, close)
return engine, err return engine, err
} }

View File

@ -1,65 +1,65 @@
package main package main
import ( import (
//"fmt" //"fmt"
"github.com/lunny/xorm" "github.com/lunny/xorm"
"strings" "strings"
"text/template" "text/template"
) )
var ( var (
CPlusTmpl LangTmpl = LangTmpl{ CPlusTmpl LangTmpl = LangTmpl{
template.FuncMap{"Mapper": mapper.Table2Obj, template.FuncMap{"Mapper": mapper.Table2Obj,
"Type": cPlusTypeStr, "Type": cPlusTypeStr,
"UnTitle": unTitle, "UnTitle": unTitle,
}, },
nil, nil,
genCPlusImports, genCPlusImports,
} }
) )
func cPlusTypeStr(col *xorm.Column) string { func cPlusTypeStr(col *xorm.Column) string {
tp := col.SQLType tp := col.SQLType
name := strings.ToUpper(tp.Name) name := strings.ToUpper(tp.Name)
switch name { switch name {
case xorm.Bit, xorm.TinyInt, xorm.SmallInt, xorm.MediumInt, xorm.Int, xorm.Integer, xorm.Serial: case xorm.Bit, xorm.TinyInt, xorm.SmallInt, xorm.MediumInt, xorm.Int, xorm.Integer, xorm.Serial:
return "int" return "int"
case xorm.BigInt, xorm.BigSerial: case xorm.BigInt, xorm.BigSerial:
return "__int64" return "__int64"
case xorm.Char, xorm.Varchar, xorm.TinyText, xorm.Text, xorm.MediumText, xorm.LongText: case xorm.Char, xorm.Varchar, xorm.TinyText, xorm.Text, xorm.MediumText, xorm.LongText:
return "tstring" return "tstring"
case xorm.Date, xorm.DateTime, xorm.Time, xorm.TimeStamp: case xorm.Date, xorm.DateTime, xorm.Time, xorm.TimeStamp:
return "time_t" return "time_t"
case xorm.Decimal, xorm.Numeric: case xorm.Decimal, xorm.Numeric:
return "tstring" return "tstring"
case xorm.Real, xorm.Float: case xorm.Real, xorm.Float:
return "float" return "float"
case xorm.Double: case xorm.Double:
return "double" return "double"
case xorm.TinyBlob, xorm.Blob, xorm.MediumBlob, xorm.LongBlob, xorm.Bytea: case xorm.TinyBlob, xorm.Blob, xorm.MediumBlob, xorm.LongBlob, xorm.Bytea:
return "tstring" return "tstring"
case xorm.Bool: case xorm.Bool:
return "bool" return "bool"
default: default:
return "tstring" return "tstring"
} }
return "" return ""
} }
func genCPlusImports(tables []*xorm.Table) map[string]string { func genCPlusImports(tables []*xorm.Table) map[string]string {
imports := make(map[string]string) imports := make(map[string]string)
for _, table := range tables { for _, table := range tables {
for _, col := range table.Columns { for _, col := range table.Columns {
switch cPlusTypeStr(col) { switch cPlusTypeStr(col) {
case "time_t": case "time_t":
imports[`<time.h>`] = `<time.h>` imports[`<time.h>`] = `<time.h>`
case "tstring": case "tstring":
imports["<string>"] = "<string>" imports["<string>"] = "<string>"
//case "__int64": //case "__int64":
// imports[""] = "" // imports[""] = ""
} }
} }
} }
return imports return imports
} }

View File

@ -1,78 +1,78 @@
package main package main
import ( import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
) )
// A Command is an implementation of a go command // A Command is an implementation of a go command
// like go build or go fix. // like go build or go fix.
type Command struct { type Command struct {
// Run runs the command. // Run runs the command.
// The args are the arguments after the command name. // The args are the arguments after the command name.
Run func(cmd *Command, args []string) Run func(cmd *Command, args []string)
// UsageLine is the one-line usage message. // UsageLine is the one-line usage message.
// The first word in the line is taken to be the command name. // The first word in the line is taken to be the command name.
UsageLine string UsageLine string
// Short is the short description shown in the 'go help' output. // Short is the short description shown in the 'go help' output.
Short string Short string
// Long is the long message shown in the 'go help <this-command>' output. // Long is the long message shown in the 'go help <this-command>' output.
Long string Long string
// Flag is a set of flags specific to this command. // Flag is a set of flags specific to this command.
Flags map[string]bool Flags map[string]bool
} }
// Name returns the command's name: the first word in the usage line. // Name returns the command's name: the first word in the usage line.
func (c *Command) Name() string { func (c *Command) Name() string {
name := c.UsageLine name := c.UsageLine
i := strings.Index(name, " ") i := strings.Index(name, " ")
if i >= 0 { if i >= 0 {
name = name[:i] name = name[:i]
} }
return name return name
} }
func (c *Command) Usage() { func (c *Command) Usage() {
fmt.Fprintf(os.Stderr, "usage: %s\n\n", c.UsageLine) fmt.Fprintf(os.Stderr, "usage: %s\n\n", c.UsageLine)
fmt.Fprintf(os.Stderr, "%s\n", strings.TrimSpace(c.Long)) fmt.Fprintf(os.Stderr, "%s\n", strings.TrimSpace(c.Long))
os.Exit(2) os.Exit(2)
} }
// Runnable reports whether the command can be run; otherwise // Runnable reports whether the command can be run; otherwise
// it is a documentation pseudo-command such as importpath. // it is a documentation pseudo-command such as importpath.
func (c *Command) Runnable() bool { func (c *Command) Runnable() bool {
return c.Run != nil return c.Run != nil
} }
// checkFlags checks if the flag exists with correct format. // checkFlags checks if the flag exists with correct format.
func checkFlags(flags map[string]bool, args []string, print func(string)) int { func checkFlags(flags map[string]bool, args []string, print func(string)) int {
num := 0 // Number of valid flags, use to cut out. num := 0 // Number of valid flags, use to cut out.
for i, f := range args { for i, f := range args {
// Check flag prefix '-'. // Check flag prefix '-'.
if !strings.HasPrefix(f, "-") { if !strings.HasPrefix(f, "-") {
// Not a flag, finish check process. // Not a flag, finish check process.
break break
} }
// Check if it a valid flag. // Check if it a valid flag.
if v, ok := flags[f]; ok { if v, ok := flags[f]; ok {
flags[f] = !v flags[f] = !v
if !v { if !v {
print(f) print(f)
} else { } else {
fmt.Println("DISABLE: " + f) fmt.Println("DISABLE: " + f)
} }
} else { } else {
fmt.Printf("[ERRO] Unknown flag: %s.\n", f) fmt.Printf("[ERRO] Unknown flag: %s.\n", f)
return -1 return -1
} }
num = i + 1 num = i + 1
} }
return num return num
} }

View File

@ -1,261 +1,261 @@
package main package main
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/lunny/xorm" "github.com/lunny/xorm"
"go/format" "go/format"
"reflect" "reflect"
"strings" "strings"
"text/template" "text/template"
) )
var ( var (
GoLangTmpl LangTmpl = LangTmpl{ GoLangTmpl LangTmpl = LangTmpl{
template.FuncMap{"Mapper": mapper.Table2Obj, template.FuncMap{"Mapper": mapper.Table2Obj,
"Type": typestring, "Type": typestring,
"Tag": tag, "Tag": tag,
"UnTitle": unTitle, "UnTitle": unTitle,
"gt": gt, "gt": gt,
"getCol": getCol, "getCol": getCol,
}, },
formatGo, formatGo,
genGoImports, genGoImports,
} }
) )
var ( var (
errBadComparisonType = errors.New("invalid type for comparison") errBadComparisonType = errors.New("invalid type for comparison")
errBadComparison = errors.New("incompatible types for comparison") errBadComparison = errors.New("incompatible types for comparison")
errNoComparison = errors.New("missing argument for comparison") errNoComparison = errors.New("missing argument for comparison")
) )
type kind int type kind int
const ( const (
invalidKind kind = iota invalidKind kind = iota
boolKind boolKind
complexKind complexKind
intKind intKind
floatKind floatKind
integerKind integerKind
stringKind stringKind
uintKind uintKind
) )
func basicKind(v reflect.Value) (kind, error) { func basicKind(v reflect.Value) (kind, error) {
switch v.Kind() { switch v.Kind() {
case reflect.Bool: case reflect.Bool:
return boolKind, nil return boolKind, nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return intKind, nil return intKind, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return uintKind, nil return uintKind, nil
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
return floatKind, nil return floatKind, nil
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
return complexKind, nil return complexKind, nil
case reflect.String: case reflect.String:
return stringKind, nil return stringKind, nil
} }
return invalidKind, errBadComparisonType return invalidKind, errBadComparisonType
} }
// eq evaluates the comparison a == b || a == c || ... // eq evaluates the comparison a == b || a == c || ...
func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) { func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
v1 := reflect.ValueOf(arg1) v1 := reflect.ValueOf(arg1)
k1, err := basicKind(v1) k1, err := basicKind(v1)
if err != nil { if err != nil {
return false, err return false, err
} }
if len(arg2) == 0 { if len(arg2) == 0 {
return false, errNoComparison return false, errNoComparison
} }
for _, arg := range arg2 { for _, arg := range arg2 {
v2 := reflect.ValueOf(arg) v2 := reflect.ValueOf(arg)
k2, err := basicKind(v2) k2, err := basicKind(v2)
if err != nil { if err != nil {
return false, err return false, err
} }
if k1 != k2 { if k1 != k2 {
return false, errBadComparison return false, errBadComparison
} }
truth := false truth := false
switch k1 { switch k1 {
case boolKind: case boolKind:
truth = v1.Bool() == v2.Bool() truth = v1.Bool() == v2.Bool()
case complexKind: case complexKind:
truth = v1.Complex() == v2.Complex() truth = v1.Complex() == v2.Complex()
case floatKind: case floatKind:
truth = v1.Float() == v2.Float() truth = v1.Float() == v2.Float()
case intKind: case intKind:
truth = v1.Int() == v2.Int() truth = v1.Int() == v2.Int()
case stringKind: case stringKind:
truth = v1.String() == v2.String() truth = v1.String() == v2.String()
case uintKind: case uintKind:
truth = v1.Uint() == v2.Uint() truth = v1.Uint() == v2.Uint()
default: default:
panic("invalid kind") panic("invalid kind")
} }
if truth { if truth {
return true, nil return true, nil
} }
} }
return false, nil return false, nil
} }
// lt evaluates the comparison a < b. // lt evaluates the comparison a < b.
func lt(arg1, arg2 interface{}) (bool, error) { func lt(arg1, arg2 interface{}) (bool, error) {
v1 := reflect.ValueOf(arg1) v1 := reflect.ValueOf(arg1)
k1, err := basicKind(v1) k1, err := basicKind(v1)
if err != nil { if err != nil {
return false, err return false, err
} }
v2 := reflect.ValueOf(arg2) v2 := reflect.ValueOf(arg2)
k2, err := basicKind(v2) k2, err := basicKind(v2)
if err != nil { if err != nil {
return false, err return false, err
} }
if k1 != k2 { if k1 != k2 {
return false, errBadComparison return false, errBadComparison
} }
truth := false truth := false
switch k1 { switch k1 {
case boolKind, complexKind: case boolKind, complexKind:
return false, errBadComparisonType return false, errBadComparisonType
case floatKind: case floatKind:
truth = v1.Float() < v2.Float() truth = v1.Float() < v2.Float()
case intKind: case intKind:
truth = v1.Int() < v2.Int() truth = v1.Int() < v2.Int()
case stringKind: case stringKind:
truth = v1.String() < v2.String() truth = v1.String() < v2.String()
case uintKind: case uintKind:
truth = v1.Uint() < v2.Uint() truth = v1.Uint() < v2.Uint()
default: default:
panic("invalid kind") panic("invalid kind")
} }
return truth, nil return truth, nil
} }
// le evaluates the comparison <= b. // le evaluates the comparison <= b.
func le(arg1, arg2 interface{}) (bool, error) { func le(arg1, arg2 interface{}) (bool, error) {
// <= is < or ==. // <= is < or ==.
lessThan, err := lt(arg1, arg2) lessThan, err := lt(arg1, arg2)
if lessThan || err != nil { if lessThan || err != nil {
return lessThan, err return lessThan, err
} }
return eq(arg1, arg2) return eq(arg1, arg2)
} }
// gt evaluates the comparison a > b. // gt evaluates the comparison a > b.
func gt(arg1, arg2 interface{}) (bool, error) { func gt(arg1, arg2 interface{}) (bool, error) {
// > is the inverse of <=. // > is the inverse of <=.
lessOrEqual, err := le(arg1, arg2) lessOrEqual, err := le(arg1, arg2)
if err != nil { if err != nil {
return false, err return false, err
} }
return !lessOrEqual, nil return !lessOrEqual, nil
} }
func getCol(cols map[string]*xorm.Column, name string) *xorm.Column { func getCol(cols map[string]*xorm.Column, name string) *xorm.Column {
return cols[name] return cols[name]
} }
func formatGo(src string) (string, error) { func formatGo(src string) (string, error) {
source, err := format.Source([]byte(src)) source, err := format.Source([]byte(src))
if err != nil { if err != nil {
return "", err return "", err
} }
return string(source), nil return string(source), nil
} }
func genGoImports(tables []*xorm.Table) map[string]string { func genGoImports(tables []*xorm.Table) map[string]string {
imports := make(map[string]string) imports := make(map[string]string)
for _, table := range tables { for _, table := range tables {
for _, col := range table.Columns { for _, col := range table.Columns {
if typestring(col) == "time.Time" { if typestring(col) == "time.Time" {
imports["time"] = "time" imports["time"] = "time"
} }
} }
} }
return imports return imports
} }
func typestring(col *xorm.Column) string { func typestring(col *xorm.Column) string {
st := col.SQLType st := col.SQLType
if col.IsPrimaryKey { if col.IsPrimaryKey {
return "int64" return "int64"
} }
t := xorm.SQLType2Type(st) t := xorm.SQLType2Type(st)
s := t.String() s := t.String()
if s == "[]uint8" { if s == "[]uint8" {
return "[]byte" return "[]byte"
} }
return s return s
} }
func tag(table *xorm.Table, col *xorm.Column) string { func tag(table *xorm.Table, col *xorm.Column) string {
isNameId := (mapper.Table2Obj(col.Name) == "Id") isNameId := (mapper.Table2Obj(col.Name) == "Id")
res := make([]string, 0) res := make([]string, 0)
if !col.Nullable { if !col.Nullable {
if !isNameId { if !isNameId {
res = append(res, "not null") res = append(res, "not null")
} }
} }
if col.IsPrimaryKey { if col.IsPrimaryKey {
if !isNameId { if !isNameId {
res = append(res, "pk") res = append(res, "pk")
} }
} }
if col.Default != "" { if col.Default != "" {
res = append(res, "default "+col.Default) res = append(res, "default "+col.Default)
} }
if col.IsAutoIncrement { if col.IsAutoIncrement {
if !isNameId { if !isNameId {
res = append(res, "autoincr") res = append(res, "autoincr")
} }
} }
if col.IsCreated { if col.IsCreated {
res = append(res, "created") res = append(res, "created")
} }
if col.IsUpdated { if col.IsUpdated {
res = append(res, "updated") res = append(res, "updated")
} }
for name, _ := range col.Indexes { for name, _ := range col.Indexes {
index := table.Indexes[name] index := table.Indexes[name]
var uistr string var uistr string
if index.Type == xorm.UniqueType { if index.Type == xorm.UniqueType {
uistr = "unique" uistr = "unique"
} else if index.Type == xorm.IndexType { } else if index.Type == xorm.IndexType {
uistr = "index" uistr = "index"
} }
if len(index.Cols) > 1 { if len(index.Cols) > 1 {
uistr += "(" + index.Name + ")" uistr += "(" + index.Name + ")"
} }
res = append(res, uistr) res = append(res, uistr)
} }
nstr := col.SQLType.Name nstr := col.SQLType.Name
if col.Length != 0 { if col.Length != 0 {
if col.Length2 != 0 { if col.Length2 != 0 {
nstr += fmt.Sprintf("(%v, %v)", col.Length, col.Length2) nstr += fmt.Sprintf("(%v, %v)", col.Length, col.Length2)
} else { } else {
nstr += fmt.Sprintf("(%v)", col.Length) nstr += fmt.Sprintf("(%v)", col.Length)
} }
} }
res = append(res, nstr) res = append(res, nstr)
var tags []string var tags []string
if genJson { if genJson {
tags = append(tags, "json:\""+col.Name+"\"") tags = append(tags, "json:\""+col.Name+"\"")
} }
if len(res) > 0 { if len(res) > 0 {
tags = append(tags, "xorm:\""+strings.Join(res, " ")+"\"") tags = append(tags, "xorm:\""+strings.Join(res, " ")+"\"")
} }
if len(tags) > 0 { if len(tags) > 0 {
return "`" + strings.Join(tags, " ") + "`" return "`" + strings.Join(tags, " ") + "`"
} else { } else {
return "" return ""
} }
} }

View File

@ -1,51 +1,51 @@
package main package main
import ( import (
"github.com/lunny/xorm" "github.com/lunny/xorm"
"io/ioutil" "io/ioutil"
"strings" "strings"
"text/template" "text/template"
) )
type LangTmpl struct { type LangTmpl struct {
Funcs template.FuncMap Funcs template.FuncMap
Formater func(string) (string, error) Formater func(string) (string, error)
GenImports func([]*xorm.Table) map[string]string GenImports func([]*xorm.Table) map[string]string
} }
var ( var (
mapper = &xorm.SnakeMapper{} mapper = &xorm.SnakeMapper{}
langTmpls = map[string]LangTmpl{ langTmpls = map[string]LangTmpl{
"go": GoLangTmpl, "go": GoLangTmpl,
"c++": CPlusTmpl, "c++": CPlusTmpl,
} }
) )
func loadConfig(f string) map[string]string { func loadConfig(f string) map[string]string {
bts, err := ioutil.ReadFile(f) bts, err := ioutil.ReadFile(f)
if err != nil { if err != nil {
return nil return nil
} }
configs := make(map[string]string) configs := make(map[string]string)
lines := strings.Split(string(bts), "\n") lines := strings.Split(string(bts), "\n")
for _, line := range lines { for _, line := range lines {
line = strings.TrimRight(line, "\r") line = strings.TrimRight(line, "\r")
vs := strings.Split(line, "=") vs := strings.Split(line, "=")
if len(vs) == 2 { if len(vs) == 2 {
configs[strings.TrimSpace(vs[0])] = strings.TrimSpace(vs[1]) configs[strings.TrimSpace(vs[0])] = strings.TrimSpace(vs[1])
} }
} }
return configs return configs
} }
func unTitle(src string) string { func unTitle(src string) string {
if src == "" { if src == "" {
return "" return ""
} }
if len(src) == 1 { if len(src) == 1 {
return strings.ToLower(string(src[0])) return strings.ToLower(string(src[0]))
} else { } else {
return strings.ToLower(string(src[0])) + src[1:] return strings.ToLower(string(src[0])) + src[1:]
} }
} }

View File

@ -1,268 +1,268 @@
package main package main
import ( import (
"bytes" "bytes"
"fmt" "fmt"
_ "github.com/bylevel/pq" _ "github.com/bylevel/pq"
"github.com/dvirsky/go-pylog/logging" "github.com/dvirsky/go-pylog/logging"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/lunny/xorm" "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"strconv" "strconv"
"text/template" "text/template"
) )
var CmdReverse = &Command{ var CmdReverse = &Command{
UsageLine: "reverse [-m] driverName datasourceName tmplPath [generatedPath]", UsageLine: "reverse [-m] driverName datasourceName tmplPath [generatedPath]",
Short: "reverse a db to codes", Short: "reverse a db to codes",
Long: ` Long: `
according database's tables and columns to generate codes for Go, C++ and etc. according database's tables and columns to generate codes for Go, C++ and etc.
-m Generated one go file for every table -m Generated one go file for every table
driverName Database driver name, now supported four: mysql mymysql sqlite3 postgres driverName Database driver name, now supported four: mysql mymysql sqlite3 postgres
datasourceName Database connection uri, for detail infomation please visit driver's project page datasourceName Database connection uri, for detail infomation please visit driver's project page
tmplPath Template dir for generated. the default templates dir has provide 1 template tmplPath Template dir for generated. the default templates dir has provide 1 template
generatedPath This parameter is optional, if blank, the default value is model, then will generatedPath This parameter is optional, if blank, the default value is model, then will
generated all codes in model dir generated all codes in model dir
`, `,
} }
func init() { func init() {
CmdReverse.Run = runReverse CmdReverse.Run = runReverse
CmdReverse.Flags = map[string]bool{ CmdReverse.Flags = map[string]bool{
"-s": false, "-s": false,
"-l": false, "-l": false,
} }
} }
var ( var (
genJson bool = false genJson bool = false
) )
func printReversePrompt(flag string) { func printReversePrompt(flag string) {
} }
type Tmpl struct { type Tmpl struct {
Tables []*xorm.Table Tables []*xorm.Table
Imports map[string]string Imports map[string]string
Model string Model string
} }
func dirExists(dir string) bool { func dirExists(dir string) bool {
d, e := os.Stat(dir) d, e := os.Stat(dir)
switch { switch {
case e != nil: case e != nil:
return false return false
case !d.IsDir(): case !d.IsDir():
return false return false
} }
return true return true
} }
func runReverse(cmd *Command, args []string) { func runReverse(cmd *Command, args []string) {
num := checkFlags(cmd.Flags, args, printReversePrompt) num := checkFlags(cmd.Flags, args, printReversePrompt)
if num == -1 { if num == -1 {
return return
} }
args = args[num:] args = args[num:]
if len(args) < 3 { if len(args) < 3 {
fmt.Println("params error, please see xorm help reverse") fmt.Println("params error, please see xorm help reverse")
return return
} }
var isMultiFile bool = true var isMultiFile bool = true
if use, ok := cmd.Flags["-s"]; ok { if use, ok := cmd.Flags["-s"]; ok {
isMultiFile = !use isMultiFile = !use
} }
curPath, err := os.Getwd() curPath, err := os.Getwd()
if err != nil { if err != nil {
fmt.Println(curPath) fmt.Println(curPath)
return return
} }
var genDir string var genDir string
var model string var model string
if len(args) == 4 { if len(args) == 4 {
genDir, err = filepath.Abs(args[3]) genDir, err = filepath.Abs(args[3])
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
model = path.Base(genDir) model = path.Base(genDir)
} else { } else {
model = "model" model = "model"
genDir = path.Join(curPath, model) genDir = path.Join(curPath, model)
} }
dir, err := filepath.Abs(args[2]) dir, err := filepath.Abs(args[2])
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return return
} }
if !dirExists(dir) { if !dirExists(dir) {
logging.Error("Template %v path is not exist", dir) logging.Error("Template %v path is not exist", dir)
return return
} }
var langTmpl LangTmpl var langTmpl LangTmpl
var ok bool var ok bool
var lang string = "go" var lang string = "go"
cfgPath := path.Join(dir, "config") cfgPath := path.Join(dir, "config")
info, err := os.Stat(cfgPath) info, err := os.Stat(cfgPath)
var configs map[string]string var configs map[string]string
if err == nil && !info.IsDir() { if err == nil && !info.IsDir() {
configs = loadConfig(cfgPath) configs = loadConfig(cfgPath)
if l, ok := configs["lang"]; ok { if l, ok := configs["lang"]; ok {
lang = l lang = l
} }
if j, ok := configs["genJson"]; ok { if j, ok := configs["genJson"]; ok {
genJson, err = strconv.ParseBool(j) genJson, err = strconv.ParseBool(j)
} }
} }
if langTmpl, ok = langTmpls[lang]; !ok { if langTmpl, ok = langTmpls[lang]; !ok {
fmt.Println("Unsupported programing language", lang) fmt.Println("Unsupported programing language", lang)
return return
} }
os.MkdirAll(genDir, os.ModePerm) os.MkdirAll(genDir, os.ModePerm)
Orm, err := xorm.NewEngine(args[0], args[1]) Orm, err := xorm.NewEngine(args[0], args[1])
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return return
} }
tables, err := Orm.DBMetas() tables, err := Orm.DBMetas()
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return return
} }
filepath.Walk(dir, func(f string, info os.FileInfo, err error) error { filepath.Walk(dir, func(f string, info os.FileInfo, err error) error {
if info.IsDir() { if info.IsDir() {
return nil return nil
} }
if info.Name() == "config" { if info.Name() == "config" {
return nil return nil
} }
bs, err := ioutil.ReadFile(f) bs, err := ioutil.ReadFile(f)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
t := template.New(f) t := template.New(f)
t.Funcs(langTmpl.Funcs) t.Funcs(langTmpl.Funcs)
tmpl, err := t.Parse(string(bs)) tmpl, err := t.Parse(string(bs))
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
var w *os.File var w *os.File
fileName := info.Name() fileName := info.Name()
newFileName := fileName[:len(fileName)-4] newFileName := fileName[:len(fileName)-4]
ext := path.Ext(newFileName) ext := path.Ext(newFileName)
if !isMultiFile { if !isMultiFile {
w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0600) w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0600)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
imports := langTmpl.GenImports(tables) imports := langTmpl.GenImports(tables)
tbls := make([]*xorm.Table, 0) tbls := make([]*xorm.Table, 0)
for _, table := range tables { for _, table := range tables {
tbls = append(tbls, table) tbls = append(tbls, table)
} }
newbytes := bytes.NewBufferString("") newbytes := bytes.NewBufferString("")
t := &Tmpl{Tables: tbls, Imports: imports, Model: model} t := &Tmpl{Tables: tbls, Imports: imports, Model: model}
err = tmpl.Execute(newbytes, t) err = tmpl.Execute(newbytes, t)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
tplcontent, err := ioutil.ReadAll(newbytes) tplcontent, err := ioutil.ReadAll(newbytes)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
var source string var source string
if langTmpl.Formater != nil { if langTmpl.Formater != nil {
source, err = langTmpl.Formater(string(tplcontent)) source, err = langTmpl.Formater(string(tplcontent))
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
} else { } else {
source = string(tplcontent) source = string(tplcontent)
} }
w.WriteString(source) w.WriteString(source)
w.Close() w.Close()
} else { } else {
for _, table := range tables { for _, table := range tables {
// imports // imports
tbs := []*xorm.Table{table} tbs := []*xorm.Table{table}
imports := langTmpl.GenImports(tbs) imports := langTmpl.GenImports(tbs)
w, err := os.OpenFile(path.Join(genDir, unTitle(mapper.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600) w, err := os.OpenFile(path.Join(genDir, unTitle(mapper.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
newbytes := bytes.NewBufferString("") newbytes := bytes.NewBufferString("")
t := &Tmpl{Tables: tbs, Imports: imports, Model: model} t := &Tmpl{Tables: tbs, Imports: imports, Model: model}
err = tmpl.Execute(newbytes, t) err = tmpl.Execute(newbytes, t)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
tplcontent, err := ioutil.ReadAll(newbytes) tplcontent, err := ioutil.ReadAll(newbytes)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
var source string var source string
if langTmpl.Formater != nil { if langTmpl.Formater != nil {
source, err = langTmpl.Formater(string(tplcontent)) source, err = langTmpl.Formater(string(tplcontent))
if err != nil { if err != nil {
logging.Error("%v-%v", err, string(tplcontent)) logging.Error("%v-%v", err, string(tplcontent))
return err return err
} }
} else { } else {
source = string(tplcontent) source = string(tplcontent)
} }
w.WriteString(source) w.WriteString(source)
w.Close() w.Close()
} }
} }
return nil return nil
}) })
} }

View File

@ -1,147 +1,147 @@
package main package main
import ( import (
"fmt" "fmt"
"github.com/lunny/xorm" "github.com/lunny/xorm"
"strings" "strings"
) )
var CmdShell = &Command{ var CmdShell = &Command{
UsageLine: "shell driverName datasourceName", UsageLine: "shell driverName datasourceName",
Short: "a general shell to operate all kinds of database", Short: "a general shell to operate all kinds of database",
Long: ` Long: `
general database's shell for sqlite3, mysql, postgres. general database's shell for sqlite3, mysql, postgres.
driverName Database driver name, now supported four: mysql mymysql sqlite3 postgres driverName Database driver name, now supported four: mysql mymysql sqlite3 postgres
datasourceName Database connection uri, for detail infomation please visit driver's project page datasourceName Database connection uri, for detail infomation please visit driver's project page
`, `,
} }
func init() { func init() {
CmdShell.Run = runShell CmdShell.Run = runShell
CmdShell.Flags = map[string]bool{} CmdShell.Flags = map[string]bool{}
} }
var engine *xorm.Engine var engine *xorm.Engine
func help() { func help() {
fmt.Println(` fmt.Println(`
show tables show all tables show tables show all tables
columns <table_name> show table's column info columns <table_name> show table's column info
indexes <table_name> show table's index info indexes <table_name> show table's index info
exit exit shell exit exit shell
source <sql_file> exec sql file to current database source <sql_file> exec sql file to current database
dump [-nodata] <sql_file> dump structs or records to sql file dump [-nodata] <sql_file> dump structs or records to sql file
help show this document help show this document
<statement> SQL statement <statement> SQL statement
`) `)
} }
func runShell(cmd *Command, args []string) { func runShell(cmd *Command, args []string) {
if len(args) != 2 { if len(args) != 2 {
fmt.Println("params error, please see xorm help shell") fmt.Println("params error, please see xorm help shell")
return return
} }
var err error var err error
engine, err = xorm.NewEngine(args[0], args[1]) engine, err = xorm.NewEngine(args[0], args[1])
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
err = engine.Ping() err = engine.Ping()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
var scmd string var scmd string
fmt.Print("xorm$ ") fmt.Print("xorm$ ")
for { for {
var input string var input string
_, err := fmt.Scan(&input) _, err := fmt.Scan(&input)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
continue continue
} }
if strings.ToLower(input) == "exit" { if strings.ToLower(input) == "exit" {
fmt.Println("bye") fmt.Println("bye")
return return
} }
if !strings.HasSuffix(input, ";") { if !strings.HasSuffix(input, ";") {
scmd = scmd + " " + input scmd = scmd + " " + input
continue continue
} }
scmd = scmd + " " + input scmd = scmd + " " + input
lcmd := strings.TrimSpace(strings.ToLower(scmd)) lcmd := strings.TrimSpace(strings.ToLower(scmd))
if strings.HasPrefix(lcmd, "select") { if strings.HasPrefix(lcmd, "select") {
res, err := engine.Query(scmd + "\n") res, err := engine.Query(scmd + "\n")
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
if len(res) <= 0 { if len(res) <= 0 {
fmt.Println("no records") fmt.Println("no records")
} else { } else {
columns := make(map[string]int) columns := make(map[string]int)
for k, _ := range res[0] { for k, _ := range res[0] {
columns[k] = len(k) columns[k] = len(k)
} }
for _, m := range res { for _, m := range res {
for k, s := range m { for k, s := range m {
l := len(string(s)) l := len(string(s))
if l > columns[k] { if l > columns[k] {
columns[k] = l columns[k] = l
} }
} }
} }
var maxlen = 0 var maxlen = 0
for _, l := range columns { for _, l := range columns {
maxlen = maxlen + l + 3 maxlen = maxlen + l + 3
} }
maxlen = maxlen + 1 maxlen = maxlen + 1
fmt.Println(strings.Repeat("-", maxlen)) fmt.Println(strings.Repeat("-", maxlen))
fmt.Print("|") fmt.Print("|")
slice := make([]string, 0) slice := make([]string, 0)
for k, l := range columns { for k, l := range columns {
fmt.Print(" " + k + " ") fmt.Print(" " + k + " ")
fmt.Print(strings.Repeat(" ", l-len(k))) fmt.Print(strings.Repeat(" ", l-len(k)))
fmt.Print("|") fmt.Print("|")
slice = append(slice, k) slice = append(slice, k)
} }
fmt.Print("\n") fmt.Print("\n")
for _, r := range res { for _, r := range res {
fmt.Print("|") fmt.Print("|")
for _, k := range slice { for _, k := range slice {
fmt.Print(" " + string(r[k]) + " ") fmt.Print(" " + string(r[k]) + " ")
fmt.Print(strings.Repeat(" ", columns[k]-len(string(r[k])))) fmt.Print(strings.Repeat(" ", columns[k]-len(string(r[k]))))
fmt.Print("|") fmt.Print("|")
} }
fmt.Print("\n") fmt.Print("\n")
} }
fmt.Println(strings.Repeat("-", maxlen)) fmt.Println(strings.Repeat("-", maxlen))
//fmt.Println(res) //fmt.Println(res)
} }
} }
} else if lcmd == "show tables;" { } else if lcmd == "show tables;" {
tables, err := engine.DBMetas() tables, err := engine.DBMetas()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
} }
} else { } else {
cnt, err := engine.Exec(scmd) cnt, err := engine.Exec(scmd)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
fmt.Printf("%d records changed.\n", cnt) fmt.Printf("%d records changed.\n", cnt)
} }
} }
scmd = "" scmd = ""
fmt.Print("xorm$ ") fmt.Print("xorm$ ")
} }
} }

View File

@ -1,16 +1,16 @@
package main package main
import ( import (
"fmt" "fmt"
"github.com/dvirsky/go-pylog/logging" "github.com/dvirsky/go-pylog/logging"
"io" "io"
"os" "os"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"text/template" "text/template"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
) )
// +build go1.1 // +build go1.1
@ -21,59 +21,59 @@ const go11tag = true
// Commands lists the available commands and help topics. // Commands lists the available commands and help topics.
// The order here is the order in which they are printed by 'gopm help'. // The order here is the order in which they are printed by 'gopm help'.
var commands = []*Command{ var commands = []*Command{
CmdReverse, CmdReverse,
CmdShell, CmdShell,
} }
func init() { func init() {
runtime.GOMAXPROCS(runtime.NumCPU()) runtime.GOMAXPROCS(runtime.NumCPU())
} }
func main() { func main() {
logging.SetLevel(logging.ALL) logging.SetLevel(logging.ALL)
// Check length of arguments. // Check length of arguments.
args := os.Args[1:] args := os.Args[1:]
if len(args) < 1 { if len(args) < 1 {
usage() usage()
return return
} }
// Show help documentation. // Show help documentation.
if args[0] == "help" { if args[0] == "help" {
help(args[1:]) help(args[1:])
return return
} }
// Check commands and run. // Check commands and run.
for _, comm := range commands { for _, comm := range commands {
if comm.Name() == args[0] && comm.Run != nil { if comm.Name() == args[0] && comm.Run != nil {
comm.Run(comm, args[1:]) comm.Run(comm, args[1:])
exit() exit()
return return
} }
} }
fmt.Fprintf(os.Stderr, "xorm: unknown subcommand %q\nRun 'xorm help' for usage.\n", args[0]) fmt.Fprintf(os.Stderr, "xorm: unknown subcommand %q\nRun 'xorm help' for usage.\n", args[0])
setExitStatus(2) setExitStatus(2)
exit() exit()
} }
var exitStatus = 0 var exitStatus = 0
var exitMu sync.Mutex var exitMu sync.Mutex
func setExitStatus(n int) { func setExitStatus(n int) {
exitMu.Lock() exitMu.Lock()
if exitStatus < n { if exitStatus < n {
exitStatus = n exitStatus = n
} }
exitMu.Unlock() exitMu.Unlock()
} }
var usageTemplate = `xorm is a database tool based xorm package. var usageTemplate = `xorm is a database tool based xorm package.
Usage: Usage:
xorm command [arguments] xorm command [arguments]
The commands are: The commands are:
{{range .}}{{if .Runnable}} {{range .}}{{if .Runnable}}
@ -96,66 +96,66 @@ var helpTemplate = `{{if .Runnable}}usage: xorm {{.UsageLine}}
// tmpl executes the given template text on data, writing the result to w. // tmpl executes the given template text on data, writing the result to w.
func tmpl(w io.Writer, text string, data interface{}) { func tmpl(w io.Writer, text string, data interface{}) {
t := template.New("top") t := template.New("top")
t.Funcs(template.FuncMap{"trim": strings.TrimSpace, "capitalize": capitalize}) t.Funcs(template.FuncMap{"trim": strings.TrimSpace, "capitalize": capitalize})
template.Must(t.Parse(text)) template.Must(t.Parse(text))
if err := t.Execute(w, data); err != nil { if err := t.Execute(w, data); err != nil {
panic(err) panic(err)
} }
} }
func capitalize(s string) string { func capitalize(s string) string {
if s == "" { if s == "" {
return s return s
} }
r, n := utf8.DecodeRuneInString(s) r, n := utf8.DecodeRuneInString(s)
return string(unicode.ToTitle(r)) + s[n:] return string(unicode.ToTitle(r)) + s[n:]
} }
func printUsage(w io.Writer) { func printUsage(w io.Writer) {
tmpl(w, usageTemplate, commands) tmpl(w, usageTemplate, commands)
} }
func usage() { func usage() {
printUsage(os.Stderr) printUsage(os.Stderr)
os.Exit(2) os.Exit(2)
} }
// help implements the 'help' command. // help implements the 'help' command.
func help(args []string) { func help(args []string) {
if len(args) == 0 { if len(args) == 0 {
printUsage(os.Stdout) printUsage(os.Stdout)
// not exit 2: succeeded at 'gopm help'. // not exit 2: succeeded at 'gopm help'.
return return
} }
if len(args) != 1 { if len(args) != 1 {
fmt.Fprintf(os.Stderr, "usage: xorm help command\n\nToo many arguments given.\n") fmt.Fprintf(os.Stderr, "usage: xorm help command\n\nToo many arguments given.\n")
os.Exit(2) // failed at 'gopm help' os.Exit(2) // failed at 'gopm help'
} }
arg := args[0] arg := args[0]
for _, cmd := range commands { for _, cmd := range commands {
if cmd.Name() == arg { if cmd.Name() == arg {
tmpl(os.Stdout, helpTemplate, cmd) tmpl(os.Stdout, helpTemplate, cmd)
// not exit 2: succeeded at 'gopm help cmd'. // not exit 2: succeeded at 'gopm help cmd'.
return return
} }
} }
fmt.Fprintf(os.Stderr, "Unknown help topic %#q. Run 'xorm help'.\n", arg) fmt.Fprintf(os.Stderr, "Unknown help topic %#q. Run 'xorm help'.\n", arg)
os.Exit(2) // failed at 'gopm help cmd' os.Exit(2) // failed at 'gopm help cmd'
} }
var atexitFuncs []func() var atexitFuncs []func()
func atexit(f func()) { func atexit(f func()) {
atexitFuncs = append(atexitFuncs, f) atexitFuncs = append(atexitFuncs, f)
} }
func exit() { func exit() {
for _, f := range atexitFuncs { for _, f := range atexitFuncs {
f() f()
} }
os.Exit(exitStatus) os.Exit(exitStatus)
} }