From 922be56e321f227607124afdc2c824d3f11c56e0 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 25 Feb 2020 13:30:15 +0800 Subject: [PATCH] Use new dialect interface --- dialect_db2.go => dialects/db2.go | 162 +++++++++++++++--------------- dialects/dialect.go | 14 +++ 2 files changed, 97 insertions(+), 79 deletions(-) rename dialect_db2.go => dialects/db2.go (68%) diff --git a/dialect_db2.go b/dialects/db2.go similarity index 68% rename from dialect_db2.go rename to dialects/db2.go index 68a44242..8f146100 100644 --- a/dialect_db2.go +++ b/dialects/db2.go @@ -1,8 +1,8 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. +// Copyright 2020 The Xorm Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( "errors" @@ -11,13 +11,18 @@ import ( "strings" "xorm.io/xorm/core" + "xorm.io/xorm/schemas" +) + +var ( + db2ReservedWords = map[string]bool{} ) type db2 struct { - core.Base + Base } -func (db *db2) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { +func (db *db2) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { err := db.Base.Init(d, db, uri, drivername, dataSourceName) if err != nil { return err @@ -25,29 +30,29 @@ func (db *db2) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string return nil } -func (db *db2) SqlType(c *core.Column) string { +func (db *db2) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case core.TinyInt: - res = core.SmallInt + case schemas.TinyInt: + res = schemas.SmallInt return res - case core.Bit: - res = core.Boolean + case schemas.Bit: + res = schemas.Boolean return res - case core.Binary, core.VarBinary: - return core.Bytea - case core.DateTime: - res = core.TimeStamp - case core.TimeStampz: + case schemas.Binary, schemas.VarBinary: + return schemas.Bytea + case schemas.DateTime: + res = schemas.TimeStamp + case schemas.TimeStampz: return "timestamp with time zone" - case core.TinyText, core.MediumText, core.LongText: - res = core.Text - case core.NVarchar: - res = core.Varchar - case core.Uuid: - return core.Uuid - case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: - return core.Bytea + case schemas.TinyText, schemas.MediumText, schemas.LongText: + res = schemas.Text + case schemas.NVarchar: + res = schemas.Varchar + case schemas.Uuid: + return schemas.Uuid + case schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob: + return schemas.Bytea default: res = t } @@ -72,13 +77,12 @@ func (db *db2) SupportInsertMany() bool { } func (db *db2) IsReserved(name string) bool { - _, ok := postgresReservedWords[name] + _, ok := db2ReservedWords[name] return ok } -func (db *db2) Quote(name string) string { - name = strings.Replace(name, ".", `"."`, -1) - return "\"" + name + "\"" +func (db *db2) Quoter() schemas.Quoter { + return schemas.Quoter{"\"", "\""} } func (db *db2) AutoIncrStr() string { @@ -97,20 +101,20 @@ func (db *db2) IndexOnTable() bool { return false } -func (db *db2) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { +func (db *db2) CreateTableSql(table *schemas.Table, tableName, storeEngine, charset string) string { var sql string sql = "CREATE TABLE " if tableName == "" { tableName = table.Name } - sql += db.Quote(tableName) + " (" + sql += db.Quoter().Quote(tableName) + " (" pkList := table.PrimaryKeys for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - sql += col.StringNoPk(db) + sql += StringNoPk(db, col) if col.IsAutoIncrement { sql += " GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1 )" } @@ -120,7 +124,7 @@ func (db *db2) CreateTableSql(table *core.Table, tableName, storeEngine, charset if len(pkList) > 0 { sql += "PRIMARY KEY ( " - sql += db.Quote(strings.Join(pkList, db.Quote(","))) + sql += db.Quoter().Join(pkList, ",") sql += " ), " } @@ -128,38 +132,38 @@ func (db *db2) CreateTableSql(table *core.Table, tableName, storeEngine, charset return sql } -func (db *db2) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - if len(db.Schema) == 0 { +func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { + if len(db.uri.Schema) == 0 { args := []interface{}{tableName, idxName} return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args } - args := []interface{}{db.Schema, tableName, idxName} + args := []interface{}{db.uri.Schema, tableName, idxName} return `SELECT indexname FROM pg_indexes ` + `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } -func (db *db2) TableCheckSql(tableName string) (string, []interface{}) { - if len(db.Schema) == 0 { +func (db *db2) TableCheckSQL(tableName string) (string, []interface{}) { + if len(db.uri.Schema) == 0 { args := []interface{}{tableName} return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args } - args := []interface{}{db.Schema, tableName} + args := []interface{}{db.uri.Schema, tableName} return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args } -func (db *db2) ModifyColumnSql(tableName string, col *core.Column) string { - if len(db.Schema) == 0 { +func (db *db2) ModifyColumnSQL(tableName string, col *schemas.Column) string { + if len(db.uri.Schema) == 0 { return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", - tableName, col.Name, db.SqlType(col)) + tableName, col.Name, db.SQLType(col)) } return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", - db.Schema, tableName, col.Name, db.SqlType(col)) + db.uri.Schema, tableName, col.Name, db.SQLType(col)) } -func (db *db2) DropIndexSql(tableName string, index *core.Index) string { - quote := db.Quote +func (db *db2) DropIndexSQL(tableName string, index *schemas.Index) string { + quote := db.Quoter().Quote idxName := index.Name tableName = strings.Replace(tableName, `"`, "", -1) @@ -167,23 +171,23 @@ func (db *db2) DropIndexSql(tableName string, index *core.Index) string { if !strings.HasPrefix(idxName, "UQE_") && !strings.HasPrefix(idxName, "IDX_") { - if index.Type == core.UniqueType { + if index.Type == schemas.UniqueType { idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name) } else { idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) } } - if db.Uri.Schema != "" { - idxName = db.Uri.Schema + "." + idxName + if db.uri.Schema != "" { + idxName = db.uri.Schema + "." + idxName } return fmt.Sprintf("DROP INDEX %v", quote(idxName)) } func (db *db2) IsColumnExist(tableName, colName string) (bool, error) { - args := []interface{}{db.Schema, tableName, colName} + args := []interface{}{db.uri.Schema, tableName, colName} query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + " AND column_name = $3" - if len(db.Schema) == 0 { + if len(db.uri.Schema) == 0 { args = []interface{}{tableName, colName} query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + " AND column_name = $2" @@ -199,7 +203,7 @@ func (db *db2) IsColumnExist(tableName, colName string) (bool, error) { return rows.Next(), nil } -func (db *db2) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { +func (db *db2) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := `Select c.colname as column_name, c.colno as position, @@ -218,8 +222,8 @@ inner join syscat.tables t on where t.type = 'T' AND c.tabname = ?` var f string - if len(db.Schema) != 0 { - args = append(args, db.Schema) + if len(db.uri.Schema) != 0 { + args = append(args, db.uri.Schema) f = " AND c.tabschema = ?" } s = s + f @@ -232,11 +236,11 @@ where t.type = 'T' AND c.tabname = ?` } defer rows.Close() - cols := make(map[string]*core.Column) + cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - col := new(core.Column) + col := new(schemas.Column) col.Indexes = make(map[string]int) var colName, position, dataType, numericScale string @@ -268,23 +272,23 @@ where t.type = 'T' AND c.tabname = ?` switch dataType { case "character", "CHARACTER": - col.SQLType = core.SQLType{Name: core.Char, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Char, DefaultLength: 0, DefaultLength2: 0} case "timestamp without time zone": - col.SQLType = core.SQLType{Name: core.DateTime, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.DateTime, DefaultLength: 0, DefaultLength2: 0} case "timestamp with time zone": - col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} case "double precision": - col.SQLType = core.SQLType{Name: core.Double, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Double, DefaultLength: 0, DefaultLength2: 0} case "boolean": - col.SQLType = core.SQLType{Name: core.Bool, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Bool, DefaultLength: 0, DefaultLength2: 0} case "time without time zone": - col.SQLType = core.SQLType{Name: core.Time, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Time, DefaultLength: 0, DefaultLength2: 0} case "oid": - col.SQLType = core.SQLType{Name: core.BigInt, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.BigInt, DefaultLength: 0, DefaultLength2: 0} default: - col.SQLType = core.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0} } - if _, ok := core.SqlTypes[col.SQLType.Name]; !ok { + if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok { return nil, nil, fmt.Errorf("Unknown colType: %v", dataType) } @@ -306,11 +310,11 @@ where t.type = 'T' AND c.tabname = ?` return colSeq, cols, nil } -func (db *db2) GetTables() ([]*core.Table, error) { +func (db *db2) GetTables() ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT TABNAME FROM SYSCAT.TABLES WHERE type = 'T' AND OWNERTYPE = 'U'" - if len(db.Schema) != 0 { - args = append(args, db.Schema) + if len(db.uri.Schema) != 0 { + args = append(args, db.uri.Schema) s = s + " AND TABSCHEMA = ?" } @@ -322,9 +326,9 @@ func (db *db2) GetTables() ([]*core.Table, error) { } defer rows.Close() - tables := make([]*core.Table, 0) + tables := make([]*schemas.Table, 0) for rows.Next() { - table := core.NewEmptyTable() + table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) if err != nil { @@ -336,14 +340,14 @@ func (db *db2) GetTables() ([]*core.Table, error) { return tables, nil } -func (db *db2) GetIndexes(tableName string) (map[string]*core.Index, error) { +func (db *db2) GetIndexes(tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := fmt.Sprintf(`select uniquerule, indname as index_name, replace(substring(colnames,2,length(colnames)),'+',',') as columns from syscat.indexes WHERE tabname = ?`) - if len(db.Schema) != 0 { - args = append(args, db.Schema) + if len(db.uri.Schema) != 0 { + args = append(args, db.uri.Schema) s = s + " AND tabschema=?" } db.LogSQL(s, args) @@ -354,7 +358,7 @@ from syscat.indexes WHERE tabname = ?`) } defer rows.Close() - indexes := make(map[string]*core.Index, 0) + indexes := make(map[string]*schemas.Index, 0) for rows.Next() { var indexTypeName, indexName, columns string /*when 'P' then 'Primary key' @@ -370,9 +374,9 @@ from syscat.indexes WHERE tabname = ?`) } var indexType int if strings.EqualFold(indexTypeName, "U") { - indexType = core.UniqueType + indexType = schemas.UniqueType } else if strings.EqualFold(indexTypeName, "D") { - indexType = core.IndexType + indexType = schemas.IndexType } var isRegular bool if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { @@ -383,7 +387,7 @@ from syscat.indexes WHERE tabname = ?`) } } - index := &core.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} + index := &schemas.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} colNames := strings.Split(columns, ",") for _, colName := range colNames { index.Cols = append(index.Cols, strings.Trim(colName, `" `)) @@ -394,13 +398,13 @@ from syscat.indexes WHERE tabname = ?`) return indexes, nil } -func (db *db2) Filters() []core.Filter { - return []core.Filter{&core.QuoteFilter{}} +func (db *db2) Filters() []Filter { + return []Filter{&QuoteFilter{}} } type db2Driver struct{} -func (p *db2Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { +func (p *db2Driver) Parse(driverName, dataSourceName string) (*URI, error) { var dbName string var defaultSchema string @@ -420,9 +424,9 @@ func (p *db2Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) if dbName == "" { return nil, errors.New("no db name provided") } - return &core.Uri{ - DbName: dbName, - DbType: "db2", + return &URI{ + DBName: dbName, + DBType: "db2", Schema: defaultSchema, }, nil } diff --git a/dialects/dialect.go b/dialects/dialect.go index fc11eac1..d844aed9 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -211,6 +211,7 @@ func regDrvsNDialects() bool { getDriver func() Driver getDialect func() Dialect }{ +<<<<<<< HEAD "mssql": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, "odbc": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access "mysql": {"mysql", func() Driver { return &mysqlDriver{} }, func() Dialect { return &mysql{} }}, @@ -221,6 +222,19 @@ func regDrvsNDialects() bool { "sqlite": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }}, "oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }}, "godror": {"oracle", func() Driver { return &godrorDriver{} }, func() Dialect { return &oracle{} }}, +======= + "mssql": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, + "odbc": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access + "mysql": {"mysql", func() Driver { return &mysqlDriver{} }, func() Dialect { return &mysql{} }}, + "mymysql": {"mysql", func() Driver { return &mymysqlDriver{} }, func() Dialect { return &mysql{} }}, + "postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }}, + "pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }}, + "sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }}, + "sqlite": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }}, + "oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }}, + "goracle": {"oracle", func() Driver { return &goracleDriver{} }, func() Dialect { return &oracle{} }}, + "go_ibm_db": {"db2", func() Driver { return &db2Driver{} }, func() Dialect { return &db2{} }}, +>>>>>>> 538a3b2 (Use new dialect interface) } for driverName, v := range providedDrvsNDialects {