Fix postgres schema problem (#1624)

Fix postgres

Add DefaultPostgresSchema back

force push

Fix postgres schema problem

Reviewed-on: https://gitea.com/xorm/xorm/pulls/1624
This commit is contained in:
Lunny Xiao 2020-03-25 09:36:45 +00:00
parent 6132eea08c
commit 0a3685be83
3 changed files with 57 additions and 35 deletions

View File

@ -32,8 +32,9 @@ type URI struct {
// SetSchema set schema // SetSchema set schema
func (uri *URI) SetSchema(schema string) { func (uri *URI) SetSchema(schema string) {
// hack me
if uri.DBType == schemas.POSTGRES { if uri.DBType == schemas.POSTGRES {
uri.Schema = schema uri.Schema = strings.TrimSpace(schema)
} }
} }
@ -43,7 +44,6 @@ type Dialect interface {
URI() *URI URI() *URI
SQLType(*schemas.Column) string SQLType(*schemas.Column) string
FormatBytes(b []byte) string FormatBytes(b []byte) string
DefaultSchema() string
IsReserved(string) bool IsReserved(string) bool
Quoter() schemas.Quoter Quoter() schemas.Quoter
@ -83,10 +83,6 @@ func (b *Base) Quoter() schemas.Quoter {
return b.quoter return b.quoter
} }
func (b *Base) DefaultSchema() string {
return ""
}
func (b *Base) Init(dialect Dialect, uri *URI) error { func (b *Base) Init(dialect Dialect, uri *URI) error {
b.dialect, b.uri = dialect, uri b.dialect, b.uri = dialect, uri
return nil return nil

View File

@ -770,7 +770,10 @@ var (
postgresQuoter = schemas.Quoter{'"', '"', schemas.AlwaysReserve} postgresQuoter = schemas.Quoter{'"', '"', schemas.AlwaysReserve}
) )
const postgresPublicSchema = "public" var (
// DefaultPostgresSchema default postgres schema
DefaultPostgresSchema = "public"
)
type postgres struct { type postgres struct {
Base Base
@ -778,14 +781,14 @@ type postgres struct {
func (db *postgres) Init(uri *URI) error { func (db *postgres) Init(uri *URI) error {
db.quoter = postgresQuoter db.quoter = postgresQuoter
err := db.Base.Init(db, uri) return db.Base.Init(db, uri)
if err != nil {
return err
} }
if db.uri.Schema == "" {
db.uri.Schema = postgresPublicSchema func (db *postgres) getSchema() string {
if db.uri.Schema != "" {
return db.uri.Schema
} }
return nil return DefaultPostgresSchema
} }
func (db *postgres) needQuote(name string) bool { func (db *postgres) needQuote(name string) bool {
@ -817,10 +820,6 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) {
} }
} }
func (db *postgres) DefaultSchema() string {
return postgresPublicSchema
}
func (db *postgres) SQLType(c *schemas.Column) string { func (db *postgres) SQLType(c *schemas.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
@ -932,32 +931,32 @@ func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]st
} }
func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
if len(db.uri.Schema) == 0 { if len(db.getSchema()) == 0 {
args := []interface{}{tableName, idxName} args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args
} }
args := []interface{}{db.uri.Schema, tableName, idxName} args := []interface{}{db.getSchema(), tableName, idxName}
return `SELECT indexname FROM pg_indexes ` + return `SELECT indexname FROM pg_indexes ` +
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
} }
func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
if len(db.uri.Schema) == 0 { if len(db.getSchema()) == 0 {
return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName) return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName)
} }
return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`, return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`,
db.uri.Schema, tableName) db.getSchema(), tableName)
} }
func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string {
if len(db.uri.Schema) == 0 || strings.Contains(tableName, ".") { if len(db.getSchema()) == 0 || strings.Contains(tableName, ".") {
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
tableName, col.Name, db.SQLType(col)) tableName, col.Name, db.SQLType(col))
} }
return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s",
db.uri.Schema, tableName, col.Name, db.SQLType(col)) db.getSchema(), tableName, col.Name, db.SQLType(col))
} }
func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string { func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string {
@ -974,17 +973,17 @@ func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
} }
} }
if db.uri.Schema != "" { if db.getSchema() != "" {
idxName = db.uri.Schema + "." + idxName idxName = db.getSchema() + "." + idxName
} }
return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName))
} }
func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
args := []interface{}{db.uri.Schema, tableName, colName} args := []interface{}{db.getSchema(), tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
" AND column_name = $3" " AND column_name = $3"
if len(db.uri.Schema) == 0 { if len(db.getSchema()) == 0 {
args = []interface{}{tableName, colName} args = []interface{}{tableName, colName}
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2" " AND column_name = $2"
@ -1000,7 +999,7 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab
} }
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{db.uri.Schema, tableName, db.uri.Schema} args := []interface{}{tableName}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
@ -1011,7 +1010,15 @@ FROM pg_attribute f
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN pg_class AS g ON p.confrelid = g.oid
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_schema = $3 AND f.attnum > 0 ORDER BY f.attnum;` WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
schema := db.getSchema()
if schema != "" {
s = fmt.Sprintf(s, "AND s.table_schema = $2")
args = append(args, schema)
} else {
s = fmt.Sprintf(s, "")
}
rows, err := queryer.QueryContext(ctx, s, args...) rows, err := queryer.QueryContext(ctx, s, args...)
if err != nil { if err != nil {
@ -1135,8 +1142,9 @@ WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_sch
func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT tablename FROM pg_tables" s := "SELECT tablename FROM pg_tables"
if len(db.uri.Schema) != 0 { schema := db.getSchema()
args = append(args, db.uri.Schema) if schema != "" {
args = append(args, schema)
s = s + " WHERE schemaname = $1" s = s + " WHERE schemaname = $1"
} }
@ -1174,8 +1182,8 @@ func getIndexColName(indexdef string) []string {
func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
if len(db.uri.Schema) != 0 { if len(db.getSchema()) != 0 {
args = append(args, db.uri.Schema) args = append(args, db.getSchema())
s = s + " AND schemaname=$2" s = s + " AND schemaname=$2"
} }
@ -1319,3 +1327,22 @@ func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*URI, error) {
} }
return pgx.pqDriver.Parse(driverName, dataSourceName) return pgx.pqDriver.Parse(driverName, dataSourceName)
} }
// QueryDefaultPostgresSchema returns the default postgres schema
func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (string, error) {
rows, err := queryer.QueryContext(ctx, "SHOW SEARCH_PATH")
if err != nil {
return "", err
}
defer rows.Close()
if rows.Next() {
var defaultSchema string
if err = rows.Scan(&defaultSchema); err != nil {
return "", err
}
parts := strings.Split(defaultSchema, ",")
return strings.TrimSpace(parts[len(parts)-1]), nil
}
return "", errors.New("No default schema")
}

View File

@ -18,7 +18,6 @@ func TableNameWithSchema(dialect Dialect, tableName string) string {
// Add schema name as prefix of table name. // Add schema name as prefix of table name.
// Only for postgres database. // Only for postgres database.
if dialect.URI().Schema != "" && if dialect.URI().Schema != "" &&
dialect.URI().Schema != dialect.DefaultSchema() &&
strings.Index(tableName, ".") == -1 { strings.Index(tableName, ".") == -1 {
return fmt.Sprintf("%s.%s", dialect.URI().Schema, tableName) return fmt.Sprintf("%s.%s", dialect.URI().Schema, tableName)
} }