Add DefaultPostgresSchema back
This commit is contained in:
parent
414ca15269
commit
32b224abe5
|
@ -770,7 +770,10 @@ var (
|
||||||
postgresQuoter = schemas.Quoter{'"', '"', schemas.AlwaysReserve}
|
postgresQuoter = schemas.Quoter{'"', '"', schemas.AlwaysReserve}
|
||||||
)
|
)
|
||||||
|
|
||||||
const postgresPublicSchema = "public"
|
var (
|
||||||
|
// DefaultPostgresSchema default postgres schema
|
||||||
|
DefaultPostgresSchema = "public"
|
||||||
|
)
|
||||||
|
|
||||||
type postgres struct {
|
type postgres struct {
|
||||||
Base
|
Base
|
||||||
|
@ -778,14 +781,11 @@ type postgres struct {
|
||||||
|
|
||||||
func (db *postgres) Init(uri *URI) error {
|
func (db *postgres) Init(uri *URI) error {
|
||||||
db.quoter = postgresQuoter
|
db.quoter = postgresQuoter
|
||||||
err := db.Base.Init(db, uri)
|
if uri.Schema == "" {
|
||||||
if err != nil {
|
uri.Schema = DefaultPostgresSchema
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
if db.uri.Schema == "" {
|
|
||||||
db.uri.Schema = postgresPublicSchema
|
return db.Base.Init(db, uri)
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) needQuote(name string) bool {
|
func (db *postgres) needQuote(name string) bool {
|
||||||
|
@ -996,7 +996,7 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||||
args := []interface{}{db.uri.Schema, tableName, db.uri.Schema}
|
args := []interface{}{tableName}
|
||||||
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
|
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
|
||||||
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
|
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
|
||||||
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
|
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
|
||||||
|
@ -1007,7 +1007,15 @@ FROM pg_attribute f
|
||||||
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
|
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
|
||||||
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
|
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
|
||||||
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
|
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
|
||||||
WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_schema = $3 AND f.attnum > 0 ORDER BY f.attnum;`
|
WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
|
||||||
|
|
||||||
|
schema := db.uri.Schema
|
||||||
|
if schema != "" {
|
||||||
|
s = fmt.Sprintf(s, "AND s.table_schema = $2")
|
||||||
|
args = append(args, schema)
|
||||||
|
} else {
|
||||||
|
s = fmt.Sprintf(s, "")
|
||||||
|
}
|
||||||
|
|
||||||
rows, err := queryer.QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1131,8 +1139,9 @@ WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_sch
|
||||||
func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
s := "SELECT tablename FROM pg_tables"
|
s := "SELECT tablename FROM pg_tables"
|
||||||
if len(db.uri.Schema) != 0 {
|
schema := db.uri.Schema
|
||||||
args = append(args, db.uri.Schema)
|
if schema != "" {
|
||||||
|
args = append(args, schema)
|
||||||
s = s + " WHERE schemaname = $1"
|
s = s + " WHERE schemaname = $1"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1315,3 +1324,22 @@ func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
}
|
}
|
||||||
return pgx.pqDriver.Parse(driverName, dataSourceName)
|
return pgx.pqDriver.Parse(driverName, dataSourceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryDefaultPostgresSchema returns the default postgres schema
|
||||||
|
func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (string, error) {
|
||||||
|
rows, err := queryer.QueryContext(ctx, "SHOW SEARCH_PATH")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if rows.Next() {
|
||||||
|
var defaultSchema string
|
||||||
|
if err = rows.Scan(&defaultSchema); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
parts := strings.Split(defaultSchema, ",")
|
||||||
|
return strings.TrimSpace(parts[len(parts)-1]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", errors.New("No default schema")
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue