From 6a71ed36962c1d8ac37ef0a9034b0ab8851904da Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 25 Mar 2020 17:04:43 +0800 Subject: [PATCH] Fix postgres --- dialects/dialect.go | 3 ++- dialects/postgres.go | 39 +++++++++++++++++++++------------------ 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/dialects/dialect.go b/dialects/dialect.go index 4bc220ce..3c98a4a5 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -32,8 +32,9 @@ type URI struct { // SetSchema set schema func (uri *URI) SetSchema(schema string) { + // hack me if uri.DBType == schemas.POSTGRES { - uri.Schema = schema + uri.Schema = strings.TrimSpace(schema) } } diff --git a/dialects/postgres.go b/dialects/postgres.go index f1da0f2c..16213c76 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -781,13 +781,16 @@ type postgres struct { func (db *postgres) Init(uri *URI) error { db.quoter = postgresQuoter - if uri.Schema == "" { - uri.Schema = DefaultPostgresSchema - } - return db.Base.Init(db, uri) } +func (db *postgres) getSchema() string { + if db.uri.Schema != "" { + return db.uri.Schema + } + return DefaultPostgresSchema +} + func (db *postgres) needQuote(name string) bool { if db.IsReserved(name) { return true @@ -928,32 +931,32 @@ func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]st } func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { - if len(db.uri.Schema) == 0 { + if len(db.getSchema()) == 0 { args := []interface{}{tableName, idxName} return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args } - args := []interface{}{db.uri.Schema, tableName, idxName} + args := []interface{}{db.getSchema(), tableName, idxName} return `SELECT indexname FROM pg_indexes ` + `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { - if len(db.uri.Schema) == 0 { + if len(db.getSchema()) == 0 { return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName) } return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`, - db.uri.Schema, tableName) + db.getSchema(), tableName) } func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string { - if len(db.uri.Schema) == 0 || strings.Contains(tableName, ".") { + if len(db.getSchema()) == 0 || strings.Contains(tableName, ".") { return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", tableName, col.Name, db.SQLType(col)) } return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", - db.uri.Schema, tableName, col.Name, db.SQLType(col)) + db.getSchema(), tableName, col.Name, db.SQLType(col)) } func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string { @@ -970,17 +973,17 @@ func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) } } - if db.uri.Schema != "" { - idxName = db.uri.Schema + "." + idxName + if db.getSchema() != "" { + idxName = db.getSchema() + "." + idxName } return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { - args := []interface{}{db.uri.Schema, tableName, colName} + args := []interface{}{db.getSchema(), tableName, colName} query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + " AND column_name = $3" - if len(db.uri.Schema) == 0 { + if len(db.getSchema()) == 0 { args = []interface{}{tableName, colName} query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + " AND column_name = $2" @@ -1009,7 +1012,7 @@ FROM pg_attribute f LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` - schema := db.uri.Schema + schema := db.getSchema() if schema != "" { s = fmt.Sprintf(s, "AND s.table_schema = $2") args = append(args, schema) @@ -1139,7 +1142,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT tablename FROM pg_tables" - schema := db.uri.Schema + schema := db.getSchema() if schema != "" { args = append(args, schema) s = s + " WHERE schemaname = $1" @@ -1179,8 +1182,8 @@ func getIndexColName(indexdef string) []string { func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") - if len(db.uri.Schema) != 0 { - args = append(args, db.uri.Schema) + if len(db.getSchema()) != 0 { + args = append(args, db.getSchema()) s = s + " AND schemaname=$2" }