diff --git a/circle.yml b/circle.yml index 69fc7164..adfd2a16 100644 --- a/circle.yml +++ b/circle.yml @@ -17,6 +17,7 @@ database: - createdb -p 5432 -e -U postgres xorm_test1 - createdb -p 5432 -e -U postgres xorm_test2 - createdb -p 5432 -e -U postgres xorm_test3 + - psql xorm_test postgres -c "create schema xorm" test: override: @@ -30,7 +31,9 @@ test: - go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic - - gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt > coverage.txt + - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic + - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic + - gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt > coverage.txt - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh diff --git a/dialect_postgres.go b/dialect_postgres.go index f2858f19..d907c68c 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -895,6 +895,7 @@ func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { args := []interface{}{tableName} return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args } + args := []interface{}{db.Schema, tableName} return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args } @@ -912,6 +913,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { quote := db.Quote idxName := index.Name + tableName = strings.Replace(tableName, `"`, "", -1) + tableName = strings.Replace(tableName, `.`, "_", -1) + if !strings.HasPrefix(idxName, "UQE_") && !strings.HasPrefix(idxName, "IDX_") { if index.Type == core.UniqueType { @@ -920,6 +924,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) } } + if db.Uri.Schema != "" { + idxName = db.Uri.Schema + "." + idxName + } return fmt.Sprintf("DROP INDEX %v", quote(idxName)) } @@ -960,7 +967,7 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att var f string if len(db.Schema) != 0 { args = append(args, db.Schema) - f = "AND s.table_schema = $2" + f = " AND s.table_schema = $2" } s = fmt.Sprintf(s, f) @@ -1085,11 +1092,11 @@ func (db *postgres) GetTables() ([]*core.Table, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { args := []interface{}{tableName} s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") - db.LogSQL(s, args) if len(db.Schema) != 0 { args = append(args, db.Schema) s = s + " AND schemaname=$2" } + db.LogSQL(s, args) rows, err := db.DB().Query(s, args...) if err != nil { diff --git a/engine.go b/engine.go index 29415cae..876004ba 100644 --- a/engine.go +++ b/engine.go @@ -536,46 +536,6 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D return nil } -func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) { - v := rValue(beanOrTableName) - if v.Type().Kind() == reflect.String { - return beanOrTableName.(string), nil - } else if v.Type().Kind() == reflect.Struct { - return engine.tbName(v), nil - } - return "", errors.New("bean should be a struct or struct's point") -} - -func (engine *Engine) tbSchemaName(v string) string { - // Add schema name as prefix of table name. - // Only for postgres database. - if engine.dialect.DBType() == core.POSTGRES && - engine.dialect.URI().Schema != "" && - engine.dialect.URI().Schema != postgresPublicSchema && - strings.Index(v, ".") == -1 { - return engine.dialect.URI().Schema + "." + v - } - return v -} - -func (engine *Engine) tbName(v reflect.Value) string { - if tb, ok := v.Interface().(TableName); ok { - return engine.tbSchemaName(tb.TableName()) - - } - - if v.Type().Kind() == reflect.Ptr { - if tb, ok := reflect.Indirect(v).Interface().(TableName); ok { - return engine.tbSchemaName(tb.TableName()) - } - } else if v.CanAddr() { - if tb, ok := v.Addr().Interface().(TableName); ok { - return engine.tbSchemaName(tb.TableName()) - } - } - return engine.tbSchemaName(engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name())) -} - // Cascade use cascade or not func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { session := engine.NewSession() @@ -859,7 +819,7 @@ func (engine *Engine) TableInfo(bean interface{}) *Table { if err != nil { engine.logger.Error(err) } - return &Table{tb, engine.tbName(v)} + return &Table{tb, engine.TableName(bean)} } func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { @@ -895,20 +855,8 @@ var ( func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { t := v.Type() table := engine.newTable() - if tb, ok := v.Interface().(TableName); ok { - table.Name = tb.TableName() - } else { - if v.CanAddr() { - if tb, ok = v.Addr().Interface().(TableName); ok { - table.Name = tb.TableName() - } - } - if table.Name == "" { - table.Name = engine.TableMapper.Obj2Table(t.Name()) - } - } - table.Type = t + table.Name = engine.tbNameForMap(v) var idFieldColName string var hasCacheTag, hasNoCacheTag bool @@ -1186,7 +1134,7 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { if t.Kind() != reflect.Struct { return errors.New("error params") } - tableName := engine.tbName(v) + tableName := engine.TableName(bean) table, err := engine.autoMapType(v) if err != nil { return err @@ -1210,7 +1158,7 @@ func (engine *Engine) ClearCache(beans ...interface{}) error { if t.Kind() != reflect.Struct { return errors.New("error params") } - tableName := engine.tbName(v) + tableName := engine.TableName(bean) table, err := engine.autoMapType(v) if err != nil { return err @@ -1237,13 +1185,13 @@ func (engine *Engine) Sync(beans ...interface{}) error { for _, bean := range beans { v := rValue(bean) - tableName := engine.tbName(v) + tableNameNoSchema := engine.tbNameNoSchema(v.Interface()) table, err := engine.autoMapType(v) if err != nil { return err } - isExist, err := session.Table(bean).isTableExist(tableName) + isExist, err := session.Table(bean).isTableExist(tableNameNoSchema) if err != nil { return err } @@ -1269,12 +1217,12 @@ func (engine *Engine) Sync(beans ...interface{}) error { } } else { for _, col := range table.Columns() { - isExist, err := engine.dialect.IsColumnExist(tableName, col.Name) + isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name) if err != nil { return err } if !isExist { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } err = session.addColumn(col.Name) @@ -1285,35 +1233,35 @@ func (engine *Engine) Sync(beans ...interface{}) error { } for name, index := range table.Indexes { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } if index.Type == core.UniqueType { - isExist, err := session.isIndexExist2(tableName, index.Cols, true) + isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true) if err != nil { return err } if !isExist { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } - err = session.addUnique(tableName, name) + err = session.addUnique(tableNameNoSchema, name) if err != nil { return err } } } else if index.Type == core.IndexType { - isExist, err := session.isIndexExist2(tableName, index.Cols, false) + isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false) if err != nil { return err } if !isExist { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } - err = session.addIndex(tableName, name) + err = session.addIndex(tableNameNoSchema, name) if err != nil { return err } @@ -1649,6 +1597,11 @@ func (engine *Engine) SetTZDatabase(tz *time.Location) { engine.DatabaseTZ = tz } +// SetSchema sets the schema of database +func (engine *Engine) SetSchema(schema string) { + engine.dialect.URI().Schema = schema +} + // Unscoped always disable struct tag "deleted" func (engine *Engine) Unscoped() *Session { session := engine.NewSession() diff --git a/engine_table.go b/engine_table.go new file mode 100644 index 00000000..1319871f --- /dev/null +++ b/engine_table.go @@ -0,0 +1,109 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "reflect" + "strings" + + "github.com/go-xorm/core" +) + +// TableNameWithSchema will automatically add schema prefix on table name +func (engine *Engine) tbNameWithSchema(v string) string { + // Add schema name as prefix of table name. + // Only for postgres database. + if engine.dialect.DBType() == core.POSTGRES && + engine.dialect.URI().Schema != "" && + engine.dialect.URI().Schema != postgresPublicSchema && + strings.Index(v, ".") == -1 { + return engine.dialect.URI().Schema + "." + v + } + return v +} + +// TableName returns table name with schema prefix if has +func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { + tbName := engine.tbNameNoSchema(bean) + if len(includeSchema) > 0 && includeSchema[0] { + tbName = engine.tbNameWithSchema(tbName) + } + + return tbName +} + +// tbName get some table's table name +func (session *Session) tbNameNoSchema(table *core.Table) string { + if len(session.statement.AltTableName) > 0 { + return session.statement.AltTableName + } + + return table.Name +} + +func (engine *Engine) tbNameForMap(v reflect.Value) string { + t := v.Type() + if tb, ok := v.Interface().(TableName); ok { + return tb.TableName() + } + if v.CanAddr() { + if tb, ok := v.Addr().Interface().(TableName); ok { + return tb.TableName() + } + } + return engine.TableMapper.Obj2Table(t.Name()) +} + +func (engine *Engine) tbNameNoSchema(tablename interface{}) string { + switch tablename.(type) { + case []string: + t := tablename.([]string) + if len(t) > 1 { + return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) + } else if len(t) == 1 { + return engine.Quote(t[0]) + } + case []interface{}: + t := tablename.([]interface{}) + l := len(t) + var table string + if l > 0 { + f := t[0] + switch f.(type) { + case string: + table = f.(string) + case TableName: + table = f.(TableName).TableName() + default: + v := rValue(f) + t := v.Type() + if t.Kind() == reflect.Struct { + table = engine.tbNameForMap(v) + } else { + table = engine.Quote(fmt.Sprintf("%v", f)) + } + } + } + if l > 1 { + return fmt.Sprintf("%v AS %v", engine.Quote(table), + engine.Quote(fmt.Sprintf("%v", t[1]))) + } else if l == 1 { + return engine.Quote(table) + } + case TableName: + return tablename.(TableName).TableName() + case string: + return tablename.(string) + default: + v := rValue(tablename) + t := v.Type() + if t.Kind() == reflect.Struct { + return engine.tbNameForMap(v) + } + return engine.Quote(fmt.Sprintf("%v", tablename)) + } + return "" +} diff --git a/interface.go b/interface.go index 85a46a27..5ce49f48 100644 --- a/interface.go +++ b/interface.go @@ -87,6 +87,7 @@ type EngineInterface interface { SetDefaultCacher(core.Cacher) SetLogLevel(core.LogLevel) SetMapper(core.IMapper) + SetSchema(string) SetTZDatabase(tz *time.Location) SetTZLocation(tz *time.Location) ShowSQL(show ...bool) @@ -94,6 +95,7 @@ type EngineInterface interface { Sync2(...interface{}) error StoreEngine(storeEngine string) *Session TableInfo(bean interface{}) *Table + TableName(interface{}, ...bool) string UnMapType(reflect.Type) } diff --git a/rows.go b/rows.go index 31e29ae2..54ec7f37 100644 --- a/rows.go +++ b/rows.go @@ -32,7 +32,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { var args []interface{} var err error - if err = rows.session.statement.setRefValue(rValue(bean)); err != nil { + if err = rows.session.statement.setRefBean(bean); err != nil { return nil, err } @@ -94,8 +94,7 @@ func (rows *Rows) Scan(bean interface{}) error { return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) } - dataStruct := rValue(bean) - if err := rows.session.statement.setRefValue(dataStruct); err != nil { + if err := rows.session.statement.setRefBean(bean); err != nil { return err } @@ -104,6 +103,7 @@ func (rows *Rows) Scan(bean interface{}) error { return err } + dataStruct := rValue(bean) _, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable) if err != nil { return err diff --git a/rows_test.go b/rows_test.go index c48938a9..ee121c5e 100644 --- a/rows_test.go +++ b/rows_test.go @@ -54,7 +54,8 @@ func TestRows(t *testing.T) { } assert.EqualValues(t, 1, cnt) - rows2, err := testEngine.SQL("SELECT * FROM user_rows").Rows(new(UserRows)) + var tbName = testEngine.Quote(testEngine.TableName(user, true)) + rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows)) assert.NoError(t, err) defer rows2.Close() diff --git a/session.go b/session.go index 5c6cb5f9..15283624 100644 --- a/session.go +++ b/session.go @@ -828,15 +828,6 @@ func (session *Session) LastSQL() (string, []interface{}) { return session.lastSQL, session.lastSQLArgs } -// tbName get some table's table name -func (session *Session) tbNameNoSchema(table *core.Table) string { - if len(session.statement.AltTableName) > 0 { - return session.statement.AltTableName - } - - return table.Name -} - // Unscoped always disable struct tag "deleted" func (session *Session) Unscoped() *Session { session.statement.Unscoped() diff --git a/session_cond_test.go b/session_cond_test.go index a80e7d03..ae4c2f82 100644 --- a/session_cond_test.go +++ b/session_cond_test.go @@ -122,18 +122,11 @@ func TestIn(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 3, cnt) + department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`" var usrs []Userinfo - err = testEngine.Limit(3).Find(&usrs) - if err != nil { - t.Error(err) - panic(err) - } - - if len(usrs) != 3 { - err = errors.New("there are not 3 records") - t.Error(err) - panic(err) - } + err = testEngine.Where(department+" = ?", "dev").Limit(3).Find(&usrs) + assert.NoError(t, err) + assert.EqualValues(t, 3, len(usrs)) var ids []int64 var idsStr string @@ -145,35 +138,20 @@ func TestIn(t *testing.T) { users := make([]Userinfo, 0) err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) - if len(users) != 3 { - err = errors.New("in uses should be " + idsStr + " total 3") - t.Error(err) - panic(err) - } + assert.EqualValues(t, 3, len(users)) users = make([]Userinfo, 0) err = testEngine.In("(id)", ids).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) - if len(users) != 3 { - err = errors.New("in uses should be " + idsStr + " total 3") - t.Error(err) - panic(err) - } + assert.EqualValues(t, 3, len(users)) for _, user := range users { if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] { err = errors.New("in uses should be " + idsStr + " total 3") - t.Error(err) - panic(err) + assert.NoError(t, err) } } @@ -183,87 +161,41 @@ func TestIn(t *testing.T) { idsInterface = append(idsInterface, id) } - department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`" err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) - - if len(users) != 3 { - err = errors.New("in uses should be " + idsStr + " total 3") - t.Error(err) - panic(err) - } + assert.EqualValues(t, 3, len(users)) for _, user := range users { if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] { err = errors.New("in uses should be " + idsStr + " total 3") - t.Error(err) - panic(err) + assert.NoError(t, err) } } dev := testEngine.GetColumnMapper().Obj2Table("Dev") err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users) - - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update records not 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) user := new(Userinfo) has, err := testEngine.ID(ids[0]).Get(user) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get record not 1") - t.Error(err) - panic(err) - } - if user.Departname != "dev-" { - err = errors.New("update not success") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "dev-", user.Departname) cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update records not 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("deleted records not 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) } func TestFindAndCount(t *testing.T) { diff --git a/session_delete.go b/session_delete.go index 688b122c..eb91614c 100644 --- a/session_delete.go +++ b/session_delete.go @@ -79,7 +79,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { defer session.Close() } - if err := session.statement.setRefValue(rValue(bean)); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return 0, err } diff --git a/session_exist.go b/session_exist.go index 378a6483..74a660e8 100644 --- a/session_exist.go +++ b/session_exist.go @@ -57,7 +57,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { } if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.setRefValue(beanValue.Elem()); err != nil { + if err := session.statement.setRefBean(bean[0]); err != nil { return false, err } } diff --git a/session_exist_test.go b/session_exist_test.go index 857bf4a1..9d985771 100644 --- a/session_exist_test.go +++ b/session_exist_test.go @@ -54,11 +54,11 @@ func TestExistStruct(t *testing.T) { assert.NoError(t, err) assert.False(t, has) - has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist() + has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test1").Exist() assert.NoError(t, err) assert.True(t, has) - has, err = testEngine.SQL("select * from record_exist where name = ?", "test2").Exist() + has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test2").Exist() assert.NoError(t, err) assert.False(t, has) diff --git a/session_find_test.go b/session_find_test.go index 04fdb030..4db7f9ce 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -96,21 +96,15 @@ func TestFind(t *testing.T) { users := make([]Userinfo, 0) err := testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) for _, user := range users { fmt.Println(user) } users2 := make([]Userinfo, 0) - userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") - err = testEngine.SQL("select * from " + testEngine.Quote(userinfo)).Find(&users2) - if err != nil { - t.Error(err) - panic(err) - } + var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true)) + err = testEngine.SQL("select * from " + tbName).Find(&users2) + assert.NoError(t, err) } func TestFind2(t *testing.T) { @@ -238,14 +232,8 @@ func TestDistinct(t *testing.T) { users := make([]Userinfo, 0) departname := testEngine.GetTableMapper().Obj2Table("Departname") err = testEngine.Distinct(departname).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - if len(users) != 1 { - t.Error(err) - panic(errors.New("should be one record")) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) fmt.Println(users) @@ -255,11 +243,9 @@ func TestDistinct(t *testing.T) { users2 := make([]Depart, 0) err = testEngine.Distinct(departname).Table(new(Userinfo)).Find(&users2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if len(users2) != 1 { + fmt.Println(len(users2)) t.Error(err) panic(errors.New("should be one record")) } @@ -272,18 +258,12 @@ func TestOrder(t *testing.T) { users := make([]Userinfo, 0) err := testEngine.OrderBy("id desc").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) users2 := make([]Userinfo, 0) err = testEngine.Asc("id", "username").Desc("height").Find(&users2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users2) } @@ -293,10 +273,7 @@ func TestHaving(t *testing.T) { users := make([]Userinfo, 0) err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) /*users = make([]Userinfo, 0) @@ -324,18 +301,12 @@ func TestOrderSameMapper(t *testing.T) { users := make([]Userinfo, 0) err := testEngine.OrderBy("(id) desc").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) users2 := make([]Userinfo, 0) err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users2) } diff --git a/session_get.go b/session_get.go index 68b37af7..58191de1 100644 --- a/session_get.go +++ b/session_get.go @@ -31,7 +31,7 @@ func (session *Session) get(bean interface{}) (bool, error) { } if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.setRefValue(beanValue.Elem()); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return false, err } } diff --git a/session_get_test.go b/session_get_test.go index e27e6de9..6a2fd575 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -84,7 +84,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) var money2 float64 - has, err = testEngine.SQL("SELECT money FROM get_var LIMIT 1").Get(&money2) + has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2) assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2)) diff --git a/session_insert.go b/session_insert.go index 129ee230..8609b80c 100644 --- a/session_insert.go +++ b/session_insert.go @@ -298,7 +298,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { } func (session *Session) innerInsert(bean interface{}) (int64, error) { - if err := session.statement.setRefValue(rValue(bean)); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return 0, err } if len(session.statement.TableName()) <= 0 { diff --git a/session_insert_test.go b/session_insert_test.go index 080ace9b..c87ae3a9 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -716,8 +716,9 @@ func (MyUserinfo2) TableName() string { func TestInsertMulti4(t *testing.T) { assert.NoError(t, prepareEngine()) - testEngine.ShowSQL(true) + testEngine.ShowSQL(false) assertSync(t, new(MyUserinfo2)) + testEngine.ShowSQL(true) users := []MyUserinfo2{ {Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, diff --git a/session_pk_test.go b/session_pk_test.go index 3370b2ad..7b025acd 100644 --- a/session_pk_test.go +++ b/session_pk_test.go @@ -1118,13 +1118,28 @@ func TestCompositePK(t *testing.T) { } assert.NoError(t, prepareEngine()) - assertSync(t, new(TaskSolution)) - assert.NoError(t, testEngine.Sync2(new(TaskSolution))) - tables, err := testEngine.DBMetas() + tables1, err := testEngine.DBMetas() assert.NoError(t, err) - assert.EqualValues(t, 1, len(tables)) - pkCols := tables[0].PKColumns() + + assertSync(t, new(TaskSolution)) + assert.NoError(t, testEngine.Sync2(new(TaskSolution))) + + tables2, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1+len(tables1), len(tables2)) + + var table *core.Table + for _, t := range tables2 { + if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") { + table = t + break + } + } + + assert.NotEqual(t, nil, table) + + pkCols := table.PKColumns() assert.EqualValues(t, 2, len(pkCols)) assert.EqualValues(t, "uid", pkCols[0].Name) assert.EqualValues(t, "tid", pkCols[1].Name) diff --git a/session_query_test.go b/session_query_test.go index 7ea413bb..8b4aefad 100644 --- a/session_query_test.go +++ b/session_query_test.go @@ -36,7 +36,7 @@ func TestQueryString(t *testing.T) { _, err := testEngine.InsertOne(data) assert.NoError(t, err) - records, err := testEngine.QueryString("select * from get_var2") + records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var2", true)) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) @@ -62,7 +62,7 @@ func TestQueryString2(t *testing.T) { _, err := testEngine.Insert(data) assert.NoError(t, err) - records, err := testEngine.QueryString("select * from get_var3") + records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var3", true)) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 2, len(records[0])) @@ -127,7 +127,7 @@ func TestQueryInterface(t *testing.T) { _, err := testEngine.InsertOne(data) assert.NoError(t, err) - records, err := testEngine.QueryInterface("select * from get_var_interface") + records, err := testEngine.QueryInterface("select * from " + testEngine.TableName("get_var_interface", true)) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) @@ -181,7 +181,7 @@ func TestQueryNoParams(t *testing.T) { assert.NoError(t, err) assertResult(t, results) - results, err = testEngine.SQL("select * from query_no_params").Query() + results, err = testEngine.SQL("select * from " + testEngine.TableName("query_no_params", true)).Query() assert.NoError(t, err) assertResult(t, results) } @@ -226,7 +226,7 @@ func TestQueryWithBuilder(t *testing.T) { assert.EqualValues(t, 3000, money) } - results, err := testEngine.Query(builder.Select("*").From("query_with_builder")) + results, err := testEngine.Query(builder.Select("*").From(testEngine.TableName("query_with_builder", true))) assert.NoError(t, err) assertResult(t, results) } diff --git a/session_raw_test.go b/session_raw_test.go index 32e8037c..766206a4 100644 --- a/session_raw_test.go +++ b/session_raw_test.go @@ -21,13 +21,13 @@ func TestExecAndQuery(t *testing.T) { assert.NoError(t, testEngine.Sync2(new(UserinfoQuery))) - res, err := testEngine.Exec("INSERT INTO `userinfo_query` (uid, name) VALUES (?, ?)", 1, "user") + res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_query`", true)+" (uid, name) VALUES (?, ?)", 1, "user") assert.NoError(t, err) cnt, err := res.RowsAffected() assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - results, err := testEngine.Query("select * from userinfo_query") + results, err := testEngine.Query("select * from " + testEngine.TableName("userinfo_query", true)) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) id, err := strconv.Atoi(string(results[0]["uid"])) diff --git a/session_schema.go b/session_schema.go index 9d9edca8..f0628661 100644 --- a/session_schema.go +++ b/session_schema.go @@ -6,9 +6,7 @@ package xorm import ( "database/sql" - "errors" "fmt" - "reflect" "strings" "github.com/go-xorm/core" @@ -34,8 +32,7 @@ func (session *Session) CreateTable(bean interface{}) error { } func (session *Session) createTable(bean interface{}) error { - v := rValue(bean) - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -54,8 +51,7 @@ func (session *Session) CreateIndexes(bean interface{}) error { } func (session *Session) createIndexes(bean interface{}) error { - v := rValue(bean) - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -78,8 +74,7 @@ func (session *Session) CreateUniques(bean interface{}) error { } func (session *Session) createUniques(bean interface{}) error { - v := rValue(bean) - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -103,8 +98,7 @@ func (session *Session) DropIndexes(bean interface{}) error { } func (session *Session) dropIndexes(bean interface{}) error { - v := rValue(bean) - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -128,11 +122,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { } func (session *Session) dropTable(beanOrTableName interface{}) error { - tableName, err := session.engine.tableName(beanOrTableName) - if err != nil { - return err - } - + tableName := session.engine.tbNameNoSchema(beanOrTableName) var needDrop = true if !session.engine.dialect.SupportDropIfExists() { sqlStr, args := session.engine.dialect.TableCheckSql(tableName) @@ -144,8 +134,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { } if needDrop { - sqlStr := session.engine.Dialect().DropTableSql(tableName) - _, err = session.exec(sqlStr) + sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true)) + _, err := session.exec(sqlStr) return err } return nil @@ -157,10 +147,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) defer session.Close() } - tableName, err := session.engine.tableName(beanOrTableName) - if err != nil { - return false, err - } + tableName := session.engine.tbNameNoSchema(beanOrTableName) return session.isTableExist(tableName) } @@ -173,24 +160,15 @@ func (session *Session) isTableExist(tableName string) (bool, error) { // IsTableEmpty if table have any records func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { - v := rValue(bean) - t := v.Type() - - if t.Kind() == reflect.String { - if session.isAutoClose { - defer session.Close() - } - return session.isTableEmpty(bean.(string)) - } else if t.Kind() == reflect.Struct { - rows, err := session.Count(bean) - return rows == 0, err + if session.isAutoClose { + defer session.Close() } - return false, errors.New("bean should be a struct or struct's point") + return session.isTableEmpty(session.engine.tbNameNoSchema(bean)) } func (session *Session) isTableEmpty(tableName string) (bool, error) { var total int64 - sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName)) + sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true))) err := session.queryRow(sqlStr).Scan(&total) if err != nil { if err == sql.ErrNoRows { @@ -270,7 +248,8 @@ func (session *Session) Sync2(beans ...interface{}) error { return err } structTables = append(structTables, table) - var tbName = session.tbNameNoSchema(table) + tbName := session.tbNameNoSchema(table) + tbNameWithSchema := engine.TableName(tbName, true) var oriTable *core.Table for _, tb := range tables { @@ -315,32 +294,32 @@ func (session *Session) Sync2(beans ...interface{}) error { if engine.dialect.DBType() == core.MYSQL || engine.dialect.DBType() == core.POSTGRES { engine.logger.Infof("Table %s column %s change type from %s to %s\n", - tbName, col.Name, curType, expectedType) - _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) + tbNameWithSchema, col.Name, curType, expectedType) + _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) } else { engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", - tbName, col.Name, curType, expectedType) + tbNameWithSchema, col.Name, curType, expectedType) } } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { if engine.dialect.DBType() == core.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", - tbName, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) + tbNameWithSchema, col.Name, oriCol.Length, col.Length) + _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) } } } else { if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", - tbName, col.Name, curType, expectedType) + tbNameWithSchema, col.Name, curType, expectedType) } } } else if expectedType == core.Varchar { if engine.dialect.DBType() == core.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", - tbName, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) + tbNameWithSchema, col.Name, oriCol.Length, col.Length) + _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) } } } @@ -354,7 +333,7 @@ func (session *Session) Sync2(beans ...interface{}) error { } } else { session.statement.RefTable = table - session.statement.tableName = tbName + session.statement.tableName = tbNameWithSchema err = session.addColumn(col.Name) } if err != nil { @@ -377,7 +356,7 @@ func (session *Session) Sync2(beans ...interface{}) error { if oriIndex != nil { if oriIndex.Type != index.Type { - sql := engine.dialect.DropIndexSql(tbName, oriIndex) + sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex) _, err = session.exec(sql) if err != nil { return err @@ -393,7 +372,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for name2, index2 := range oriTable.Indexes { if _, ok := foundIndexNames[name2]; !ok { - sql := engine.dialect.DropIndexSql(tbName, index2) + sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2) _, err = session.exec(sql) if err != nil { return err @@ -404,12 +383,12 @@ func (session *Session) Sync2(beans ...interface{}) error { for name, index := range addedNames { if index.Type == core.UniqueType { session.statement.RefTable = table - session.statement.tableName = tbName - err = session.addUnique(tbName, name) + session.statement.tableName = tbNameWithSchema + err = session.addUnique(tbNameWithSchema, name) } else if index.Type == core.IndexType { session.statement.RefTable = table - session.statement.tableName = tbName - err = session.addIndex(tbName, name) + session.statement.tableName = tbNameWithSchema + err = session.addIndex(tbNameWithSchema, name) } if err != nil { return err @@ -434,7 +413,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for _, colName := range table.ColumnsSeq() { if oriTable.GetColumn(colName) == nil { - engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName) + engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName) } } } diff --git a/session_stats_test.go b/session_stats_test.go index ec5cace1..564fd99a 100644 --- a/session_stats_test.go +++ b/session_stats_test.go @@ -153,7 +153,7 @@ func TestSQLCount(t *testing.T) { assertSync(t, new(UserinfoCount2), new(UserinfoBooks)) - total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2"). + total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableName("userinfo_count2", true)). Count() assert.NoError(t, err) assert.EqualValues(t, 0, total) diff --git a/session_tx_test.go b/session_tx_test.go index 7102f5c7..568ed0f1 100644 --- a/session_tx_test.go +++ b/session_tx_test.go @@ -32,45 +32,21 @@ func TestTransaction(t *testing.T) { defer session.Close() err := session.Begin() - if err != nil { - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} _, err = session.Insert(&user1) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) user2 := Userinfo{Username: "yyy"} _, err = session.Where("(id) = ?", 0).Update(&user2) - if err != nil { - session.Rollback() - fmt.Println(err) - //t.Error(err) - return - } + assert.NoError(t, err) _, err = session.Delete(&user2) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - return - } - // panic(err) !nashtsai! should remove this + assert.NoError(t, err) } func TestCombineTransaction(t *testing.T) { @@ -91,38 +67,21 @@ func TestCombineTransaction(t *testing.T) { defer session.Close() err := session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} _, err = session.Insert(&user1) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + assert.NoError(t, err) + user2 := Userinfo{Username: "zzz"} _, err = session.Where("id = ?", 0).Update(&user2) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + assert.NoError(t, err) - _, err = session.Exec("delete from userinfo where username = ?", user2.Username) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + _, err = session.Exec("delete from "+testEngine.TableName("userinfo", true)+" where username = ?", user2.Username) + assert.NoError(t, err) err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } func TestCombineTransactionSameMapper(t *testing.T) { @@ -148,45 +107,24 @@ func TestCombineTransactionSameMapper(t *testing.T) { counter() defer counter() + session := testEngine.NewSession() defer session.Close() err := session.Begin() - if err != nil { - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} _, err = session.Insert(&user1) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) user2 := Userinfo{Username: "zzz"} _, err = session.Where("(id) = ?", 0).Update(&user2) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) - _, err = session.Exec("delete from `Userinfo` where `Username` = ?", user2.Username) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - return - } + _, err = session.Exec("delete from "+testEngine.TableName("`Userinfo`", true)+" where `Username` = ?", user2.Username) + assert.NoError(t, err) err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } diff --git a/session_update.go b/session_update.go index f5587456..11264a61 100644 --- a/session_update.go +++ b/session_update.go @@ -167,7 +167,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var isMap = t.Kind() == reflect.Map var isStruct = t.Kind() == reflect.Struct if isStruct { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return 0, err } diff --git a/session_update_test.go b/session_update_test.go index d1bc47bc..79205671 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -462,30 +462,18 @@ func TestUpdate1(t *testing.T) { col1 := &UpdateAllCols{Ptr: &s} err = testEngine.Sync(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col2 := &UpdateAllCols{col1.Id, true, "", nil} _, err = testEngine.ID(col2.Id).AllCols().Update(col2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col3 := &UpdateAllCols{} has, err = testEngine.ID(col2.Id).Get(col3) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if !has { err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) @@ -759,7 +747,7 @@ func TestUpdateUpdated(t *testing.T) { func TestUpdateSameMapper(t *testing.T) { assert.NoError(t, prepareEngine()) - oldMapper := testEngine.GetColumnMapper() + oldMapper := testEngine.GetTableMapper() testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(rValue(new(Condi)).Type()) testEngine.UnMapType(rValue(new(Article)).Type()) @@ -786,81 +774,38 @@ func TestUpdateSameMapper(t *testing.T) { var ori Userinfo has, err := testEngine.Get(&ori) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - t.Error(errors.New("not exist")) - panic(errors.New("not exist")) - } + assert.NoError(t, err) + assert.True(t, has) + // update by id user := Userinfo{Username: "xxx", Height: 1.2} cnt, err := testEngine.ID(ori.Uid).Update(&user) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) condi := Condi{"Username": "zzz", "Departname": ""} cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) total, err := testEngine.Count(&user) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != total { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.EqualValues(t, cnt, total) err = testEngine.Sync(&Article{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) defer func() { err = testEngine.DropTables(&Article{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) }() a := &Article{0, "1", "2", "3", "4", "5", 2} cnt, err = testEngine.Insert(a) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if cnt != 1 { err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) @@ -875,10 +820,7 @@ func TestUpdateSameMapper(t *testing.T) { } cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if cnt != 1 { err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) @@ -889,30 +831,18 @@ func TestUpdateSameMapper(t *testing.T) { col1 := &UpdateAllCols{} err = testEngine.Sync(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col2 := &UpdateAllCols{col1.Id, true, "", nil} _, err = testEngine.ID(col2.Id).AllCols().Update(col2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col3 := &UpdateAllCols{} has, err = testEngine.ID(col2.Id).Get(col3) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if !has { err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) @@ -931,32 +861,20 @@ func TestUpdateSameMapper(t *testing.T) { { col1 := &UpdateMustCols{} err = testEngine.Sync(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col2 := &UpdateMustCols{col1.Id, true, ""} boolStr := testEngine.GetColumnMapper().Obj2Table("Bool") stringStr := testEngine.GetColumnMapper().Obj2Table("String") _, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col3 := &UpdateMustCols{} has, err := testEngine.ID(col2.Id).Get(col3) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if !has { err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) diff --git a/statement.go b/statement.go index 02d73559..603b5990 100644 --- a/statement.go +++ b/statement.go @@ -221,26 +221,18 @@ func (statement *Statement) setRefValue(v reflect.Value) error { if err != nil { return err } - statement.tableName = statement.Engine.tbName(v) + statement.tableName = statement.Engine.TableName(v.Interface(), true) return nil } -// Table tempororily set table name, the parameter could be a string or a pointer of struct -func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { - v := rValue(tableNameOrBean) - t := v.Type() - if t.Kind() == reflect.String { - statement.AltTableName = tableNameOrBean.(string) - } else if t.Kind() == reflect.Struct { - var err error - statement.RefTable, err = statement.Engine.autoMapType(v) - if err != nil { - statement.Engine.logger.Error(err) - return statement - } - statement.AltTableName = statement.Engine.tbName(v) +func (statement *Statement) setRefBean(bean interface{}) error { + var err error + statement.RefTable, err = statement.Engine.autoMapType(rValue(bean)) + if err != nil { + return err } - return statement + statement.tableName = statement.Engine.TableName(bean, true) + return nil } // Auto generating update columnes and values according a struct @@ -743,6 +735,23 @@ func (statement *Statement) Asc(colNames ...string) *Statement { return statement } +// Table tempororily set table name, the parameter could be a string or a pointer of struct +func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { + v := rValue(tableNameOrBean) + t := v.Type() + if t.Kind() == reflect.Struct { + var err error + statement.RefTable, err = statement.Engine.autoMapType(v) + if err != nil { + statement.Engine.logger.Error(err) + return statement + } + } + + statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true) + return statement +} + // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { var buf bytes.Buffer @@ -752,56 +761,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, "%v JOIN ", joinOP) } - switch tablename.(type) { - case []string: - t := tablename.([]string) - if len(t) > 1 { - fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1])) - } else if len(t) == 1 { - fmt.Fprintf(&buf, statement.Engine.Quote(t[0])) - } - case []interface{}: - t := tablename.([]interface{}) - l := len(t) - var table string - if l > 0 { - f := t[0] - switch f.(type) { - case string: - table = f.(string) - case TableName: - table = f.(TableName).TableName() - default: - v := rValue(f) - t := v.Type() - if t.Kind() == reflect.Struct { - fmt.Fprintf(&buf, statement.Engine.tbName(v)) - } else { - fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", f))) - } - } - } - if l > 1 { - fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table), - statement.Engine.Quote(fmt.Sprintf("%v", t[1]))) - } else if l == 1 { - fmt.Fprintf(&buf, statement.Engine.Quote(table)) - } - case TableName: - fmt.Fprintf(&buf, tablename.(TableName).TableName()) - case string: - fmt.Fprintf(&buf, tablename.(string)) - default: - v := rValue(tablename) - t := v.Type() - if t.Kind() == reflect.Struct { - fmt.Fprintf(&buf, statement.Engine.tbName(v)) - } else { - fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) - } - } + tbName := statement.Engine.TableName(tablename, true) - fmt.Fprintf(&buf, " ON %v", condition) + fmt.Fprintf(&buf, "%s ON %v", tbName, condition) statement.JoinStr = buf.String() statement.joinArgs = append(statement.joinArgs, args...) return statement @@ -876,11 +838,13 @@ func (statement *Statement) genCreateTableSQL() string { func (statement *Statement) genIndexSQL() []string { var sqls []string tbName := statement.TableName() - quote := statement.Engine.Quote - for idxName, index := range statement.RefTable.Indexes { + for _, index := range statement.RefTable.Indexes { if index.Type == core.IndexType { - sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), - quote(tbName), quote(strings.Join(index.Cols, quote(",")))) + sql := statement.Engine.dialect.CreateIndexSql(tbName, index) + /*idxTBName := strings.Replace(tbName, ".", "_", -1) + idxTBName = strings.Replace(idxTBName, `"`, "", -1) + sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)), + quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/ sqls = append(sqls, sql) } } @@ -906,16 +870,18 @@ func (statement *Statement) genUniqueSQL() []string { func (statement *Statement) genDelIndexSQL() []string { var sqls []string tbName := statement.TableName() + idxPrefixName := strings.Replace(tbName, `"`, "", -1) + idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1) for idxName, index := range statement.RefTable.Indexes { var rIdxName string if index.Type == core.UniqueType { - rIdxName = uniqueName(tbName, idxName) + rIdxName = uniqueName(idxPrefixName, idxName) } else if index.Type == core.IndexType { - rIdxName = indexName(tbName, idxName) + rIdxName = indexName(idxPrefixName, idxName) } - sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName)) + sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true))) if statement.Engine.dialect.IndexOnTable() { - sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName())) + sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) } sqls = append(sqls, sql) } @@ -966,7 +932,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, v := rValue(bean) isStruct := v.Kind() == reflect.Struct if isStruct { - statement.setRefValue(v) + statement.setRefBean(bean) } var columnStr = statement.ColumnStr @@ -1018,7 +984,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa var condArgs []interface{} var err error if len(beans) > 0 { - statement.setRefValue(rValue(beans[0])) + statement.setRefBean(beans[0]) condSQL, condArgs, err = statement.genConds(beans[0]) } else { condSQL, condArgs, err = builder.ToSQL(statement.cond) @@ -1044,7 +1010,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa } func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { - statement.setRefValue(rValue(bean)) + statement.setRefBean(bean) var sumStrs = make([]string, 0, len(columns)) for _, colName := range columns { diff --git a/tag_extends_test.go b/tag_extends_test.go index b70eefe3..4a4150ba 100644 --- a/tag_extends_test.go +++ b/tag_extends_test.go @@ -202,17 +202,14 @@ func TestExtends(t *testing.T) { var info UserAndDetail qt := testEngine.Quote - ui := testEngine.GetTableMapper().Obj2Table("Userinfo") - ud := testEngine.GetTableMapper().Obj2Table("Userdetail") - uiid := testEngine.GetTableMapper().Obj2Table("Id") + ui := testEngine.TableName(new(Userinfo), true) + ud := testEngine.TableName(&detail, true) + uiid := testEngine.GetColumnMapper().Obj2Table("Id") udid := "detail_id" sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) b, err := testEngine.SQL(sql).NoCascade().Get(&info) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if !b { err = errors.New("should has lest one record") t.Error(err) @@ -341,19 +338,17 @@ func TestExtends2(t *testing.T) { } var mapper = testEngine.GetTableMapper().Obj2Table - userTableName := mapper("MessageUser") - typeTableName := mapper("MessageType") - msgTableName := mapper("Message") + var quote = testEngine.Quote + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) + msgTableName := quote(testEngine.TableName(mapper("Message"), true)) list := make([]Message, 0) - err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`"). - Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`"). - Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`"). + err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). + Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). Find(&list) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if len(list) != 1 { err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) @@ -406,25 +401,20 @@ func TestExtends3(t *testing.T) { assert.NoError(t, err) } _, err = testEngine.Insert(&msg) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) var mapper = testEngine.GetTableMapper().Obj2Table - userTableName := mapper("MessageUser") - typeTableName := mapper("MessageType") - msgTableName := mapper("Message") + var quote = testEngine.Quote + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) + msgTableName := quote(testEngine.TableName(mapper("Message"), true)) list := make([]MessageExtend3, 0) - err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`"). - Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`"). - Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`"). + err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). + Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). Find(&list) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if len(list) != 1 { err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) @@ -499,13 +489,14 @@ func TestExtends4(t *testing.T) { } var mapper = testEngine.GetTableMapper().Obj2Table - userTableName := mapper("MessageUser") - typeTableName := mapper("MessageType") - msgTableName := mapper("Message") + var quote = testEngine.Quote + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) + msgTableName := quote(testEngine.TableName(mapper("Message"), true)) list := make([]MessageExtend4, 0) - err = testEngine.Table(msgTableName).Join("LEFT", userTableName, "`"+userTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`"). - Join("LEFT", typeTableName, "`"+typeTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`"). + err = testEngine.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). Find(&list) if err != nil { t.Error(err) diff --git a/types_test.go b/types_test.go index 3dc1cf9d..20511407 100644 --- a/types_test.go +++ b/types_test.go @@ -301,10 +301,11 @@ type UserCus struct { func TestCustomType2(t *testing.T) { assert.NoError(t, prepareEngine()) - err := testEngine.CreateTables(&UserCus{}) + var uc UserCus + err := testEngine.CreateTables(&uc) assert.NoError(t, err) - tableName := testEngine.GetTableMapper().Obj2Table("UserCus") + tableName := testEngine.TableName(&uc, true) _, err = testEngine.Exec("delete from " + testEngine.Quote(tableName)) assert.NoError(t, err) diff --git a/xorm_test.go b/xorm_test.go index 569bc681..4e88dc40 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -27,6 +27,7 @@ var ( cache = flag.Bool("cache", false, "if enable cache") cluster = flag.Bool("cluster", false, "if this is a cluster") splitter = flag.String("splitter", ";", "the splitter on connstr for cluster") + schema = flag.String("schema", "", "specify the schema") ) func createEngine(dbType, connStr string) error { @@ -35,7 +36,6 @@ func createEngine(dbType, connStr string) error { if !*cluster { testEngine, err = NewEngine(dbType, connStr) - } else { testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter)) } @@ -43,6 +43,9 @@ func createEngine(dbType, connStr string) error { return err } + if *schema != "" { + testEngine.SetSchema(*schema) + } testEngine.ShowSQL(*showSQL) testEngine.SetLogLevel(core.LOG_DEBUG) if *cache {