From e7df46bb606fe3acca9f38858965ce05ab7491bd Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 11 Sep 2020 15:09:05 +0800 Subject: [PATCH] Rebase codes --- dialects/db2.go | 78 ++++++++++++++++++++++++++++--------------------- go.mod | 2 +- go.sum | 2 ++ 3 files changed, 47 insertions(+), 35 deletions(-) diff --git a/dialects/db2.go b/dialects/db2.go index 8f146100..1b2f0038 100644 --- a/dialects/db2.go +++ b/dialects/db2.go @@ -5,6 +5,7 @@ package dialects import ( + "context" "errors" "fmt" "strconv" @@ -16,18 +17,20 @@ import ( var ( db2ReservedWords = map[string]bool{} + db2Quoter = schemas.Quoter{ + Prefix: '"', + Suffix: '"', + IsReserved: schemas.AlwaysReserve, + } ) type db2 struct { Base } -func (db *db2) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - err := db.Base.Init(d, db, uri, drivername, dataSourceName) - if err != nil { - return err - } - return nil +func (db *db2) Init(uri *URI) error { + db.quoter = db2Quoter + return db.Base.Init(db, uri) } func (db *db2) SQLType(c *schemas.Column) string { @@ -81,10 +84,6 @@ func (db *db2) IsReserved(name string) bool { return ok } -func (db *db2) Quoter() schemas.Quoter { - return schemas.Quoter{"\"", "\""} -} - func (db *db2) AutoIncrStr() string { return "" } @@ -101,7 +100,7 @@ func (db *db2) IndexOnTable() bool { return false } -func (db *db2) CreateTableSql(table *schemas.Table, tableName, storeEngine, charset string) string { +func (db *db2) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { var sql string sql = "CREATE TABLE " if tableName == "" { @@ -114,7 +113,8 @@ func (db *db2) CreateTableSql(table *schemas.Table, tableName, storeEngine, char for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - sql += StringNoPk(db, col) + s, _ := ColumnString(db, col, false) + sql += s if col.IsAutoIncrement { sql += " GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1 )" } @@ -129,7 +129,7 @@ func (db *db2) CreateTableSql(table *schemas.Table, tableName, storeEngine, char } sql = sql[:len(sql)-2] + ")" - return sql + return []string{sql}, false } func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { @@ -143,14 +143,30 @@ func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{}) `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } -func (db *db2) TableCheckSQL(tableName string) (string, []interface{}) { - if len(db.uri.Schema) == 0 { - args := []interface{}{tableName} - return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args +func (db *db2) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = oracleQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = oracleQuoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = oracleQuoter } +} - args := []interface{}{db.uri.Schema, tableName} - return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args +func (db *db2) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { + if len(db.uri.Schema) == 0 { + return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = ?`, tableName) + } + return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, + db.uri.Schema, tableName, + ) } func (db *db2) ModifyColumnSQL(tableName string, col *schemas.Column) string { @@ -183,7 +199,7 @@ func (db *db2) DropIndexSQL(tableName string, index *schemas.Index) string { return fmt.Sprintf("DROP INDEX %v", quote(idxName)) } -func (db *db2) IsColumnExist(tableName, colName string) (bool, error) { +func (db *db2) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { args := []interface{}{db.uri.Schema, tableName, colName} query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + " AND column_name = $3" @@ -192,9 +208,8 @@ func (db *db2) IsColumnExist(tableName, colName string) (bool, error) { query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + " AND column_name = $2" } - db.LogSQL(query, args) - rows, err := db.DB().Query(query, args...) + rows, err := queryer.QueryContext(ctx, query, args...) if err != nil { return false, err } @@ -203,7 +218,7 @@ func (db *db2) IsColumnExist(tableName, colName string) (bool, error) { return rows.Next(), nil } -func (db *db2) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *db2) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := `Select c.colname as column_name, c.colno as position, @@ -228,9 +243,7 @@ where t.type = 'T' AND c.tabname = ?` } s = s + f - db.LogSQL(s, args) - - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -310,7 +323,7 @@ where t.type = 'T' AND c.tabname = ?` return colSeq, cols, nil } -func (db *db2) GetTables() ([]*schemas.Table, error) { +func (db *db2) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT TABNAME FROM SYSCAT.TABLES WHERE type = 'T' AND OWNERTYPE = 'U'" if len(db.uri.Schema) != 0 { @@ -318,9 +331,7 @@ func (db *db2) GetTables() ([]*schemas.Table, error) { s = s + " AND TABSCHEMA = ?" } - db.LogSQL(s, args) - - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -340,7 +351,7 @@ func (db *db2) GetTables() ([]*schemas.Table, error) { return tables, nil } -func (db *db2) GetIndexes(tableName string) (map[string]*schemas.Index, error) { +func (db *db2) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := fmt.Sprintf(`select uniquerule, indname as index_name, @@ -350,9 +361,8 @@ from syscat.indexes WHERE tabname = ?`) args = append(args, db.uri.Schema) s = s + " AND tabschema=?" } - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -399,7 +409,7 @@ from syscat.indexes WHERE tabname = ?`) } func (db *db2) Filters() []Filter { - return []Filter{&QuoteFilter{}} + return []Filter{} } type db2Driver struct{} diff --git a/go.mod b/go.mod index 7e2bbc2f..27ee1752 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/go-sql-driver/mysql v1.6.0 github.com/goccy/go-json v0.7.4 github.com/jackc/pgx/v4 v4.12.0 - github.com/ibmdb/go_ibm_db v0.1.0 + github.com/ibmdb/go_ibm_db v0.3.0 github.com/json-iterator/go v1.1.11 github.com/lib/pq v1.10.2 github.com/mattn/go-sqlite3 v1.14.8 diff --git a/go.sum b/go.sum index 58754e7d..e2a816bf 100644 --- a/go.sum +++ b/go.sum @@ -205,6 +205,8 @@ github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/ibmdb/go_ibm_db v0.1.0 h1:Ok7W7wysBUa8eyVYxWLS5vIA0VomTsurK57l5Rah1M8= github.com/ibmdb/go_ibm_db v0.1.0/go.mod h1:nl5aUh1IzBVExcqYXaZLApaq8RUvTEph3VP49UTmEvg= +github.com/ibmdb/go_ibm_db v0.3.0 h1:KCSVFS9eXmlTEFL8ScyROsYWmP02G3eGce7VRAt4Csk= +github.com/ibmdb/go_ibm_db v0.3.0/go.mod h1:nl5aUh1IzBVExcqYXaZLApaq8RUvTEph3VP49UTmEvg= github.com/json-iterator/go v1.1.11 h1:uVUAXhF2To8cbw/3xN3pxj6kk7TYKs98NIrTqPlMWAQ= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=