Rebase codes

This commit is contained in:
Lunny Xiao 2020-09-11 15:09:05 +08:00
parent 922be56e32
commit e7df46bb60
3 changed files with 47 additions and 35 deletions

View File

@ -5,6 +5,7 @@
package dialects package dialects
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
@ -16,18 +17,20 @@ import (
var ( var (
db2ReservedWords = map[string]bool{} db2ReservedWords = map[string]bool{}
db2Quoter = schemas.Quoter{
Prefix: '"',
Suffix: '"',
IsReserved: schemas.AlwaysReserve,
}
) )
type db2 struct { type db2 struct {
Base Base
} }
func (db *db2) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { func (db *db2) Init(uri *URI) error {
err := db.Base.Init(d, db, uri, drivername, dataSourceName) db.quoter = db2Quoter
if err != nil { return db.Base.Init(db, uri)
return err
}
return nil
} }
func (db *db2) SQLType(c *schemas.Column) string { func (db *db2) SQLType(c *schemas.Column) string {
@ -81,10 +84,6 @@ func (db *db2) IsReserved(name string) bool {
return ok return ok
} }
func (db *db2) Quoter() schemas.Quoter {
return schemas.Quoter{"\"", "\""}
}
func (db *db2) AutoIncrStr() string { func (db *db2) AutoIncrStr() string {
return "" return ""
} }
@ -101,7 +100,7 @@ func (db *db2) IndexOnTable() bool {
return false return false
} }
func (db *db2) CreateTableSql(table *schemas.Table, tableName, storeEngine, charset string) string { func (db *db2) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
var sql string var sql string
sql = "CREATE TABLE " sql = "CREATE TABLE "
if tableName == "" { if tableName == "" {
@ -114,7 +113,8 @@ func (db *db2) CreateTableSql(table *schemas.Table, tableName, storeEngine, char
for _, colName := range table.ColumnsSeq() { for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
sql += StringNoPk(db, col) s, _ := ColumnString(db, col, false)
sql += s
if col.IsAutoIncrement { if col.IsAutoIncrement {
sql += " GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1 )" sql += " GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1 )"
} }
@ -129,7 +129,7 @@ func (db *db2) CreateTableSql(table *schemas.Table, tableName, storeEngine, char
} }
sql = sql[:len(sql)-2] + ")" sql = sql[:len(sql)-2] + ")"
return sql return []string{sql}, false
} }
func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
@ -143,14 +143,30 @@ func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{})
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
} }
func (db *db2) TableCheckSQL(tableName string) (string, []interface{}) { func (db *db2) SetQuotePolicy(quotePolicy QuotePolicy) {
if len(db.uri.Schema) == 0 { switch quotePolicy {
args := []interface{}{tableName} case QuotePolicyNone:
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args var q = oracleQuoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = oracleQuoter
q.IsReserved = db.IsReserved
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = oracleQuoter
} }
}
args := []interface{}{db.uri.Schema, tableName} func (db *db2) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args if len(db.uri.Schema) == 0 {
return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = ?`, tableName)
}
return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`,
db.uri.Schema, tableName,
)
} }
func (db *db2) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *db2) ModifyColumnSQL(tableName string, col *schemas.Column) string {
@ -183,7 +199,7 @@ func (db *db2) DropIndexSQL(tableName string, index *schemas.Index) string {
return fmt.Sprintf("DROP INDEX %v", quote(idxName)) return fmt.Sprintf("DROP INDEX %v", quote(idxName))
} }
func (db *db2) IsColumnExist(tableName, colName string) (bool, error) { func (db *db2) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
args := []interface{}{db.uri.Schema, tableName, colName} args := []interface{}{db.uri.Schema, tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
" AND column_name = $3" " AND column_name = $3"
@ -192,9 +208,8 @@ func (db *db2) IsColumnExist(tableName, colName string) (bool, error) {
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2" " AND column_name = $2"
} }
db.LogSQL(query, args)
rows, err := db.DB().Query(query, args...) rows, err := queryer.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -203,7 +218,7 @@ func (db *db2) IsColumnExist(tableName, colName string) (bool, error) {
return rows.Next(), nil return rows.Next(), nil
} }
func (db *db2) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) { func (db *db2) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := `Select c.colname as column_name, s := `Select c.colname as column_name,
c.colno as position, c.colno as position,
@ -228,9 +243,7 @@ where t.type = 'T' AND c.tabname = ?`
} }
s = s + f s = s + f
db.LogSQL(s, args) rows, err := queryer.QueryContext(ctx, s, args...)
rows, err := db.DB().Query(s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -310,7 +323,7 @@ where t.type = 'T' AND c.tabname = ?`
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *db2) GetTables() ([]*schemas.Table, error) { func (db *db2) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT TABNAME FROM SYSCAT.TABLES WHERE type = 'T' AND OWNERTYPE = 'U'" s := "SELECT TABNAME FROM SYSCAT.TABLES WHERE type = 'T' AND OWNERTYPE = 'U'"
if len(db.uri.Schema) != 0 { if len(db.uri.Schema) != 0 {
@ -318,9 +331,7 @@ func (db *db2) GetTables() ([]*schemas.Table, error) {
s = s + " AND TABSCHEMA = ?" s = s + " AND TABSCHEMA = ?"
} }
db.LogSQL(s, args) rows, err := queryer.QueryContext(ctx, s, args...)
rows, err := db.DB().Query(s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -340,7 +351,7 @@ func (db *db2) GetTables() ([]*schemas.Table, error) {
return tables, nil return tables, nil
} }
func (db *db2) GetIndexes(tableName string) (map[string]*schemas.Index, error) { func (db *db2) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := fmt.Sprintf(`select uniquerule, s := fmt.Sprintf(`select uniquerule,
indname as index_name, indname as index_name,
@ -350,9 +361,8 @@ from syscat.indexes WHERE tabname = ?`)
args = append(args, db.uri.Schema) args = append(args, db.uri.Schema)
s = s + " AND tabschema=?" s = s + " AND tabschema=?"
} }
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...) rows, err := queryer.QueryContext(ctx, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -399,7 +409,7 @@ from syscat.indexes WHERE tabname = ?`)
} }
func (db *db2) Filters() []Filter { func (db *db2) Filters() []Filter {
return []Filter{&QuoteFilter{}} return []Filter{}
} }
type db2Driver struct{} type db2Driver struct{}

2
go.mod
View File

@ -7,7 +7,7 @@ require (
github.com/go-sql-driver/mysql v1.6.0 github.com/go-sql-driver/mysql v1.6.0
github.com/goccy/go-json v0.7.4 github.com/goccy/go-json v0.7.4
github.com/jackc/pgx/v4 v4.12.0 github.com/jackc/pgx/v4 v4.12.0
github.com/ibmdb/go_ibm_db v0.1.0 github.com/ibmdb/go_ibm_db v0.3.0
github.com/json-iterator/go v1.1.11 github.com/json-iterator/go v1.1.11
github.com/lib/pq v1.10.2 github.com/lib/pq v1.10.2
github.com/mattn/go-sqlite3 v1.14.8 github.com/mattn/go-sqlite3 v1.14.8

2
go.sum
View File

@ -205,6 +205,8 @@ github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u
github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/ibmdb/go_ibm_db v0.1.0 h1:Ok7W7wysBUa8eyVYxWLS5vIA0VomTsurK57l5Rah1M8= github.com/ibmdb/go_ibm_db v0.1.0 h1:Ok7W7wysBUa8eyVYxWLS5vIA0VomTsurK57l5Rah1M8=
github.com/ibmdb/go_ibm_db v0.1.0/go.mod h1:nl5aUh1IzBVExcqYXaZLApaq8RUvTEph3VP49UTmEvg= github.com/ibmdb/go_ibm_db v0.1.0/go.mod h1:nl5aUh1IzBVExcqYXaZLApaq8RUvTEph3VP49UTmEvg=
github.com/ibmdb/go_ibm_db v0.3.0 h1:KCSVFS9eXmlTEFL8ScyROsYWmP02G3eGce7VRAt4Csk=
github.com/ibmdb/go_ibm_db v0.3.0/go.mod h1:nl5aUh1IzBVExcqYXaZLApaq8RUvTEph3VP49UTmEvg=
github.com/json-iterator/go v1.1.11 h1:uVUAXhF2To8cbw/3xN3pxj6kk7TYKs98NIrTqPlMWAQ= github.com/json-iterator/go v1.1.11 h1:uVUAXhF2To8cbw/3xN3pxj6kk7TYKs98NIrTqPlMWAQ=
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=