fix db2 gettables

This commit is contained in:
Lunny Xiao 2019-11-18 09:39:01 +08:00
parent 296ec09941
commit 817dbe4f61
2 changed files with 46 additions and 43 deletions

View File

@ -191,24 +191,28 @@ func (db *db2) IsColumnExist(tableName, colName string) (bool, error) {
func (db *db2) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { func (db *db2) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , s := `Select c.colname as column_name,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, c.colno as position,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey c.typename as data_type,
FROM pg_attribute f c.length,
JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid c.scale,
LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum c.remarks as description,
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace case when c.nulls = 'Y' then 1 else 0 end as nullable,
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) default as default_value,
LEFT JOIN pg_class AS g ON p.confrelid = g.oid case when c.identity ='Y' then 1 else 0 end as is_identity,
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name case when c.generated ='' then 0 else 1 end as is_computed,
WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` c.text as computed_formula
from syscat.columns c
inner join syscat.tables t on
t.tabschema = c.tabschema and t.tabname = c.tabname
where t.type = 'T' AND c.tabname = ?`
var f string var f string
if len(db.Schema) != 0 { if len(db.Schema) != 0 {
args = append(args, db.Schema) args = append(args, db.Schema)
f = " AND s.table_schema = $2" f = " AND c.tabschema = ?"
} }
s = fmt.Sprintf(s, f) s = s + f
db.LogSQL(s, args) db.LogSQL(s, args)
@ -225,15 +229,15 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
col := new(core.Column) col := new(core.Column)
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
var colName, isNullable, dataType string var colName, position, dataType, numericScale string
var maxLenStr, colDefault, numPrecision, numRadix *string var description, colDefault, computedFormula, maxLenStr *string
var isPK, isUnique bool var isComputed bool
err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &numPrecision, &numRadix, &isPK, &isUnique) err = rows.Scan(&colName, &position, &dataType, &maxLenStr, &numericScale, &description, &col.Nullable, &colDefault, &col.IsPrimaryKey, &isComputed, &computedFormula)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// fmt.Println(args, colName, isNullable, dataType, maxLenStr, colDefault, numPrecision, numRadix, isPK, isUnique) //fmt.Println(colName, position, dataType, maxLenStr, numericScale, description, col.Nullable, colDefault, col.IsPrimaryKey, isComputed, computedFormula)
var maxLen int var maxLen int
if maxLenStr != nil { if maxLenStr != nil {
maxLen, err = strconv.Atoi(*maxLenStr) maxLen, err = strconv.Atoi(*maxLenStr)
@ -243,24 +247,18 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
} }
col.Name = strings.Trim(colName, `" `) col.Name = strings.Trim(colName, `" `)
if colDefault != nil {
if colDefault != nil || isPK { col.DefaultIsEmpty = false
if isPK {
col.IsPrimaryKey = true
} else {
col.Default = *colDefault col.Default = *colDefault
} }
}
if colDefault != nil && strings.HasPrefix(*colDefault, "nextval(") { if colDefault != nil && strings.HasPrefix(*colDefault, "nextval(") {
col.IsAutoIncrement = true col.IsAutoIncrement = true
} }
col.Nullable = (isNullable == "YES")
switch dataType { switch dataType {
case "character varying", "character": case "character", "CHARACTER":
col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: 0, DefaultLength2: 0} col.SQLType = core.SQLType{Name: core.Char, DefaultLength: 0, DefaultLength2: 0}
case "timestamp without time zone": case "timestamp without time zone":
col.SQLType = core.SQLType{Name: core.DateTime, DefaultLength: 0, DefaultLength2: 0} col.SQLType = core.SQLType{Name: core.DateTime, DefaultLength: 0, DefaultLength2: 0}
case "timestamp with time zone": case "timestamp with time zone":
@ -300,10 +298,10 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
func (db *db2) GetTables() ([]*core.Table, error) { func (db *db2) GetTables() ([]*core.Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT tablename FROM pg_tables" s := "SELECT NAME FROM SYSIBM.SYSTABLES WHERE type = 'T'"
if len(db.Schema) != 0 { if len(db.Schema) != 0 {
args = append(args, db.Schema) args = append(args, db.Schema)
s = s + " WHERE schemaname = $1" s = s + " AND creator = ?"
} }
db.LogSQL(s, args) db.LogSQL(s, args)
@ -330,10 +328,13 @@ func (db *db2) GetTables() ([]*core.Table, error) {
func (db *db2) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *db2) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") s := fmt.Sprintf(`select uniquerule,
indname as index_name,
replace(substring(colnames,2,length(colnames)),'+',',') as columns
from syscat.indexes WHERE tabname = ?`)
if len(db.Schema) != 0 { if len(db.Schema) != 0 {
args = append(args, db.Schema) args = append(args, db.Schema)
s = s + " AND schemaname=$2" s = s + " AND tabschema=?"
} }
db.LogSQL(s, args) db.LogSQL(s, args)
@ -345,10 +346,11 @@ func (db *db2) GetIndexes(tableName string) (map[string]*core.Index, error) {
indexes := make(map[string]*core.Index, 0) indexes := make(map[string]*core.Index, 0)
for rows.Next() { for rows.Next() {
var indexType int var indexTypeName, indexName, columns string
var indexName, indexdef string /*when 'P' then 'Primary key'
var colNames []string when 'U' then 'Unique'
err = rows.Scan(&indexName, &indexdef) when 'D' then 'Nonunique'*/
err = rows.Scan(&indexTypeName, &indexName, &columns)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -356,12 +358,12 @@ func (db *db2) GetIndexes(tableName string) (map[string]*core.Index, error) {
if strings.HasSuffix(indexName, "_pkey") { if strings.HasSuffix(indexName, "_pkey") {
continue continue
} }
if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") { var indexType int
if strings.EqualFold(indexTypeName, "U") {
indexType = core.UniqueType indexType = core.UniqueType
} else { } else if strings.EqualFold(indexTypeName, "D") {
indexType = core.IndexType indexType = core.IndexType
} }
colNames = getIndexColName(indexdef)
var isRegular bool var isRegular bool
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
newIdxName := indexName[5+len(tableName):] newIdxName := indexName[5+len(tableName):]
@ -372,6 +374,7 @@ func (db *db2) GetIndexes(tableName string) (map[string]*core.Index, error) {
} }
index := &core.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} index := &core.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
colNames := strings.Split(columns, ",")
for _, colName := range colNames { for _, colName := range colNames {
index.Cols = append(index.Cols, strings.Trim(colName, `" `)) index.Cols = append(index.Cols, strings.Trim(colName, `" `))
} }
@ -392,7 +395,7 @@ func (p *db2Driver) Parse(driverName, dataSourceName string) (*core.Uri, error)
kv := strings.Split(dataSourceName, ";") kv := strings.Split(dataSourceName, ";")
for _, c := range kv { for _, c := range kv {
vv := strings.Split(strings.TrimSpace(c), "=") vv := strings.SplitN(strings.TrimSpace(c), "=", 2)
if len(vv) == 2 { if len(vv) == 2 {
switch strings.ToLower(vv[0]) { switch strings.ToLower(vv[0]) {
case "database": case "database":
@ -404,5 +407,5 @@ func (p *db2Driver) Parse(driverName, dataSourceName string) (*core.Uri, error)
if dbName == "" { if dbName == "" {
return nil, errors.New("no db name provided") return nil, errors.New("no db name provided")
} }
return &core.Uri{DbName: dbName, DbType: core.MSSQL}, nil return &core.Uri{DbName: dbName, DbType: "db2"}, nil
} }

View File

@ -8,4 +8,4 @@ export CGO_CFLAGS=-I$DB2HOME/include
export CGO_LDFLAGS=-L$DB2HOME/lib export CGO_LDFLAGS=-L$DB2HOME/lib
export DYLD_LIBRARY_PATH=$DYLD_LIBRARY_PATH:$DB2HOME/clidriver/lib export DYLD_LIBRARY_PATH=$DYLD_LIBRARY_PATH:$DB2HOME/clidriver/lib
cd $cur cd $cur
go test -db=go_ibm_db -tags=db2 -conn_str="HOSTNAME=localhost;DATABASE=testdb;PORT=50000;UID=db2inst1;PWD=password" go test -db=go_ibm_db -tags=db2 -conn_str="HOSTNAME=localhost;DATABASE=testdb;PORT=50000;UID=db2inst1;PWD=123#2@23"