fix schema support

This commit is contained in:
Lunny Xiao 2018-04-09 23:18:31 +08:00
parent f90ef0062d
commit a21c00771f
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
11 changed files with 71 additions and 79 deletions

View File

@ -913,6 +913,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
quote := db.Quote
idxName := index.Name
tableName = strings.Replace(tableName, `"`, "", -1)
tableName = strings.Replace(tableName, `.`, "_", -1)
if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") {
if index.Type == core.UniqueType {
@ -921,6 +924,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
}
}
if db.Uri.Schema != "" {
idxName = db.Uri.Schema + "." + idxName
}
return fmt.Sprintf("DROP INDEX %v", quote(idxName))
}

View File

@ -1185,7 +1185,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
for _, bean := range beans {
v := rValue(bean)
tableNameNoSchema := engine.tbNameNoSchemaString(v.Interface())
tableNameNoSchema := engine.tbNameNoSchema(v.Interface())
table, err := engine.autoMapType(v)
if err != nil {
return err

View File

@ -5,9 +5,7 @@
package xorm
import (
"bytes"
"fmt"
"io"
"reflect"
"strings"
@ -28,7 +26,7 @@ func (engine *Engine) TableNameWithSchema(v string) string {
}
func (engine *Engine) tableName(bean interface{}) string {
return engine.TableNameWithSchema(engine.tbNameNoSchemaString(bean))
return engine.TableNameWithSchema(engine.tbNameNoSchema(bean))
}
func (engine *Engine) tbNameForMap(v reflect.Value) string {
@ -44,14 +42,14 @@ func (engine *Engine) tbNameForMap(v reflect.Value) string {
return engine.TableMapper.Obj2Table(t.Name())
}
func (engine *Engine) tbNameNoSchema(w io.Writer, tablename interface{}) {
func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
fmt.Fprintf(w, "%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
} else if len(t) == 1 {
fmt.Fprintf(w, engine.Quote(t[0]))
return engine.Quote(t[0])
}
case []interface{}:
t := tablename.([]interface{})
@ -68,35 +66,29 @@ func (engine *Engine) tbNameNoSchema(w io.Writer, tablename interface{}) {
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(w, engine.tbNameForMap(v))
table = engine.tbNameForMap(v)
} else {
fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", f)))
table = engine.Quote(fmt.Sprintf("%v", f))
}
}
}
if l > 1 {
fmt.Fprintf(w, "%v AS %v", engine.Quote(table),
return fmt.Sprintf("%v AS %v", engine.Quote(table),
engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
fmt.Fprintf(w, engine.Quote(table))
return engine.Quote(table)
}
case TableName:
fmt.Fprintf(w, tablename.(TableName).TableName())
return tablename.(TableName).TableName()
case string:
fmt.Fprintf(w, tablename.(string))
return tablename.(string)
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(w, engine.tbNameForMap(v))
} else {
fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", tablename)))
return engine.tbNameForMap(v)
}
return engine.Quote(fmt.Sprintf("%v", tablename))
}
}
func (engine *Engine) tbNameNoSchemaString(tablename interface{}) string {
var buf bytes.Buffer
engine.tbNameNoSchema(&buf, tablename)
return buf.String()
return ""
}

View File

@ -54,11 +54,11 @@ func TestExistStruct(t *testing.T) {
assert.NoError(t, err)
assert.False(t, has)
has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist()
has, err = testEngine.SQL("select * from "+testEngine.TableNameWithSchema("record_exist")+" where name = ?", "test1").Exist()
assert.NoError(t, err)
assert.True(t, has)
has, err = testEngine.SQL("select * from record_exist where name = ?", "test2").Exist()
has, err = testEngine.SQL("select * from "+testEngine.TableNameWithSchema("record_exist")+" where name = ?", "test2").Exist()
assert.NoError(t, err)
assert.False(t, has)

View File

@ -84,7 +84,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))
var money2 float64
has, err = testEngine.SQL("SELECT money FROM get_var LIMIT 1").Get(&money2)
has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableNameWithSchema("get_var") + " LIMIT 1").Get(&money2)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2))

View File

@ -716,8 +716,9 @@ func (MyUserinfo2) TableName() string {
func TestInsertMulti4(t *testing.T) {
assert.NoError(t, prepareEngine())
testEngine.ShowSQL(true)
testEngine.ShowSQL(false)
assertSync(t, new(MyUserinfo2))
testEngine.ShowSQL(true)
users := []MyUserinfo2{
{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},

View File

@ -122,12 +122,10 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
}
func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName := session.engine.tbNameNoSchemaString(beanOrTableName)
tableName := session.engine.tbNameNoSchema(beanOrTableName)
var needDrop = true
if !session.engine.dialect.SupportDropIfExists() {
fmt.Println("TableCheckSql:", tableName)
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
fmt.Println("sqlStr:", sqlStr)
results, err := session.queryBytes(sqlStr, args...)
if err != nil {
return err
@ -149,7 +147,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
defer session.Close()
}
tableName := session.engine.tbNameNoSchemaString(beanOrTableName)
tableName := session.engine.tbNameNoSchema(beanOrTableName)
return session.isTableExist(tableName)
}
@ -165,7 +163,7 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
if session.isAutoClose {
defer session.Close()
}
return session.isTableEmpty(session.engine.tbNameNoSchemaString(bean))
return session.isTableEmpty(session.engine.tbNameNoSchema(bean))
}
func (session *Session) isTableEmpty(tableName string) (bool, error) {

View File

@ -153,7 +153,7 @@ func TestSQLCount(t *testing.T) {
assertSync(t, new(UserinfoCount2), new(UserinfoBooks))
total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2").
total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableNameWithSchema("userinfo_count2")).
Count()
assert.NoError(t, err)
assert.EqualValues(t, 0, total)

View File

@ -231,7 +231,7 @@ func (statement *Statement) setRefBean(bean interface{}) error {
if err != nil {
return err
}
statement.tableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchemaString(bean))
statement.tableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(bean))
return nil
}
@ -748,7 +748,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
}
}
statement.AltTableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchemaString(tableNameOrBean))
statement.AltTableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(tableNameOrBean))
return statement
}
@ -761,9 +761,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt.Fprintf(&buf, "%v JOIN ", joinOP)
}
statement.Engine.tbNameNoSchema(&buf, tablename)
tbName := statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(tablename))
fmt.Fprintf(&buf, " ON %v", condition)
fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...)
return statement
@ -838,11 +838,13 @@ func (statement *Statement) genCreateTableSQL() string {
func (statement *Statement) genIndexSQL() []string {
var sqls []string
tbName := statement.TableName()
quote := statement.Engine.Quote
for idxName, index := range statement.RefTable.Indexes {
for _, index := range statement.RefTable.Indexes {
if index.Type == core.IndexType {
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
quote(tbName), quote(strings.Join(index.Cols, quote(","))))
sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
/*idxTBName := strings.Replace(tbName, ".", "_", -1)
idxTBName = strings.Replace(idxTBName, `"`, "", -1)
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
sqls = append(sqls, sql)
}
}
@ -868,14 +870,16 @@ func (statement *Statement) genUniqueSQL() []string {
func (statement *Statement) genDelIndexSQL() []string {
var sqls []string
tbName := statement.TableName()
idxPrefixName := strings.Replace(tbName, `"`, "", -1)
idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
for idxName, index := range statement.RefTable.Indexes {
var rIdxName string
if index.Type == core.UniqueType {
rIdxName = uniqueName(tbName, idxName)
rIdxName = uniqueName(idxPrefixName, idxName)
} else if index.Type == core.IndexType {
rIdxName = indexName(tbName, idxName)
rIdxName = indexName(idxPrefixName, idxName)
}
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName))
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.TableNameWithSchema(statement.Engine.Quote(rIdxName)))
if statement.Engine.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
}

View File

@ -202,17 +202,14 @@ func TestExtends(t *testing.T) {
var info UserAndDetail
qt := testEngine.Quote
ui := testEngine.GetTableMapper().Obj2Table("Userinfo")
ud := testEngine.GetTableMapper().Obj2Table("Userdetail")
ui := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("Userinfo"))
ud := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("Userdetail"))
uiid := testEngine.GetTableMapper().Obj2Table("Id")
udid := "detail_id"
sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s",
qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid))
b, err := testEngine.SQL(sql).NoCascade().Get(&info)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
if !b {
err = errors.New("should has lest one record")
t.Error(err)
@ -341,19 +338,17 @@ func TestExtends2(t *testing.T) {
}
var mapper = testEngine.GetTableMapper().Obj2Table
userTableName := mapper("MessageUser")
typeTableName := mapper("MessageType")
msgTableName := mapper("Message")
var quote = testEngine.Quote
userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser")))
typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType")))
msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message")))
list := make([]Message, 0)
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`").
Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`").
Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`").
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`").
Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
Find(&list)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
if len(list) != 1 {
err = errors.New(fmt.Sprintln("should have 1 message, got", len(list)))
@ -406,25 +401,20 @@ func TestExtends3(t *testing.T) {
assert.NoError(t, err)
}
_, err = testEngine.Insert(&msg)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
var mapper = testEngine.GetTableMapper().Obj2Table
userTableName := mapper("MessageUser")
typeTableName := mapper("MessageType")
msgTableName := mapper("Message")
var quote = testEngine.Quote
userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser")))
typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType")))
msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message")))
list := make([]MessageExtend3, 0)
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`").
Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`").
Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`").
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`").
Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
Find(&list)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
if len(list) != 1 {
err = errors.New(fmt.Sprintln("should have 1 message, got", len(list)))
@ -499,13 +489,14 @@ func TestExtends4(t *testing.T) {
}
var mapper = testEngine.GetTableMapper().Obj2Table
userTableName := mapper("MessageUser")
typeTableName := mapper("MessageType")
msgTableName := mapper("Message")
var quote = testEngine.Quote
userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser")))
typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType")))
msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message")))
list := make([]MessageExtend4, 0)
err = testEngine.Table(msgTableName).Join("LEFT", userTableName, "`"+userTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`").
Join("LEFT", typeTableName, "`"+typeTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`").
err = testEngine.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
Find(&list)
if err != nil {
t.Error(err)

View File

@ -304,7 +304,7 @@ func TestCustomType2(t *testing.T) {
err := testEngine.CreateTables(&UserCus{})
assert.NoError(t, err)
tableName := testEngine.GetTableMapper().Obj2Table("UserCus")
tableName := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("UserCus"))
_, err = testEngine.Exec("delete from " + testEngine.Quote(tableName))
assert.NoError(t, err)