Improve oracle

This commit is contained in:
Lunny Xiao 2020-12-02 22:05:02 +08:00
parent 572e277b42
commit 3de08f5396
4 changed files with 48 additions and 37 deletions

View File

@ -43,8 +43,10 @@ const (
SequenceAutoincrMode SequenceAutoincrMode
) )
// DialectFeatures represents the features that the dialect supports
type DialectFeatures struct { type DialectFeatures struct {
AutoincrMode int // 0 autoincrement column, 1 sequence AutoincrMode int // 0 autoincrement column, 1 sequence
SupportReturnIDWhenInsert bool
} }
// Dialect represents a kind of database // Dialect represents a kind of database

View File

@ -542,7 +542,8 @@ func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.V
func (db *oracle) Features() *DialectFeatures { func (db *oracle) Features() *DialectFeatures {
return &DialectFeatures{ return &DialectFeatures{
AutoincrMode: SequenceAutoincrMode, AutoincrMode: SequenceAutoincrMode,
SupportReturnIDWhenInsert: false,
} }
} }
@ -553,8 +554,9 @@ func (db *oracle) SQLType(c *schemas.Column) string {
res = "NUMBER" res = "NUMBER"
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
return schemas.Blob return schemas.Blob
case schemas.Time, schemas.DateTime, schemas.TimeStamp: case schemas.Date, schemas.Time, schemas.DateTime, schemas.TimeStamp:
res = schemas.Date res = schemas.Date
return res
case schemas.TimeStampz: case schemas.TimeStampz:
res = "TIMESTAMP" res = "TIMESTAMP"
case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal: case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal:
@ -607,42 +609,44 @@ func (db *oracle) DropTableSQL(tableName, autoincrCol string) ([]string, bool) {
fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)), fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)),
} }
if autoincrCol != "" { if autoincrCol != "" {
sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", seqName(tableName))) sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", OracleSeqName(tableName)))
} }
return sqls, false return sqls, false
} }
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
var sql = "CREATE TABLE "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
} }
quoter := db.Quoter() quoter := db.Quoter()
sql += quoter.Quote(tableName) + " (" var b strings.Builder
b.WriteString("CREATE TABLE ")
quoter.QuoteTo(&b, tableName)
b.WriteString(" (")
pkList := table.PrimaryKeys pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
/*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect)
} else {*/
s, _ := ColumnString(db, col, false) s, _ := ColumnString(db, col, false)
sql += s b.WriteString(s)
// } if i != len(table.ColumnsSeq())-1 {
sql = strings.TrimSpace(sql) b.WriteString(", ")
sql += ", " }
} }
if len(pkList) > 0 { if len(pkList) > 0 {
sql += "PRIMARY KEY ( " if len(table.ColumnsSeq()) > 0 {
sql += quoter.Join(pkList, ",") b.WriteString(", ")
sql += " ), " }
b.WriteString("PRIMARY KEY (")
quoter.JoinWrite(&b, pkList, ",")
b.WriteString(")")
} }
b.WriteString(")")
sql = sql[:len(sql)-2] + ")" return b.String(), false, nil
return sql, false, nil
} }
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
@ -679,7 +683,7 @@ func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, table
return db.HasRecords(queryer, ctx, query, args...) return db.HasRecords(queryer, ctx, query, args...)
} }
func seqName(tableName string) string { func OracleSeqName(tableName string) string {
return "SEQ_" + strings.ToUpper(tableName) return "SEQ_" + strings.ToUpper(tableName)
} }
@ -746,7 +750,7 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
if pkName != "" && pkName == col.Name { if pkName != "" && pkName == col.Name {
col.IsPrimaryKey = true col.IsPrimaryKey = true
has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", seqName(tableName)) has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", OracleSeqName(tableName))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -10,6 +10,7 @@ import (
"testing" "testing"
"xorm.io/xorm" "xorm.io/xorm"
"xorm.io/xorm/dialects"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -885,11 +886,12 @@ func TestAfterLoadProcessor(t *testing.T) {
} }
type AfterInsertStruct struct { type AfterInsertStruct struct {
Id int64 Id int64
Dialect dialects.Dialect `xorm:"-"`
} }
func (a *AfterInsertStruct) AfterInsert() { func (a *AfterInsertStruct) AfterInsert() {
if a.Id == 0 { if a.Dialect.Features().SupportReturnIDWhenInsert && a.Id == 0 {
panic("a.Id") panic("a.Id")
} }
} }
@ -899,6 +901,8 @@ func TestAfterInsert(t *testing.T) {
assertSync(t, new(AfterInsertStruct)) assertSync(t, new(AfterInsertStruct))
_, err := testEngine.Insert(&AfterInsertStruct{}) _, err := testEngine.Insert(&AfterInsertStruct{
Dialect: testEngine.Dialect(),
})
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@ -44,7 +44,6 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
} }
var hasInsertColumns = len(colNames) > 0 var hasInsertColumns = len(colNames) > 0
<<<<<<< HEAD
var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
if needSeq { if needSeq {
for _, col := range colNames { for _, col := range colNames {
@ -57,10 +56,6 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE && if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE &&
statement.dialect.URI().DBType != schemas.DAMENG { statement.dialect.URI().DBType != schemas.DAMENG {
=======
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE {
>>>>>>> d24e7cb (Fix insert)
if statement.dialect.URI().DBType == schemas.MYSQL { if statement.dialect.URI().DBType == schemas.MYSQL {
if _, err := buf.WriteString(" VALUES ()"); err != nil { if _, err := buf.WriteString(" VALUES ()"); err != nil {
return nil, err return nil, err
@ -173,14 +168,20 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
} }
} }
if len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.POSTGRES || if len(table.AutoIncrement) > 0 {
statement.dialect.URI().DBType == schemas.ORACLE) { if statement.dialect.URI().DBType == schemas.POSTGRES {
if _, err := buf.WriteString(" RETURNING "); err != nil { if _, err := buf.WriteString(" RETURNING "); err != nil {
return nil, err return nil, err
} }
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
return nil, err return nil, err
} }
} /* else if statement.dialect.URI().DBType == schemas.ORACLE {
if _, err := buf.WriteString(fmt.Sprintf("; select %s.currval from dual",
dialects.OracleSeqName(tableName))); err != nil {
return nil, err
}
}*/
} }
return buf, nil return buf, nil