From 4d5681caf8955a7253cfa5c6da2b19b9bb51e701 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 3 Mar 2015 14:58:01 +0800 Subject: [PATCH] oracle support, to be continued --- oracle_dialect.go | 84 ++++++++++++++++++++++++++++++++++++++++++----- session.go | 10 +++--- statement.go | 5 +-- 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/oracle_dialect.go b/oracle_dialect.go index e4a933e5..5dfdda36 100644 --- a/oracle_dialect.go +++ b/oracle_dialect.go @@ -509,7 +509,7 @@ func (db *oracle) SqlType(c *core.Column) string { var res string switch t := c.SQLType.Name; t { case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool, core.Serial, core.BigSerial: - return "NUMBER" + res = "NUMBER" case core.Binary, core.VarBinary, core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob, core.Bytea: return core.Blob case core.Time, core.DateTime, core.TimeStamp: @@ -521,7 +521,7 @@ func (db *oracle) SqlType(c *core.Column) string { case core.Text, core.MediumText, core.LongText: res = "CLOB" case core.Char, core.Varchar, core.TinyText: - return "VARCHAR2" + res = "VARCHAR2" default: res = t } @@ -536,6 +536,10 @@ func (db *oracle) SqlType(c *core.Column) string { return res } +func (db *oracle) AutoIncrStr() string { + return "AUTO_INCREMENT" +} + func (db *oracle) SupportInsertMany() bool { return true } @@ -553,10 +557,6 @@ func (db *oracle) QuoteStr() string { return "\"" } -func (db *oracle) AutoIncrStr() string { - return "" -} - func (db *oracle) SupportEngine() bool { return false } @@ -569,6 +569,50 @@ func (db *oracle) IndexOnTable() bool { return false } +func (b *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { + var sql string + sql = "CREATE TABLE IF NOT EXISTS " + if tableName == "" { + tableName = table.Name + } + + sql += b.Quote(tableName) + " (" + + pkList := table.PrimaryKeys + + for _, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + /*if col.IsPrimaryKey && len(pkList) == 1 { + sql += col.String(b.dialect) + } else {*/ + sql += col.StringNoPk(b) + //} + sql = strings.TrimSpace(sql) + sql += ", " + } + + if len(pkList) > 0 { + sql += "PRIMARY KEY ( " + sql += b.Quote(strings.Join(pkList, b.Quote(","))) + sql += " ), " + } + + sql = sql[:len(sql)-2] + ")" + if b.SupportEngine() && storeEngine != "" { + sql += " ENGINE=" + storeEngine + } + if b.SupportCharset() { + if len(charset) == 0 { + charset = b.URI().Charset + } + if len(charset) > 0 { + sql += " DEFAULT CHARSET " + charset + } + } + sql += ";" + return sql +} + func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) { args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(idxName)} return `SELECT INDEX_NAME FROM USER_INDEXES ` + @@ -577,7 +621,31 @@ func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{ func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) { args := []interface{}{strings.ToUpper(tableName)} - return `SELECT table_name FROM user_tables WHERE table_name = ?`, args + return `SELECT table_name FROM user_tables WHERE table_name = :1`, args +} + +func (db *oracle) MustDropTable(tableName string) error { + sql, args := db.TableCheckSql(tableName) + if db.Logger != nil { + db.Logger.Info("[sql]", sql, args) + } + + rows, err := db.DB().Query(sql, args...) + if err != nil { + return err + } + defer rows.Close() + + if !rows.Next() { + return nil + } + + sql = "Drop Table \"" + tableName + "\";" + if db.Logger != nil { + db.Logger.Info("[sql]", sql) + } + _, err = db.DB().Exec(sql) + return err } /*func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface{}) { @@ -666,7 +734,7 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum default: col.SQLType = core.SQLType{strings.ToUpper(*dataType), 0, 0} } - fmt.Println(tableName, ":", col.Name) + //fmt.Println(tableName, ":", col.Name) if ignore { continue } diff --git a/session.go b/session.go index cf566902..c0c4d4dc 100644 --- a/session.go +++ b/session.go @@ -572,9 +572,10 @@ func (session *Session) DropTable(bean interface{}) error { return errors.New("Unsupported type") } - sqlStr := session.Statement.genDropSQL() + return session.Engine.Dialect().MustDropTable(session.Statement.TableName()) + /*sqlStr := session.Statement.genDropSQL() _, err := session.exec(sqlStr) - return err + return err*/ } func (statement *Statement) JoinColumns(cols []*core.Column) string { @@ -1491,8 +1492,9 @@ func (session *Session) dropAll() error { for _, table := range session.Engine.Tables { session.Statement.Init() session.Statement.RefTable = table - sqlStr := session.Statement.genDropSQL() - _, err := session.exec(sqlStr) + err := session.Engine.Dialect().MustDropTable(session.Statement.TableName()) + //sqlStr := session.Statement.genDropSQL() + //_, err := session.exec(sqlStr) if err != nil { return err } diff --git a/statement.go b/statement.go index f645a1d3..9024e406 100644 --- a/statement.go +++ b/statement.go @@ -1111,9 +1111,10 @@ func (s *Statement) genDelIndexSQL() []string { return sqls } +/* func (s *Statement) genDropSQL() string { - return s.Engine.dialect.DropTableSql(s.TableName()) + ";" -} + return s.Engine.dialect.MustDropTa(s.TableName()) + ";" +}*/ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) { var table *core.Table