many bugs fixed

This commit is contained in:
Lunny Xiao 2013-09-26 15:19:39 +08:00
parent 9d676ebddd
commit 48fa4c6fbc
17 changed files with 474 additions and 86 deletions

View File

@ -1 +1 @@
xorm v0.1.9 xorm v0.2

View File

@ -112,9 +112,29 @@ func exec(engine *Engine, t *testing.T) {
fmt.Println(res) fmt.Println(res)
} }
func querySameMapper(engine *Engine, t *testing.T) {
sql := "select * from `Userinfo`"
results, err := engine.Query(sql)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(results)
}
func execSameMapper(engine *Engine, t *testing.T) {
sql := "update `Userinfo` set `Username`=? where (id)=?"
res, err := engine.Exec(sql, "xiaolun", 1)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(res)
}
func insertAutoIncr(engine *Engine, t *testing.T) { func insertAutoIncr(engine *Engine, t *testing.T) {
// auto increment insert // auto increment insert
user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now(), user := Userinfo{Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(),
Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true} Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true}
_, err := engine.Insert(&user) _, err := engine.Insert(&user)
fmt.Println(user.Uid) fmt.Println(user.Uid)
@ -175,6 +195,29 @@ func update(engine *Engine, t *testing.T) {
} }
} }
func updateSameMapper(engine *Engine, t *testing.T) {
// update by id
user := Userinfo{Username: "xxx", Height: 1.2}
_, err := engine.Id(1).Update(&user)
if err != nil {
t.Error(err)
panic(err)
}
condi := Condi{"Username": "zzz", "Height": 0.0, "Departname": ""}
_, err = engine.Table(&user).Id(1).Update(&condi)
if err != nil {
t.Error(err)
panic(err)
}
_, err = engine.Update(&Userinfo{Username: "yyy"}, &user)
if err != nil {
t.Error(err)
panic(err)
}
}
func testdelete(engine *Engine, t *testing.T) { func testdelete(engine *Engine, t *testing.T) {
user := Userinfo{Uid: 1} user := Userinfo{Uid: 1}
_, err := engine.Delete(&user) _, err := engine.Delete(&user)
@ -243,12 +286,12 @@ func count(engine *Engine, t *testing.T) {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
fmt.Printf("Total %d records!!!", total) fmt.Printf("Total %d records!!!\n", total)
} }
func where(engine *Engine, t *testing.T) { func where(engine *Engine, t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
err := engine.Where("id > ?", 2).Find(&users) err := engine.Where("(id) > ?", 2).Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -258,7 +301,7 @@ func where(engine *Engine, t *testing.T) {
func in(engine *Engine, t *testing.T) { func in(engine *Engine, t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
err := engine.In("id", 1, 2, 3).Find(&users) err := engine.In("(id)", 1, 2, 3).Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -266,7 +309,7 @@ func in(engine *Engine, t *testing.T) {
fmt.Println(users) fmt.Println(users)
ids := []interface{}{1, 2, 3} ids := []interface{}{1, 2, 3}
err = engine.Where("id > ?", 2).In("id", ids...).Find(&users) err = engine.Where("(id) > ?", 2).In("(id)", ids...).Find(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -321,6 +364,43 @@ func having(engine *Engine, t *testing.T) {
fmt.Println(users) fmt.Println(users)
} }
func orderSameMapper(engine *Engine, t *testing.T) {
users := make([]Userinfo, 0)
err := engine.OrderBy("(id) desc").Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(users)
users2 := make([]Userinfo, 0)
err = engine.Asc("(id)", "Username").Desc("Height").Find(&users2)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(users2)
}
func joinSameMapper(engine *Engine, t *testing.T) {
users := make([]Userinfo, 0)
err := engine.Join("LEFT", `"Userdetail"`, `"Userinfo"."id"="Userdetail"."Id"`).Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
}
func havingSameMapper(engine *Engine, t *testing.T) {
users := make([]Userinfo, 0)
err := engine.GroupBy("Username").Having(`"Username"='xlw'`).Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(users)
}
func transaction(engine *Engine, t *testing.T) { func transaction(engine *Engine, t *testing.T) {
counter := func() { counter := func() {
total, err := engine.Count(&Userinfo{}) total, err := engine.Count(&Userinfo{})
@ -349,7 +429,7 @@ func transaction(engine *Engine, t *testing.T) {
panic(err) panic(err)
} }
user2 := Userinfo{Username: "yyy"} user2 := Userinfo{Username: "yyy"}
_, err = session.Where("uid = ?", 0).Update(&user2) _, err = session.Where("(id) = ?", 0).Update(&user2)
if err != nil { if err != nil {
session.Rollback() session.Rollback()
fmt.Println(err) fmt.Println(err)
@ -421,6 +501,55 @@ func combineTransaction(engine *Engine, t *testing.T) {
} }
} }
func combineTransactionSameMapper(engine *Engine, t *testing.T) {
counter := func() {
total, err := engine.Count(&Userinfo{})
if err != nil {
t.Error(err)
}
fmt.Printf("----now total %v records\n", total)
}
counter()
defer counter()
session := engine.NewSession()
defer session.Close()
err := session.Begin()
if err != nil {
t.Error(err)
panic(err)
}
//session.IsAutoRollback = false
user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
}
user2 := Userinfo{Username: "zzz"}
_, err = session.Where("(id) = ?", 0).Update(&user2)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
}
_, err = session.Exec("delete from `Userinfo` where `Username` = ?", user2.Username)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
}
err = session.Commit()
if err != nil {
t.Error(err)
panic(err)
}
}
func table(engine *Engine, t *testing.T) { func table(engine *Engine, t *testing.T) {
err := engine.DropTables("user_user") err := engine.DropTables("user_user")
if err != nil { if err != nil {
@ -554,6 +683,41 @@ func testCols(engine *Engine, t *testing.T) {
panic(err) panic(err)
} }
fmt.Println(tmpUsers) fmt.Println(tmpUsers)
user := &Userinfo{Uid: 1, Alias: "", Height: 0}
affected, err := engine.Cols("departname, height").Id(1).Update(user)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println("===================", user, affected)
}
func testColsSameMapper(engine *Engine, t *testing.T) {
users := []Userinfo{}
err := engine.Cols("(id), Username").Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(users)
tmpUsers := []tempUser{}
err = engine.Table("Userinfo").Cols("(id), Username").Find(&tmpUsers)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(tmpUsers)
user := &Userinfo{Uid: 1, Alias: "", Height: 0}
affected, err := engine.Cols("Departname, Height").Update(user)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println("===================", user, affected)
} }
type tempUser2 struct { type tempUser2 struct {
@ -708,8 +872,8 @@ func testCustomType(engine *Engine, t *testing.T) {
i.UIA32 = []uint32{4, 5} i.UIA32 = []uint32{4, 5}
i.UIA64 = []uint64{6, 7, 9} i.UIA64 = []uint64{6, 7, 9}
i.UIA8 = []uint8{1, 2, 3, 4} i.UIA8 = []uint8{1, 2, 3, 4}
i.NameArray = []string{"ssss fsdf", "lllll, ss"} i.NameArray = []string{"ssss", "fsdf", "lllll, ss"}
i.MSS = map[string]string{"s": "sfds,ss ", "x": "lfjljsl"} i.MSS = map[string]string{"s": "sfds,ss", "x": "lfjljsl"}
_, err = engine.Insert(&i) _, err = engine.Insert(&i)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -804,13 +968,13 @@ func testIndexAndUnique(engine *Engine, t *testing.T) {
err := engine.DropTables(&IndexOrUnique{}) err := engine.DropTables(&IndexOrUnique{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) //panic(err)
} }
err = engine.CreateTables(&IndexOrUnique{}) err = engine.CreateTables(&IndexOrUnique{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) //panic(err)
} }
} }
@ -853,38 +1017,75 @@ func testInt32Id(engine *Engine, t *testing.T) {
} }
func testAll(engine *Engine, t *testing.T) { func testAll(engine *Engine, t *testing.T) {
fmt.Println("-------------- directCreateTable --------------")
directCreateTable(engine, t) directCreateTable(engine, t)
fmt.Println("-------------- mapper --------------")
mapper(engine, t) mapper(engine, t)
fmt.Println("-------------- insert --------------")
insert(engine, t) insert(engine, t)
fmt.Println("-------------- query --------------")
query(engine, t) query(engine, t)
fmt.Println("-------------- exec --------------")
exec(engine, t) exec(engine, t)
fmt.Println("-------------- insertAutoIncr --------------")
insertAutoIncr(engine, t) insertAutoIncr(engine, t)
fmt.Println("-------------- insertMulti --------------")
insertMulti(engine, t) insertMulti(engine, t)
fmt.Println("-------------- insertTwoTable --------------")
insertTwoTable(engine, t) insertTwoTable(engine, t)
fmt.Println("-------------- update --------------")
update(engine, t) update(engine, t)
fmt.Println("-------------- testdelete --------------")
testdelete(engine, t) testdelete(engine, t)
fmt.Println("-------------- get --------------")
get(engine, t) get(engine, t)
fmt.Println("-------------- cascadeGet --------------")
cascadeGet(engine, t) cascadeGet(engine, t)
fmt.Println("-------------- find --------------")
find(engine, t) find(engine, t)
fmt.Println("-------------- findMap --------------")
findMap(engine, t) findMap(engine, t)
fmt.Println("-------------- count --------------")
count(engine, t) count(engine, t)
fmt.Println("-------------- where --------------")
where(engine, t) where(engine, t)
fmt.Println("-------------- in --------------")
in(engine, t) in(engine, t)
fmt.Println("-------------- limit --------------")
limit(engine, t) limit(engine, t)
fmt.Println("-------------- order --------------")
order(engine, t) order(engine, t)
fmt.Println("-------------- join --------------")
join(engine, t) join(engine, t)
fmt.Println("-------------- having --------------")
having(engine, t) having(engine, t)
transaction(engine, t) }
combineTransaction(engine, t)
table(engine, t) func testAll2(engine *Engine, t *testing.T) {
createMultiTables(engine, t) fmt.Println("-------------- combineTransaction --------------")
tableOp(engine, t) combineTransaction(engine, t)
testCols(engine, t) fmt.Println("-------------- table --------------")
testCharst(engine, t) table(engine, t)
testStoreEngine(engine, t) fmt.Println("-------------- createMultiTables --------------")
testExtends(engine, t) createMultiTables(engine, t)
testColTypes(engine, t) fmt.Println("-------------- tableOp --------------")
testCustomType(engine, t) tableOp(engine, t)
testCreatedAndUpdated(engine, t) fmt.Println("-------------- testCols --------------")
testIndexAndUnique(engine, t) testCols(engine, t)
fmt.Println("-------------- testCharst --------------")
testCharst(engine, t)
fmt.Println("-------------- testStoreEngine --------------")
testStoreEngine(engine, t)
fmt.Println("-------------- testExtends --------------")
testExtends(engine, t)
fmt.Println("-------------- testColTypes --------------")
testColTypes(engine, t)
fmt.Println("-------------- testCustomType --------------")
testCustomType(engine, t)
fmt.Println("-------------- testCreatedAndUpdated --------------")
testCreatedAndUpdated(engine, t)
fmt.Println("-------------- testIndexAndUnique --------------")
testIndexAndUnique(engine, t)
fmt.Println("-------------- transaction --------------")
transaction(engine, t)
} }

View File

@ -29,14 +29,12 @@ func (s *MemoryStore) Put(key, value interface{}) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
s.store[key] = value s.store[key] = value
//fmt.Println("after put store:", s.store)
return nil return nil
} }
func (s *MemoryStore) Get(key interface{}) (interface{}, error) { func (s *MemoryStore) Get(key interface{}) (interface{}, error) {
s.mutex.Rlock() s.mutex.RLock()
defer s.mutex.UnRlock() defer s.mutex.RUnlock()
//fmt.Println("before get store:", s.store)
if v, ok := s.store[key]; ok { if v, ok := s.store[key]; ok {
return v, nil return v, nil
} }
@ -47,9 +45,7 @@ func (s *MemoryStore) Get(key interface{}) (interface{}, error) {
func (s *MemoryStore) Del(key interface{}) error { func (s *MemoryStore) Del(key interface{}) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
//fmt.Println("before del store:", s.store)
delete(s.store, key) delete(s.store, key)
//fmt.Println("after del store:", s.store)
return nil return nil
} }

View File

@ -24,6 +24,7 @@ type dialect interface {
AutoIncrStr() string AutoIncrStr() string
SupportEngine() bool SupportEngine() bool
SupportCharset() bool SupportCharset() bool
IndexOnTable() bool
} }
type Engine struct { type Engine struct {

25
helpers.go Normal file
View File

@ -0,0 +1,25 @@
package xorm
import (
"strings"
)
func IndexNoCase(s, sep string) int {
return strings.Index(strings.ToLower(s), strings.ToLower(sep))
}
func SplitNoCase(s, sep string) []string {
idx := IndexNoCase(s, sep)
if idx < 0 {
return []string{s}
}
return strings.Split(s, s[idx:idx+len(sep)])
}
func SplitNNoCase(s, sep string, n int) []string {
idx := IndexNoCase(s, sep)
if idx < 0 {
return []string{s}
}
return strings.SplitN(s, s[idx:idx+len(sep)], n)
}

View File

@ -20,4 +20,5 @@ func TestMyMysql(t *testing.T) {
engine.ShowSQL = true engine.ShowSQL = true
testAll(engine, t) testAll(engine, t)
testAll2(engine, t)
} }

View File

@ -57,3 +57,7 @@ func (db *mysql) SupportEngine() bool {
func (db *mysql) SupportCharset() bool { func (db *mysql) SupportCharset() bool {
return true return true
} }
func (db *mysql) IndexOnTable() bool {
return true
}

View File

@ -20,4 +20,5 @@ func TestMysql(t *testing.T) {
engine.ShowSQL = true engine.ShowSQL = true
testAll(engine, t) testAll(engine, t)
testAll2(engine, t)
} }

11
pool.go
View File

@ -2,7 +2,7 @@ package xorm
import ( import (
"database/sql" "database/sql"
"fmt" //"fmt"
"sync" "sync"
//"sync/atomic" //"sync/atomic"
"container/list" "container/list"
@ -118,7 +118,7 @@ func NewNode() *node {
// RetrieveDB just return the only db // RetrieveDB just return the only db
func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) {
if s.maxConns > 0 { /*if s.maxConns > 0 {
fmt.Println("before retrieve") fmt.Println("before retrieve")
s.mutex.Lock() s.mutex.Lock()
for s.curConns >= s.maxConns { for s.curConns >= s.maxConns {
@ -135,13 +135,13 @@ func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) {
s.curConns += 1 s.curConns += 1
s.mutex.Unlock() s.mutex.Unlock()
fmt.Println("after retrieve") fmt.Println("after retrieve")
} }*/
return s.db, nil return s.db, nil
} }
// ReleaseDB do nothing // ReleaseDB do nothing
func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
if s.maxConns > 0 { /*if s.maxConns > 0 {
s.mutex.Lock() s.mutex.Lock()
fmt.Println("before release", s.queue.Len()) fmt.Println("before release", s.queue.Len())
s.curConns -= 1 s.curConns -= 1
@ -156,7 +156,7 @@ func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) {
} }
fmt.Println("after released", s.queue.Len()) fmt.Println("after released", s.queue.Len())
s.mutex.Unlock() s.mutex.Unlock()
} }*/
} }
// Close closed the only db // Close closed the only db
@ -176,6 +176,7 @@ func (p *SysConnectPool) MaxIdleConns() int {
// not implemented // not implemented
func (p *SysConnectPool) SetMaxConns(conns int) { func (p *SysConnectPool) SetMaxConns(conns int) {
p.maxConns = conns p.maxConns = conns
//p.db.SetMaxOpenConns(conns)
} }
// not implemented // not implemented

View File

@ -64,3 +64,7 @@ func (db *postgres) SupportEngine() bool {
func (db *postgres) SupportCharset() bool { func (db *postgres) SupportCharset() bool {
return false return false
} }
func (db *postgres) IndexOnTable() bool {
return false
}

View File

@ -1,18 +1,100 @@
package xorm package xorm
import ( import (
"fmt"
_ "github.com/bylevel/pq" _ "github.com/bylevel/pq"
"testing" "testing"
) )
func TestPostgres(t *testing.T) { func TestPostgres(t *testing.T) {
engine, err := NewEngine("postgres", "dbname=xorm_test sslmode=disable") engine, err := NewEngine("postgres", "dbname=xorm_test sslmode=disable")
defer engine.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer engine.Close()
engine.ShowSQL = true engine.ShowSQL = true
testAll(engine, t) testAll(engine, t)
testAll2(engine, t)
}
func TestPostgres2(t *testing.T) {
engine, err := NewEngine("postgres", "dbname=xorm_test sslmode=disable")
if err != nil {
t.Error(err)
return
}
defer engine.Close()
engine.ShowSQL = true
engine.Mapper = SameMapper{}
fmt.Println("-------------- directCreateTable --------------")
directCreateTable(engine, t)
fmt.Println("-------------- mapper --------------")
mapper(engine, t)
fmt.Println("-------------- insert --------------")
insert(engine, t)
fmt.Println("-------------- querySameMapper --------------")
querySameMapper(engine, t)
fmt.Println("-------------- execSameMapper --------------")
execSameMapper(engine, t)
fmt.Println("-------------- insertAutoIncr --------------")
insertAutoIncr(engine, t)
fmt.Println("-------------- insertMulti --------------")
insertMulti(engine, t)
fmt.Println("-------------- insertTwoTable --------------")
insertTwoTable(engine, t)
fmt.Println("-------------- updateSameMapper --------------")
updateSameMapper(engine, t)
fmt.Println("-------------- testdelete --------------")
testdelete(engine, t)
fmt.Println("-------------- get --------------")
get(engine, t)
fmt.Println("-------------- cascadeGet --------------")
cascadeGet(engine, t)
fmt.Println("-------------- find --------------")
find(engine, t)
fmt.Println("-------------- findMap --------------")
findMap(engine, t)
fmt.Println("-------------- count --------------")
count(engine, t)
fmt.Println("-------------- where --------------")
where(engine, t)
fmt.Println("-------------- in --------------")
in(engine, t)
fmt.Println("-------------- limit --------------")
limit(engine, t)
fmt.Println("-------------- orderSameMapper --------------")
orderSameMapper(engine, t)
fmt.Println("-------------- joinSameMapper --------------")
joinSameMapper(engine, t)
fmt.Println("-------------- havingSameMapper --------------")
havingSameMapper(engine, t)
fmt.Println("-------------- transaction --------------")
transaction(engine, t)
fmt.Println("-------------- combineTransactionSameMapper --------------")
combineTransactionSameMapper(engine, t)
fmt.Println("-------------- table --------------")
table(engine, t)
fmt.Println("-------------- createMultiTables --------------")
createMultiTables(engine, t)
fmt.Println("-------------- tableOp --------------")
tableOp(engine, t)
fmt.Println("-------------- testColsSameMapper --------------")
testColsSameMapper(engine, t)
fmt.Println("-------------- testCharst --------------")
testCharst(engine, t)
fmt.Println("-------------- testStoreEngine --------------")
testStoreEngine(engine, t)
fmt.Println("-------------- testExtends --------------")
testExtends(engine, t)
fmt.Println("-------------- testColTypes --------------")
testColTypes(engine, t)
fmt.Println("-------------- testCustomType --------------")
testCustomType(engine, t)
fmt.Println("-------------- testCreatedAndUpdated --------------")
testCreatedAndUpdated(engine, t)
fmt.Println("-------------- testIndexAndUnique --------------")
testIndexAndUnique(engine, t)
} }

View File

@ -337,6 +337,7 @@ func (session *Session) CreateAll() error {
return nil return nil
} }
// DropTable drop a table and all indexes of the table
func (session *Session) DropTable(bean interface{}) error { func (session *Session) DropTable(bean interface{}) error {
err := session.newDb() err := session.newDb()
if err != nil { if err != nil {
@ -354,6 +355,14 @@ func (session *Session) DropTable(bean interface{}) error {
session.Statement.AltTableName = bean.(string) session.Statement.AltTableName = bean.(string)
} else if t.Kind() == reflect.Struct { } else if t.Kind() == reflect.Struct {
session.Statement.RefTable = session.Engine.AutoMap(bean) session.Statement.RefTable = session.Engine.AutoMap(bean)
sqls := session.Statement.genDelIndexSQL()
for _, sql := range sqls {
_, err = session.exec(sql)
if err != nil {
return err
}
}
} else { } else {
return errors.New("Unsupported type") return errors.New("Unsupported type")
} }
@ -1209,41 +1218,15 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) (
func (session *Session) innerInsert(bean interface{}) (int64, error) { func (session *Session) innerInsert(bean interface{}) (int64, error) {
table := session.Engine.AutoMap(bean) table := session.Engine.AutoMap(bean)
session.Statement.RefTable = table session.Statement.RefTable = table
colNames := make([]string, 0)
colPlaces := make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns { colNames, args, err := table.GenCols(session, bean, false, false)
if col.MapType == ONLYFROMDB {
continue
}
fieldValue := col.ValueOf(bean)
if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue
}
if session.Statement.ColumnStr != "" {
if _, ok := session.Statement.columnMap[col.Name]; !ok {
continue
}
}
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
args = append(args, time.Now())
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil { if err != nil {
return 0, err return 0, err
} }
args = append(args, arg)
}
colNames = append(colNames, col.Name) colPlaces := strings.Repeat("?, ", len(colNames))
colPlaces = append(colPlaces, "?") colPlaces = colPlaces[0 : len(colPlaces)-2]
}
sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v);", sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v);",
session.Engine.QuoteStr(), session.Engine.QuoteStr(),
@ -1252,7 +1235,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
session.Engine.QuoteStr(), session.Engine.QuoteStr(),
strings.Join(colNames, session.Engine.Quote(", ")), strings.Join(colNames, session.Engine.Quote(", ")),
session.Engine.QuoteStr(), session.Engine.QuoteStr(),
strings.Join(colPlaces, ", ")) colPlaces)
res, err := session.exec(sql, args...) res, err := session.exec(sql, args...)
if err != nil { if err != nil {
@ -1402,13 +1385,14 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error {
for _, id := range ids { for _, id := range ids {
if bean := cacher.GetBean(tableName, id); bean != nil { if bean := cacher.GetBean(tableName, id); bean != nil {
sqls := strings.SplitN(strings.ToLower(sql), "where", 2) sqls := SplitNNoCase(sql, "where", 2)
if len(sqls) != 2 { if len(sqls) != 2 {
return nil return ErrCacheFailed
} }
sqls = strings.SplitN(sqls[0], "set", 2)
sqls = SplitNNoCase(sqls[0], "set", 2)
if len(sqls) != 2 { if len(sqls) != 2 {
return nil return ErrCacheFailed
} }
kvs := strings.Split(strings.TrimSpace(sqls[1]), ",") kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")
for idx, kv := range kvs { for idx, kv := range kvs {
@ -1419,13 +1403,14 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error {
colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1)) colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1))
} else if strings.Contains(colName, session.Engine.QuoteStr()) { } else if strings.Contains(colName, session.Engine.QuoteStr()) {
colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1)) colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1))
} else {
session.Engine.LogDebug("[xorm:cacheUpdate] cannot find column", tableName, colName)
return ErrCacheFailed
} }
//fmt.Println("find", colName)
if col, ok := table.Columns[colName]; ok { if col, ok := table.Columns[colName]; ok {
fieldValue := col.ValueOf(bean) fieldValue := col.ValueOf(bean)
session.Engine.LogDebug("[xorm:cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) session.Engine.LogDebug("[xorm:cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
//session.bytes2Value(col, fieldValue, []byte(args[idx]))
fieldValue.Set(reflect.ValueOf(args[idx])) fieldValue.Set(reflect.ValueOf(args[idx]))
} }
} }
@ -1457,14 +1442,21 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
table = session.Engine.AutoMap(bean) table = session.Engine.AutoMap(bean)
session.Statement.RefTable = table session.Statement.RefTable = table
if session.Statement.ColumnStr == "" {
colNames, args = BuildConditions(session.Engine, table, bean) colNames, args = BuildConditions(session.Engine, table, bean)
} else {
colNames, args, err = table.GenCols(session, bean, true, true)
if err != nil {
return 0, err
}
}
if session.Statement.UseAutoTime && table.Updated != "" { if session.Statement.UseAutoTime && table.Updated != "" {
colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?") colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
args = append(args, time.Now()) args = append(args, time.Now())
} }
} else if t.Kind() == reflect.Map { } else if t.Kind() == reflect.Map {
if session.Statement.RefTable == nil { if session.Statement.RefTable == nil {
return -1, ErrTableNotFound return 0, ErrTableNotFound
} }
table = session.Statement.RefTable table = session.Statement.RefTable
colNames = make([]string, 0) colNames = make([]string, 0)
@ -1480,7 +1472,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
args = append(args, time.Now()) args = append(args, time.Now())
} }
} else { } else {
return -1, ErrParamsType return 0, ErrParamsType
} }
var condiColNames []string var condiColNames []string

View File

@ -46,3 +46,7 @@ func (db *sqlite3) SupportEngine() bool {
func (db *sqlite3) SupportCharset() bool { func (db *sqlite3) SupportCharset() bool {
return false return false
} }
func (db *sqlite3) IndexOnTable() bool {
return false
}

View File

@ -2,12 +2,12 @@ package xorm
import ( import (
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" //"os"
"testing" "testing"
) )
func TestSqlite3(t *testing.T) { func TestSqlite3(t *testing.T) {
os.Remove("./test.db") //os.Remove("./test.db")
engine, err := NewEngine("sqlite3", "./test.db") engine, err := NewEngine("sqlite3", "./test.db")
defer engine.Close() defer engine.Close()
if err != nil { if err != nil {
@ -17,4 +17,5 @@ func TestSqlite3(t *testing.T) {
engine.ShowSQL = true engine.ShowSQL = true
testAll(engine, t) testAll(engine, t)
testAll2(engine, t)
} }

View File

@ -207,10 +207,18 @@ func (statement *Statement) In(column string, args ...interface{}) {
} }
func (statement *Statement) Cols(columns ...string) { func (statement *Statement) Cols(columns ...string) {
statement.ColumnStr = strings.Join(columns, statement.Engine.Quote(", ")) newColumns := make([]string, 0)
for _, column := range columns { for _, col := range columns {
statement.columnMap[column] = true strings.Replace(col, "`", "", -1)
strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
nc := strings.TrimSpace(c)
statement.columnMap[nc] = true
newColumns = append(newColumns, nc)
} }
}
statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
} }
func (statement *Statement) Limit(limit int, start ...int) { func (statement *Statement) Limit(limit int, start ...int) {
@ -284,13 +292,32 @@ func (statement *Statement) genIndexSQL() []string {
func (statement *Statement) genUniqueSQL() []string { func (statement *Statement) genUniqueSQL() []string {
var sqls []string = make([]string, 0) var sqls []string = make([]string, 0)
for indexName, cols := range statement.RefTable.Uniques { for indexName, cols := range statement.RefTable.Uniques {
sql := fmt.Sprintf("CREATE UNIQUE INDEX UQE_%v_%v ON %v (%v);", statement.TableName(), indexName, sql := fmt.Sprintf("CREATE UNIQUE INDEX `UQE_%v_%v` ON %v (%v);", statement.TableName(), indexName,
statement.Engine.Quote(statement.TableName()), statement.Engine.Quote(strings.Join(cols, statement.Engine.Quote(",")))) statement.Engine.Quote(statement.TableName()), statement.Engine.Quote(strings.Join(cols, statement.Engine.Quote(","))))
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }
return sqls return sqls
} }
func (statement *Statement) genDelIndexSQL() []string {
var sqls []string = make([]string, 0)
for indexName, _ := range statement.RefTable.Uniques {
sql := fmt.Sprintf("DROP INDEX `UQE_%v_%v`", statement.TableName(), indexName)
if statement.Engine.Dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName()))
}
sqls = append(sqls, sql)
}
for indexName, _ := range statement.RefTable.Indexes {
sql := fmt.Sprintf("DROP INDEX IDX_%v_%v", statement.TableName(), indexName)
if statement.Engine.Dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName()))
}
sqls = append(sqls, sql)
}
return sqls
}
func (statement *Statement) genDropSQL() string { func (statement *Statement) genDropSQL() string {
sql := "DROP TABLE IF EXISTS " + statement.Engine.Quote(statement.TableName()) + ";" sql := "DROP TABLE IF EXISTS " + statement.Engine.Quote(statement.TableName()) + ";"
return sql return sql

View File

@ -2,7 +2,6 @@ package xorm
import ( import (
"reflect" "reflect"
//"strconv"
"strings" "strings"
"time" "time"
) )
@ -244,6 +243,50 @@ func (table *Table) AddColumn(col *Column) {
table.Columns[col.Name] = col table.Columns[col.Name] = col
} }
func (table *Table) GenCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {
colNames := make([]string, 0)
args := make([]interface{}, 0)
for _, col := range table.Columns {
if useCol {
if _, ok := session.Statement.columnMap[col.Name]; !ok {
continue
}
}
if col.MapType == ONLYFROMDB {
continue
}
fieldValue := col.ValueOf(bean)
if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue
}
if session.Statement.ColumnStr != "" {
if _, ok := session.Statement.columnMap[col.Name]; !ok {
continue
}
}
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
args = append(args, time.Now())
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
if includeQuote {
colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?")
} else {
colNames = append(colNames, col.Name)
}
}
return colNames, args, nil
}
type Conversion interface { type Conversion interface {
FromDB([]byte) error FromDB([]byte) error
ToDB() ([]byte, error) ToDB() ([]byte, error)

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
"runtime"
"sync" "sync"
) )
@ -12,6 +13,10 @@ const (
version string = "0.1.9" version string = "0.1.9"
) )
func close(engine *Engine) {
engine.Close()
}
// new a db manager according to the parameter. Currently support three // new a db manager according to the parameter. Currently support three
// driver // driver
func NewEngine(driverName string, dataSourceName string) (*Engine, error) { func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
@ -42,6 +47,6 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
//engine.Pool = NewNoneConnectPool() //engine.Pool = NewNoneConnectPool()
//engine.Cacher = NewLRUCacher() //engine.Cacher = NewLRUCacher()
err := engine.SetPool(NewSysConnectPool()) err := engine.SetPool(NewSysConnectPool())
runtime.SetFinalizer(engine, close)
return engine, err return engine, err
} }