From 479deaff0213831f24e14bc87a4ebd28b8c79c41 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 12 Mar 2015 17:21:02 +0800 Subject: [PATCH] oci8 support --- oracle_dialect.go | 45 +++++++++++++++++++++++++++++---------------- statement.go | 2 +- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/oracle_dialect.go b/oracle_dialect.go index 5dfdda36..5caa0b27 100644 --- a/oracle_dialect.go +++ b/oracle_dialect.go @@ -571,7 +571,7 @@ func (db *oracle) IndexOnTable() bool { func (b *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { var sql string - sql = "CREATE TABLE IF NOT EXISTS " + sql = "CREATE TABLE " if tableName == "" { tableName = table.Name } @@ -609,18 +609,17 @@ func (b *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, chars sql += " DEFAULT CHARSET " + charset } } - sql += ";" return sql } func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(idxName)} + args := []interface{}{tableName, idxName} return `SELECT INDEX_NAME FROM USER_INDEXES ` + - `WHERE TABLE_NAME = ? AND INDEX_NAME = ?`, args + `WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args } func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{strings.ToUpper(tableName)} + args := []interface{}{tableName} return `SELECT table_name FROM user_tables WHERE table_name = :1`, args } @@ -640,7 +639,7 @@ func (db *oracle) MustDropTable(tableName string) error { return nil } - sql = "Drop Table \"" + tableName + "\";" + sql = "Drop Table \"" + tableName + "\"" if db.Logger != nil { db.Logger.Info("[sql]", sql) } @@ -655,9 +654,9 @@ func (db *oracle) MustDropTable(tableName string) error { }*/ func (db *oracle) IsColumnExist(tableName string, col *core.Column) (bool, error) { - args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(col.Name)} - query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + - " AND column_name = ?" + args := []interface{}{tableName, col.Name} + query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" + + " AND column_name = :2" rows, err := db.DB().Query(query, args...) if db.Logger != nil { db.Logger.Info("[sql]", query, args) @@ -674,7 +673,7 @@ func (db *oracle) IsColumnExist(tableName string, col *core.Column) (bool, error } func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { - args := []interface{}{strings.ToUpper(tableName)} + args := []interface{}{tableName} s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" @@ -716,13 +715,27 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum var ignore bool - switch *dataType { + var dt string + var len1, len2 int + dts := strings.Split(*dataType, "(") + dt = dts[0] + if len(dts) > 1 { + lens := strings.Split(dts[1][:len(dts[1])-1], ",") + if len(lens) > 1 { + len1, _ = strconv.Atoi(lens[0]) + len2, _ = strconv.Atoi(lens[1]) + } else { + len1, _ = strconv.Atoi(lens[0]) + } + } + + switch dt { case "VARCHAR2": - col.SQLType = core.SQLType{core.Varchar, 0, 0} + col.SQLType = core.SQLType{core.Varchar, len1, len2} case "TIMESTAMP WITH TIME ZONE": col.SQLType = core.SQLType{core.TimeStampz, 0, 0} case "NUMBER": - col.SQLType = core.SQLType{core.Double, 0, 0} + col.SQLType = core.SQLType{core.Double, len1, len2} case "LONG", "LONG RAW": col.SQLType = core.SQLType{core.Text, 0, 0} case "RAW": @@ -732,15 +745,15 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum case "AQ$_SUBSCRIBERS": ignore = true default: - col.SQLType = core.SQLType{strings.ToUpper(*dataType), 0, 0} + col.SQLType = core.SQLType{strings.ToUpper(dt), len1, len2} } - //fmt.Println(tableName, ":", col.Name) + if ignore { continue } if _, ok := core.SqlTypes[col.SQLType.Name]; !ok { - return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", *dataType)) + return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v %v", *dataType, col.SQLType)) } col.Length = dataLen diff --git a/statement.go b/statement.go index 9024e406..18137440 100644 --- a/statement.go +++ b/statement.go @@ -1185,7 +1185,7 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} id = "" } statement.attachInSql() - return statement.genSelectSql(fmt.Sprintf("count(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...) + return statement.genSelectSql(fmt.Sprintf("count(%v)", id)), append(statement.Params, statement.BeanArgs...) } func (statement *Statement) genSelectSql(columnStr string) (a string) {