From 32b224abe5a9c8832e2ccc4eedb0f0c283994fa5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 25 Mar 2020 13:15:24 +0800 Subject: [PATCH] Add DefaultPostgresSchema back --- dialects/postgres.go | 52 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/dialects/postgres.go b/dialects/postgres.go index cd4751a5..f1da0f2c 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -770,7 +770,10 @@ var ( postgresQuoter = schemas.Quoter{'"', '"', schemas.AlwaysReserve} ) -const postgresPublicSchema = "public" +var ( + // DefaultPostgresSchema default postgres schema + DefaultPostgresSchema = "public" +) type postgres struct { Base @@ -778,14 +781,11 @@ type postgres struct { func (db *postgres) Init(uri *URI) error { db.quoter = postgresQuoter - err := db.Base.Init(db, uri) - if err != nil { - return err + if uri.Schema == "" { + uri.Schema = DefaultPostgresSchema } - if db.uri.Schema == "" { - db.uri.Schema = postgresPublicSchema - } - return nil + + return db.Base.Init(db, uri) } func (db *postgres) needQuote(name string) bool { @@ -996,7 +996,7 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab } func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { - args := []interface{}{db.uri.Schema, tableName, db.uri.Schema} + args := []interface{}{tableName} s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey @@ -1007,7 +1007,15 @@ FROM pg_attribute f LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name -WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_schema = $3 AND f.attnum > 0 ORDER BY f.attnum;` +WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` + + schema := db.uri.Schema + if schema != "" { + s = fmt.Sprintf(s, "AND s.table_schema = $2") + args = append(args, schema) + } else { + s = fmt.Sprintf(s, "") + } rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { @@ -1131,8 +1139,9 @@ WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_sch func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT tablename FROM pg_tables" - if len(db.uri.Schema) != 0 { - args = append(args, db.uri.Schema) + schema := db.uri.Schema + if schema != "" { + args = append(args, schema) s = s + " WHERE schemaname = $1" } @@ -1315,3 +1324,22 @@ func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*URI, error) { } return pgx.pqDriver.Parse(driverName, dataSourceName) } + +// QueryDefaultPostgresSchema returns the default postgres schema +func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (string, error) { + rows, err := queryer.QueryContext(ctx, "SHOW SEARCH_PATH") + if err != nil { + return "", err + } + defer rows.Close() + if rows.Next() { + var defaultSchema string + if err = rows.Scan(&defaultSchema); err != nil { + return "", err + } + parts := strings.Split(defaultSchema, ",") + return strings.TrimSpace(parts[len(parts)-1]), nil + } + + return "", errors.New("No default schema") +}