fix schema support

This commit is contained in:
Lunny Xiao 2018-04-09 23:18:31 +08:00
parent f90ef0062d
commit a21c00771f
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
11 changed files with 71 additions and 79 deletions

View File

@ -913,6 +913,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
quote := db.Quote quote := db.Quote
idxName := index.Name idxName := index.Name
tableName = strings.Replace(tableName, `"`, "", -1)
tableName = strings.Replace(tableName, `.`, "_", -1)
if !strings.HasPrefix(idxName, "UQE_") && if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") { !strings.HasPrefix(idxName, "IDX_") {
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
@ -921,6 +924,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
} }
} }
if db.Uri.Schema != "" {
idxName = db.Uri.Schema + "." + idxName
}
return fmt.Sprintf("DROP INDEX %v", quote(idxName)) return fmt.Sprintf("DROP INDEX %v", quote(idxName))
} }

View File

@ -1185,7 +1185,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
for _, bean := range beans { for _, bean := range beans {
v := rValue(bean) v := rValue(bean)
tableNameNoSchema := engine.tbNameNoSchemaString(v.Interface()) tableNameNoSchema := engine.tbNameNoSchema(v.Interface())
table, err := engine.autoMapType(v) table, err := engine.autoMapType(v)
if err != nil { if err != nil {
return err return err

View File

@ -5,9 +5,7 @@
package xorm package xorm
import ( import (
"bytes"
"fmt" "fmt"
"io"
"reflect" "reflect"
"strings" "strings"
@ -28,7 +26,7 @@ func (engine *Engine) TableNameWithSchema(v string) string {
} }
func (engine *Engine) tableName(bean interface{}) string { func (engine *Engine) tableName(bean interface{}) string {
return engine.TableNameWithSchema(engine.tbNameNoSchemaString(bean)) return engine.TableNameWithSchema(engine.tbNameNoSchema(bean))
} }
func (engine *Engine) tbNameForMap(v reflect.Value) string { func (engine *Engine) tbNameForMap(v reflect.Value) string {
@ -44,14 +42,14 @@ func (engine *Engine) tbNameForMap(v reflect.Value) string {
return engine.TableMapper.Obj2Table(t.Name()) return engine.TableMapper.Obj2Table(t.Name())
} }
func (engine *Engine) tbNameNoSchema(w io.Writer, tablename interface{}) { func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
switch tablename.(type) { switch tablename.(type) {
case []string: case []string:
t := tablename.([]string) t := tablename.([]string)
if len(t) > 1 { if len(t) > 1 {
fmt.Fprintf(w, "%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
} else if len(t) == 1 { } else if len(t) == 1 {
fmt.Fprintf(w, engine.Quote(t[0])) return engine.Quote(t[0])
} }
case []interface{}: case []interface{}:
t := tablename.([]interface{}) t := tablename.([]interface{})
@ -68,35 +66,29 @@ func (engine *Engine) tbNameNoSchema(w io.Writer, tablename interface{}) {
v := rValue(f) v := rValue(f)
t := v.Type() t := v.Type()
if t.Kind() == reflect.Struct { if t.Kind() == reflect.Struct {
fmt.Fprintf(w, engine.tbNameForMap(v)) table = engine.tbNameForMap(v)
} else { } else {
fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", f))) table = engine.Quote(fmt.Sprintf("%v", f))
} }
} }
} }
if l > 1 { if l > 1 {
fmt.Fprintf(w, "%v AS %v", engine.Quote(table), return fmt.Sprintf("%v AS %v", engine.Quote(table),
engine.Quote(fmt.Sprintf("%v", t[1]))) engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 { } else if l == 1 {
fmt.Fprintf(w, engine.Quote(table)) return engine.Quote(table)
} }
case TableName: case TableName:
fmt.Fprintf(w, tablename.(TableName).TableName()) return tablename.(TableName).TableName()
case string: case string:
fmt.Fprintf(w, tablename.(string)) return tablename.(string)
default: default:
v := rValue(tablename) v := rValue(tablename)
t := v.Type() t := v.Type()
if t.Kind() == reflect.Struct { if t.Kind() == reflect.Struct {
fmt.Fprintf(w, engine.tbNameForMap(v)) return engine.tbNameForMap(v)
} else {
fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", tablename)))
} }
return engine.Quote(fmt.Sprintf("%v", tablename))
} }
} return ""
func (engine *Engine) tbNameNoSchemaString(tablename interface{}) string {
var buf bytes.Buffer
engine.tbNameNoSchema(&buf, tablename)
return buf.String()
} }

View File

@ -54,11 +54,11 @@ func TestExistStruct(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, has) assert.False(t, has)
has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist() has, err = testEngine.SQL("select * from "+testEngine.TableNameWithSchema("record_exist")+" where name = ?", "test1").Exist()
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
has, err = testEngine.SQL("select * from record_exist where name = ?", "test2").Exist() has, err = testEngine.SQL("select * from "+testEngine.TableNameWithSchema("record_exist")+" where name = ?", "test2").Exist()
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, has) assert.False(t, has)

View File

@ -84,7 +84,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))
var money2 float64 var money2 float64
has, err = testEngine.SQL("SELECT money FROM get_var LIMIT 1").Get(&money2) has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableNameWithSchema("get_var") + " LIMIT 1").Get(&money2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2)) assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2))

View File

@ -716,8 +716,9 @@ func (MyUserinfo2) TableName() string {
func TestInsertMulti4(t *testing.T) { func TestInsertMulti4(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
testEngine.ShowSQL(true) testEngine.ShowSQL(false)
assertSync(t, new(MyUserinfo2)) assertSync(t, new(MyUserinfo2))
testEngine.ShowSQL(true)
users := []MyUserinfo2{ users := []MyUserinfo2{
{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, {Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},

View File

@ -122,12 +122,10 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
} }
func (session *Session) dropTable(beanOrTableName interface{}) error { func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName := session.engine.tbNameNoSchemaString(beanOrTableName) tableName := session.engine.tbNameNoSchema(beanOrTableName)
var needDrop = true var needDrop = true
if !session.engine.dialect.SupportDropIfExists() { if !session.engine.dialect.SupportDropIfExists() {
fmt.Println("TableCheckSql:", tableName)
sqlStr, args := session.engine.dialect.TableCheckSql(tableName) sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
fmt.Println("sqlStr:", sqlStr)
results, err := session.queryBytes(sqlStr, args...) results, err := session.queryBytes(sqlStr, args...)
if err != nil { if err != nil {
return err return err
@ -149,7 +147,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
defer session.Close() defer session.Close()
} }
tableName := session.engine.tbNameNoSchemaString(beanOrTableName) tableName := session.engine.tbNameNoSchema(beanOrTableName)
return session.isTableExist(tableName) return session.isTableExist(tableName)
} }
@ -165,7 +163,7 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
return session.isTableEmpty(session.engine.tbNameNoSchemaString(bean)) return session.isTableEmpty(session.engine.tbNameNoSchema(bean))
} }
func (session *Session) isTableEmpty(tableName string) (bool, error) { func (session *Session) isTableEmpty(tableName string) (bool, error) {

View File

@ -153,7 +153,7 @@ func TestSQLCount(t *testing.T) {
assertSync(t, new(UserinfoCount2), new(UserinfoBooks)) assertSync(t, new(UserinfoCount2), new(UserinfoBooks))
total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2"). total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableNameWithSchema("userinfo_count2")).
Count() Count()
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 0, total) assert.EqualValues(t, 0, total)

View File

@ -231,7 +231,7 @@ func (statement *Statement) setRefBean(bean interface{}) error {
if err != nil { if err != nil {
return err return err
} }
statement.tableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchemaString(bean)) statement.tableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(bean))
return nil return nil
} }
@ -748,7 +748,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
} }
} }
statement.AltTableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchemaString(tableNameOrBean)) statement.AltTableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(tableNameOrBean))
return statement return statement
} }
@ -761,9 +761,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt.Fprintf(&buf, "%v JOIN ", joinOP) fmt.Fprintf(&buf, "%v JOIN ", joinOP)
} }
statement.Engine.tbNameNoSchema(&buf, tablename) tbName := statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(tablename))
fmt.Fprintf(&buf, " ON %v", condition) fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
statement.JoinStr = buf.String() statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...) statement.joinArgs = append(statement.joinArgs, args...)
return statement return statement
@ -838,11 +838,13 @@ func (statement *Statement) genCreateTableSQL() string {
func (statement *Statement) genIndexSQL() []string { func (statement *Statement) genIndexSQL() []string {
var sqls []string var sqls []string
tbName := statement.TableName() tbName := statement.TableName()
quote := statement.Engine.Quote for _, index := range statement.RefTable.Indexes {
for idxName, index := range statement.RefTable.Indexes {
if index.Type == core.IndexType { if index.Type == core.IndexType {
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
quote(tbName), quote(strings.Join(index.Cols, quote(",")))) /*idxTBName := strings.Replace(tbName, ".", "_", -1)
idxTBName = strings.Replace(idxTBName, `"`, "", -1)
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }
} }
@ -868,14 +870,16 @@ func (statement *Statement) genUniqueSQL() []string {
func (statement *Statement) genDelIndexSQL() []string { func (statement *Statement) genDelIndexSQL() []string {
var sqls []string var sqls []string
tbName := statement.TableName() tbName := statement.TableName()
idxPrefixName := strings.Replace(tbName, `"`, "", -1)
idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
for idxName, index := range statement.RefTable.Indexes { for idxName, index := range statement.RefTable.Indexes {
var rIdxName string var rIdxName string
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
rIdxName = uniqueName(tbName, idxName) rIdxName = uniqueName(idxPrefixName, idxName)
} else if index.Type == core.IndexType { } else if index.Type == core.IndexType {
rIdxName = indexName(tbName, idxName) rIdxName = indexName(idxPrefixName, idxName)
} }
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName)) sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.TableNameWithSchema(statement.Engine.Quote(rIdxName)))
if statement.Engine.dialect.IndexOnTable() { if statement.Engine.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
} }

View File

@ -202,17 +202,14 @@ func TestExtends(t *testing.T) {
var info UserAndDetail var info UserAndDetail
qt := testEngine.Quote qt := testEngine.Quote
ui := testEngine.GetTableMapper().Obj2Table("Userinfo") ui := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("Userinfo"))
ud := testEngine.GetTableMapper().Obj2Table("Userdetail") ud := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("Userdetail"))
uiid := testEngine.GetTableMapper().Obj2Table("Id") uiid := testEngine.GetTableMapper().Obj2Table("Id")
udid := "detail_id" udid := "detail_id"
sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s",
qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid))
b, err := testEngine.SQL(sql).NoCascade().Get(&info) b, err := testEngine.SQL(sql).NoCascade().Get(&info)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
if !b { if !b {
err = errors.New("should has lest one record") err = errors.New("should has lest one record")
t.Error(err) t.Error(err)
@ -341,19 +338,17 @@ func TestExtends2(t *testing.T) {
} }
var mapper = testEngine.GetTableMapper().Obj2Table var mapper = testEngine.GetTableMapper().Obj2Table
userTableName := mapper("MessageUser") var quote = testEngine.Quote
typeTableName := mapper("MessageType") userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser")))
msgTableName := mapper("Message") typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType")))
msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message")))
list := make([]Message, 0) list := make([]Message, 0)
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`"). err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`"). Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`").
Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`"). Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
Find(&list) Find(&list)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
if len(list) != 1 { if len(list) != 1 {
err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) err = errors.New(fmt.Sprintln("should have 1 message, got", len(list)))
@ -406,25 +401,20 @@ func TestExtends3(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
_, err = testEngine.Insert(&msg) _, err = testEngine.Insert(&msg)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
var mapper = testEngine.GetTableMapper().Obj2Table var mapper = testEngine.GetTableMapper().Obj2Table
userTableName := mapper("MessageUser") var quote = testEngine.Quote
typeTableName := mapper("MessageType") userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser")))
msgTableName := mapper("Message") typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType")))
msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message")))
list := make([]MessageExtend3, 0) list := make([]MessageExtend3, 0)
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`"). err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`"). Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`").
Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`"). Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
Find(&list) Find(&list)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
if len(list) != 1 { if len(list) != 1 {
err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) err = errors.New(fmt.Sprintln("should have 1 message, got", len(list)))
@ -499,13 +489,14 @@ func TestExtends4(t *testing.T) {
} }
var mapper = testEngine.GetTableMapper().Obj2Table var mapper = testEngine.GetTableMapper().Obj2Table
userTableName := mapper("MessageUser") var quote = testEngine.Quote
typeTableName := mapper("MessageType") userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser")))
msgTableName := mapper("Message") typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType")))
msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message")))
list := make([]MessageExtend4, 0) list := make([]MessageExtend4, 0)
err = testEngine.Table(msgTableName).Join("LEFT", userTableName, "`"+userTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`"). err = testEngine.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
Join("LEFT", typeTableName, "`"+typeTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`"). Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
Find(&list) Find(&list)
if err != nil { if err != nil {
t.Error(err) t.Error(err)

View File

@ -304,7 +304,7 @@ func TestCustomType2(t *testing.T) {
err := testEngine.CreateTables(&UserCus{}) err := testEngine.CreateTables(&UserCus{})
assert.NoError(t, err) assert.NoError(t, err)
tableName := testEngine.GetTableMapper().Obj2Table("UserCus") tableName := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("UserCus"))
_, err = testEngine.Exec("delete from " + testEngine.Quote(tableName)) _, err = testEngine.Exec("delete from " + testEngine.Quote(tableName))
assert.NoError(t, err) assert.NoError(t, err)