Improve oracle

This commit is contained in:
Lunny Xiao 2020-12-02 22:05:02 +08:00
parent 572e277b42
commit 3de08f5396
4 changed files with 48 additions and 37 deletions

View File

@ -43,8 +43,10 @@ const (
SequenceAutoincrMode
)
// DialectFeatures represents the features that the dialect supports
type DialectFeatures struct {
AutoincrMode int // 0 autoincrement column, 1 sequence
SupportReturnIDWhenInsert bool
}
// Dialect represents a kind of database

View File

@ -543,6 +543,7 @@ func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.V
func (db *oracle) Features() *DialectFeatures {
return &DialectFeatures{
AutoincrMode: SequenceAutoincrMode,
SupportReturnIDWhenInsert: false,
}
}
@ -553,8 +554,9 @@ func (db *oracle) SQLType(c *schemas.Column) string {
res = "NUMBER"
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
return schemas.Blob
case schemas.Time, schemas.DateTime, schemas.TimeStamp:
case schemas.Date, schemas.Time, schemas.DateTime, schemas.TimeStamp:
res = schemas.Date
return res
case schemas.TimeStampz:
res = "TIMESTAMP"
case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal:
@ -607,42 +609,44 @@ func (db *oracle) DropTableSQL(tableName, autoincrCol string) ([]string, bool) {
fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)),
}
if autoincrCol != "" {
sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", seqName(tableName)))
sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", OracleSeqName(tableName)))
}
return sqls, false
}
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
var sql = "CREATE TABLE "
if tableName == "" {
tableName = table.Name
}
quoter := db.Quoter()
sql += quoter.Quote(tableName) + " ("
var b strings.Builder
b.WriteString("CREATE TABLE ")
quoter.QuoteTo(&b, tableName)
b.WriteString(" (")
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
/*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect)
} else {*/
s, _ := ColumnString(db, col, false)
sql += s
// }
sql = strings.TrimSpace(sql)
sql += ", "
b.WriteString(s)
if i != len(table.ColumnsSeq())-1 {
b.WriteString(", ")
}
}
if len(pkList) > 0 {
sql += "PRIMARY KEY ( "
sql += quoter.Join(pkList, ",")
sql += " ), "
if len(table.ColumnsSeq()) > 0 {
b.WriteString(", ")
}
b.WriteString("PRIMARY KEY (")
quoter.JoinWrite(&b, pkList, ",")
b.WriteString(")")
}
b.WriteString(")")
sql = sql[:len(sql)-2] + ")"
return sql, false, nil
return b.String(), false, nil
}
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
@ -679,7 +683,7 @@ func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, table
return db.HasRecords(queryer, ctx, query, args...)
}
func seqName(tableName string) string {
func OracleSeqName(tableName string) string {
return "SEQ_" + strings.ToUpper(tableName)
}
@ -746,7 +750,7 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
if pkName != "" && pkName == col.Name {
col.IsPrimaryKey = true
has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", seqName(tableName))
has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", OracleSeqName(tableName))
if err != nil {
return nil, nil, err
}

View File

@ -10,6 +10,7 @@ import (
"testing"
"xorm.io/xorm"
"xorm.io/xorm/dialects"
"github.com/stretchr/testify/assert"
)
@ -886,10 +887,11 @@ func TestAfterLoadProcessor(t *testing.T) {
type AfterInsertStruct struct {
Id int64
Dialect dialects.Dialect `xorm:"-"`
}
func (a *AfterInsertStruct) AfterInsert() {
if a.Id == 0 {
if a.Dialect.Features().SupportReturnIDWhenInsert && a.Id == 0 {
panic("a.Id")
}
}
@ -899,6 +901,8 @@ func TestAfterInsert(t *testing.T) {
assertSync(t, new(AfterInsertStruct))
_, err := testEngine.Insert(&AfterInsertStruct{})
_, err := testEngine.Insert(&AfterInsertStruct{
Dialect: testEngine.Dialect(),
})
assert.NoError(t, err)
}

View File

@ -44,7 +44,6 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
}
var hasInsertColumns = len(colNames) > 0
<<<<<<< HEAD
var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
if needSeq {
for _, col := range colNames {
@ -57,10 +56,6 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE &&
statement.dialect.URI().DBType != schemas.DAMENG {
=======
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE {
>>>>>>> d24e7cb (Fix insert)
if statement.dialect.URI().DBType == schemas.MYSQL {
if _, err := buf.WriteString(" VALUES ()"); err != nil {
return nil, err
@ -173,14 +168,20 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
}
}
if len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.POSTGRES ||
statement.dialect.URI().DBType == schemas.ORACLE) {
if len(table.AutoIncrement) > 0 {
if statement.dialect.URI().DBType == schemas.POSTGRES {
if _, err := buf.WriteString(" RETURNING "); err != nil {
return nil, err
}
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
return nil, err
}
} /* else if statement.dialect.URI().DBType == schemas.ORACLE {
if _, err := buf.WriteString(fmt.Sprintf("; select %s.currval from dual",
dialects.OracleSeqName(tableName))); err != nil {
return nil, err
}
}*/
}
return buf, nil