diff --git a/dialects/dialect.go b/dialects/dialect.go index a3328e05..3c98a4a5 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -32,8 +32,9 @@ type URI struct { // SetSchema set schema func (uri *URI) SetSchema(schema string) { + // hack me if uri.DBType == schemas.POSTGRES { - uri.Schema = schema + uri.Schema = strings.TrimSpace(schema) } } @@ -43,7 +44,6 @@ type Dialect interface { URI() *URI SQLType(*schemas.Column) string FormatBytes(b []byte) string - DefaultSchema() string IsReserved(string) bool Quoter() schemas.Quoter @@ -83,10 +83,6 @@ func (b *Base) Quoter() schemas.Quoter { return b.quoter } -func (b *Base) DefaultSchema() string { - return "" -} - func (b *Base) Init(dialect Dialect, uri *URI) error { b.dialect, b.uri = dialect, uri return nil diff --git a/dialects/postgres.go b/dialects/postgres.go index a83c3a5c..16213c76 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -770,7 +770,10 @@ var ( postgresQuoter = schemas.Quoter{'"', '"', schemas.AlwaysReserve} ) -const postgresPublicSchema = "public" +var ( + // DefaultPostgresSchema default postgres schema + DefaultPostgresSchema = "public" +) type postgres struct { Base @@ -778,14 +781,14 @@ type postgres struct { func (db *postgres) Init(uri *URI) error { db.quoter = postgresQuoter - err := db.Base.Init(db, uri) - if err != nil { - return err + return db.Base.Init(db, uri) +} + +func (db *postgres) getSchema() string { + if db.uri.Schema != "" { + return db.uri.Schema } - if db.uri.Schema == "" { - db.uri.Schema = postgresPublicSchema - } - return nil + return DefaultPostgresSchema } func (db *postgres) needQuote(name string) bool { @@ -817,10 +820,6 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { } } -func (db *postgres) DefaultSchema() string { - return postgresPublicSchema -} - func (db *postgres) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { @@ -932,32 +931,32 @@ func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]st } func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { - if len(db.uri.Schema) == 0 { + if len(db.getSchema()) == 0 { args := []interface{}{tableName, idxName} return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args } - args := []interface{}{db.uri.Schema, tableName, idxName} + args := []interface{}{db.getSchema(), tableName, idxName} return `SELECT indexname FROM pg_indexes ` + `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { - if len(db.uri.Schema) == 0 { + if len(db.getSchema()) == 0 { return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName) } return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`, - db.uri.Schema, tableName) + db.getSchema(), tableName) } func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string { - if len(db.uri.Schema) == 0 || strings.Contains(tableName, ".") { + if len(db.getSchema()) == 0 || strings.Contains(tableName, ".") { return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", tableName, col.Name, db.SQLType(col)) } return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", - db.uri.Schema, tableName, col.Name, db.SQLType(col)) + db.getSchema(), tableName, col.Name, db.SQLType(col)) } func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string { @@ -974,17 +973,17 @@ func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) } } - if db.uri.Schema != "" { - idxName = db.uri.Schema + "." + idxName + if db.getSchema() != "" { + idxName = db.getSchema() + "." + idxName } return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { - args := []interface{}{db.uri.Schema, tableName, colName} + args := []interface{}{db.getSchema(), tableName, colName} query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + " AND column_name = $3" - if len(db.uri.Schema) == 0 { + if len(db.getSchema()) == 0 { args = []interface{}{tableName, colName} query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + " AND column_name = $2" @@ -1000,7 +999,7 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab } func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { - args := []interface{}{db.uri.Schema, tableName, db.uri.Schema} + args := []interface{}{tableName} s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey @@ -1011,7 +1010,15 @@ FROM pg_attribute f LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name -WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_schema = $3 AND f.attnum > 0 ORDER BY f.attnum;` +WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` + + schema := db.getSchema() + if schema != "" { + s = fmt.Sprintf(s, "AND s.table_schema = $2") + args = append(args, schema) + } else { + s = fmt.Sprintf(s, "") + } rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { @@ -1135,8 +1142,9 @@ WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_sch func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT tablename FROM pg_tables" - if len(db.uri.Schema) != 0 { - args = append(args, db.uri.Schema) + schema := db.getSchema() + if schema != "" { + args = append(args, schema) s = s + " WHERE schemaname = $1" } @@ -1174,8 +1182,8 @@ func getIndexColName(indexdef string) []string { func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") - if len(db.uri.Schema) != 0 { - args = append(args, db.uri.Schema) + if len(db.getSchema()) != 0 { + args = append(args, db.getSchema()) s = s + " AND schemaname=$2" } @@ -1319,3 +1327,22 @@ func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*URI, error) { } return pgx.pqDriver.Parse(driverName, dataSourceName) } + +// QueryDefaultPostgresSchema returns the default postgres schema +func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (string, error) { + rows, err := queryer.QueryContext(ctx, "SHOW SEARCH_PATH") + if err != nil { + return "", err + } + defer rows.Close() + if rows.Next() { + var defaultSchema string + if err = rows.Scan(&defaultSchema); err != nil { + return "", err + } + parts := strings.Split(defaultSchema, ",") + return strings.TrimSpace(parts[len(parts)-1]), nil + } + + return "", errors.New("No default schema") +} diff --git a/dialects/table_name.go b/dialects/table_name.go index a989b386..e190cd4b 100644 --- a/dialects/table_name.go +++ b/dialects/table_name.go @@ -18,7 +18,6 @@ func TableNameWithSchema(dialect Dialect, tableName string) string { // Add schema name as prefix of table name. // Only for postgres database. if dialect.URI().Schema != "" && - dialect.URI().Schema != dialect.DefaultSchema() && strings.Index(tableName, ".") == -1 { return fmt.Sprintf("%s.%s", dialect.URI().Schema, tableName) }