From 48fa4c6fbc3f763ca22bf99c12576c25f376e2df Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 26 Sep 2013 15:19:39 +0800 Subject: [PATCH] many bugs fixed --- VERSION | 2 +- base_test.go | 247 ++++++++++++++++++++++++++++++++++++++++++----- cache.go | 8 +- engine.go | 1 + helpers.go | 25 +++++ mymysql_test.go | 1 + mysql.go | 4 + mysql_test.go | 1 + pool.go | 11 ++- postgres.go | 4 + postgres_test.go | 84 +++++++++++++++- session.go | 76 +++++++-------- sqlite3.go | 4 + sqlite3_test.go | 5 +- statement.go | 35 ++++++- table.go | 45 ++++++++- xorm.go | 7 +- 17 files changed, 474 insertions(+), 86 deletions(-) create mode 100644 helpers.go diff --git a/VERSION b/VERSION index a992b005..a7805434 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -xorm v0.1.9 +xorm v0.2 diff --git a/base_test.go b/base_test.go index 7f617b46..61ae50d2 100644 --- a/base_test.go +++ b/base_test.go @@ -112,9 +112,29 @@ func exec(engine *Engine, t *testing.T) { fmt.Println(res) } +func querySameMapper(engine *Engine, t *testing.T) { + sql := "select * from `Userinfo`" + results, err := engine.Query(sql) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(results) +} + +func execSameMapper(engine *Engine, t *testing.T) { + sql := "update `Userinfo` set `Username`=? where (id)=?" + res, err := engine.Exec(sql, "xiaolun", 1) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(res) +} + func insertAutoIncr(engine *Engine, t *testing.T) { // auto increment insert - user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now(), + user := Userinfo{Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(), Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true} _, err := engine.Insert(&user) fmt.Println(user.Uid) @@ -175,6 +195,29 @@ func update(engine *Engine, t *testing.T) { } } +func updateSameMapper(engine *Engine, t *testing.T) { + // update by id + user := Userinfo{Username: "xxx", Height: 1.2} + _, err := engine.Id(1).Update(&user) + if err != nil { + t.Error(err) + panic(err) + } + + condi := Condi{"Username": "zzz", "Height": 0.0, "Departname": ""} + _, err = engine.Table(&user).Id(1).Update(&condi) + if err != nil { + t.Error(err) + panic(err) + } + + _, err = engine.Update(&Userinfo{Username: "yyy"}, &user) + if err != nil { + t.Error(err) + panic(err) + } +} + func testdelete(engine *Engine, t *testing.T) { user := Userinfo{Uid: 1} _, err := engine.Delete(&user) @@ -243,12 +286,12 @@ func count(engine *Engine, t *testing.T) { t.Error(err) panic(err) } - fmt.Printf("Total %d records!!!", total) + fmt.Printf("Total %d records!!!\n", total) } func where(engine *Engine, t *testing.T) { users := make([]Userinfo, 0) - err := engine.Where("id > ?", 2).Find(&users) + err := engine.Where("(id) > ?", 2).Find(&users) if err != nil { t.Error(err) panic(err) @@ -258,7 +301,7 @@ func where(engine *Engine, t *testing.T) { func in(engine *Engine, t *testing.T) { users := make([]Userinfo, 0) - err := engine.In("id", 1, 2, 3).Find(&users) + err := engine.In("(id)", 1, 2, 3).Find(&users) if err != nil { t.Error(err) panic(err) @@ -266,7 +309,7 @@ func in(engine *Engine, t *testing.T) { fmt.Println(users) ids := []interface{}{1, 2, 3} - err = engine.Where("id > ?", 2).In("id", ids...).Find(&users) + err = engine.Where("(id) > ?", 2).In("(id)", ids...).Find(&users) if err != nil { t.Error(err) panic(err) @@ -321,6 +364,43 @@ func having(engine *Engine, t *testing.T) { fmt.Println(users) } +func orderSameMapper(engine *Engine, t *testing.T) { + users := make([]Userinfo, 0) + err := engine.OrderBy("(id) desc").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) + + users2 := make([]Userinfo, 0) + err = engine.Asc("(id)", "Username").Desc("Height").Find(&users2) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users2) +} + +func joinSameMapper(engine *Engine, t *testing.T) { + users := make([]Userinfo, 0) + err := engine.Join("LEFT", `"Userdetail"`, `"Userinfo"."id"="Userdetail"."Id"`).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } +} + +func havingSameMapper(engine *Engine, t *testing.T) { + users := make([]Userinfo, 0) + err := engine.GroupBy("Username").Having(`"Username"='xlw'`).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) +} + func transaction(engine *Engine, t *testing.T) { counter := func() { total, err := engine.Count(&Userinfo{}) @@ -349,7 +429,7 @@ func transaction(engine *Engine, t *testing.T) { panic(err) } user2 := Userinfo{Username: "yyy"} - _, err = session.Where("uid = ?", 0).Update(&user2) + _, err = session.Where("(id) = ?", 0).Update(&user2) if err != nil { session.Rollback() fmt.Println(err) @@ -421,6 +501,55 @@ func combineTransaction(engine *Engine, t *testing.T) { } } +func combineTransactionSameMapper(engine *Engine, t *testing.T) { + counter := func() { + total, err := engine.Count(&Userinfo{}) + if err != nil { + t.Error(err) + } + fmt.Printf("----now total %v records\n", total) + } + + counter() + defer counter() + session := engine.NewSession() + defer session.Close() + + err := session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + //session.IsAutoRollback = false + user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } + user2 := Userinfo{Username: "zzz"} + _, err = session.Where("(id) = ?", 0).Update(&user2) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } + + _, err = session.Exec("delete from `Userinfo` where `Username` = ?", user2.Username) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } + + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } +} + func table(engine *Engine, t *testing.T) { err := engine.DropTables("user_user") if err != nil { @@ -554,6 +683,41 @@ func testCols(engine *Engine, t *testing.T) { panic(err) } fmt.Println(tmpUsers) + + user := &Userinfo{Uid: 1, Alias: "", Height: 0} + affected, err := engine.Cols("departname, height").Id(1).Update(user) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println("===================", user, affected) +} + +func testColsSameMapper(engine *Engine, t *testing.T) { + users := []Userinfo{} + err := engine.Cols("(id), Username").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + + fmt.Println(users) + + tmpUsers := []tempUser{} + err = engine.Table("Userinfo").Cols("(id), Username").Find(&tmpUsers) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(tmpUsers) + + user := &Userinfo{Uid: 1, Alias: "", Height: 0} + affected, err := engine.Cols("Departname, Height").Update(user) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println("===================", user, affected) } type tempUser2 struct { @@ -708,8 +872,8 @@ func testCustomType(engine *Engine, t *testing.T) { i.UIA32 = []uint32{4, 5} i.UIA64 = []uint64{6, 7, 9} i.UIA8 = []uint8{1, 2, 3, 4} - i.NameArray = []string{"ssss fsdf", "lllll, ss"} - i.MSS = map[string]string{"s": "sfds,ss ", "x": "lfjljsl"} + i.NameArray = []string{"ssss", "fsdf", "lllll, ss"} + i.MSS = map[string]string{"s": "sfds,ss", "x": "lfjljsl"} _, err = engine.Insert(&i) if err != nil { t.Error(err) @@ -804,13 +968,13 @@ func testIndexAndUnique(engine *Engine, t *testing.T) { err := engine.DropTables(&IndexOrUnique{}) if err != nil { t.Error(err) - panic(err) + //panic(err) } err = engine.CreateTables(&IndexOrUnique{}) if err != nil { t.Error(err) - panic(err) + //panic(err) } } @@ -853,38 +1017,75 @@ func testInt32Id(engine *Engine, t *testing.T) { } func testAll(engine *Engine, t *testing.T) { + fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) + fmt.Println("-------------- mapper --------------") mapper(engine, t) + fmt.Println("-------------- insert --------------") insert(engine, t) + fmt.Println("-------------- query --------------") query(engine, t) + fmt.Println("-------------- exec --------------") exec(engine, t) + fmt.Println("-------------- insertAutoIncr --------------") insertAutoIncr(engine, t) + fmt.Println("-------------- insertMulti --------------") insertMulti(engine, t) + fmt.Println("-------------- insertTwoTable --------------") insertTwoTable(engine, t) + fmt.Println("-------------- update --------------") update(engine, t) + fmt.Println("-------------- testdelete --------------") testdelete(engine, t) + fmt.Println("-------------- get --------------") get(engine, t) + fmt.Println("-------------- cascadeGet --------------") cascadeGet(engine, t) + fmt.Println("-------------- find --------------") find(engine, t) + fmt.Println("-------------- findMap --------------") findMap(engine, t) + fmt.Println("-------------- count --------------") count(engine, t) + fmt.Println("-------------- where --------------") where(engine, t) + fmt.Println("-------------- in --------------") in(engine, t) + fmt.Println("-------------- limit --------------") limit(engine, t) + fmt.Println("-------------- order --------------") order(engine, t) + fmt.Println("-------------- join --------------") join(engine, t) + fmt.Println("-------------- having --------------") having(engine, t) - transaction(engine, t) - combineTransaction(engine, t) - table(engine, t) - createMultiTables(engine, t) - tableOp(engine, t) - testCols(engine, t) - testCharst(engine, t) - testStoreEngine(engine, t) - testExtends(engine, t) - testColTypes(engine, t) - testCustomType(engine, t) - testCreatedAndUpdated(engine, t) - testIndexAndUnique(engine, t) +} + +func testAll2(engine *Engine, t *testing.T) { + fmt.Println("-------------- combineTransaction --------------") + combineTransaction(engine, t) + fmt.Println("-------------- table --------------") + table(engine, t) + fmt.Println("-------------- createMultiTables --------------") + createMultiTables(engine, t) + fmt.Println("-------------- tableOp --------------") + tableOp(engine, t) + fmt.Println("-------------- testCols --------------") + testCols(engine, t) + fmt.Println("-------------- testCharst --------------") + testCharst(engine, t) + fmt.Println("-------------- testStoreEngine --------------") + testStoreEngine(engine, t) + fmt.Println("-------------- testExtends --------------") + testExtends(engine, t) + fmt.Println("-------------- testColTypes --------------") + testColTypes(engine, t) + fmt.Println("-------------- testCustomType --------------") + testCustomType(engine, t) + fmt.Println("-------------- testCreatedAndUpdated --------------") + testCreatedAndUpdated(engine, t) + fmt.Println("-------------- testIndexAndUnique --------------") + testIndexAndUnique(engine, t) + fmt.Println("-------------- transaction --------------") + transaction(engine, t) } diff --git a/cache.go b/cache.go index 53194841..6b8efc09 100644 --- a/cache.go +++ b/cache.go @@ -29,14 +29,12 @@ func (s *MemoryStore) Put(key, value interface{}) error { s.mutex.Lock() defer s.mutex.Unlock() s.store[key] = value - //fmt.Println("after put store:", s.store) return nil } func (s *MemoryStore) Get(key interface{}) (interface{}, error) { - s.mutex.Rlock() - defer s.mutex.UnRlock() - //fmt.Println("before get store:", s.store) + s.mutex.RLock() + defer s.mutex.RUnlock() if v, ok := s.store[key]; ok { return v, nil } @@ -47,9 +45,7 @@ func (s *MemoryStore) Get(key interface{}) (interface{}, error) { func (s *MemoryStore) Del(key interface{}) error { s.mutex.Lock() defer s.mutex.Unlock() - //fmt.Println("before del store:", s.store) delete(s.store, key) - //fmt.Println("after del store:", s.store) return nil } diff --git a/engine.go b/engine.go index f8a95079..9fea8bcd 100644 --- a/engine.go +++ b/engine.go @@ -24,6 +24,7 @@ type dialect interface { AutoIncrStr() string SupportEngine() bool SupportCharset() bool + IndexOnTable() bool } type Engine struct { diff --git a/helpers.go b/helpers.go new file mode 100644 index 00000000..eb24e341 --- /dev/null +++ b/helpers.go @@ -0,0 +1,25 @@ +package xorm + +import ( + "strings" +) + +func IndexNoCase(s, sep string) int { + return strings.Index(strings.ToLower(s), strings.ToLower(sep)) +} + +func SplitNoCase(s, sep string) []string { + idx := IndexNoCase(s, sep) + if idx < 0 { + return []string{s} + } + return strings.Split(s, s[idx:idx+len(sep)]) +} + +func SplitNNoCase(s, sep string, n int) []string { + idx := IndexNoCase(s, sep) + if idx < 0 { + return []string{s} + } + return strings.SplitN(s, s[idx:idx+len(sep)], n) +} diff --git a/mymysql_test.go b/mymysql_test.go index 8164e29e..2f179d07 100644 --- a/mymysql_test.go +++ b/mymysql_test.go @@ -20,4 +20,5 @@ func TestMyMysql(t *testing.T) { engine.ShowSQL = true testAll(engine, t) + testAll2(engine, t) } diff --git a/mysql.go b/mysql.go index b70bc4c4..fdddfd2d 100644 --- a/mysql.go +++ b/mysql.go @@ -57,3 +57,7 @@ func (db *mysql) SupportEngine() bool { func (db *mysql) SupportCharset() bool { return true } + +func (db *mysql) IndexOnTable() bool { + return true +} diff --git a/mysql_test.go b/mysql_test.go index 4b446c94..085144ac 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -20,4 +20,5 @@ func TestMysql(t *testing.T) { engine.ShowSQL = true testAll(engine, t) + testAll2(engine, t) } diff --git a/pool.go b/pool.go index ea88af11..c8d9e1fb 100644 --- a/pool.go +++ b/pool.go @@ -2,7 +2,7 @@ package xorm import ( "database/sql" - "fmt" + //"fmt" "sync" //"sync/atomic" "container/list" @@ -118,7 +118,7 @@ func NewNode() *node { // RetrieveDB just return the only db func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { - if s.maxConns > 0 { + /*if s.maxConns > 0 { fmt.Println("before retrieve") s.mutex.Lock() for s.curConns >= s.maxConns { @@ -135,13 +135,13 @@ func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { s.curConns += 1 s.mutex.Unlock() fmt.Println("after retrieve") - } + }*/ return s.db, nil } // ReleaseDB do nothing func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { - if s.maxConns > 0 { + /*if s.maxConns > 0 { s.mutex.Lock() fmt.Println("before release", s.queue.Len()) s.curConns -= 1 @@ -156,7 +156,7 @@ func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { } fmt.Println("after released", s.queue.Len()) s.mutex.Unlock() - } + }*/ } // Close closed the only db @@ -176,6 +176,7 @@ func (p *SysConnectPool) MaxIdleConns() int { // not implemented func (p *SysConnectPool) SetMaxConns(conns int) { p.maxConns = conns + //p.db.SetMaxOpenConns(conns) } // not implemented diff --git a/postgres.go b/postgres.go index 4dc6555d..1287dfc6 100644 --- a/postgres.go +++ b/postgres.go @@ -64,3 +64,7 @@ func (db *postgres) SupportEngine() bool { func (db *postgres) SupportCharset() bool { return false } + +func (db *postgres) IndexOnTable() bool { + return false +} diff --git a/postgres_test.go b/postgres_test.go index c994bb1b..7aa72d69 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -1,18 +1,100 @@ package xorm import ( + "fmt" _ "github.com/bylevel/pq" "testing" ) func TestPostgres(t *testing.T) { engine, err := NewEngine("postgres", "dbname=xorm_test sslmode=disable") - defer engine.Close() if err != nil { t.Error(err) return } + defer engine.Close() engine.ShowSQL = true testAll(engine, t) + testAll2(engine, t) +} + +func TestPostgres2(t *testing.T) { + engine, err := NewEngine("postgres", "dbname=xorm_test sslmode=disable") + if err != nil { + t.Error(err) + return + } + defer engine.Close() + engine.ShowSQL = true + engine.Mapper = SameMapper{} + + fmt.Println("-------------- directCreateTable --------------") + directCreateTable(engine, t) + fmt.Println("-------------- mapper --------------") + mapper(engine, t) + fmt.Println("-------------- insert --------------") + insert(engine, t) + fmt.Println("-------------- querySameMapper --------------") + querySameMapper(engine, t) + fmt.Println("-------------- execSameMapper --------------") + execSameMapper(engine, t) + fmt.Println("-------------- insertAutoIncr --------------") + insertAutoIncr(engine, t) + fmt.Println("-------------- insertMulti --------------") + insertMulti(engine, t) + fmt.Println("-------------- insertTwoTable --------------") + insertTwoTable(engine, t) + fmt.Println("-------------- updateSameMapper --------------") + updateSameMapper(engine, t) + fmt.Println("-------------- testdelete --------------") + testdelete(engine, t) + fmt.Println("-------------- get --------------") + get(engine, t) + fmt.Println("-------------- cascadeGet --------------") + cascadeGet(engine, t) + fmt.Println("-------------- find --------------") + find(engine, t) + fmt.Println("-------------- findMap --------------") + findMap(engine, t) + fmt.Println("-------------- count --------------") + count(engine, t) + fmt.Println("-------------- where --------------") + where(engine, t) + fmt.Println("-------------- in --------------") + in(engine, t) + fmt.Println("-------------- limit --------------") + limit(engine, t) + fmt.Println("-------------- orderSameMapper --------------") + orderSameMapper(engine, t) + fmt.Println("-------------- joinSameMapper --------------") + joinSameMapper(engine, t) + fmt.Println("-------------- havingSameMapper --------------") + havingSameMapper(engine, t) + fmt.Println("-------------- transaction --------------") + transaction(engine, t) + fmt.Println("-------------- combineTransactionSameMapper --------------") + combineTransactionSameMapper(engine, t) + fmt.Println("-------------- table --------------") + table(engine, t) + fmt.Println("-------------- createMultiTables --------------") + createMultiTables(engine, t) + fmt.Println("-------------- tableOp --------------") + tableOp(engine, t) + fmt.Println("-------------- testColsSameMapper --------------") + testColsSameMapper(engine, t) + fmt.Println("-------------- testCharst --------------") + testCharst(engine, t) + fmt.Println("-------------- testStoreEngine --------------") + testStoreEngine(engine, t) + fmt.Println("-------------- testExtends --------------") + testExtends(engine, t) + fmt.Println("-------------- testColTypes --------------") + testColTypes(engine, t) + fmt.Println("-------------- testCustomType --------------") + testCustomType(engine, t) + fmt.Println("-------------- testCreatedAndUpdated --------------") + testCreatedAndUpdated(engine, t) + fmt.Println("-------------- testIndexAndUnique --------------") + testIndexAndUnique(engine, t) } diff --git a/session.go b/session.go index 85226d48..3128d4fa 100644 --- a/session.go +++ b/session.go @@ -337,6 +337,7 @@ func (session *Session) CreateAll() error { return nil } +// DropTable drop a table and all indexes of the table func (session *Session) DropTable(bean interface{}) error { err := session.newDb() if err != nil { @@ -354,6 +355,14 @@ func (session *Session) DropTable(bean interface{}) error { session.Statement.AltTableName = bean.(string) } else if t.Kind() == reflect.Struct { session.Statement.RefTable = session.Engine.AutoMap(bean) + + sqls := session.Statement.genDelIndexSQL() + for _, sql := range sqls { + _, err = session.exec(sql) + if err != nil { + return err + } + } } else { return errors.New("Unsupported type") } @@ -1209,42 +1218,16 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( func (session *Session) innerInsert(bean interface{}) (int64, error) { table := session.Engine.AutoMap(bean) - session.Statement.RefTable = table - colNames := make([]string, 0) - colPlaces := make([]string, 0) - var args = make([]interface{}, 0) - for _, col := range table.Columns { - if col.MapType == ONLYFROMDB { - continue - } - - fieldValue := col.ValueOf(bean) - if col.IsAutoIncrement && fieldValue.Int() == 0 { - continue - } - - if session.Statement.ColumnStr != "" { - if _, ok := session.Statement.columnMap[col.Name]; !ok { - continue - } - } - - if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { - args = append(args, time.Now()) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return 0, err - } - args = append(args, arg) - } - - colNames = append(colNames, col.Name) - colPlaces = append(colPlaces, "?") + colNames, args, err := table.GenCols(session, bean, false, false) + if err != nil { + return 0, err } + colPlaces := strings.Repeat("?, ", len(colNames)) + colPlaces = colPlaces[0 : len(colPlaces)-2] + sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v);", session.Engine.QuoteStr(), session.Statement.TableName(), @@ -1252,7 +1235,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.Engine.QuoteStr(), strings.Join(colNames, session.Engine.Quote(", ")), session.Engine.QuoteStr(), - strings.Join(colPlaces, ", ")) + colPlaces) res, err := session.exec(sql, args...) if err != nil { @@ -1402,13 +1385,14 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { for _, id := range ids { if bean := cacher.GetBean(tableName, id); bean != nil { - sqls := strings.SplitN(strings.ToLower(sql), "where", 2) + sqls := SplitNNoCase(sql, "where", 2) if len(sqls) != 2 { - return nil + return ErrCacheFailed } - sqls = strings.SplitN(sqls[0], "set", 2) + + sqls = SplitNNoCase(sqls[0], "set", 2) if len(sqls) != 2 { - return nil + return ErrCacheFailed } kvs := strings.Split(strings.TrimSpace(sqls[1]), ",") for idx, kv := range kvs { @@ -1419,13 +1403,14 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1)) } else if strings.Contains(colName, session.Engine.QuoteStr()) { colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1)) + } else { + session.Engine.LogDebug("[xorm:cacheUpdate] cannot find column", tableName, colName) + return ErrCacheFailed } - //fmt.Println("find", colName) if col, ok := table.Columns[colName]; ok { fieldValue := col.ValueOf(bean) session.Engine.LogDebug("[xorm:cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) - //session.bytes2Value(col, fieldValue, []byte(args[idx])) fieldValue.Set(reflect.ValueOf(args[idx])) } } @@ -1457,14 +1442,21 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 table = session.Engine.AutoMap(bean) session.Statement.RefTable = table - colNames, args = BuildConditions(session.Engine, table, bean) + if session.Statement.ColumnStr == "" { + colNames, args = BuildConditions(session.Engine, table, bean) + } else { + colNames, args, err = table.GenCols(session, bean, true, true) + if err != nil { + return 0, err + } + } if session.Statement.UseAutoTime && table.Updated != "" { colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?") args = append(args, time.Now()) } } else if t.Kind() == reflect.Map { if session.Statement.RefTable == nil { - return -1, ErrTableNotFound + return 0, ErrTableNotFound } table = session.Statement.RefTable colNames = make([]string, 0) @@ -1480,7 +1472,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 args = append(args, time.Now()) } } else { - return -1, ErrParamsType + return 0, ErrParamsType } var condiColNames []string diff --git a/sqlite3.go b/sqlite3.go index 8f47cf03..525e8ebc 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -46,3 +46,7 @@ func (db *sqlite3) SupportEngine() bool { func (db *sqlite3) SupportCharset() bool { return false } + +func (db *sqlite3) IndexOnTable() bool { + return false +} diff --git a/sqlite3_test.go b/sqlite3_test.go index 41bdb707..f1636f87 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -2,12 +2,12 @@ package xorm import ( _ "github.com/mattn/go-sqlite3" - "os" + //"os" "testing" ) func TestSqlite3(t *testing.T) { - os.Remove("./test.db") + //os.Remove("./test.db") engine, err := NewEngine("sqlite3", "./test.db") defer engine.Close() if err != nil { @@ -17,4 +17,5 @@ func TestSqlite3(t *testing.T) { engine.ShowSQL = true testAll(engine, t) + testAll2(engine, t) } diff --git a/statement.go b/statement.go index 2343ccb6..d23b5757 100644 --- a/statement.go +++ b/statement.go @@ -207,10 +207,18 @@ func (statement *Statement) In(column string, args ...interface{}) { } func (statement *Statement) Cols(columns ...string) { - statement.ColumnStr = strings.Join(columns, statement.Engine.Quote(", ")) - for _, column := range columns { - statement.columnMap[column] = true + newColumns := make([]string, 0) + for _, col := range columns { + strings.Replace(col, "`", "", -1) + strings.Replace(col, `"`, "", -1) + ccols := strings.Split(col, ",") + for _, c := range ccols { + nc := strings.TrimSpace(c) + statement.columnMap[nc] = true + newColumns = append(newColumns, nc) + } } + statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) } func (statement *Statement) Limit(limit int, start ...int) { @@ -284,13 +292,32 @@ func (statement *Statement) genIndexSQL() []string { func (statement *Statement) genUniqueSQL() []string { var sqls []string = make([]string, 0) for indexName, cols := range statement.RefTable.Uniques { - sql := fmt.Sprintf("CREATE UNIQUE INDEX UQE_%v_%v ON %v (%v);", statement.TableName(), indexName, + sql := fmt.Sprintf("CREATE UNIQUE INDEX `UQE_%v_%v` ON %v (%v);", statement.TableName(), indexName, statement.Engine.Quote(statement.TableName()), statement.Engine.Quote(strings.Join(cols, statement.Engine.Quote(",")))) sqls = append(sqls, sql) } return sqls } +func (statement *Statement) genDelIndexSQL() []string { + var sqls []string = make([]string, 0) + for indexName, _ := range statement.RefTable.Uniques { + sql := fmt.Sprintf("DROP INDEX `UQE_%v_%v`", statement.TableName(), indexName) + if statement.Engine.Dialect.IndexOnTable() { + sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName())) + } + sqls = append(sqls, sql) + } + for indexName, _ := range statement.RefTable.Indexes { + sql := fmt.Sprintf("DROP INDEX IDX_%v_%v", statement.TableName(), indexName) + if statement.Engine.Dialect.IndexOnTable() { + sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName())) + } + sqls = append(sqls, sql) + } + return sqls +} + func (statement *Statement) genDropSQL() string { sql := "DROP TABLE IF EXISTS " + statement.Engine.Quote(statement.TableName()) + ";" return sql diff --git a/table.go b/table.go index f9316f24..de4807af 100644 --- a/table.go +++ b/table.go @@ -2,7 +2,6 @@ package xorm import ( "reflect" - //"strconv" "strings" "time" ) @@ -244,6 +243,50 @@ func (table *Table) AddColumn(col *Column) { table.Columns[col.Name] = col } +func (table *Table) GenCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) { + colNames := make([]string, 0) + args := make([]interface{}, 0) + + for _, col := range table.Columns { + if useCol { + if _, ok := session.Statement.columnMap[col.Name]; !ok { + continue + } + } + if col.MapType == ONLYFROMDB { + continue + } + + fieldValue := col.ValueOf(bean) + if col.IsAutoIncrement && fieldValue.Int() == 0 { + continue + } + + if session.Statement.ColumnStr != "" { + if _, ok := session.Statement.columnMap[col.Name]; !ok { + continue + } + } + + if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { + args = append(args, time.Now()) + } else { + arg, err := session.value2Interface(col, fieldValue) + if err != nil { + return colNames, args, err + } + args = append(args, arg) + } + + if includeQuote { + colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?") + } else { + colNames = append(colNames, col.Name) + } + } + return colNames, args, nil +} + type Conversion interface { FromDB([]byte) error ToDB() ([]byte, error) diff --git a/xorm.go b/xorm.go index 84859207..c85d8f55 100644 --- a/xorm.go +++ b/xorm.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "reflect" + "runtime" "sync" ) @@ -12,6 +13,10 @@ const ( version string = "0.1.9" ) +func close(engine *Engine) { + engine.Close() +} + // new a db manager according to the parameter. Currently support three // driver func NewEngine(driverName string, dataSourceName string) (*Engine, error) { @@ -42,6 +47,6 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { //engine.Pool = NewNoneConnectPool() //engine.Cacher = NewLRUCacher() err := engine.SetPool(NewSysConnectPool()) - + runtime.SetFinalizer(engine, close) return engine, err }