diff --git a/dialect_postgres.go b/dialect_postgres.go index b8c8308e..d907c68c 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -913,6 +913,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { quote := db.Quote idxName := index.Name + tableName = strings.Replace(tableName, `"`, "", -1) + tableName = strings.Replace(tableName, `.`, "_", -1) + if !strings.HasPrefix(idxName, "UQE_") && !strings.HasPrefix(idxName, "IDX_") { if index.Type == core.UniqueType { @@ -921,6 +924,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) } } + if db.Uri.Schema != "" { + idxName = db.Uri.Schema + "." + idxName + } return fmt.Sprintf("DROP INDEX %v", quote(idxName)) } diff --git a/engine.go b/engine.go index 2c8434c7..6eadb93a 100644 --- a/engine.go +++ b/engine.go @@ -1185,7 +1185,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { for _, bean := range beans { v := rValue(bean) - tableNameNoSchema := engine.tbNameNoSchemaString(v.Interface()) + tableNameNoSchema := engine.tbNameNoSchema(v.Interface()) table, err := engine.autoMapType(v) if err != nil { return err diff --git a/engine_table.go b/engine_table.go index 78218cb3..2ff5c388 100644 --- a/engine_table.go +++ b/engine_table.go @@ -5,9 +5,7 @@ package xorm import ( - "bytes" "fmt" - "io" "reflect" "strings" @@ -28,7 +26,7 @@ func (engine *Engine) TableNameWithSchema(v string) string { } func (engine *Engine) tableName(bean interface{}) string { - return engine.TableNameWithSchema(engine.tbNameNoSchemaString(bean)) + return engine.TableNameWithSchema(engine.tbNameNoSchema(bean)) } func (engine *Engine) tbNameForMap(v reflect.Value) string { @@ -44,14 +42,14 @@ func (engine *Engine) tbNameForMap(v reflect.Value) string { return engine.TableMapper.Obj2Table(t.Name()) } -func (engine *Engine) tbNameNoSchema(w io.Writer, tablename interface{}) { +func (engine *Engine) tbNameNoSchema(tablename interface{}) string { switch tablename.(type) { case []string: t := tablename.([]string) if len(t) > 1 { - fmt.Fprintf(w, "%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) + return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) } else if len(t) == 1 { - fmt.Fprintf(w, engine.Quote(t[0])) + return engine.Quote(t[0]) } case []interface{}: t := tablename.([]interface{}) @@ -68,35 +66,29 @@ func (engine *Engine) tbNameNoSchema(w io.Writer, tablename interface{}) { v := rValue(f) t := v.Type() if t.Kind() == reflect.Struct { - fmt.Fprintf(w, engine.tbNameForMap(v)) + table = engine.tbNameForMap(v) } else { - fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", f))) + table = engine.Quote(fmt.Sprintf("%v", f)) } } } if l > 1 { - fmt.Fprintf(w, "%v AS %v", engine.Quote(table), + return fmt.Sprintf("%v AS %v", engine.Quote(table), engine.Quote(fmt.Sprintf("%v", t[1]))) } else if l == 1 { - fmt.Fprintf(w, engine.Quote(table)) + return engine.Quote(table) } case TableName: - fmt.Fprintf(w, tablename.(TableName).TableName()) + return tablename.(TableName).TableName() case string: - fmt.Fprintf(w, tablename.(string)) + return tablename.(string) default: v := rValue(tablename) t := v.Type() if t.Kind() == reflect.Struct { - fmt.Fprintf(w, engine.tbNameForMap(v)) - } else { - fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", tablename))) + return engine.tbNameForMap(v) } + return engine.Quote(fmt.Sprintf("%v", tablename)) } -} - -func (engine *Engine) tbNameNoSchemaString(tablename interface{}) string { - var buf bytes.Buffer - engine.tbNameNoSchema(&buf, tablename) - return buf.String() + return "" } diff --git a/session_exist_test.go b/session_exist_test.go index 857bf4a1..9c27ce9d 100644 --- a/session_exist_test.go +++ b/session_exist_test.go @@ -54,11 +54,11 @@ func TestExistStruct(t *testing.T) { assert.NoError(t, err) assert.False(t, has) - has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist() + has, err = testEngine.SQL("select * from "+testEngine.TableNameWithSchema("record_exist")+" where name = ?", "test1").Exist() assert.NoError(t, err) assert.True(t, has) - has, err = testEngine.SQL("select * from record_exist where name = ?", "test2").Exist() + has, err = testEngine.SQL("select * from "+testEngine.TableNameWithSchema("record_exist")+" where name = ?", "test2").Exist() assert.NoError(t, err) assert.False(t, has) diff --git a/session_get_test.go b/session_get_test.go index e27e6de9..193c3dfe 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -84,7 +84,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) var money2 float64 - has, err = testEngine.SQL("SELECT money FROM get_var LIMIT 1").Get(&money2) + has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableNameWithSchema("get_var") + " LIMIT 1").Get(&money2) assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2)) diff --git a/session_insert_test.go b/session_insert_test.go index 080ace9b..c87ae3a9 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -716,8 +716,9 @@ func (MyUserinfo2) TableName() string { func TestInsertMulti4(t *testing.T) { assert.NoError(t, prepareEngine()) - testEngine.ShowSQL(true) + testEngine.ShowSQL(false) assertSync(t, new(MyUserinfo2)) + testEngine.ShowSQL(true) users := []MyUserinfo2{ {Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, diff --git a/session_schema.go b/session_schema.go index fad811b8..5596e075 100644 --- a/session_schema.go +++ b/session_schema.go @@ -122,12 +122,10 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { } func (session *Session) dropTable(beanOrTableName interface{}) error { - tableName := session.engine.tbNameNoSchemaString(beanOrTableName) + tableName := session.engine.tbNameNoSchema(beanOrTableName) var needDrop = true if !session.engine.dialect.SupportDropIfExists() { - fmt.Println("TableCheckSql:", tableName) sqlStr, args := session.engine.dialect.TableCheckSql(tableName) - fmt.Println("sqlStr:", sqlStr) results, err := session.queryBytes(sqlStr, args...) if err != nil { return err @@ -149,7 +147,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) defer session.Close() } - tableName := session.engine.tbNameNoSchemaString(beanOrTableName) + tableName := session.engine.tbNameNoSchema(beanOrTableName) return session.isTableExist(tableName) } @@ -165,7 +163,7 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { if session.isAutoClose { defer session.Close() } - return session.isTableEmpty(session.engine.tbNameNoSchemaString(bean)) + return session.isTableEmpty(session.engine.tbNameNoSchema(bean)) } func (session *Session) isTableEmpty(tableName string) (bool, error) { diff --git a/session_stats_test.go b/session_stats_test.go index ec5cace1..b79aed03 100644 --- a/session_stats_test.go +++ b/session_stats_test.go @@ -153,7 +153,7 @@ func TestSQLCount(t *testing.T) { assertSync(t, new(UserinfoCount2), new(UserinfoBooks)) - total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2"). + total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableNameWithSchema("userinfo_count2")). Count() assert.NoError(t, err) assert.EqualValues(t, 0, total) diff --git a/statement.go b/statement.go index e8acb1e3..4d20510e 100644 --- a/statement.go +++ b/statement.go @@ -231,7 +231,7 @@ func (statement *Statement) setRefBean(bean interface{}) error { if err != nil { return err } - statement.tableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchemaString(bean)) + statement.tableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(bean)) return nil } @@ -748,7 +748,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { } } - statement.AltTableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchemaString(tableNameOrBean)) + statement.AltTableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(tableNameOrBean)) return statement } @@ -761,9 +761,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, "%v JOIN ", joinOP) } - statement.Engine.tbNameNoSchema(&buf, tablename) + tbName := statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(tablename)) - fmt.Fprintf(&buf, " ON %v", condition) + fmt.Fprintf(&buf, "%s ON %v", tbName, condition) statement.JoinStr = buf.String() statement.joinArgs = append(statement.joinArgs, args...) return statement @@ -838,11 +838,13 @@ func (statement *Statement) genCreateTableSQL() string { func (statement *Statement) genIndexSQL() []string { var sqls []string tbName := statement.TableName() - quote := statement.Engine.Quote - for idxName, index := range statement.RefTable.Indexes { + for _, index := range statement.RefTable.Indexes { if index.Type == core.IndexType { - sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), - quote(tbName), quote(strings.Join(index.Cols, quote(",")))) + sql := statement.Engine.dialect.CreateIndexSql(tbName, index) + /*idxTBName := strings.Replace(tbName, ".", "_", -1) + idxTBName = strings.Replace(idxTBName, `"`, "", -1) + sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)), + quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/ sqls = append(sqls, sql) } } @@ -868,14 +870,16 @@ func (statement *Statement) genUniqueSQL() []string { func (statement *Statement) genDelIndexSQL() []string { var sqls []string tbName := statement.TableName() + idxPrefixName := strings.Replace(tbName, `"`, "", -1) + idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1) for idxName, index := range statement.RefTable.Indexes { var rIdxName string if index.Type == core.UniqueType { - rIdxName = uniqueName(tbName, idxName) + rIdxName = uniqueName(idxPrefixName, idxName) } else if index.Type == core.IndexType { - rIdxName = indexName(tbName, idxName) + rIdxName = indexName(idxPrefixName, idxName) } - sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName)) + sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.TableNameWithSchema(statement.Engine.Quote(rIdxName))) if statement.Engine.dialect.IndexOnTable() { sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) } diff --git a/tag_extends_test.go b/tag_extends_test.go index b70eefe3..f0a95244 100644 --- a/tag_extends_test.go +++ b/tag_extends_test.go @@ -202,17 +202,14 @@ func TestExtends(t *testing.T) { var info UserAndDetail qt := testEngine.Quote - ui := testEngine.GetTableMapper().Obj2Table("Userinfo") - ud := testEngine.GetTableMapper().Obj2Table("Userdetail") + ui := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("Userinfo")) + ud := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("Userdetail")) uiid := testEngine.GetTableMapper().Obj2Table("Id") udid := "detail_id" sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) b, err := testEngine.SQL(sql).NoCascade().Get(&info) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if !b { err = errors.New("should has lest one record") t.Error(err) @@ -341,19 +338,17 @@ func TestExtends2(t *testing.T) { } var mapper = testEngine.GetTableMapper().Obj2Table - userTableName := mapper("MessageUser") - typeTableName := mapper("MessageType") - msgTableName := mapper("Message") + var quote = testEngine.Quote + userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser"))) + typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType"))) + msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message"))) list := make([]Message, 0) - err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`"). - Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`"). - Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`"). + err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). + Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). Find(&list) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if len(list) != 1 { err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) @@ -406,25 +401,20 @@ func TestExtends3(t *testing.T) { assert.NoError(t, err) } _, err = testEngine.Insert(&msg) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) var mapper = testEngine.GetTableMapper().Obj2Table - userTableName := mapper("MessageUser") - typeTableName := mapper("MessageType") - msgTableName := mapper("Message") + var quote = testEngine.Quote + userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser"))) + typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType"))) + msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message"))) list := make([]MessageExtend3, 0) - err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`"). - Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`"). - Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`"). + err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). + Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). Find(&list) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if len(list) != 1 { err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) @@ -499,13 +489,14 @@ func TestExtends4(t *testing.T) { } var mapper = testEngine.GetTableMapper().Obj2Table - userTableName := mapper("MessageUser") - typeTableName := mapper("MessageType") - msgTableName := mapper("Message") + var quote = testEngine.Quote + userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser"))) + typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType"))) + msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message"))) list := make([]MessageExtend4, 0) - err = testEngine.Table(msgTableName).Join("LEFT", userTableName, "`"+userTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`"). - Join("LEFT", typeTableName, "`"+typeTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`"). + err = testEngine.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). Find(&list) if err != nil { t.Error(err) diff --git a/types_test.go b/types_test.go index 3dc1cf9d..2d5c55d5 100644 --- a/types_test.go +++ b/types_test.go @@ -304,7 +304,7 @@ func TestCustomType2(t *testing.T) { err := testEngine.CreateTables(&UserCus{}) assert.NoError(t, err) - tableName := testEngine.GetTableMapper().Obj2Table("UserCus") + tableName := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("UserCus")) _, err = testEngine.Exec("delete from " + testEngine.Quote(tableName)) assert.NoError(t, err)