diff --git a/dialect_postgres.go b/dialect_postgres.go index 3cf2b9bb..2b2a0b78 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -771,12 +771,17 @@ var ( type postgres struct { core.Base - schema string } func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - db.schema = DefaultPostgresSchema - return db.Base.Init(d, db, uri, drivername, dataSourceName) + err := db.Base.Init(d, db, uri, drivername, dataSourceName) + if err != nil { + return err + } + if db.Schema == "" { + db.Schema = DefaultPostgresSchema + } + return nil } func (db *postgres) SqlType(c *core.Column) string { @@ -873,23 +878,35 @@ func (db *postgres) IndexOnTable() bool { } func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{db.schema, tableName, idxName} + if len(db.Schema) == 0 { + args := []interface{}{tableName, idxName} + return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args + } + + args := []interface{}{db.Schema, tableName, idxName} return `SELECT indexname FROM pg_indexes ` + `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{db.schema, tableName} + if len(db.Schema) == 0 { + args := []interface{}{tableName} + return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args + } + args := []interface{}{db.Schema, tableName} return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args } func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { + if len(db.Schema) == 0 { + return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", + tableName, col.Name, db.SqlType(col)) + } return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", - db.schema, tableName, col.Name, db.SqlType(col)) + db.Schema, tableName, col.Name, db.SqlType(col)) } func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { - //var unique string quote := db.Quote idxName := index.Name @@ -905,9 +922,14 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { } func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { - args := []interface{}{db.schema, tableName, colName} + args := []interface{}{db.Schema, tableName, colName} query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + " AND column_name = $3" + if len(db.Schema) == 0 { + args = []interface{}{tableName, colName} + query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + + " AND column_name = $2" + } db.LogSQL(query, args) rows, err := db.DB().Query(query, args...) @@ -920,8 +942,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { } func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { - // FIXME: the schema should be replaced by user custom's - args := []interface{}{tableName, db.schema} + args := []interface{}{tableName} s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey @@ -932,7 +953,15 @@ FROM pg_attribute f LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name -WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;` +WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` + + var f string + if len(db.Schema) != 0 { + args = append(args, db.Schema) + f = "AND s.table_schema = $2" + } + s = fmt.Sprintf(s, f) + db.LogSQL(s, args) rows, err := db.DB().Query(s, args...) @@ -1022,8 +1051,13 @@ WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.att } func (db *postgres) GetTables() ([]*core.Table, error) { - args := []interface{}{db.schema} - s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1") + args := []interface{}{} + s := "SELECT tablename FROM pg_tables" + if len(db.Schema) != 0 { + args = append(args, db.Schema) + s = s + " WHERE schemaname = $1" + } + db.LogSQL(s, args) rows, err := db.DB().Query(s, args...) @@ -1047,9 +1081,13 @@ func (db *postgres) GetTables() ([]*core.Table, error) { } func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { - args := []interface{}{db.schema, tableName} - s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2") + args := []interface{}{tableName} + s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") db.LogSQL(s, args) + if len(db.Schema) != 0 { + args = append(args, db.Schema) + s = s + " AND schemaname=$2" + } rows, err := db.DB().Query(s, args...) if err != nil {