From 004b30f2fb6987d8ec9d6a31c385a592d73b7621 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 9 Jan 2014 17:55:33 +0800 Subject: [PATCH] samemapper support --- base_test.go | 71 +++++++++++++--------- mysql_test.go | 28 ++++++++- postgres_test.go | 150 +++++++++++++++++------------------------------ sqlite3_test.go | 36 ++++++++++++ statement.go | 24 +++++--- 5 files changed, 175 insertions(+), 134 deletions(-) diff --git a/base_test.go b/base_test.go index 9b66890c..cc53fe1c 100644 --- a/base_test.go +++ b/base_test.go @@ -150,7 +150,7 @@ func exec(engine *Engine, t *testing.T) { fmt.Println(res) } -func querySameMapper(engine *Engine, t *testing.T) { +func testQuerySameMapper(engine *Engine, t *testing.T) { sql := "select * from `Userinfo`" results, err := engine.Query(sql) if err != nil { @@ -274,7 +274,7 @@ type Condi map[string]interface{} func update(engine *Engine, t *testing.T) { // update by id user := Userinfo{Username: "xxx", Height: 1.2} - cnt, err := engine.Id(1).Update(&user) + cnt, err := engine.Id(4).Update(&user) if err != nil { t.Error(err) panic(err) @@ -286,8 +286,8 @@ func update(engine *Engine, t *testing.T) { return } - condi := Condi{"username": "zzz", "height": 0.0, "departname": ""} - cnt, err = engine.Table(&user).Id(1).Update(&condi) + condi := Condi{"username": "zzz", "departname": ""} + cnt, err = engine.Table(&user).Id(4).Update(&condi) if err != nil { t.Error(err) panic(err) @@ -321,7 +321,7 @@ func update(engine *Engine, t *testing.T) { func updateSameMapper(engine *Engine, t *testing.T) { // update by id user := Userinfo{Username: "xxx", Height: 1.2} - cnt, err := engine.Id(1).Update(&user) + cnt, err := engine.Id(4).Update(&user) if err != nil { t.Error(err) panic(err) @@ -333,15 +333,15 @@ func updateSameMapper(engine *Engine, t *testing.T) { return } - condi := Condi{"Username": "zzz", "Height": 0.0, "Departname": ""} - cnt, err = engine.Table(&user).Id(1).Update(&condi) + condi := Condi{"Username": "zzz", "Departname": ""} + cnt, err = engine.Table(&user).Id(4).Update(&condi) if err != nil { t.Error(err) panic(err) } if cnt != 1 { - err = errors.New("insert not returned 1") + err = errors.New("update not returned 1") t.Error(err) panic(err) return @@ -354,7 +354,7 @@ func updateSameMapper(engine *Engine, t *testing.T) { } if cnt != 1 { - err = errors.New("insert not returned 1") + err = errors.New("update not returned 1") t.Error(err) panic(err) return @@ -376,6 +376,7 @@ func testDelete(engine *Engine, t *testing.T) { } user.Uid = 0 + user.IsMan = true has, err := engine.Id(3).Get(&user) if err != nil { t.Error(err) @@ -423,7 +424,9 @@ func get(engine *Engine, t *testing.T) { panic(err) } - _, err = engine.Where("`user` = ?", "xlw").Delete(&NoIdUser{}) + userCol := engine.columnMapper.Obj2Table("User") + + _, err = engine.Where("`"+userCol+"` = ?", "xlw").Delete(&NoIdUser{}) if err != nil { t.Error(err) panic(err) @@ -442,7 +445,7 @@ func get(engine *Engine, t *testing.T) { } noIdUser := new(NoIdUser) - has, err = engine.Where("`user` = ?", "xlw").Get(noIdUser) + has, err = engine.Where("`"+userCol+"` = ?", "xlw").Get(noIdUser) if err != nil { t.Error(err) panic(err) @@ -484,7 +487,8 @@ func find(engine *Engine, t *testing.T) { } users2 := make([]Userinfo, 0) - err = engine.Sql("select * from userinfo").Find(&users2) + userinfo := engine.tableMapper.Obj2Table("Userinfo") + err = engine.Sql("select * from " + engine.Quote(userinfo)).Find(&users2) if err != nil { t.Error(err) panic(err) @@ -574,7 +578,10 @@ func in(engine *Engine, t *testing.T) { } fmt.Println(users) - err = engine.In("(id)", 1).In("(id)", 2).In("departname", "dev").Find(&users) + department := engine.columnMapper.Obj2Table("Departname") + dev := engine.columnMapper.Obj2Table("Dev") + + err = engine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users) if err != nil { t.Error(err) panic(err) @@ -3782,18 +3789,12 @@ func testAll(engine *Engine, t *testing.T) { directCreateTable(engine, t) fmt.Println("-------------- insert --------------") insert(engine, t) - fmt.Println("-------------- query --------------") - testQuery(engine, t) - fmt.Println("-------------- exec --------------") - exec(engine, t) fmt.Println("-------------- insertAutoIncr --------------") insertAutoIncr(engine, t) fmt.Println("-------------- insertMulti --------------") insertMulti(engine, t) fmt.Println("-------------- insertTwoTable --------------") insertTwoTable(engine, t) - fmt.Println("-------------- update --------------") - update(engine, t) fmt.Println("-------------- testDelete --------------") testDelete(engine, t) fmt.Println("-------------- get --------------") @@ -3816,12 +3817,6 @@ func testAll(engine *Engine, t *testing.T) { in(engine, t) fmt.Println("-------------- limit --------------") limit(engine, t) - fmt.Println("-------------- order --------------") - order(engine, t) - fmt.Println("-------------- join --------------") - join(engine, t) - fmt.Println("-------------- having --------------") - having(engine, t) } func testAll2(engine *Engine, t *testing.T) { @@ -3904,9 +3899,31 @@ func testAll3(engine *Engine, t *testing.T) { } func testAllSnakeMapper(engine *Engine, t *testing.T) { - + fmt.Println("-------------- query --------------") + testQuery(engine, t) + fmt.Println("-------------- exec --------------") + exec(engine, t) + fmt.Println("-------------- update --------------") + update(engine, t) + fmt.Println("-------------- order --------------") + order(engine, t) + fmt.Println("-------------- join --------------") + join(engine, t) + fmt.Println("-------------- having --------------") + having(engine, t) } func testAllSameMapper(engine *Engine, t *testing.T) { - + fmt.Println("-------------- query --------------") + testQuerySameMapper(engine, t) + fmt.Println("-------------- exec --------------") + execSameMapper(engine, t) + fmt.Println("-------------- update --------------") + updateSameMapper(engine, t) + fmt.Println("-------------- order --------------") + orderSameMapper(engine, t) + fmt.Println("-------------- join --------------") + joinSameMapper(engine, t) + fmt.Println("-------------- having --------------") + havingSameMapper(engine, t) } diff --git a/mysql_test.go b/mysql_test.go index e0c3deac..f26095cb 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -42,7 +42,7 @@ func TestMysqlSameMapper(t *testing.T) { return } - engine, err := NewEngine("mysql", "root:@/xorm_test3?charset=utf8") + engine, err := NewEngine("mysql", "root:@/xorm_test1?charset=utf8") defer engine.Close() if err != nil { t.Error(err) @@ -66,7 +66,7 @@ func TestMysqlWithCache(t *testing.T) { return } - engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8") + engine, err := NewEngine("mysql", "root:@/xorm_test2?charset=utf8") defer engine.Close() if err != nil { t.Error(err) @@ -82,6 +82,30 @@ func TestMysqlWithCache(t *testing.T) { testAll2(engine, t) } +func TestMysqlWithCacheSameMapper(t *testing.T) { + err := mysqlDdlImport() + if err != nil { + t.Error(err) + return + } + + engine, err := NewEngine("mysql", "root:@/xorm_test3?charset=utf8") + defer engine.Close() + if err != nil { + t.Error(err) + return + } + engine.SetMapper(SameMapper{}) + engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql + + testAll(engine, t) + testAll2(engine, t) +} + func newMysqlEngine() (*Engine, error) { return NewEngine("mysql", "root:@/xorm_test?charset=utf8") } diff --git a/postgres_test.go b/postgres_test.go index 9657cffd..32fb8b17 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -12,7 +12,22 @@ import ( var connStr string = "dbname=xorm_test sslmode=disable" func newPostgresEngine() (*Engine, error) { - return NewEngine("postgres", connStr) + orm, err := NewEngine("postgres", connStr) + if err != nil { + return nil, err + } + tables, err := orm.DBMetas() + if err != nil { + return nil, err + } + for _, table := range tables { + _, err = orm.Exec("drop table \"" + table.Name + "\"") + if err != nil { + return nil, err + } + } + + return orm, err } func newPostgresDriverDB() (*sql.DB, error) { @@ -32,6 +47,7 @@ func TestPostgres(t *testing.T) { engine.ShowDebug = showTestSql testAll(engine, t) + testAllSnakeMapper(engine, t) testAll2(engine, t) testAll3(engine, t) } @@ -50,105 +66,47 @@ func TestPostgresWithCache(t *testing.T) { engine.ShowDebug = showTestSql testAll(engine, t) + testAllSnakeMapper(engine, t) testAll2(engine, t) } -/* -func TestPostgres2(t *testing.T) { - engine, err := NewEngine("postgres", "dbname=xorm_test sslmode=disable") - if err != nil { - t.Error(err) - return - } - defer engine.Close() - engine.ShowSQL = showTestSql - engine.Mapper = SameMapper{} +func TestPostgresSameMapper(t *testing.T) { + engine, err := newPostgresEngine() + if err != nil { + t.Error(err) + return + } + defer engine.Close() + engine.SetMapper(SameMapper{}) + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql - fmt.Println("-------------- directCreateTable --------------") - directCreateTable(engine, t) - fmt.Println("-------------- mapper --------------") - mapper(engine, t) - fmt.Println("-------------- insert --------------") - insert(engine, t) - fmt.Println("-------------- querySameMapper --------------") - querySameMapper(engine, t) - fmt.Println("-------------- execSameMapper --------------") - execSameMapper(engine, t) - fmt.Println("-------------- insertAutoIncr --------------") - insertAutoIncr(engine, t) - fmt.Println("-------------- insertMulti --------------") - insertMulti(engine, t) - fmt.Println("-------------- insertTwoTable --------------") - insertTwoTable(engine, t) - fmt.Println("-------------- updateSameMapper --------------") - updateSameMapper(engine, t) - fmt.Println("-------------- testdelete --------------") - testdelete(engine, t) - fmt.Println("-------------- get --------------") - get(engine, t) - fmt.Println("-------------- cascadeGet --------------") - cascadeGet(engine, t) - fmt.Println("-------------- find --------------") - find(engine, t) - fmt.Println("-------------- find2 --------------") - find2(engine, t) - fmt.Println("-------------- findMap --------------") - findMap(engine, t) - fmt.Println("-------------- findMap2 --------------") - findMap2(engine, t) - fmt.Println("-------------- count --------------") - count(engine, t) - fmt.Println("-------------- where --------------") - where(engine, t) - fmt.Println("-------------- in --------------") - in(engine, t) - fmt.Println("-------------- limit --------------") - limit(engine, t) - fmt.Println("-------------- orderSameMapper --------------") - orderSameMapper(engine, t) - fmt.Println("-------------- joinSameMapper --------------") - joinSameMapper(engine, t) - fmt.Println("-------------- havingSameMapper --------------") - havingSameMapper(engine, t) - fmt.Println("-------------- combineTransactionSameMapper --------------") - combineTransactionSameMapper(engine, t) - fmt.Println("-------------- table --------------") - table(engine, t) - fmt.Println("-------------- createMultiTables --------------") - createMultiTables(engine, t) - fmt.Println("-------------- tableOp --------------") - tableOp(engine, t) - fmt.Println("-------------- testColsSameMapper --------------") - testColsSameMapper(engine, t) - fmt.Println("-------------- testCharst --------------") - testCharst(engine, t) - fmt.Println("-------------- testStoreEngine --------------") - testStoreEngine(engine, t) - fmt.Println("-------------- testExtends --------------") - testExtends(engine, t) - fmt.Println("-------------- testColTypes --------------") - testColTypes(engine, t) - fmt.Println("-------------- testCustomType --------------") - testCustomType(engine, t) - fmt.Println("-------------- testCreatedAndUpdated --------------") - testCreatedAndUpdated(engine, t) - fmt.Println("-------------- testIndexAndUnique --------------") - testIndexAndUnique(engine, t) - fmt.Println("-------------- testMetaInfo --------------") - testMetaInfo(engine, t) - fmt.Println("-------------- testIterate --------------") - testIterate(engine, t) - fmt.Println("-------------- testStrangeName --------------") - testStrangeName(engine, t) - fmt.Println("-------------- testVersion --------------") - testVersion(engine, t) - fmt.Println("-------------- testDistinct --------------") - testDistinct(engine, t) - fmt.Println("-------------- testUseBool --------------") - testUseBool(engine, t) - fmt.Println("-------------- transaction --------------") - transaction(engine, t) -}*/ + testAll(engine, t) + testAllSameMapper(engine, t) + testAll2(engine, t) + testAll3(engine, t) +} + +func TestPostgresWithCacheSameMapper(t *testing.T) { + engine, err := newPostgresEngine() + if err != nil { + t.Error(err) + return + } + engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) + defer engine.Close() + engine.SetMapper(SameMapper{}) + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql + + testAll(engine, t) + testAllSameMapper(engine, t) + testAll2(engine, t) +} const ( createTablePostgres = `CREATE TABLE IF NOT EXISTS "big_struct" ("id" SERIAL PRIMARY KEY NOT NULL, "name" VARCHAR(255) NULL, "title" VARCHAR(255) NULL, "age" VARCHAR(255) NULL, "alias" VARCHAR(255) NULL, "nick_name" VARCHAR(255) NULL);` diff --git a/sqlite3_test.go b/sqlite3_test.go index b55702b0..62922462 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -52,6 +52,42 @@ func TestSqlite3WithCache(t *testing.T) { testAll2(engine, t) } +func TestSqlite3SameMapper(t *testing.T) { + engine, err := newSqlite3Engine() + defer engine.Close() + if err != nil { + t.Error(err) + return + } + engine.SetMapper(SameMapper{}) + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql + + testAll(engine, t) + testAll2(engine, t) + testAll3(engine, t) +} + +func TestSqlite3WithCacheSameMapper(t *testing.T) { + engine, err := newSqlite3Engine() + defer engine.Close() + if err != nil { + t.Error(err) + return + } + engine.SetMapper(SameMapper{}) + engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql + + testAll(engine, t) + testAll2(engine, t) +} + const ( createTableSqlite3 = "CREATE TABLE IF NOT EXISTS `big_struct` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, `title` TEXT NULL, `age` TEXT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL);" dropTableSqlite3 = "DROP TABLE IF EXISTS `big_struct`;" diff --git a/statement.go b/statement.go index 49811110..8b5ff430 100644 --- a/statement.go +++ b/statement.go @@ -12,6 +12,11 @@ import ( "github.com/lunny/xorm/core" ) +type inParam struct { + colName string + args []interface{} +} + // statement save all the sql info for executing SQL type Statement struct { RefTable *core.Table @@ -43,7 +48,7 @@ type Statement struct { allUseBool bool checkVersion bool boolColumnMap map[string]bool - inColumns map[string][]interface{} + inColumns map[string]*inParam } // init @@ -72,7 +77,7 @@ func (statement *Statement) Init() { statement.allUseBool = false statement.boolColumnMap = make(map[string]bool) statement.checkVersion = true - statement.inColumns = make(map[string][]interface{}) + statement.inColumns = make(map[string]*inParam) } // add the raw sql statement @@ -456,10 +461,10 @@ func (statement *Statement) Id(id interface{}) *Statement { // Generate "Where column IN (?) " statment func (statement *Statement) In(column string, args ...interface{}) *Statement { k := strings.ToLower(column) - if params, ok := statement.inColumns[k]; ok { - statement.inColumns[k] = append(params, args...) + if _, ok := statement.inColumns[k]; ok { + statement.inColumns[k].args = append(statement.inColumns[k].args, args...) } else { - statement.inColumns[k] = args + statement.inColumns[k] = &inParam{column, args} } return statement } @@ -471,10 +476,11 @@ func (statement *Statement) genInSql() (string, []interface{}) { inStrs := make([]string, 0, len(statement.inColumns)) args := make([]interface{}, 0) - for column, params := range statement.inColumns { - inStrs = append(inStrs, fmt.Sprintf("(%v IN (%v))", statement.Engine.Quote(column), - strings.Join(makeArray("?", len(params)), ","))) - args = append(args, params...) + for _, params := range statement.inColumns { + inStrs = append(inStrs, fmt.Sprintf("(%v IN (%v))", + statement.Engine.Quote(params.colName), + strings.Join(makeArray("?", len(params.args)), ","))) + args = append(args, params.args...) } if len(statement.inColumns) == 1 {