diff --git a/engine.go b/engine.go index 6eadb93a..876004ba 100644 --- a/engine.go +++ b/engine.go @@ -819,7 +819,7 @@ func (engine *Engine) TableInfo(bean interface{}) *Table { if err != nil { engine.logger.Error(err) } - return &Table{tb, engine.tableName(bean)} + return &Table{tb, engine.TableName(bean)} } func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { @@ -1134,7 +1134,7 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { if t.Kind() != reflect.Struct { return errors.New("error params") } - tableName := engine.tableName(bean) + tableName := engine.TableName(bean) table, err := engine.autoMapType(v) if err != nil { return err @@ -1158,7 +1158,7 @@ func (engine *Engine) ClearCache(beans ...interface{}) error { if t.Kind() != reflect.Struct { return errors.New("error params") } - tableName := engine.tableName(bean) + tableName := engine.TableName(bean) table, err := engine.autoMapType(v) if err != nil { return err diff --git a/engine_table.go b/engine_table.go index 2ff5c388..1319871f 100644 --- a/engine_table.go +++ b/engine_table.go @@ -13,7 +13,7 @@ import ( ) // TableNameWithSchema will automatically add schema prefix on table name -func (engine *Engine) TableNameWithSchema(v string) string { +func (engine *Engine) tbNameWithSchema(v string) string { // Add schema name as prefix of table name. // Only for postgres database. if engine.dialect.DBType() == core.POSTGRES && @@ -25,8 +25,23 @@ func (engine *Engine) TableNameWithSchema(v string) string { return v } -func (engine *Engine) tableName(bean interface{}) string { - return engine.TableNameWithSchema(engine.tbNameNoSchema(bean)) +// TableName returns table name with schema prefix if has +func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { + tbName := engine.tbNameNoSchema(bean) + if len(includeSchema) > 0 && includeSchema[0] { + tbName = engine.tbNameWithSchema(tbName) + } + + return tbName +} + +// tbName get some table's table name +func (session *Session) tbNameNoSchema(table *core.Table) string { + if len(session.statement.AltTableName) > 0 { + return session.statement.AltTableName + } + + return table.Name } func (engine *Engine) tbNameForMap(v reflect.Value) string { diff --git a/interface.go b/interface.go index ecc2c4f7..5ce49f48 100644 --- a/interface.go +++ b/interface.go @@ -95,7 +95,7 @@ type EngineInterface interface { Sync2(...interface{}) error StoreEngine(storeEngine string) *Session TableInfo(bean interface{}) *Table - TableNameWithSchema(string) string + TableName(interface{}, ...bool) string UnMapType(reflect.Type) } diff --git a/rows_test.go b/rows_test.go index ebfabef0..ee121c5e 100644 --- a/rows_test.go +++ b/rows_test.go @@ -54,10 +54,7 @@ func TestRows(t *testing.T) { } assert.EqualValues(t, 1, cnt) - var tbName = testEngine.Quote("user_rows") - if testEngine.Dialect().URI().Schema != "" { - tbName = testEngine.Quote(testEngine.Dialect().URI().Schema) + "." + tbName - } + var tbName = testEngine.Quote(testEngine.TableName(user, true)) rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows)) assert.NoError(t, err) defer rows2.Close() diff --git a/session.go b/session.go index 5c6cb5f9..15283624 100644 --- a/session.go +++ b/session.go @@ -828,15 +828,6 @@ func (session *Session) LastSQL() (string, []interface{}) { return session.lastSQL, session.lastSQLArgs } -// tbName get some table's table name -func (session *Session) tbNameNoSchema(table *core.Table) string { - if len(session.statement.AltTableName) > 0 { - return session.statement.AltTableName - } - - return table.Name -} - // Unscoped always disable struct tag "deleted" func (session *Session) Unscoped() *Session { session.statement.Unscoped() diff --git a/session_exist_test.go b/session_exist_test.go index 9c27ce9d..9d985771 100644 --- a/session_exist_test.go +++ b/session_exist_test.go @@ -54,11 +54,11 @@ func TestExistStruct(t *testing.T) { assert.NoError(t, err) assert.False(t, has) - has, err = testEngine.SQL("select * from "+testEngine.TableNameWithSchema("record_exist")+" where name = ?", "test1").Exist() + has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test1").Exist() assert.NoError(t, err) assert.True(t, has) - has, err = testEngine.SQL("select * from "+testEngine.TableNameWithSchema("record_exist")+" where name = ?", "test2").Exist() + has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test2").Exist() assert.NoError(t, err) assert.False(t, has) diff --git a/session_find_test.go b/session_find_test.go index 40747dce..4db7f9ce 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -102,11 +102,7 @@ func TestFind(t *testing.T) { } users2 := make([]Userinfo, 0) - userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") - var tbName = testEngine.Quote(userinfo) - if testEngine.Dialect().URI().Schema != "" { - tbName = testEngine.Quote(testEngine.Dialect().URI().Schema) + "." + tbName - } + var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true)) err = testEngine.SQL("select * from " + tbName).Find(&users2) assert.NoError(t, err) } diff --git a/session_get_test.go b/session_get_test.go index 193c3dfe..6a2fd575 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -84,7 +84,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) var money2 float64 - has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableNameWithSchema("get_var") + " LIMIT 1").Get(&money2) + has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2) assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2)) diff --git a/session_query_test.go b/session_query_test.go index 59ce386c..8b4aefad 100644 --- a/session_query_test.go +++ b/session_query_test.go @@ -36,7 +36,7 @@ func TestQueryString(t *testing.T) { _, err := testEngine.InsertOne(data) assert.NoError(t, err) - records, err := testEngine.QueryString("select * from " + testEngine.TableNameWithSchema("get_var2")) + records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var2", true)) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) @@ -62,7 +62,7 @@ func TestQueryString2(t *testing.T) { _, err := testEngine.Insert(data) assert.NoError(t, err) - records, err := testEngine.QueryString("select * from " + testEngine.TableNameWithSchema("get_var3")) + records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var3", true)) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 2, len(records[0])) @@ -127,7 +127,7 @@ func TestQueryInterface(t *testing.T) { _, err := testEngine.InsertOne(data) assert.NoError(t, err) - records, err := testEngine.QueryInterface("select * from " + testEngine.TableNameWithSchema("get_var_interface")) + records, err := testEngine.QueryInterface("select * from " + testEngine.TableName("get_var_interface", true)) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) @@ -181,7 +181,7 @@ func TestQueryNoParams(t *testing.T) { assert.NoError(t, err) assertResult(t, results) - results, err = testEngine.SQL("select * from " + testEngine.TableNameWithSchema("query_no_params")).Query() + results, err = testEngine.SQL("select * from " + testEngine.TableName("query_no_params", true)).Query() assert.NoError(t, err) assertResult(t, results) } @@ -226,7 +226,7 @@ func TestQueryWithBuilder(t *testing.T) { assert.EqualValues(t, 3000, money) } - results, err := testEngine.Query(builder.Select("*").From(testEngine.TableNameWithSchema("query_with_builder"))) + results, err := testEngine.Query(builder.Select("*").From(testEngine.TableName("query_with_builder", true))) assert.NoError(t, err) assertResult(t, results) } diff --git a/session_raw_test.go b/session_raw_test.go index 6fafd848..766206a4 100644 --- a/session_raw_test.go +++ b/session_raw_test.go @@ -21,13 +21,13 @@ func TestExecAndQuery(t *testing.T) { assert.NoError(t, testEngine.Sync2(new(UserinfoQuery))) - res, err := testEngine.Exec("INSERT INTO "+testEngine.TableNameWithSchema("`userinfo_query`")+" (uid, name) VALUES (?, ?)", 1, "user") + res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_query`", true)+" (uid, name) VALUES (?, ?)", 1, "user") assert.NoError(t, err) cnt, err := res.RowsAffected() assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - results, err := testEngine.Query("select * from " + testEngine.TableNameWithSchema("userinfo_query")) + results, err := testEngine.Query("select * from " + testEngine.TableName("userinfo_query", true)) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) id, err := strconv.Atoi(string(results[0]["uid"])) diff --git a/session_schema.go b/session_schema.go index 5596e075..f0628661 100644 --- a/session_schema.go +++ b/session_schema.go @@ -134,7 +134,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { } if needDrop { - sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableNameWithSchema(tableName)) + sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true)) _, err := session.exec(sqlStr) return err } @@ -168,7 +168,7 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { func (session *Session) isTableEmpty(tableName string) (bool, error) { var total int64 - sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.TableNameWithSchema(session.engine.Quote(tableName))) + sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true))) err := session.queryRow(sqlStr).Scan(&total) if err != nil { if err == sql.ErrNoRows { @@ -249,7 +249,7 @@ func (session *Session) Sync2(beans ...interface{}) error { } structTables = append(structTables, table) tbName := session.tbNameNoSchema(table) - tbNameWithSchema := engine.TableNameWithSchema(tbName) + tbNameWithSchema := engine.TableName(tbName, true) var oriTable *core.Table for _, tb := range tables { @@ -413,7 +413,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for _, colName := range table.ColumnsSeq() { if oriTable.GetColumn(colName) == nil { - engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableNameWithSchema(table.Name), colName) + engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName) } } } diff --git a/session_stats_test.go b/session_stats_test.go index b79aed03..564fd99a 100644 --- a/session_stats_test.go +++ b/session_stats_test.go @@ -153,7 +153,7 @@ func TestSQLCount(t *testing.T) { assertSync(t, new(UserinfoCount2), new(UserinfoBooks)) - total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableNameWithSchema("userinfo_count2")). + total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableName("userinfo_count2", true)). Count() assert.NoError(t, err) assert.EqualValues(t, 0, total) diff --git a/session_tx_test.go b/session_tx_test.go index 5f2b5817..568ed0f1 100644 --- a/session_tx_test.go +++ b/session_tx_test.go @@ -77,7 +77,7 @@ func TestCombineTransaction(t *testing.T) { _, err = session.Where("id = ?", 0).Update(&user2) assert.NoError(t, err) - _, err = session.Exec("delete from "+testEngine.TableNameWithSchema("userinfo")+" where username = ?", user2.Username) + _, err = session.Exec("delete from "+testEngine.TableName("userinfo", true)+" where username = ?", user2.Username) assert.NoError(t, err) err = session.Commit() @@ -122,7 +122,7 @@ func TestCombineTransactionSameMapper(t *testing.T) { _, err = session.Where("(id) = ?", 0).Update(&user2) assert.NoError(t, err) - _, err = session.Exec("delete from "+testEngine.TableNameWithSchema("`Userinfo`")+" where `Username` = ?", user2.Username) + _, err = session.Exec("delete from "+testEngine.TableName("`Userinfo`", true)+" where `Username` = ?", user2.Username) assert.NoError(t, err) err = session.Commit() diff --git a/statement.go b/statement.go index 4d20510e..603b5990 100644 --- a/statement.go +++ b/statement.go @@ -221,7 +221,7 @@ func (statement *Statement) setRefValue(v reflect.Value) error { if err != nil { return err } - statement.tableName = statement.Engine.tableName(v.Interface()) + statement.tableName = statement.Engine.TableName(v.Interface(), true) return nil } @@ -231,7 +231,7 @@ func (statement *Statement) setRefBean(bean interface{}) error { if err != nil { return err } - statement.tableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(bean)) + statement.tableName = statement.Engine.TableName(bean, true) return nil } @@ -748,7 +748,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { } } - statement.AltTableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(tableNameOrBean)) + statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true) return statement } @@ -761,7 +761,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, "%v JOIN ", joinOP) } - tbName := statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(tablename)) + tbName := statement.Engine.TableName(tablename, true) fmt.Fprintf(&buf, "%s ON %v", tbName, condition) statement.JoinStr = buf.String() @@ -879,7 +879,7 @@ func (statement *Statement) genDelIndexSQL() []string { } else if index.Type == core.IndexType { rIdxName = indexName(idxPrefixName, idxName) } - sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.TableNameWithSchema(statement.Engine.Quote(rIdxName))) + sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true))) if statement.Engine.dialect.IndexOnTable() { sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) } diff --git a/tag_extends_test.go b/tag_extends_test.go index f0a95244..4a4150ba 100644 --- a/tag_extends_test.go +++ b/tag_extends_test.go @@ -202,9 +202,9 @@ func TestExtends(t *testing.T) { var info UserAndDetail qt := testEngine.Quote - ui := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("Userinfo")) - ud := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("Userdetail")) - uiid := testEngine.GetTableMapper().Obj2Table("Id") + ui := testEngine.TableName(new(Userinfo), true) + ud := testEngine.TableName(&detail, true) + uiid := testEngine.GetColumnMapper().Obj2Table("Id") udid := "detail_id" sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) @@ -339,9 +339,9 @@ func TestExtends2(t *testing.T) { var mapper = testEngine.GetTableMapper().Obj2Table var quote = testEngine.Quote - userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser"))) - typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType"))) - msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message"))) + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) + msgTableName := quote(testEngine.TableName(mapper("Message"), true)) list := make([]Message, 0) err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). @@ -405,9 +405,9 @@ func TestExtends3(t *testing.T) { var mapper = testEngine.GetTableMapper().Obj2Table var quote = testEngine.Quote - userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser"))) - typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType"))) - msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message"))) + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) + msgTableName := quote(testEngine.TableName(mapper("Message"), true)) list := make([]MessageExtend3, 0) err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). @@ -490,9 +490,9 @@ func TestExtends4(t *testing.T) { var mapper = testEngine.GetTableMapper().Obj2Table var quote = testEngine.Quote - userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser"))) - typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType"))) - msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message"))) + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) + msgTableName := quote(testEngine.TableName(mapper("Message"), true)) list := make([]MessageExtend4, 0) err = testEngine.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). diff --git a/types_test.go b/types_test.go index 2d5c55d5..20511407 100644 --- a/types_test.go +++ b/types_test.go @@ -301,10 +301,11 @@ type UserCus struct { func TestCustomType2(t *testing.T) { assert.NoError(t, prepareEngine()) - err := testEngine.CreateTables(&UserCus{}) + var uc UserCus + err := testEngine.CreateTables(&uc) assert.NoError(t, err) - tableName := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("UserCus")) + tableName := testEngine.TableName(&uc, true) _, err = testEngine.Exec("delete from " + testEngine.Quote(tableName)) assert.NoError(t, err)