resolved merger from upstream

This commit is contained in:
Nash Tsai 2013-12-18 15:26:48 +08:00
commit c2fe9ee0d5
18 changed files with 7898 additions and 7878 deletions

1
.gitignore vendored
View File

@ -25,3 +25,4 @@ vendor
*.log *.log
.vendor .vendor
temp_test.go

File diff suppressed because it is too large Load Diff

528
cache.go
View File

@ -1,131 +1,131 @@
package xorm package xorm
import ( import (
"container/list" "container/list"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
) )
const ( const (
// default cache expired time // default cache expired time
CacheExpired = 60 * time.Minute CacheExpired = 60 * time.Minute
// not use now // not use now
CacheMaxMemory = 256 CacheMaxMemory = 256
// evey ten minutes to clear all expired nodes // evey ten minutes to clear all expired nodes
CacheGcInterval = 10 * time.Minute CacheGcInterval = 10 * time.Minute
// each time when gc to removed max nodes // each time when gc to removed max nodes
CacheGcMaxRemoved = 20 CacheGcMaxRemoved = 20
) )
// CacheStore is a interface to store cache // CacheStore is a interface to store cache
type CacheStore interface { type CacheStore interface {
Put(key, value interface{}) error Put(key, value interface{}) error
Get(key interface{}) (interface{}, error) Get(key interface{}) (interface{}, error)
Del(key interface{}) error Del(key interface{}) error
} }
// MemoryStore implements CacheStore provide local machine // MemoryStore implements CacheStore provide local machine
// memory store // memory store
type MemoryStore struct { type MemoryStore struct {
store map[interface{}]interface{} store map[interface{}]interface{}
mutex sync.RWMutex mutex sync.RWMutex
} }
func NewMemoryStore() *MemoryStore { func NewMemoryStore() *MemoryStore {
return &MemoryStore{store: make(map[interface{}]interface{})} return &MemoryStore{store: make(map[interface{}]interface{})}
} }
func (s *MemoryStore) Put(key, value interface{}) error { func (s *MemoryStore) Put(key, value interface{}) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
s.store[key] = value s.store[key] = value
return nil return nil
} }
func (s *MemoryStore) Get(key interface{}) (interface{}, error) { func (s *MemoryStore) Get(key interface{}) (interface{}, error) {
s.mutex.RLock() s.mutex.RLock()
defer s.mutex.RUnlock() defer s.mutex.RUnlock()
if v, ok := s.store[key]; ok { if v, ok := s.store[key]; ok {
return v, nil return v, nil
} }
return nil, ErrNotExist return nil, ErrNotExist
} }
func (s *MemoryStore) Del(key interface{}) error { func (s *MemoryStore) Del(key interface{}) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
delete(s.store, key) delete(s.store, key)
return nil return nil
} }
// Cacher is an interface to provide cache // Cacher is an interface to provide cache
type Cacher interface { type Cacher interface {
GetIds(tableName, sql string) interface{} GetIds(tableName, sql string) interface{}
GetBean(tableName string, id int64) interface{} GetBean(tableName string, id int64) interface{}
PutIds(tableName, sql string, ids interface{}) PutIds(tableName, sql string, ids interface{})
PutBean(tableName string, id int64, obj interface{}) PutBean(tableName string, id int64, obj interface{})
DelIds(tableName, sql string) DelIds(tableName, sql string)
DelBean(tableName string, id int64) DelBean(tableName string, id int64)
ClearIds(tableName string) ClearIds(tableName string)
ClearBeans(tableName string) ClearBeans(tableName string)
} }
type idNode struct { type idNode struct {
tbName string tbName string
id int64 id int64
lastVisit time.Time lastVisit time.Time
} }
type sqlNode struct { type sqlNode struct {
tbName string tbName string
sql string sql string
lastVisit time.Time lastVisit time.Time
} }
func newIdNode(tbName string, id int64) *idNode { func newIdNode(tbName string, id int64) *idNode {
return &idNode{tbName, id, time.Now()} return &idNode{tbName, id, time.Now()}
} }
func newSqlNode(tbName, sql string) *sqlNode { func newSqlNode(tbName, sql string) *sqlNode {
return &sqlNode{tbName, sql, time.Now()} return &sqlNode{tbName, sql, time.Now()}
} }
// LRUCacher implements Cacher according to LRU algorithm // LRUCacher implements Cacher according to LRU algorithm
type LRUCacher struct { type LRUCacher struct {
idList *list.List idList *list.List
sqlList *list.List sqlList *list.List
idIndex map[string]map[interface{}]*list.Element idIndex map[string]map[interface{}]*list.Element
sqlIndex map[string]map[interface{}]*list.Element sqlIndex map[string]map[interface{}]*list.Element
store CacheStore store CacheStore
Max int Max int
mutex sync.Mutex mutex sync.Mutex
Expired time.Duration Expired time.Duration
maxSize int maxSize int
GcInterval time.Duration GcInterval time.Duration
} }
func newLRUCacher(store CacheStore, expired time.Duration, maxSize int, max int) *LRUCacher { func newLRUCacher(store CacheStore, expired time.Duration, maxSize int, max int) *LRUCacher {
cacher := &LRUCacher{store: store, idList: list.New(), cacher := &LRUCacher{store: store, idList: list.New(),
sqlList: list.New(), Expired: expired, maxSize: maxSize, sqlList: list.New(), Expired: expired, maxSize: maxSize,
GcInterval: CacheGcInterval, Max: max, GcInterval: CacheGcInterval, Max: max,
sqlIndex: make(map[string]map[interface{}]*list.Element), sqlIndex: make(map[string]map[interface{}]*list.Element),
idIndex: make(map[string]map[interface{}]*list.Element), idIndex: make(map[string]map[interface{}]*list.Element),
} }
cacher.RunGC() cacher.RunGC()
return cacher return cacher
} }
func NewLRUCacher(store CacheStore, max int) *LRUCacher { func NewLRUCacher(store CacheStore, max int) *LRUCacher {
return newLRUCacher(store, CacheExpired, CacheMaxMemory, max) return newLRUCacher(store, CacheExpired, CacheMaxMemory, max)
} }
func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher { func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher {
return newLRUCacher(store, expired, 0, max) return newLRUCacher(store, expired, 0, max)
} }
//func NewLRUCacher3(store CacheStore, expired time.Duration, maxSize int) *LRUCacher { //func NewLRUCacher3(store CacheStore, expired time.Duration, maxSize int) *LRUCacher {
@ -134,262 +134,262 @@ func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher
// RunGC run once every m.GcInterval // RunGC run once every m.GcInterval
func (m *LRUCacher) RunGC() { func (m *LRUCacher) RunGC() {
time.AfterFunc(m.GcInterval, func() { time.AfterFunc(m.GcInterval, func() {
m.RunGC() m.RunGC()
m.GC() m.GC()
}) })
} }
// GC check ids lit and sql list to remove all element expired // GC check ids lit and sql list to remove all element expired
func (m *LRUCacher) GC() { func (m *LRUCacher) GC() {
//fmt.Println("begin gc ...") //fmt.Println("begin gc ...")
//defer fmt.Println("end gc ...") //defer fmt.Println("end gc ...")
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var removedNum int var removedNum int
for e := m.idList.Front(); e != nil; { for e := m.idList.Front(); e != nil; {
if removedNum <= CacheGcMaxRemoved && if removedNum <= CacheGcMaxRemoved &&
time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired {
removedNum++ removedNum++
next := e.Next() next := e.Next()
//fmt.Println("removing ...", e.Value) //fmt.Println("removing ...", e.Value)
node := e.Value.(*idNode) node := e.Value.(*idNode)
m.delBean(node.tbName, node.id) m.delBean(node.tbName, node.id)
e = next e = next
} else { } else {
//fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.idList.Len()) //fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.idList.Len())
break break
} }
} }
removedNum = 0 removedNum = 0
for e := m.sqlList.Front(); e != nil; { for e := m.sqlList.Front(); e != nil; {
if removedNum <= CacheGcMaxRemoved && if removedNum <= CacheGcMaxRemoved &&
time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired {
removedNum++ removedNum++
next := e.Next() next := e.Next()
//fmt.Println("removing ...", e.Value) //fmt.Println("removing ...", e.Value)
node := e.Value.(*sqlNode) node := e.Value.(*sqlNode)
m.delIds(node.tbName, node.sql) m.delIds(node.tbName, node.sql)
e = next e = next
} else { } else {
//fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.sqlList.Len()) //fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.sqlList.Len())
break break
} }
} }
} }
// Get all bean's ids according to sql and parameter from cache // Get all bean's ids according to sql and parameter from cache
func (m *LRUCacher) GetIds(tableName, sql string) interface{} { func (m *LRUCacher) GetIds(tableName, sql string) interface{} {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if _, ok := m.sqlIndex[tableName]; !ok { if _, ok := m.sqlIndex[tableName]; !ok {
m.sqlIndex[tableName] = make(map[interface{}]*list.Element) m.sqlIndex[tableName] = make(map[interface{}]*list.Element)
} }
if v, err := m.store.Get(sql); err == nil { if v, err := m.store.Get(sql); err == nil {
if el, ok := m.sqlIndex[tableName][sql]; !ok { if el, ok := m.sqlIndex[tableName][sql]; !ok {
el = m.sqlList.PushBack(newSqlNode(tableName, sql)) el = m.sqlList.PushBack(newSqlNode(tableName, sql))
m.sqlIndex[tableName][sql] = el m.sqlIndex[tableName][sql] = el
} else { } else {
lastTime := el.Value.(*sqlNode).lastVisit lastTime := el.Value.(*sqlNode).lastVisit
// if expired, remove the node and return nil // if expired, remove the node and return nil
if time.Now().Sub(lastTime) > m.Expired { if time.Now().Sub(lastTime) > m.Expired {
m.delIds(tableName, sql) m.delIds(tableName, sql)
return nil return nil
} }
m.sqlList.MoveToBack(el) m.sqlList.MoveToBack(el)
el.Value.(*sqlNode).lastVisit = time.Now() el.Value.(*sqlNode).lastVisit = time.Now()
} }
return v return v
} else { } else {
m.delIds(tableName, sql) m.delIds(tableName, sql)
} }
return nil return nil
} }
// Get bean according tableName and id from cache // Get bean according tableName and id from cache
func (m *LRUCacher) GetBean(tableName string, id int64) interface{} { func (m *LRUCacher) GetBean(tableName string, id int64) interface{} {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if _, ok := m.idIndex[tableName]; !ok { if _, ok := m.idIndex[tableName]; !ok {
m.idIndex[tableName] = make(map[interface{}]*list.Element) m.idIndex[tableName] = make(map[interface{}]*list.Element)
} }
tid := genId(tableName, id) tid := genId(tableName, id)
if v, err := m.store.Get(tid); err == nil { if v, err := m.store.Get(tid); err == nil {
if el, ok := m.idIndex[tableName][id]; ok { if el, ok := m.idIndex[tableName][id]; ok {
lastTime := el.Value.(*idNode).lastVisit lastTime := el.Value.(*idNode).lastVisit
// if expired, remove the node and return nil // if expired, remove the node and return nil
if time.Now().Sub(lastTime) > m.Expired { if time.Now().Sub(lastTime) > m.Expired {
m.delBean(tableName, id) m.delBean(tableName, id)
//m.clearIds(tableName) //m.clearIds(tableName)
return nil return nil
} }
m.idList.MoveToBack(el) m.idList.MoveToBack(el)
el.Value.(*idNode).lastVisit = time.Now() el.Value.(*idNode).lastVisit = time.Now()
} else { } else {
el = m.idList.PushBack(newIdNode(tableName, id)) el = m.idList.PushBack(newIdNode(tableName, id))
m.idIndex[tableName][id] = el m.idIndex[tableName][id] = el
} }
return v return v
} else { } else {
// store bean is not exist, then remove memory's index // store bean is not exist, then remove memory's index
m.delBean(tableName, id) m.delBean(tableName, id)
//m.clearIds(tableName) //m.clearIds(tableName)
return nil return nil
} }
} }
// Clear all sql-ids mapping on table tableName from cache // Clear all sql-ids mapping on table tableName from cache
func (m *LRUCacher) clearIds(tableName string) { func (m *LRUCacher) clearIds(tableName string) {
if tis, ok := m.sqlIndex[tableName]; ok { if tis, ok := m.sqlIndex[tableName]; ok {
for sql, v := range tis { for sql, v := range tis {
m.sqlList.Remove(v) m.sqlList.Remove(v)
m.store.Del(sql) m.store.Del(sql)
} }
} }
m.sqlIndex[tableName] = make(map[interface{}]*list.Element) m.sqlIndex[tableName] = make(map[interface{}]*list.Element)
} }
func (m *LRUCacher) ClearIds(tableName string) { func (m *LRUCacher) ClearIds(tableName string) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.clearIds(tableName) m.clearIds(tableName)
} }
func (m *LRUCacher) clearBeans(tableName string) { func (m *LRUCacher) clearBeans(tableName string) {
if tis, ok := m.idIndex[tableName]; ok { if tis, ok := m.idIndex[tableName]; ok {
for id, v := range tis { for id, v := range tis {
m.idList.Remove(v) m.idList.Remove(v)
tid := genId(tableName, id.(int64)) tid := genId(tableName, id.(int64))
m.store.Del(tid) m.store.Del(tid)
} }
} }
m.idIndex[tableName] = make(map[interface{}]*list.Element) m.idIndex[tableName] = make(map[interface{}]*list.Element)
} }
func (m *LRUCacher) ClearBeans(tableName string) { func (m *LRUCacher) ClearBeans(tableName string) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.clearBeans(tableName) m.clearBeans(tableName)
} }
func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) { func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if _, ok := m.sqlIndex[tableName]; !ok { if _, ok := m.sqlIndex[tableName]; !ok {
m.sqlIndex[tableName] = make(map[interface{}]*list.Element) m.sqlIndex[tableName] = make(map[interface{}]*list.Element)
} }
if el, ok := m.sqlIndex[tableName][sql]; !ok { if el, ok := m.sqlIndex[tableName][sql]; !ok {
el = m.sqlList.PushBack(newSqlNode(tableName, sql)) el = m.sqlList.PushBack(newSqlNode(tableName, sql))
m.sqlIndex[tableName][sql] = el m.sqlIndex[tableName][sql] = el
} else { } else {
el.Value.(*sqlNode).lastVisit = time.Now() el.Value.(*sqlNode).lastVisit = time.Now()
} }
m.store.Put(sql, ids) m.store.Put(sql, ids)
if m.sqlList.Len() > m.Max { if m.sqlList.Len() > m.Max {
e := m.sqlList.Front() e := m.sqlList.Front()
node := e.Value.(*sqlNode) node := e.Value.(*sqlNode)
m.delIds(node.tbName, node.sql) m.delIds(node.tbName, node.sql)
} }
} }
func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) { func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var el *list.Element var el *list.Element
var ok bool var ok bool
if el, ok = m.idIndex[tableName][id]; !ok { if el, ok = m.idIndex[tableName][id]; !ok {
el = m.idList.PushBack(newIdNode(tableName, id)) el = m.idList.PushBack(newIdNode(tableName, id))
m.idIndex[tableName][id] = el m.idIndex[tableName][id] = el
} else { } else {
el.Value.(*idNode).lastVisit = time.Now() el.Value.(*idNode).lastVisit = time.Now()
} }
m.store.Put(genId(tableName, id), obj) m.store.Put(genId(tableName, id), obj)
if m.idList.Len() > m.Max { if m.idList.Len() > m.Max {
e := m.idList.Front() e := m.idList.Front()
node := e.Value.(*idNode) node := e.Value.(*idNode)
m.delBean(node.tbName, node.id) m.delBean(node.tbName, node.id)
} }
} }
func (m *LRUCacher) delIds(tableName, sql string) { func (m *LRUCacher) delIds(tableName, sql string) {
if _, ok := m.sqlIndex[tableName]; ok { if _, ok := m.sqlIndex[tableName]; ok {
if el, ok := m.sqlIndex[tableName][sql]; ok { if el, ok := m.sqlIndex[tableName][sql]; ok {
delete(m.sqlIndex[tableName], sql) delete(m.sqlIndex[tableName], sql)
m.sqlList.Remove(el) m.sqlList.Remove(el)
} }
} }
m.store.Del(sql) m.store.Del(sql)
} }
func (m *LRUCacher) DelIds(tableName, sql string) { func (m *LRUCacher) DelIds(tableName, sql string) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.delIds(tableName, sql) m.delIds(tableName, sql)
} }
func (m *LRUCacher) delBean(tableName string, id int64) { func (m *LRUCacher) delBean(tableName string, id int64) {
tid := genId(tableName, id) tid := genId(tableName, id)
if el, ok := m.idIndex[tableName][id]; ok { if el, ok := m.idIndex[tableName][id]; ok {
delete(m.idIndex[tableName], id) delete(m.idIndex[tableName], id)
m.idList.Remove(el) m.idList.Remove(el)
m.clearIds(tableName) m.clearIds(tableName)
} }
m.store.Del(tid) m.store.Del(tid)
} }
func (m *LRUCacher) DelBean(tableName string, id int64) { func (m *LRUCacher) DelBean(tableName string, id int64) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.delBean(tableName, id) m.delBean(tableName, id)
} }
func encodeIds(ids []int64) (s string) { func encodeIds(ids []int64) (s string) {
s = "[" s = "["
for _, id := range ids { for _, id := range ids {
s += fmt.Sprintf("%v,", id) s += fmt.Sprintf("%v,", id)
} }
s = s[:len(s)-1] + "]" s = s[:len(s)-1] + "]"
return return
} }
func decodeIds(s string) []int64 { func decodeIds(s string) []int64 {
res := make([]int64, 0) res := make([]int64, 0)
if len(s) >= 2 { if len(s) >= 2 {
ss := strings.Split(s[1:len(s)-1], ",") ss := strings.Split(s[1:len(s)-1], ",")
for _, s := range ss { for _, s := range ss {
i, err := strconv.ParseInt(s, 10, 64) i, err := strconv.ParseInt(s, 10, 64)
if err != nil { if err != nil {
return res return res
} }
res = append(res, i) res = append(res, i)
} }
} }
return res return res
} }
func getCacheSql(m Cacher, tableName, sql string, args interface{}) ([]int64, error) { func getCacheSql(m Cacher, tableName, sql string, args interface{}) ([]int64, error) {
bytes := m.GetIds(tableName, genSqlKey(sql, args)) bytes := m.GetIds(tableName, genSqlKey(sql, args))
if bytes == nil { if bytes == nil {
return nil, errors.New("Not Exist") return nil, errors.New("Not Exist")
} }
objs := decodeIds(bytes.(string)) objs := decodeIds(bytes.(string))
return objs, nil return objs, nil
} }
func putCacheSql(m Cacher, ids []int64, tableName, sql string, args interface{}) error { func putCacheSql(m Cacher, ids []int64, tableName, sql string, args interface{}) error {
bytes := encodeIds(ids) bytes := encodeIds(ids)
m.PutIds(tableName, genSqlKey(sql, args), bytes) m.PutIds(tableName, genSqlKey(sql, args), bytes)
return nil return nil
} }
func genSqlKey(sql string, args interface{}) string { func genSqlKey(sql string, args interface{}) string {
return fmt.Sprintf("%v-%v", sql, args) return fmt.Sprintf("%v-%v", sql, args)
} }
func genId(prefix string, id int64) string { func genId(prefix string, id int64) string {
return fmt.Sprintf("%v-%v", prefix, id) return fmt.Sprintf("%v-%v", prefix, id)
} }

1236
engine.go

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,15 @@
package xorm package xorm
import ( import (
"errors" "errors"
) )
var ( var (
ErrParamsType error = errors.New("Params type error") ErrParamsType error = errors.New("Params type error")
ErrTableNotFound error = errors.New("Not found table") ErrTableNotFound error = errors.New("Not found table")
ErrUnSupportedType error = errors.New("Unsupported type error") ErrUnSupportedType error = errors.New("Unsupported type error")
ErrNotExist error = errors.New("Not exist error") ErrNotExist error = errors.New("Not exist error")
ErrCacheFailed error = errors.New("Cache failed") ErrCacheFailed error = errors.New("Cache failed")
ErrNeedDeletedCond error = errors.New("Delete need at least one condition") ErrNeedDeletedCond error = errors.New("Delete need at least one condition")
ErrNotImplemented error = errors.New("Not implemented.") ErrNotImplemented error = errors.New("Not implemented.")
) )

View File

@ -1,13 +1,13 @@
package xorm package xorm
import ( import (
"fmt" "fmt"
"strings" "strings"
) )
// Filter is an interface to filter SQL // Filter is an interface to filter SQL
type Filter interface { type Filter interface {
Do(sql string, session *Session) string Do(sql string, session *Session) string
} }
// PgSeqFilter filter SQL replace ?, ? ... to $1, $2 ... // PgSeqFilter filter SQL replace ?, ? ... to $1, $2 ...
@ -15,16 +15,16 @@ type PgSeqFilter struct {
} }
func (s *PgSeqFilter) Do(sql string, session *Session) string { func (s *PgSeqFilter) Do(sql string, session *Session) string {
segs := strings.Split(sql, "?") segs := strings.Split(sql, "?")
size := len(segs) size := len(segs)
res := "" res := ""
for i, c := range segs { for i, c := range segs {
if i < size-1 { if i < size-1 {
res += c + fmt.Sprintf("$%v", i+1) res += c + fmt.Sprintf("$%v", i+1)
} }
} }
res += segs[size-1] res += segs[size-1]
return res return res
} }
// QuoteFilter filter SQL replace ` to database's own quote character // QuoteFilter filter SQL replace ` to database's own quote character
@ -32,7 +32,7 @@ type QuoteFilter struct {
} }
func (s *QuoteFilter) Do(sql string, session *Session) string { func (s *QuoteFilter) Do(sql string, session *Session) string {
return strings.Replace(sql, "`", session.Engine.QuoteStr(), -1) return strings.Replace(sql, "`", session.Engine.QuoteStr(), -1)
} }
// IdFilter filter SQL replace (id) to primary key column name // IdFilter filter SQL replace (id) to primary key column name
@ -40,10 +40,10 @@ type IdFilter struct {
} }
func (i *IdFilter) Do(sql string, session *Session) string { func (i *IdFilter) Do(sql string, session *Session) string {
if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" {
sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1)
sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1)
return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1)
} }
return sql return sql
} }

View File

@ -1,63 +1,63 @@
package xorm package xorm
import ( import (
"reflect" "reflect"
"strings" "strings"
) )
func indexNoCase(s, sep string) int { func indexNoCase(s, sep string) int {
return strings.Index(strings.ToLower(s), strings.ToLower(sep)) return strings.Index(strings.ToLower(s), strings.ToLower(sep))
} }
func splitNoCase(s, sep string) []string { func splitNoCase(s, sep string) []string {
idx := indexNoCase(s, sep) idx := indexNoCase(s, sep)
if idx < 0 { if idx < 0 {
return []string{s} return []string{s}
} }
return strings.Split(s, s[idx:idx+len(sep)]) return strings.Split(s, s[idx:idx+len(sep)])
} }
func splitNNoCase(s, sep string, n int) []string { func splitNNoCase(s, sep string, n int) []string {
idx := indexNoCase(s, sep) idx := indexNoCase(s, sep)
if idx < 0 { if idx < 0 {
return []string{s} return []string{s}
} }
return strings.SplitN(s, s[idx:idx+len(sep)], n) return strings.SplitN(s, s[idx:idx+len(sep)], n)
} }
func makeArray(elem string, count int) []string { func makeArray(elem string, count int) []string {
res := make([]string, count) res := make([]string, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
res[i] = elem res[i] = elem
} }
return res return res
} }
func rType(bean interface{}) reflect.Type { func rType(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean)) sliceValue := reflect.Indirect(reflect.ValueOf(bean))
return reflect.TypeOf(sliceValue.Interface()) return reflect.TypeOf(sliceValue.Interface())
} }
func structName(v reflect.Type) string { func structName(v reflect.Type) string {
for v.Kind() == reflect.Ptr { for v.Kind() == reflect.Ptr {
v = v.Elem() v = v.Elem()
} }
return v.Name() return v.Name()
} }
func sliceEq(left, right []string) bool { func sliceEq(left, right []string) bool {
for _, l := range left { for _, l := range left {
var find bool var find bool
for _, r := range right { for _, r := range right {
if l == r { if l == r {
find = true find = true
break break
} }
} }
if !find { if !find {
return false return false
} }
} }
return true return true
} }

View File

@ -1,13 +1,13 @@
package xorm package xorm
import ( import (
"strings" "strings"
) )
// name translation between struct, fields names and table, column names // name translation between struct, fields names and table, column names
type IMapper interface { type IMapper interface {
Obj2Table(string) string Obj2Table(string) string
Table2Obj(string) string Table2Obj(string) string
} }
// SameMapper implements IMapper and provides same name between struct and // SameMapper implements IMapper and provides same name between struct and
@ -16,11 +16,11 @@ type SameMapper struct {
} }
func (m SameMapper) Obj2Table(o string) string { func (m SameMapper) Obj2Table(o string) string {
return o return o
} }
func (m SameMapper) Table2Obj(t string) string { func (m SameMapper) Table2Obj(t string) string {
return t return t
} }
// SnakeMapper implements IMapper and provides name transaltion between // SnakeMapper implements IMapper and provides name transaltion between
@ -29,18 +29,18 @@ type SnakeMapper struct {
} }
func snakeCasedName(name string) string { func snakeCasedName(name string) string {
newstr := make([]rune, 0) newstr := make([]rune, 0)
for idx, chr := range name { for idx, chr := range name {
if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { if isUpper := 'A' <= chr && chr <= 'Z'; isUpper {
if idx > 0 { if idx > 0 {
newstr = append(newstr, '_') newstr = append(newstr, '_')
} }
chr -= ('A' - 'a') chr -= ('A' - 'a')
} }
newstr = append(newstr, chr) newstr = append(newstr, chr)
} }
return string(newstr) return string(newstr)
} }
/*func pascal2Sql(s string) (d string) { /*func pascal2Sql(s string) (d string) {
@ -63,69 +63,69 @@ func snakeCasedName(name string) string {
}*/ }*/
func (mapper SnakeMapper) Obj2Table(name string) string { func (mapper SnakeMapper) Obj2Table(name string) string {
return snakeCasedName(name) return snakeCasedName(name)
} }
func titleCasedName(name string) string { func titleCasedName(name string) string {
newstr := make([]rune, 0) newstr := make([]rune, 0)
upNextChar := true upNextChar := true
name = strings.ToLower(name) name = strings.ToLower(name)
for _, chr := range name { for _, chr := range name {
switch { switch {
case upNextChar: case upNextChar:
upNextChar = false upNextChar = false
if 'a' <= chr && chr <= 'z' { if 'a' <= chr && chr <= 'z' {
chr -= ('a' - 'A') chr -= ('a' - 'A')
} }
case chr == '_': case chr == '_':
upNextChar = true upNextChar = true
continue continue
} }
newstr = append(newstr, chr) newstr = append(newstr, chr)
} }
return string(newstr) return string(newstr)
} }
func (mapper SnakeMapper) Table2Obj(name string) string { func (mapper SnakeMapper) Table2Obj(name string) string {
return titleCasedName(name) return titleCasedName(name)
} }
// provide prefix table name support // provide prefix table name support
type PrefixMapper struct { type PrefixMapper struct {
Mapper IMapper Mapper IMapper
Prefix string Prefix string
} }
func (mapper PrefixMapper) Obj2Table(name string) string { func (mapper PrefixMapper) Obj2Table(name string) string {
return mapper.Prefix + mapper.Mapper.Obj2Table(name) return mapper.Prefix + mapper.Mapper.Obj2Table(name)
} }
func (mapper PrefixMapper) Table2Obj(name string) string { func (mapper PrefixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):])
} }
func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper { func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper {
return PrefixMapper{mapper, prefix} return PrefixMapper{mapper, prefix}
} }
// provide suffix table name support // provide suffix table name support
type SuffixMapper struct { type SuffixMapper struct {
Mapper IMapper Mapper IMapper
Suffix string Suffix string
} }
func (mapper SuffixMapper) Obj2Table(name string) string { func (mapper SuffixMapper) Obj2Table(name string) string {
return mapper.Suffix + mapper.Mapper.Obj2Table(name) return mapper.Suffix + mapper.Mapper.Obj2Table(name)
} }
func (mapper SuffixMapper) Table2Obj(name string) string { func (mapper SuffixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[len(mapper.Suffix):]) return mapper.Mapper.Table2Obj(name[len(mapper.Suffix):])
} }
func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper { func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper {
return SuffixMapper{mapper, suffix} return SuffixMapper{mapper, suffix}
} }

View File

@ -1,66 +1,67 @@
package xorm package xorm
import ( import (
"errors" "errors"
"strings" "strings"
"time" "time"
) )
type mymysql struct { type mymysql struct {
mysql mysql
proto string }
raddr string
laddr string type mymysqlParser struct {
timeout time.Duration }
db string
user string func (p *mymysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
passwd string db := &uri{dbType: MYSQL}
pd := strings.SplitN(dataSourceName, "*", 2)
if len(pd) == 2 {
// Parse protocol part of URI
p := strings.SplitN(pd[0], ":", 2)
if len(p) != 2 {
return nil, errors.New("Wrong protocol part of URI")
}
db.proto = p[0]
options := strings.Split(p[1], ",")
db.raddr = options[0]
for _, o := range options[1:] {
kv := strings.SplitN(o, "=", 2)
var k, v string
if len(kv) == 2 {
k, v = kv[0], kv[1]
} else {
k, v = o, "true"
}
switch k {
case "laddr":
db.laddr = v
case "timeout":
to, err := time.ParseDuration(v)
if err != nil {
return nil, err
}
db.timeout = to
default:
return nil, errors.New("Unknown option: " + k)
}
}
// Remove protocol part
pd = pd[1:]
}
// Parse database part of URI
dup := strings.SplitN(pd[0], "/", 3)
if len(dup) != 3 {
return nil, errors.New("Wrong database part of URI")
}
db.dbName = dup[0]
db.user = dup[1]
db.passwd = dup[2]
return db, nil
} }
func (db *mymysql) Init(drivername, uri string) error { func (db *mymysql) Init(drivername, uri string) error {
db.mysql.base.init(drivername, uri) return db.mysql.base.init(&mymysqlParser{}, drivername, uri)
pd := strings.SplitN(uri, "*", 2)
if len(pd) == 2 {
// Parse protocol part of URI
p := strings.SplitN(pd[0], ":", 2)
if len(p) != 2 {
return errors.New("Wrong protocol part of URI")
}
db.proto = p[0]
options := strings.Split(p[1], ",")
db.raddr = options[0]
for _, o := range options[1:] {
kv := strings.SplitN(o, "=", 2)
var k, v string
if len(kv) == 2 {
k, v = kv[0], kv[1]
} else {
k, v = o, "true"
}
switch k {
case "laddr":
db.laddr = v
case "timeout":
to, err := time.ParseDuration(v)
if err != nil {
return err
}
db.timeout = to
default:
return errors.New("Unknown option: " + k)
}
}
// Remove protocol part
pd = pd[1:]
}
// Parse database part of URI
dup := strings.SplitN(pd[0], "/", 3)
if len(dup) != 3 {
return errors.New("Wrong database part of URI")
}
db.dbname = dup[0]
db.user = dup[1]
db.passwd = dup[2]
return nil
} }

520
mysql.go
View File

@ -1,311 +1,323 @@
package xorm package xorm
import ( import (
"crypto/tls" "crypto/tls"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
) )
type base struct { type uri struct {
drivername string dbType string
dataSourceName string proto string
host string
port string
dbName string
user string
passwd string
charset string
laddr string
raddr string
timeout time.Duration
} }
func (b *base) init(drivername, dataSourceName string) { type parser interface {
b.drivername, b.dataSourceName = drivername, dataSourceName parse(driverName, dataSourceName string) (*uri, error)
}
type mysqlParser struct {
}
func (p *mysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
//cfg.params = make(map[string]string)
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dataSourceName)
//tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
uri := &uri{dbType: MYSQL}
for i, match := range matches {
switch names[i] {
case "dbname":
uri.dbName = match
}
}
return uri, nil
}
type base struct {
parser parser
driverName string
dataSourceName string
*uri
}
func (b *base) init(parser parser, drivername, dataSourceName string) (err error) {
b.parser = parser
b.driverName, b.dataSourceName = drivername, dataSourceName
b.uri, err = b.parser.parse(b.driverName, b.dataSourceName)
return
} }
type mysql struct { type mysql struct {
base base
user string net string
passwd string addr string
net string params map[string]string
addr string loc *time.Location
dbname string timeout time.Duration
params map[string]string tls *tls.Config
loc *time.Location allowAllFiles bool
timeout time.Duration allowOldPasswords bool
tls *tls.Config clientFoundRows bool
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
}
/*func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}
// Not a valid bool value
return
}*/
func (cfg *mysql) parseDSN(dsn string) (err error) {
//cfg.params = make(map[string]string)
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dsn)
//tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
for i, match := range matches {
switch names[i] {
case "dbname":
cfg.dbname = match
}
}
return
} }
func (db *mysql) Init(drivername, uri string) error { func (db *mysql) Init(drivername, uri string) error {
db.base.init(drivername, uri) return db.base.init(&mysqlParser{}, drivername, uri)
return db.parseDSN(uri)
} }
func (db *mysql) SqlType(c *Column) string { func (db *mysql) SqlType(c *Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case Bool: case Bool:
res = TinyInt res = TinyInt
case Serial: case Serial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = Int res = Int
case BigSerial: case BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = BigInt res = BigInt
case Bytea: case Bytea:
res = Blob res = Blob
case TimeStampz: case TimeStampz:
res = Char res = Char
c.Length = 64 c.Length = 64
default: default:
res = t res = t
} }
var hasLen1 bool = (c.Length > 0) var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0) var hasLen2 bool = (c.Length2 > 0)
if hasLen1 { if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.Itoa(c.Length) + ")"
} else if hasLen2 { } else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
} }
return res return res
} }
func (db *mysql) SupportInsertMany() bool { func (db *mysql) SupportInsertMany() bool {
return true return true
} }
func (db *mysql) QuoteStr() string { func (db *mysql) QuoteStr() string {
return "`" return "`"
} }
func (db *mysql) SupportEngine() bool { func (db *mysql) SupportEngine() bool {
return true return true
} }
func (db *mysql) AutoIncrStr() string { func (db *mysql) AutoIncrStr() string {
return "AUTO_INCREMENT" return "AUTO_INCREMENT"
} }
func (db *mysql) SupportCharset() bool { func (db *mysql) SupportCharset() bool {
return true return true
} }
func (db *mysql) IndexOnTable() bool { func (db *mysql) IndexOnTable() bool {
return true return true
} }
func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{db.dbname, tableName, idxName} args := []interface{}{db.dbName, tableName, idxName}
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?" sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?"
return sql, args return sql, args
} }
func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{db.dbname, tableName, colName} args := []interface{}{db.dbName, tableName, colName}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
return sql, args return sql, args
} }
func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{db.dbname, tableName} args := []interface{}{db.dbName, tableName}
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
return sql, args return sql, args
} }
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{db.dbname, tableName} args := []interface{}{db.dbName, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, record := range res { for _, record := range res {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "COLUMN_NAME": case "COLUMN_NAME":
col.Name = strings.Trim(string(content), "` ") col.Name = strings.Trim(string(content), "` ")
case "IS_NULLABLE": case "IS_NULLABLE":
if "YES" == string(content) { if "YES" == string(content) {
col.Nullable = true col.Nullable = true
} }
case "COLUMN_DEFAULT": case "COLUMN_DEFAULT":
// add '' // add ''
col.Default = string(content) col.Default = string(content)
case "COLUMN_TYPE": case "COLUMN_TYPE":
cts := strings.Split(string(content), "(") cts := strings.Split(string(content), "(")
var len1, len2 int var len1, len2 int
if len(cts) == 2 { if len(cts) == 2 {
idx := strings.Index(cts[1], ")") idx := strings.Index(cts[1], ")")
lens := strings.Split(cts[1][0:idx], ",") lens := strings.Split(cts[1][0:idx], ",")
len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if len(lens) == 2 { if len(lens) == 2 {
len2, err = strconv.Atoi(lens[1]) len2, err = strconv.Atoi(lens[1])
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} }
} }
colName := cts[0] colName := cts[0]
colType := strings.ToUpper(colName) colType := strings.ToUpper(colName)
col.Length = len1 col.Length = len1
col.Length2 = len2 col.Length2 = len2
if _, ok := sqlTypes[colType]; ok { if _, ok := sqlTypes[colType]; ok {
col.SQLType = SQLType{colType, len1, len2} col.SQLType = SQLType{colType, len1, len2}
} else { } else {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
} }
case "COLUMN_KEY": case "COLUMN_KEY":
key := string(content) key := string(content)
if key == "PRI" { if key == "PRI" {
col.IsPrimaryKey = true col.IsPrimaryKey = true
} }
if key == "UNI" { if key == "UNI" {
//col.is //col.is
} }
case "EXTRA": case "EXTRA":
extra := string(content) extra := string(content)
if extra == "auto_increment" { if extra == "auto_increment" {
col.IsAutoIncrement = true col.IsAutoIncrement = true
} }
} }
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *mysql) GetTables() ([]*Table, error) { func (db *mysql) GetTables() ([]*Table, error) {
args := []interface{}{db.dbname} args := []interface{}{db.dbName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?" s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for _, record := range res {
table := new(Table) table := new(Table)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "TABLE_NAME": case "TABLE_NAME":
table.Name = strings.Trim(string(content), "` ") table.Name = strings.Trim(string(content), "` ")
case "ENGINE": case "ENGINE":
} }
} }
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
} }
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{db.dbname, tableName} args := []interface{}{db.dbName, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for _, record := range res {
var indexType int var indexType int
var indexName, colName string var indexName, colName string
for name, content := range record { for name, content := range record {
switch name { switch name {
case "NON_UNIQUE": case "NON_UNIQUE":
if "YES" == string(content) || string(content) == "1" { if "YES" == string(content) || string(content) == "1" {
indexType = IndexType indexType = IndexType
} else { } else {
indexType = UniqueType indexType = UniqueType
} }
case "INDEX_NAME": case "INDEX_NAME":
indexName = string(content) indexName = string(content)
case "COLUMN_NAME": case "COLUMN_NAME":
colName = strings.Trim(string(content), "` ") colName = strings.Trim(string(content), "` ")
} }
} }
if indexName == "PRIMARY" { if indexName == "PRIMARY" {
continue continue
} }
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)] indexName = indexName[5+len(tableName) : len(indexName)]
} }
var index *Index var index *Index
var ok bool var ok bool
if index, ok = indexes[indexName]; !ok { if index, ok = indexes[indexName]; !ok {
index = new(Index) index = new(Index)
index.Type = indexType index.Type = indexType
index.Name = indexName index.Name = indexName
indexes[indexName] = index indexes[indexName] = index
} }
index.AddColumn(colName) index.AddColumn(colName)
} }
return indexes, nil return indexes, nil
} }

290
pool.go
View File

@ -1,13 +1,13 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
//"fmt" //"fmt"
"sync" "sync"
//"sync/atomic" //"sync/atomic"
"container/list" "container/list"
"reflect" "reflect"
"time" "time"
) )
// Interface IConnecPool is a connection pool interface, all implements should implement // Interface IConnecPool is a connection pool interface, all implements should implement
@ -17,14 +17,14 @@ import (
// ReleaseDB for releasing a db connection; // ReleaseDB for releasing a db connection;
// Close for invoking when engine.Close // Close for invoking when engine.Close
type IConnectPool interface { type IConnectPool interface {
Init(engine *Engine) error Init(engine *Engine) error
RetrieveDB(engine *Engine) (*sql.DB, error) RetrieveDB(engine *Engine) (*sql.DB, error)
ReleaseDB(engine *Engine, db *sql.DB) ReleaseDB(engine *Engine, db *sql.DB)
Close(engine *Engine) error Close(engine *Engine) error
SetMaxIdleConns(conns int) SetMaxIdleConns(conns int)
MaxIdleConns() int MaxIdleConns() int
SetMaxConns(conns int) SetMaxConns(conns int)
MaxConns() int MaxConns() int
} }
// Struct NoneConnectPool is a implement for IConnectPool. It provides directly invoke driver's // Struct NoneConnectPool is a implement for IConnectPool. It provides directly invoke driver's
@ -34,35 +34,35 @@ type NoneConnectPool struct {
// NewNoneConnectPool new a NoneConnectPool. // NewNoneConnectPool new a NoneConnectPool.
func NewNoneConnectPool() IConnectPool { func NewNoneConnectPool() IConnectPool {
return &NoneConnectPool{} return &NoneConnectPool{}
} }
// Init do nothing // Init do nothing
func (p *NoneConnectPool) Init(engine *Engine) error { func (p *NoneConnectPool) Init(engine *Engine) error {
return nil return nil
} }
// RetrieveDB directly open a connection // RetrieveDB directly open a connection
func (p *NoneConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { func (p *NoneConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) {
db, err = engine.OpenDB() db, err = engine.OpenDB()
return return
} }
// ReleaseDB directly close a connection // ReleaseDB directly close a connection
func (p *NoneConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { func (p *NoneConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
db.Close() db.Close()
} }
// Close do nothing // Close do nothing
func (p *NoneConnectPool) Close(engine *Engine) error { func (p *NoneConnectPool) Close(engine *Engine) error {
return nil return nil
} }
func (p *NoneConnectPool) SetMaxIdleConns(conns int) { func (p *NoneConnectPool) SetMaxIdleConns(conns int) {
} }
func (p *NoneConnectPool) MaxIdleConns() int { func (p *NoneConnectPool) MaxIdleConns() int {
return 0 return 0
} }
// not implemented // not implemented
@ -71,133 +71,133 @@ func (p *NoneConnectPool) SetMaxConns(conns int) {
// not implemented // not implemented
func (p *NoneConnectPool) MaxConns() int { func (p *NoneConnectPool) MaxConns() int {
return -1 return -1
} }
// Struct SysConnectPool is a simple wrapper for using system default connection pool. // Struct SysConnectPool is a simple wrapper for using system default connection pool.
// About the system connection pool, you can review the code database/sql/sql.go // About the system connection pool, you can review the code database/sql/sql.go
// It's currently default Pool implments. // It's currently default Pool implments.
type SysConnectPool struct { type SysConnectPool struct {
db *sql.DB db *sql.DB
maxIdleConns int maxIdleConns int
maxConns int maxConns int
curConns int curConns int
mutex *sync.Mutex mutex *sync.Mutex
queue *list.List queue *list.List
} }
// NewSysConnectPool new a SysConnectPool. // NewSysConnectPool new a SysConnectPool.
func NewSysConnectPool() IConnectPool { func NewSysConnectPool() IConnectPool {
return &SysConnectPool{} return &SysConnectPool{}
} }
// Init create a db immediately and keep it util engine closed. // Init create a db immediately and keep it util engine closed.
func (s *SysConnectPool) Init(engine *Engine) error { func (s *SysConnectPool) Init(engine *Engine) error {
db, err := engine.OpenDB() db, err := engine.OpenDB()
if err != nil { if err != nil {
return err return err
} }
s.db = db s.db = db
s.maxIdleConns = 2 s.maxIdleConns = 2
s.maxConns = -1 s.maxConns = -1
s.curConns = 0 s.curConns = 0
s.mutex = &sync.Mutex{} s.mutex = &sync.Mutex{}
s.queue = list.New() s.queue = list.New()
return nil return nil
} }
type node struct { type node struct {
mutex sync.Mutex mutex sync.Mutex
cond *sync.Cond cond *sync.Cond
} }
func newCondNode() *node { func newCondNode() *node {
n := &node{} n := &node{}
n.cond = sync.NewCond(&n.mutex) n.cond = sync.NewCond(&n.mutex)
return n return n
} }
// RetrieveDB just return the only db // RetrieveDB just return the only db
func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) {
/*if s.maxConns > 0 { /*if s.maxConns > 0 {
fmt.Println("before retrieve") fmt.Println("before retrieve")
s.mutex.Lock() s.mutex.Lock()
for s.curConns >= s.maxConns { for s.curConns >= s.maxConns {
fmt.Println("before waiting...", s.curConns, s.queue.Len()) fmt.Println("before waiting...", s.curConns, s.queue.Len())
s.mutex.Unlock() s.mutex.Unlock()
n := NewNode() n := NewNode()
n.cond.L.Lock() n.cond.L.Lock()
s.queue.PushBack(n) s.queue.PushBack(n)
n.cond.Wait() n.cond.Wait()
n.cond.L.Unlock() n.cond.L.Unlock()
s.mutex.Lock() s.mutex.Lock()
fmt.Println("after waiting...", s.curConns, s.queue.Len()) fmt.Println("after waiting...", s.curConns, s.queue.Len())
} }
s.curConns += 1 s.curConns += 1
s.mutex.Unlock() s.mutex.Unlock()
fmt.Println("after retrieve") fmt.Println("after retrieve")
}*/ }*/
return s.db, nil return s.db, nil
} }
// ReleaseDB do nothing // ReleaseDB do nothing
func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
/*if s.maxConns > 0 { /*if s.maxConns > 0 {
s.mutex.Lock() s.mutex.Lock()
fmt.Println("before release", s.queue.Len()) fmt.Println("before release", s.queue.Len())
s.curConns -= 1 s.curConns -= 1
if e := s.queue.Front(); e != nil { if e := s.queue.Front(); e != nil {
n := e.Value.(*node) n := e.Value.(*node)
//n.cond.L.Lock() //n.cond.L.Lock()
n.cond.Signal() n.cond.Signal()
fmt.Println("signaled...") fmt.Println("signaled...")
s.queue.Remove(e) s.queue.Remove(e)
//n.cond.L.Unlock() //n.cond.L.Unlock()
} }
fmt.Println("after released", s.queue.Len()) fmt.Println("after released", s.queue.Len())
s.mutex.Unlock() s.mutex.Unlock()
}*/ }*/
} }
// Close closed the only db // Close closed the only db
func (p *SysConnectPool) Close(engine *Engine) error { func (p *SysConnectPool) Close(engine *Engine) error {
return p.db.Close() return p.db.Close()
} }
func (p *SysConnectPool) SetMaxIdleConns(conns int) { func (p *SysConnectPool) SetMaxIdleConns(conns int) {
p.db.SetMaxIdleConns(conns) p.db.SetMaxIdleConns(conns)
p.maxIdleConns = conns p.maxIdleConns = conns
} }
func (p *SysConnectPool) MaxIdleConns() int { func (p *SysConnectPool) MaxIdleConns() int {
return p.maxIdleConns return p.maxIdleConns
} }
// not implemented // not implemented
func (p *SysConnectPool) SetMaxConns(conns int) { func (p *SysConnectPool) SetMaxConns(conns int) {
p.maxConns = conns p.maxConns = conns
// if support SetMaxOpenConns, go 1.2+, then set // if support SetMaxOpenConns, go 1.2+, then set
if reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").IsValid() { if reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").IsValid() {
reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").Call([]reflect.Value{reflect.ValueOf(conns)}) reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").Call([]reflect.Value{reflect.ValueOf(conns)})
} }
//p.db.SetMaxOpenConns(conns) //p.db.SetMaxOpenConns(conns)
} }
// not implemented // not implemented
func (p *SysConnectPool) MaxConns() int { func (p *SysConnectPool) MaxConns() int {
return p.maxConns return p.maxConns
} }
// NewSimpleConnectPool new a SimpleConnectPool // NewSimpleConnectPool new a SimpleConnectPool
func NewSimpleConnectPool() IConnectPool { func NewSimpleConnectPool() IConnectPool {
return &SimpleConnectPool{releasedConnects: make([]*sql.DB, 10), return &SimpleConnectPool{releasedConnects: make([]*sql.DB, 10),
usingConnects: map[*sql.DB]time.Time{}, usingConnects: map[*sql.DB]time.Time{},
cur: -1, cur: -1,
maxWaitTimeOut: 14400, maxWaitTimeOut: 14400,
maxIdleConns: 10, maxIdleConns: 10,
mutex: &sync.Mutex{}, mutex: &sync.Mutex{},
} }
} }
// Struct SimpleConnectPool is a simple implementation for IConnectPool. // Struct SimpleConnectPool is a simple implementation for IConnectPool.
@ -205,75 +205,75 @@ func NewSimpleConnectPool() IConnectPool {
// Opening or Closing a database connection must be enter a lock. // Opening or Closing a database connection must be enter a lock.
// This implements will be improved in furture. // This implements will be improved in furture.
type SimpleConnectPool struct { type SimpleConnectPool struct {
releasedConnects []*sql.DB releasedConnects []*sql.DB
cur int cur int
usingConnects map[*sql.DB]time.Time usingConnects map[*sql.DB]time.Time
maxWaitTimeOut int maxWaitTimeOut int
mutex *sync.Mutex mutex *sync.Mutex
maxIdleConns int maxIdleConns int
} }
func (s *SimpleConnectPool) Init(engine *Engine) error { func (s *SimpleConnectPool) Init(engine *Engine) error {
return nil return nil
} }
// RetrieveDB get a connection from connection pool // RetrieveDB get a connection from connection pool
func (p *SimpleConnectPool) RetrieveDB(engine *Engine) (*sql.DB, error) { func (p *SimpleConnectPool) RetrieveDB(engine *Engine) (*sql.DB, error) {
p.mutex.Lock() p.mutex.Lock()
defer p.mutex.Unlock() defer p.mutex.Unlock()
var db *sql.DB = nil var db *sql.DB = nil
var err error = nil var err error = nil
//fmt.Printf("%x, rbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) //fmt.Printf("%x, rbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects))
if p.cur < 0 { if p.cur < 0 {
db, err = engine.OpenDB() db, err = engine.OpenDB()
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.usingConnects[db] = time.Now() p.usingConnects[db] = time.Now()
} else { } else {
db = p.releasedConnects[p.cur] db = p.releasedConnects[p.cur]
p.usingConnects[db] = time.Now() p.usingConnects[db] = time.Now()
p.releasedConnects[p.cur] = nil p.releasedConnects[p.cur] = nil
p.cur = p.cur - 1 p.cur = p.cur - 1
} }
//fmt.Printf("%x, rend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) //fmt.Printf("%x, rend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects))
return db, nil return db, nil
} }
// ReleaseDB release a db from connection pool // ReleaseDB release a db from connection pool
func (p *SimpleConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { func (p *SimpleConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
p.mutex.Lock() p.mutex.Lock()
defer p.mutex.Unlock() defer p.mutex.Unlock()
//fmt.Printf("%x, lbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) //fmt.Printf("%x, lbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects))
if p.cur >= p.maxIdleConns-1 { if p.cur >= p.maxIdleConns-1 {
db.Close() db.Close()
} else { } else {
p.cur = p.cur + 1 p.cur = p.cur + 1
p.releasedConnects[p.cur] = db p.releasedConnects[p.cur] = db
} }
delete(p.usingConnects, db) delete(p.usingConnects, db)
//fmt.Printf("%x, lend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) //fmt.Printf("%x, lend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects))
} }
// Close release all db // Close release all db
func (p *SimpleConnectPool) Close(engine *Engine) error { func (p *SimpleConnectPool) Close(engine *Engine) error {
p.mutex.Lock() p.mutex.Lock()
defer p.mutex.Unlock() defer p.mutex.Unlock()
for len(p.releasedConnects) > 0 { for len(p.releasedConnects) > 0 {
p.releasedConnects[0].Close() p.releasedConnects[0].Close()
p.releasedConnects = p.releasedConnects[1:] p.releasedConnects = p.releasedConnects[1:]
} }
return nil return nil
} }
func (p *SimpleConnectPool) SetMaxIdleConns(conns int) { func (p *SimpleConnectPool) SetMaxIdleConns(conns int) {
p.maxIdleConns = conns p.maxIdleConns = conns
} }
func (p *SimpleConnectPool) MaxIdleConns() int { func (p *SimpleConnectPool) MaxIdleConns() int {
return p.maxIdleConns return p.maxIdleConns
} }
// not implemented // not implemented
@ -282,5 +282,5 @@ func (p *SimpleConnectPool) SetMaxConns(conns int) {
// not implemented // not implemented
func (p *SimpleConnectPool) MaxConns() int { func (p *SimpleConnectPool) MaxConns() int {
return -1 return -1
} }

View File

@ -1,300 +1,305 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
) )
type postgres struct { type postgres struct {
base base
dbname string
} }
type values map[string]string type values map[string]string
func (vs values) Set(k, v string) { func (vs values) Set(k, v string) {
vs[k] = v vs[k] = v
} }
func (vs values) Get(k string) (v string) { func (vs values) Get(k string) (v string) {
return vs[k] return vs[k]
} }
func errorf(s string, args ...interface{}) { func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
} }
func parseOpts(name string, o values) { func parseOpts(name string, o values) {
if len(name) == 0 { if len(name) == 0 {
return return
} }
name = strings.TrimSpace(name) name = strings.TrimSpace(name)
ps := strings.Split(name, " ") ps := strings.Split(name, " ")
for _, p := range ps { for _, p := range ps {
kv := strings.Split(p, "=") kv := strings.Split(p, "=")
if len(kv) < 2 { if len(kv) < 2 {
errorf("invalid option: %q", p) errorf("invalid option: %q", p)
} }
o.Set(kv[0], kv[1]) o.Set(kv[0], kv[1])
} }
}
type postgresParser struct {
}
func (p *postgresParser) parse(driverName, dataSourceName string) (*uri, error) {
db := &uri{dbType: POSTGRES}
o := make(values)
parseOpts(dataSourceName, o)
db.dbName = o.Get("dbname")
if db.dbName == "" {
return nil, errors.New("dbname is empty")
}
return db, nil
} }
func (db *postgres) Init(drivername, uri string) error { func (db *postgres) Init(drivername, uri string) error {
db.base.init(drivername, uri) return db.base.init(&postgresParser{}, drivername, uri)
o := make(values)
parseOpts(uri, o)
db.dbname = o.Get("dbname")
if db.dbname == "" {
return errors.New("dbname is empty")
}
return nil
} }
func (db *postgres) SqlType(c *Column) string { func (db *postgres) SqlType(c *Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case TinyInt: case TinyInt:
res = SmallInt res = SmallInt
case MediumInt, Int, Integer: case MediumInt, Int, Integer:
return Integer return Integer
case Serial, BigSerial: case Serial, BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.Nullable = false c.Nullable = false
res = t res = t
case Binary, VarBinary: case Binary, VarBinary:
return Bytea return Bytea
case DateTime: case DateTime:
res = TimeStamp res = TimeStamp
case TimeStampz: case TimeStampz:
return "timestamp with time zone" return "timestamp with time zone"
case Float: case Float:
res = Real res = Real
case TinyText, MediumText, LongText: case TinyText, MediumText, LongText:
res = Text res = Text
case Blob, TinyBlob, MediumBlob, LongBlob: case Blob, TinyBlob, MediumBlob, LongBlob:
return Bytea return Bytea
case Double: case Double:
return "DOUBLE PRECISION" return "DOUBLE PRECISION"
default: default:
if c.IsAutoIncrement { if c.IsAutoIncrement {
return Serial return Serial
} }
res = t res = t
} }
var hasLen1 bool = (c.Length > 0) var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0) var hasLen2 bool = (c.Length2 > 0)
if hasLen1 { if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.Itoa(c.Length) + ")"
} else if hasLen2 { } else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
} }
return res return res
} }
func (db *postgres) SupportInsertMany() bool { func (db *postgres) SupportInsertMany() bool {
return true return true
} }
func (db *postgres) QuoteStr() string { func (db *postgres) QuoteStr() string {
return "\"" return "\""
} }
func (db *postgres) AutoIncrStr() string { func (db *postgres) AutoIncrStr() string {
return "" return ""
} }
func (db *postgres) SupportEngine() bool { func (db *postgres) SupportEngine() bool {
return false return false
} }
func (db *postgres) SupportCharset() bool { func (db *postgres) SupportCharset() bool {
return false return false
} }
func (db *postgres) IndexOnTable() bool { func (db *postgres) IndexOnTable() bool {
return false return false
} }
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{tableName, idxName} args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes ` + return `SELECT indexname FROM pg_indexes ` +
`WHERE tablename = ? AND indexname = ?`, args `WHERE tablename = ? AND indexname = ?`, args
} }
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
} }
func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName, colName} args := []interface{}{tableName, colName}
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
" AND column_name = ?", args " AND column_name = ?", args
} }
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, record := range res { for _, record := range res {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "column_name": case "column_name":
col.Name = strings.Trim(string(content), `" `) col.Name = strings.Trim(string(content), `" `)
case "column_default": case "column_default":
if strings.HasPrefix(string(content), "nextval") { if strings.HasPrefix(string(content), "nextval") {
col.IsPrimaryKey = true col.IsPrimaryKey = true
} else { } else {
col.Default = string(content) col.Default = string(content)
} }
case "is_nullable": case "is_nullable":
if string(content) == "YES" { if string(content) == "YES" {
col.Nullable = true col.Nullable = true
} else { } else {
col.Nullable = false col.Nullable = false
} }
case "data_type": case "data_type":
ct := string(content) ct := string(content)
switch ct { switch ct {
case "character varying", "character": case "character varying", "character":
col.SQLType = SQLType{Varchar, 0, 0} col.SQLType = SQLType{Varchar, 0, 0}
case "timestamp without time zone": case "timestamp without time zone":
col.SQLType = SQLType{DateTime, 0, 0} col.SQLType = SQLType{DateTime, 0, 0}
case "timestamp with time zone": case "timestamp with time zone":
col.SQLType = SQLType{TimeStampz, 0, 0} col.SQLType = SQLType{TimeStampz, 0, 0}
case "double precision": case "double precision":
col.SQLType = SQLType{Double, 0, 0} col.SQLType = SQLType{Double, 0, 0}
case "boolean": case "boolean":
col.SQLType = SQLType{Bool, 0, 0} col.SQLType = SQLType{Bool, 0, 0}
case "time without time zone": case "time without time zone":
col.SQLType = SQLType{Time, 0, 0} col.SQLType = SQLType{Time, 0, 0}
default: default:
col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} col.SQLType = SQLType{strings.ToUpper(ct), 0, 0}
} }
if _, ok := sqlTypes[col.SQLType.Name]; !ok { if _, ok := sqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct))
} }
case "character_maximum_length": case "character_maximum_length":
i, err := strconv.Atoi(string(content)) i, err := strconv.Atoi(string(content))
if err != nil { if err != nil {
return nil, nil, errors.New("retrieve length error") return nil, nil, errors.New("retrieve length error")
} }
col.Length = i col.Length = i
case "numeric_precision": case "numeric_precision":
case "numeric_precision_radix": case "numeric_precision_radix":
} }
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *postgres) GetTables() ([]*Table, error) { func (db *postgres) GetTables() ([]*Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT tablename FROM pg_tables where schemaname = 'public'" s := "SELECT tablename FROM pg_tables where schemaname = 'public'"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for _, record := range res {
table := new(Table) table := new(Table)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "tablename": case "tablename":
table.Name = string(content) table.Name = string(content)
} }
} }
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
} }
func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for _, record := range res {
var indexType int var indexType int
var indexName string var indexName string
var colNames []string var colNames []string
for name, content := range record { for name, content := range record {
switch name { switch name {
case "indexname": case "indexname":
indexName = strings.Trim(string(content), `" `) indexName = strings.Trim(string(content), `" `)
case "indexdef": case "indexdef":
c := string(content) c := string(content)
if strings.HasPrefix(c, "CREATE UNIQUE INDEX") { if strings.HasPrefix(c, "CREATE UNIQUE INDEX") {
indexType = UniqueType indexType = UniqueType
} else { } else {
indexType = IndexType indexType = IndexType
} }
cs := strings.Split(c, "(") cs := strings.Split(c, "(")
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
} }
} }
if strings.HasSuffix(indexName, "_pkey") { if strings.HasSuffix(indexName, "_pkey") {
continue continue
} }
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
newIdxName := indexName[5+len(tableName) : len(indexName)] newIdxName := indexName[5+len(tableName) : len(indexName)]
if newIdxName != "" { if newIdxName != "" {
indexName = newIdxName indexName = newIdxName
} }
} }
index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
for _, colName := range colNames { for _, colName := range colNames {
index.Cols = append(index.Cols, strings.Trim(colName, `" `)) index.Cols = append(index.Cols, strings.Trim(colName, `" `))
} }
indexes[index.Name] = index indexes[index.Name] = index
} }
return indexes, nil return indexes, nil
} }

View File

@ -2,17 +2,17 @@ package xorm
// Executed before an object is initially persisted to the database // Executed before an object is initially persisted to the database
type BeforeInsertProcessor interface { type BeforeInsertProcessor interface {
BeforeInsert() BeforeInsert()
} }
// Executed before an object is updated // Executed before an object is updated
type BeforeUpdateProcessor interface { type BeforeUpdateProcessor interface {
BeforeUpdate() BeforeUpdate()
} }
// Executed before an object is deleted // Executed before an object is deleted
type BeforeDeleteProcessor interface { type BeforeDeleteProcessor interface {
BeforeDelete() BeforeDelete()
} }
// !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations // !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations
@ -24,16 +24,15 @@ type BeforeDeleteProcessor interface {
// Executed after an object is persisted to the database // Executed after an object is persisted to the database
type AfterInsertProcessor interface { type AfterInsertProcessor interface {
AfterInsert() AfterInsert()
} }
// Executed after an object has been updated // Executed after an object has been updated
type AfterUpdateProcessor interface { type AfterUpdateProcessor interface {
AfterUpdate() AfterUpdate()
} }
// Executed after an object has been deleted // Executed after an object has been deleted
type AfterDeleteProcessor interface { type AfterDeleteProcessor interface {
AfterDelete() AfterDelete()
} }

4523
session.go

File diff suppressed because it is too large Load Diff

View File

@ -1,223 +1,229 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"strings" "strings"
) )
type sqlite3 struct { type sqlite3 struct {
base base
}
type sqlite3Parser struct {
}
func (p *sqlite3Parser) parse(driverName, dataSourceName string) (*uri, error) {
return &uri{dbType: SQLITE, dbName: dataSourceName}, nil
} }
func (db *sqlite3) Init(drivername, dataSourceName string) error { func (db *sqlite3) Init(drivername, dataSourceName string) error {
db.base.init(drivername, dataSourceName) return db.base.init(&sqlite3Parser{}, drivername, dataSourceName)
return nil
} }
func (db *sqlite3) SqlType(c *Column) string { func (db *sqlite3) SqlType(c *Column) string {
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case Date, DateTime, TimeStamp, Time: case Date, DateTime, TimeStamp, Time:
return Numeric return Numeric
case TimeStampz: case TimeStampz:
return Text return Text
case Char, Varchar, TinyText, Text, MediumText, LongText: case Char, Varchar, TinyText, Text, MediumText, LongText:
return Text return Text
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool: case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool:
return Integer return Integer
case Float, Double, Real: case Float, Double, Real:
return Real return Real
case Decimal, Numeric: case Decimal, Numeric:
return Numeric return Numeric
case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary: case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary:
return Blob return Blob
case Serial, BigSerial: case Serial, BigSerial:
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.Nullable = false c.Nullable = false
return Integer return Integer
default: default:
return t return t
} }
} }
func (db *sqlite3) SupportInsertMany() bool { func (db *sqlite3) SupportInsertMany() bool {
return true return true
} }
func (db *sqlite3) QuoteStr() string { func (db *sqlite3) QuoteStr() string {
return "`" return "`"
} }
func (db *sqlite3) AutoIncrStr() string { func (db *sqlite3) AutoIncrStr() string {
return "AUTOINCREMENT" return "AUTOINCREMENT"
} }
func (db *sqlite3) SupportEngine() bool { func (db *sqlite3) SupportEngine() bool {
return false return false
} }
func (db *sqlite3) SupportCharset() bool { func (db *sqlite3) SupportCharset() bool {
return false return false
} }
func (db *sqlite3) IndexOnTable() bool { func (db *sqlite3) IndexOnTable() bool {
return false return false
} }
func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{idxName} args := []interface{}{idxName}
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
} }
func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) { func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
} }
func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))"
return sql, args return sql, args
} }
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
var sql string var sql string
for _, record := range res { for _, record := range res {
for name, content := range record { for name, content := range record {
if name == "sql" { if name == "sql" {
sql = string(content) sql = string(content)
} }
} }
} }
nStart := strings.Index(sql, "(") nStart := strings.Index(sql, "(")
nEnd := strings.Index(sql, ")") nEnd := strings.Index(sql, ")")
colCreates := strings.Split(sql[nStart+1:nEnd], ",") colCreates := strings.Split(sql[nStart+1:nEnd], ",")
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, colStr := range colCreates { for _, colStr := range colCreates {
fields := strings.Fields(strings.TrimSpace(colStr)) fields := strings.Fields(strings.TrimSpace(colStr))
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
col.Nullable = true col.Nullable = true
for idx, field := range fields { for idx, field := range fields {
if idx == 0 { if idx == 0 {
col.Name = strings.Trim(field, "`[] ") col.Name = strings.Trim(field, "`[] ")
continue continue
} else if idx == 1 { } else if idx == 1 {
col.SQLType = SQLType{field, 0, 0} col.SQLType = SQLType{field, 0, 0}
} }
switch field { switch field {
case "PRIMARY": case "PRIMARY":
col.IsPrimaryKey = true col.IsPrimaryKey = true
case "AUTOINCREMENT": case "AUTOINCREMENT":
col.IsAutoIncrement = true col.IsAutoIncrement = true
case "NULL": case "NULL":
if fields[idx-1] == "NOT" { if fields[idx-1] == "NOT" {
col.Nullable = false col.Nullable = false
} else { } else {
col.Nullable = true col.Nullable = true
} }
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *sqlite3) GetTables() ([]*Table, error) { func (db *sqlite3) GetTables() ([]*Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT name FROM sqlite_master WHERE type='table'" s := "SELECT name FROM sqlite_master WHERE type='table'"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for _, record := range res {
table := new(Table) table := new(Table)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "name": case "name":
table.Name = string(content) table.Name = string(content)
} }
} }
if table.Name == "sqlite_sequence" { if table.Name == "sqlite_sequence" {
continue continue
} }
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
} }
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for _, record := range res {
var sql string var sql string
index := new(Index) index := new(Index)
for name, content := range record { for name, content := range record {
if name == "sql" { if name == "sql" {
sql = string(content) sql = string(content)
} }
} }
nNStart := strings.Index(sql, "INDEX") nNStart := strings.Index(sql, "INDEX")
nNEnd := strings.Index(sql, "ON") nNEnd := strings.Index(sql, "ON")
indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []")
//fmt.Println(indexName) //fmt.Println(indexName)
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
index.Name = indexName[5+len(tableName) : len(indexName)] index.Name = indexName[5+len(tableName) : len(indexName)]
} else { } else {
index.Name = indexName index.Name = indexName
} }
if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") {
index.Type = UniqueType index.Type = UniqueType
} else { } else {
index.Type = IndexType index.Type = IndexType
} }
nStart := strings.Index(sql, "(") nStart := strings.Index(sql, "(")
nEnd := strings.Index(sql, ")") nEnd := strings.Index(sql, ")")
colIndexes := strings.Split(sql[nStart+1:nEnd], ",") colIndexes := strings.Split(sql[nStart+1:nEnd], ",")
index.Cols = make([]string, 0) index.Cols = make([]string, 0)
for _, col := range colIndexes { for _, col := range colIndexes {
index.Cols = append(index.Cols, strings.Trim(col, "` []")) index.Cols = append(index.Cols, strings.Trim(col, "` []"))
} }
indexes[index.Name] = index indexes[index.Name] = index
} }
return indexes, nil return indexes, nil
} }

File diff suppressed because it is too large Load Diff

600
table.go
View File

@ -1,335 +1,335 @@
package xorm package xorm
import ( import (
"reflect" "reflect"
"sort" "sort"
"strings" "strings"
"time" "time"
) )
// xorm SQL types // xorm SQL types
type SQLType struct { type SQLType struct {
Name string Name string
DefaultLength int DefaultLength int
DefaultLength2 int DefaultLength2 int
} }
func (s *SQLType) IsText() bool { func (s *SQLType) IsText() bool {
return s.Name == Char || s.Name == Varchar || s.Name == TinyText || return s.Name == Char || s.Name == Varchar || s.Name == TinyText ||
s.Name == Text || s.Name == MediumText || s.Name == LongText s.Name == Text || s.Name == MediumText || s.Name == LongText
} }
func (s *SQLType) IsBlob() bool { func (s *SQLType) IsBlob() bool {
return (s.Name == TinyBlob) || (s.Name == Blob) || return (s.Name == TinyBlob) || (s.Name == Blob) ||
s.Name == MediumBlob || s.Name == LongBlob || s.Name == MediumBlob || s.Name == LongBlob ||
s.Name == Binary || s.Name == VarBinary || s.Name == Bytea s.Name == Binary || s.Name == VarBinary || s.Name == Bytea
} }
const () const ()
var ( var (
Bit = "BIT" Bit = "BIT"
TinyInt = "TINYINT" TinyInt = "TINYINT"
SmallInt = "SMALLINT" SmallInt = "SMALLINT"
MediumInt = "MEDIUMINT" MediumInt = "MEDIUMINT"
Int = "INT" Int = "INT"
Integer = "INTEGER" Integer = "INTEGER"
BigInt = "BIGINT" BigInt = "BIGINT"
Char = "CHAR" Char = "CHAR"
Varchar = "VARCHAR" Varchar = "VARCHAR"
TinyText = "TINYTEXT" TinyText = "TINYTEXT"
Text = "TEXT" Text = "TEXT"
MediumText = "MEDIUMTEXT" MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT" LongText = "LONGTEXT"
Binary = "BINARY"
VarBinary = "VARBINARY"
Date = "DATE" Date = "DATE"
DateTime = "DATETIME" DateTime = "DATETIME"
Time = "TIME" Time = "TIME"
TimeStamp = "TIMESTAMP" TimeStamp = "TIMESTAMP"
TimeStampz = "TIMESTAMPZ" TimeStampz = "TIMESTAMPZ"
Decimal = "DECIMAL" Decimal = "DECIMAL"
Numeric = "NUMERIC" Numeric = "NUMERIC"
Real = "REAL" Real = "REAL"
Float = "FLOAT" Float = "FLOAT"
Double = "DOUBLE" Double = "DOUBLE"
TinyBlob = "TINYBLOB" Binary = "BINARY"
Blob = "BLOB" VarBinary = "VARBINARY"
MediumBlob = "MEDIUMBLOB" TinyBlob = "TINYBLOB"
LongBlob = "LONGBLOB" Blob = "BLOB"
Bytea = "BYTEA" MediumBlob = "MEDIUMBLOB"
LongBlob = "LONGBLOB"
Bytea = "BYTEA"
Bool = "BOOL" Bool = "BOOL"
Serial = "SERIAL" Serial = "SERIAL"
BigSerial = "BIGSERIAL" BigSerial = "BIGSERIAL"
sqlTypes = map[string]bool{ sqlTypes = map[string]bool{
Bit: true, Bit: true,
TinyInt: true, TinyInt: true,
SmallInt: true, SmallInt: true,
MediumInt: true, MediumInt: true,
Int: true, Int: true,
Integer: true, Integer: true,
BigInt: true, BigInt: true,
Char: true, Char: true,
Varchar: true, Varchar: true,
TinyText: true, TinyText: true,
Text: true, Text: true,
MediumText: true, MediumText: true,
LongText: true, LongText: true,
Binary: true,
VarBinary: true,
Date: true, Date: true,
DateTime: true, DateTime: true,
Time: true, Time: true,
TimeStamp: true, TimeStamp: true,
TimeStampz: true, TimeStampz: true,
Decimal: true, Decimal: true,
Numeric: true, Numeric: true,
Real: true, Binary: true,
Float: true, VarBinary: true,
Double: true, Real: true,
TinyBlob: true, Float: true,
Blob: true, Double: true,
MediumBlob: true, TinyBlob: true,
LongBlob: true, Blob: true,
Bytea: true, MediumBlob: true,
LongBlob: true,
Bytea: true,
Bool: true, Bool: true,
Serial: true, Serial: true,
BigSerial: true, BigSerial: true,
} }
intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"} intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"}
uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"}
) )
var b byte var b byte
var tm time.Time var tm time.Time
func Type2SQLType(t reflect.Type) (st SQLType) { func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k { switch k := t.Kind(); k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
st = SQLType{Int, 0, 0} st = SQLType{Int, 0, 0}
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
st = SQLType{BigInt, 0, 0} st = SQLType{BigInt, 0, 0}
case reflect.Float32: case reflect.Float32:
st = SQLType{Float, 0, 0} st = SQLType{Float, 0, 0}
case reflect.Float64: case reflect.Float64:
st = SQLType{Double, 0, 0} st = SQLType{Double, 0, 0}
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
st = SQLType{Varchar, 64, 0} st = SQLType{Varchar, 64, 0}
case reflect.Array, reflect.Slice, reflect.Map: case reflect.Array, reflect.Slice, reflect.Map:
if t.Elem() == reflect.TypeOf(b) { if t.Elem() == reflect.TypeOf(b) {
st = SQLType{Blob, 0, 0} st = SQLType{Blob, 0, 0}
} else { } else {
st = SQLType{Text, 0, 0} st = SQLType{Text, 0, 0}
} }
case reflect.Bool: case reflect.Bool:
st = SQLType{Bool, 0, 0} st = SQLType{Bool, 0, 0}
case reflect.String: case reflect.String:
st = SQLType{Varchar, 255, 0} st = SQLType{Varchar, 255, 0}
case reflect.Struct: case reflect.Struct:
if t == reflect.TypeOf(tm) { if t == reflect.TypeOf(tm) {
st = SQLType{DateTime, 0, 0} st = SQLType{DateTime, 0, 0}
} else { } else {
st = SQLType{Text, 0, 0} st = SQLType{Text, 0, 0}
} }
case reflect.Ptr: case reflect.Ptr:
st, _ = ptrType2SQLType(t) st, _ = ptrType2SQLType(t)
default: default:
st = SQLType{Text, 0, 0} st = SQLType{Text, 0, 0}
} }
return return
} }
func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) { func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) {
has = true has = true
switch t { switch t {
case reflect.TypeOf(&c_EMPTY_STRING): case reflect.TypeOf(&c_EMPTY_STRING):
st = SQLType{Varchar, 255, 0} st = SQLType{Varchar, 255, 0}
return return
case reflect.TypeOf(&c_BOOL_DEFAULT): case reflect.TypeOf(&c_BOOL_DEFAULT):
st = SQLType{Bool, 0, 0} st = SQLType{Bool, 0, 0}
case reflect.TypeOf(&c_COMPLEX64_DEFAULT), reflect.TypeOf(&c_COMPLEX128_DEFAULT): case reflect.TypeOf(&c_COMPLEX64_DEFAULT), reflect.TypeOf(&c_COMPLEX128_DEFAULT):
st = SQLType{Varchar, 64, 0} st = SQLType{Varchar, 64, 0}
case reflect.TypeOf(&c_FLOAT32_DEFAULT): case reflect.TypeOf(&c_FLOAT32_DEFAULT):
st = SQLType{Float, 0, 0} st = SQLType{Float, 0, 0}
case reflect.TypeOf(&c_FLOAT64_DEFAULT): case reflect.TypeOf(&c_FLOAT64_DEFAULT):
st = SQLType{Double, 0, 0} st = SQLType{Double, 0, 0}
case reflect.TypeOf(&c_INT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT): case reflect.TypeOf(&c_INT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT):
st = SQLType{BigInt, 0, 0} st = SQLType{BigInt, 0, 0}
case reflect.TypeOf(&c_TIME_DEFAULT): case reflect.TypeOf(&c_TIME_DEFAULT):
st = SQLType{DateTime, 0, 0} st = SQLType{DateTime, 0, 0}
case reflect.TypeOf(&c_INT_DEFAULT), reflect.TypeOf(&c_INT32_DEFAULT), reflect.TypeOf(&c_INT8_DEFAULT), reflect.TypeOf(&c_INT16_DEFAULT), reflect.TypeOf(&c_UINT_DEFAULT), reflect.TypeOf(&c_UINT32_DEFAULT), reflect.TypeOf(&c_UINT8_DEFAULT), reflect.TypeOf(&c_UINT16_DEFAULT): case reflect.TypeOf(&c_INT_DEFAULT), reflect.TypeOf(&c_INT32_DEFAULT), reflect.TypeOf(&c_INT8_DEFAULT), reflect.TypeOf(&c_INT16_DEFAULT), reflect.TypeOf(&c_UINT_DEFAULT), reflect.TypeOf(&c_UINT32_DEFAULT), reflect.TypeOf(&c_UINT8_DEFAULT), reflect.TypeOf(&c_UINT16_DEFAULT):
st = SQLType{Int, 0, 0} st = SQLType{Int, 0, 0}
default: default:
has = false has = false
} }
return return
} }
// default sql type change to go types // default sql type change to go types
func SQLType2Type(st SQLType) reflect.Type { func SQLType2Type(st SQLType) reflect.Type {
name := strings.ToUpper(st.Name) name := strings.ToUpper(st.Name)
switch name { switch name {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial:
return reflect.TypeOf(1) return reflect.TypeOf(1)
case BigInt, BigSerial: case BigInt, BigSerial:
return reflect.TypeOf(int64(1)) return reflect.TypeOf(int64(1))
case Float, Real: case Float, Real:
return reflect.TypeOf(float32(1)) return reflect.TypeOf(float32(1))
case Double: case Double:
return reflect.TypeOf(float64(1)) return reflect.TypeOf(float64(1))
case Char, Varchar, TinyText, Text, MediumText, LongText: case Char, Varchar, TinyText, Text, MediumText, LongText:
return reflect.TypeOf("") return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary: case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary:
return reflect.TypeOf([]byte{}) return reflect.TypeOf([]byte{})
case Bool: case Bool:
return reflect.TypeOf(true) return reflect.TypeOf(true)
case DateTime, Date, Time, TimeStamp, TimeStampz: case DateTime, Date, Time, TimeStamp, TimeStampz:
return reflect.TypeOf(tm) return reflect.TypeOf(tm)
case Decimal, Numeric: case Decimal, Numeric:
return reflect.TypeOf("") return reflect.TypeOf("")
default: default:
return reflect.TypeOf("") return reflect.TypeOf("")
} }
} }
const ( const (
IndexType = iota + 1 IndexType = iota + 1
UniqueType UniqueType
) )
// database index // database index
type Index struct { type Index struct {
Name string Name string
Type int Type int
Cols []string Cols []string
} }
// add columns which will be composite index // add columns which will be composite index
func (index *Index) AddColumn(cols ...string) { func (index *Index) AddColumn(cols ...string) {
for _, col := range cols { for _, col := range cols {
index.Cols = append(index.Cols, col) index.Cols = append(index.Cols, col)
} }
} }
// new an index // new an index
func NewIndex(name string, indexType int) *Index { func NewIndex(name string, indexType int) *Index {
return &Index{name, indexType, make([]string, 0)} return &Index{name, indexType, make([]string, 0)}
} }
const ( const (
TWOSIDES = iota + 1 TWOSIDES = iota + 1
ONLYTODB ONLYTODB
ONLYFROMDB ONLYFROMDB
) )
// database column // database column
type Column struct { type Column struct {
Name string Name string
FieldName string FieldName string
SQLType SQLType SQLType SQLType
Length int Length int
Length2 int Length2 int
Nullable bool Nullable bool
Default string Default string
Indexes map[string]bool Indexes map[string]bool
IsPrimaryKey bool IsPrimaryKey bool
IsAutoIncrement bool IsAutoIncrement bool
MapType int MapType int
IsCreated bool IsCreated bool
IsUpdated bool IsUpdated bool
IsCascade bool IsCascade bool
IsVersion bool IsVersion bool
} }
// generate column description string according dialect // generate column description string according dialect
func (col *Column) String(d dialect) string { func (col *Column) String(d dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
sql += d.SqlType(col) + " " sql += d.SqlType(col) + " "
if col.IsPrimaryKey { if col.IsPrimaryKey {
sql += "PRIMARY KEY " sql += "PRIMARY KEY "
if col.IsAutoIncrement { if col.IsAutoIncrement {
sql += d.AutoIncrStr() + " " sql += d.AutoIncrStr() + " "
} }
} }
if col.Nullable { if col.Nullable {
sql += "NULL " sql += "NULL "
} else { } else {
sql += "NOT NULL " sql += "NOT NULL "
} }
if col.Default != "" { if col.Default != "" {
sql += "DEFAULT " + col.Default + " " sql += "DEFAULT " + col.Default + " "
} }
return sql return sql
} }
func (col *Column) stringNoPk(d dialect) string { func (col *Column) stringNoPk(d dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
sql += d.SqlType(col) + " " sql += d.SqlType(col) + " "
if col.Nullable { if col.Nullable {
sql += "NULL " sql += "NULL "
} else { } else {
sql += "NOT NULL " sql += "NOT NULL "
} }
if col.Default != "" { if col.Default != "" {
sql += "DEFAULT " + col.Default + " " sql += "DEFAULT " + col.Default + " "
} }
return sql return sql
} }
// return col's filed of struct's value // return col's filed of struct's value
func (col *Column) ValueOf(bean interface{}) reflect.Value { func (col *Column) ValueOf(bean interface{}) reflect.Value {
var fieldValue reflect.Value var fieldValue reflect.Value
if strings.Contains(col.FieldName, ".") { if strings.Contains(col.FieldName, ".") {
fields := strings.Split(col.FieldName, ".") fields := strings.Split(col.FieldName, ".")
if len(fields) > 2 { if len(fields) > 2 {
return reflect.ValueOf(nil) return reflect.ValueOf(nil)
} }
fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(fields[0]) fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(fields[0])
fieldValue = fieldValue.FieldByName(fields[1]) fieldValue = fieldValue.FieldByName(fields[1])
} else { } else {
fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
} }
return fieldValue return fieldValue
} }
// database table // database table
type Table struct { type Table struct {
Name string Name string
Type reflect.Type Type reflect.Type
ColumnsSeq []string ColumnsSeq []string
Columns map[string]*Column Columns map[string]*Column
Indexes map[string]*Index Indexes map[string]*Index
PrimaryKey string PrimaryKey string
Created map[string]bool Created map[string]bool
Updated string Updated string
Version string Version string
Cacher Cacher Cacher Cacher
} }
/* /*
@ -344,90 +344,90 @@ func NewTable(name string, t reflect.Type) *Table {
// if has primary key, return column // if has primary key, return column
func (table *Table) PKColumn() *Column { func (table *Table) PKColumn() *Column {
return table.Columns[table.PrimaryKey] return table.Columns[table.PrimaryKey]
} }
func (table *Table) VersionColumn() *Column { func (table *Table) VersionColumn() *Column {
return table.Columns[table.Version] return table.Columns[table.Version]
} }
// add a column to table // add a column to table
func (table *Table) AddColumn(col *Column) { func (table *Table) AddColumn(col *Column) {
table.ColumnsSeq = append(table.ColumnsSeq, col.Name) table.ColumnsSeq = append(table.ColumnsSeq, col.Name)
table.Columns[col.Name] = col table.Columns[col.Name] = col
if col.IsPrimaryKey { if col.IsPrimaryKey {
table.PrimaryKey = col.Name table.PrimaryKey = col.Name
} }
if col.IsCreated { if col.IsCreated {
table.Created[col.Name] = true table.Created[col.Name] = true
} }
if col.IsUpdated { if col.IsUpdated {
table.Updated = col.Name table.Updated = col.Name
} }
if col.IsVersion { if col.IsVersion {
table.Version = col.Name table.Version = col.Name
} }
} }
// add an index or an unique to table // add an index or an unique to table
func (table *Table) AddIndex(index *Index) { func (table *Table) AddIndex(index *Index) {
table.Indexes[index.Name] = index table.Indexes[index.Name] = index
} }
func (table *Table) genCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) { func (table *Table) genCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {
colNames := make([]string, 0) colNames := make([]string, 0)
args := make([]interface{}, 0) args := make([]interface{}, 0)
for _, col := range table.Columns { for _, col := range table.Columns {
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if _, ok := session.Statement.columnMap[col.Name]; !ok { if _, ok := session.Statement.columnMap[col.Name]; !ok {
continue continue
} }
} }
if col.MapType == ONLYFROMDB { if col.MapType == ONLYFROMDB {
continue continue
} }
fieldValue := col.ValueOf(bean) fieldValue := col.ValueOf(bean)
if col.IsAutoIncrement && fieldValue.Int() == 0 { if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue continue
} }
if session.Statement.ColumnStr != "" { if session.Statement.ColumnStr != "" {
if _, ok := session.Statement.columnMap[col.Name]; !ok { if _, ok := session.Statement.columnMap[col.Name]; !ok {
continue continue
} }
} }
if session.Statement.OmitStr != "" { if session.Statement.OmitStr != "" {
if _, ok := session.Statement.columnMap[col.Name]; ok { if _, ok := session.Statement.columnMap[col.Name]; ok {
continue continue
} }
} }
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
args = append(args, time.Now()) args = append(args, time.Now())
} else if col.IsVersion && session.Statement.checkVersion { } else if col.IsVersion && session.Statement.checkVersion {
args = append(args, 1) args = append(args, 1)
} else { } else {
arg, err := session.value2Interface(col, fieldValue) arg, err := session.value2Interface(col, fieldValue)
if err != nil { if err != nil {
return colNames, args, err return colNames, args, err
} }
args = append(args, arg) args = append(args, arg)
} }
if includeQuote { if includeQuote {
colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?") colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?")
} else { } else {
colNames = append(colNames, col.Name) colNames = append(colNames, col.Name)
} }
} }
return colNames, args, nil return colNames, args, nil
} }
// Conversion is an interface. A type implements Conversion will according // Conversion is an interface. A type implements Conversion will according
// the custom method to fill into database and retrieve from database. // the custom method to fill into database and retrieve from database.
type Conversion interface { type Conversion interface {
FromDB([]byte) error FromDB([]byte) error
ToDB() ([]byte, error) ToDB() ([]byte, error)
} }

78
xorm.go
View File

@ -1,58 +1,58 @@
package xorm package xorm
import ( import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
"sync" "sync"
) )
const ( const (
version string = "0.2.2" version string = "0.2.3"
) )
func close(engine *Engine) { func close(engine *Engine) {
engine.Close() engine.Close()
} }
// new a db manager according to the parameter. Currently support four // new a db manager according to the parameter. Currently support four
// drivers // drivers
func NewEngine(driverName string, dataSourceName string) (*Engine, error) { func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine := &Engine{DriverName: driverName, engine := &Engine{DriverName: driverName,
DataSourceName: dataSourceName, Filters: make([]Filter, 0)} DataSourceName: dataSourceName, Filters: make([]Filter, 0)}
engine.SetMapper(SnakeMapper{}) engine.SetMapper(SnakeMapper{})
if driverName == SQLITE { if driverName == SQLITE {
engine.dialect = &sqlite3{} engine.dialect = &sqlite3{}
} else if driverName == MYSQL { } else if driverName == MYSQL {
engine.dialect = &mysql{} engine.dialect = &mysql{}
} else if driverName == POSTGRES { } else if driverName == POSTGRES {
engine.dialect = &postgres{} engine.dialect = &postgres{}
engine.Filters = append(engine.Filters, &PgSeqFilter{}) engine.Filters = append(engine.Filters, &PgSeqFilter{})
engine.Filters = append(engine.Filters, &QuoteFilter{}) engine.Filters = append(engine.Filters, &QuoteFilter{})
} else if driverName == MYMYSQL { } else if driverName == MYMYSQL {
engine.dialect = &mymysql{} engine.dialect = &mymysql{}
} else { } else {
return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName)) return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName))
} }
err := engine.dialect.Init(driverName, dataSourceName) err := engine.dialect.Init(driverName, dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
engine.Tables = make(map[reflect.Type]*Table) engine.Tables = make(map[reflect.Type]*Table)
engine.mutex = &sync.Mutex{} engine.mutex = &sync.Mutex{}
engine.TagIdentifier = "xorm" engine.TagIdentifier = "xorm"
engine.Filters = append(engine.Filters, &IdFilter{}) engine.Filters = append(engine.Filters, &IdFilter{})
engine.Logger = os.Stdout engine.Logger = os.Stdout
//engine.Pool = NewSimpleConnectPool() //engine.Pool = NewSimpleConnectPool()
//engine.Pool = NewNoneConnectPool() //engine.Pool = NewNoneConnectPool()
//engine.Cacher = NewLRUCacher() //engine.Cacher = NewLRUCacher()
err = engine.SetPool(NewSysConnectPool()) err = engine.SetPool(NewSysConnectPool())
runtime.SetFinalizer(engine, close) runtime.SetFinalizer(engine, close)
return engine, err return engine, err
} }