Rebase codes

This commit is contained in:
Lunny Xiao 2020-09-11 15:09:05 +08:00
parent 922be56e32
commit e7df46bb60
3 changed files with 47 additions and 35 deletions

View File

@ -5,6 +5,7 @@
package dialects
import (
"context"
"errors"
"fmt"
"strconv"
@ -16,18 +17,20 @@ import (
var (
db2ReservedWords = map[string]bool{}
db2Quoter = schemas.Quoter{
Prefix: '"',
Suffix: '"',
IsReserved: schemas.AlwaysReserve,
}
)
type db2 struct {
Base
}
func (db *db2) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
err := db.Base.Init(d, db, uri, drivername, dataSourceName)
if err != nil {
return err
}
return nil
func (db *db2) Init(uri *URI) error {
db.quoter = db2Quoter
return db.Base.Init(db, uri)
}
func (db *db2) SQLType(c *schemas.Column) string {
@ -81,10 +84,6 @@ func (db *db2) IsReserved(name string) bool {
return ok
}
func (db *db2) Quoter() schemas.Quoter {
return schemas.Quoter{"\"", "\""}
}
func (db *db2) AutoIncrStr() string {
return ""
}
@ -101,7 +100,7 @@ func (db *db2) IndexOnTable() bool {
return false
}
func (db *db2) CreateTableSql(table *schemas.Table, tableName, storeEngine, charset string) string {
func (db *db2) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
var sql string
sql = "CREATE TABLE "
if tableName == "" {
@ -114,7 +113,8 @@ func (db *db2) CreateTableSql(table *schemas.Table, tableName, storeEngine, char
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
sql += StringNoPk(db, col)
s, _ := ColumnString(db, col, false)
sql += s
if col.IsAutoIncrement {
sql += " GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1 )"
}
@ -129,7 +129,7 @@ func (db *db2) CreateTableSql(table *schemas.Table, tableName, storeEngine, char
}
sql = sql[:len(sql)-2] + ")"
return sql
return []string{sql}, false
}
func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
@ -143,14 +143,30 @@ func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{})
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
}
func (db *db2) TableCheckSQL(tableName string) (string, []interface{}) {
if len(db.uri.Schema) == 0 {
args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
func (db *db2) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy {
case QuotePolicyNone:
var q = oracleQuoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = oracleQuoter
q.IsReserved = db.IsReserved
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = oracleQuoter
}
}
args := []interface{}{db.uri.Schema, tableName}
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
func (db *db2) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
if len(db.uri.Schema) == 0 {
return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = ?`, tableName)
}
return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`,
db.uri.Schema, tableName,
)
}
func (db *db2) ModifyColumnSQL(tableName string, col *schemas.Column) string {
@ -183,7 +199,7 @@ func (db *db2) DropIndexSQL(tableName string, index *schemas.Index) string {
return fmt.Sprintf("DROP INDEX %v", quote(idxName))
}
func (db *db2) IsColumnExist(tableName, colName string) (bool, error) {
func (db *db2) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
args := []interface{}{db.uri.Schema, tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
" AND column_name = $3"
@ -192,9 +208,8 @@ func (db *db2) IsColumnExist(tableName, colName string) (bool, error) {
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2"
}
db.LogSQL(query, args)
rows, err := db.DB().Query(query, args...)
rows, err := queryer.QueryContext(ctx, query, args...)
if err != nil {
return false, err
}
@ -203,7 +218,7 @@ func (db *db2) IsColumnExist(tableName, colName string) (bool, error) {
return rows.Next(), nil
}
func (db *db2) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) {
func (db *db2) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName}
s := `Select c.colname as column_name,
c.colno as position,
@ -228,9 +243,7 @@ where t.type = 'T' AND c.tabname = ?`
}
s = s + f
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...)
rows, err := queryer.QueryContext(ctx, s, args...)
if err != nil {
return nil, nil, err
}
@ -310,7 +323,7 @@ where t.type = 'T' AND c.tabname = ?`
return colSeq, cols, nil
}
func (db *db2) GetTables() ([]*schemas.Table, error) {
func (db *db2) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{}
s := "SELECT TABNAME FROM SYSCAT.TABLES WHERE type = 'T' AND OWNERTYPE = 'U'"
if len(db.uri.Schema) != 0 {
@ -318,9 +331,7 @@ func (db *db2) GetTables() ([]*schemas.Table, error) {
s = s + " AND TABSCHEMA = ?"
}
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...)
rows, err := queryer.QueryContext(ctx, s, args...)
if err != nil {
return nil, err
}
@ -340,7 +351,7 @@ func (db *db2) GetTables() ([]*schemas.Table, error) {
return tables, nil
}
func (db *db2) GetIndexes(tableName string) (map[string]*schemas.Index, error) {
func (db *db2) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName}
s := fmt.Sprintf(`select uniquerule,
indname as index_name,
@ -350,9 +361,8 @@ from syscat.indexes WHERE tabname = ?`)
args = append(args, db.uri.Schema)
s = s + " AND tabschema=?"
}
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...)
rows, err := queryer.QueryContext(ctx, s, args...)
if err != nil {
return nil, err
}
@ -399,7 +409,7 @@ from syscat.indexes WHERE tabname = ?`)
}
func (db *db2) Filters() []Filter {
return []Filter{&QuoteFilter{}}
return []Filter{}
}
type db2Driver struct{}

2
go.mod
View File

@ -7,7 +7,7 @@ require (
github.com/go-sql-driver/mysql v1.6.0
github.com/goccy/go-json v0.7.4
github.com/jackc/pgx/v4 v4.12.0
github.com/ibmdb/go_ibm_db v0.1.0
github.com/ibmdb/go_ibm_db v0.3.0
github.com/json-iterator/go v1.1.11
github.com/lib/pq v1.10.2
github.com/mattn/go-sqlite3 v1.14.8

2
go.sum
View File

@ -205,6 +205,8 @@ github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u
github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/ibmdb/go_ibm_db v0.1.0 h1:Ok7W7wysBUa8eyVYxWLS5vIA0VomTsurK57l5Rah1M8=
github.com/ibmdb/go_ibm_db v0.1.0/go.mod h1:nl5aUh1IzBVExcqYXaZLApaq8RUvTEph3VP49UTmEvg=
github.com/ibmdb/go_ibm_db v0.3.0 h1:KCSVFS9eXmlTEFL8ScyROsYWmP02G3eGce7VRAt4Csk=
github.com/ibmdb/go_ibm_db v0.3.0/go.mod h1:nl5aUh1IzBVExcqYXaZLApaq8RUvTEph3VP49UTmEvg=
github.com/json-iterator/go v1.1.11 h1:uVUAXhF2To8cbw/3xN3pxj6kk7TYKs98NIrTqPlMWAQ=
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=