Improve oracle
This commit is contained in:
parent
572e277b42
commit
3de08f5396
|
@ -43,8 +43,10 @@ const (
|
|||
SequenceAutoincrMode
|
||||
)
|
||||
|
||||
// DialectFeatures represents the features that the dialect supports
|
||||
type DialectFeatures struct {
|
||||
AutoincrMode int // 0 autoincrement column, 1 sequence
|
||||
SupportReturnIDWhenInsert bool
|
||||
}
|
||||
|
||||
// Dialect represents a kind of database
|
||||
|
|
|
@ -543,6 +543,7 @@ func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.V
|
|||
func (db *oracle) Features() *DialectFeatures {
|
||||
return &DialectFeatures{
|
||||
AutoincrMode: SequenceAutoincrMode,
|
||||
SupportReturnIDWhenInsert: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -553,8 +554,9 @@ func (db *oracle) SQLType(c *schemas.Column) string {
|
|||
res = "NUMBER"
|
||||
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
|
||||
return schemas.Blob
|
||||
case schemas.Time, schemas.DateTime, schemas.TimeStamp:
|
||||
case schemas.Date, schemas.Time, schemas.DateTime, schemas.TimeStamp:
|
||||
res = schemas.Date
|
||||
return res
|
||||
case schemas.TimeStampz:
|
||||
res = "TIMESTAMP"
|
||||
case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal:
|
||||
|
@ -607,42 +609,44 @@ func (db *oracle) DropTableSQL(tableName, autoincrCol string) ([]string, bool) {
|
|||
fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)),
|
||||
}
|
||||
if autoincrCol != "" {
|
||||
sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", seqName(tableName)))
|
||||
sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", OracleSeqName(tableName)))
|
||||
}
|
||||
return sqls, false
|
||||
}
|
||||
|
||||
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
|
||||
var sql = "CREATE TABLE "
|
||||
if tableName == "" {
|
||||
tableName = table.Name
|
||||
}
|
||||
|
||||
quoter := db.Quoter()
|
||||
sql += quoter.Quote(tableName) + " ("
|
||||
var b strings.Builder
|
||||
b.WriteString("CREATE TABLE ")
|
||||
quoter.QuoteTo(&b, tableName)
|
||||
b.WriteString(" (")
|
||||
|
||||
pkList := table.PrimaryKeys
|
||||
|
||||
for _, colName := range table.ColumnsSeq() {
|
||||
for i, colName := range table.ColumnsSeq() {
|
||||
col := table.GetColumn(colName)
|
||||
/*if col.IsPrimaryKey && len(pkList) == 1 {
|
||||
sql += col.String(b.dialect)
|
||||
} else {*/
|
||||
s, _ := ColumnString(db, col, false)
|
||||
sql += s
|
||||
// }
|
||||
sql = strings.TrimSpace(sql)
|
||||
sql += ", "
|
||||
b.WriteString(s)
|
||||
if i != len(table.ColumnsSeq())-1 {
|
||||
b.WriteString(", ")
|
||||
}
|
||||
}
|
||||
|
||||
if len(pkList) > 0 {
|
||||
sql += "PRIMARY KEY ( "
|
||||
sql += quoter.Join(pkList, ",")
|
||||
sql += " ), "
|
||||
if len(table.ColumnsSeq()) > 0 {
|
||||
b.WriteString(", ")
|
||||
}
|
||||
b.WriteString("PRIMARY KEY (")
|
||||
quoter.JoinWrite(&b, pkList, ",")
|
||||
b.WriteString(")")
|
||||
}
|
||||
b.WriteString(")")
|
||||
|
||||
sql = sql[:len(sql)-2] + ")"
|
||||
return sql, false, nil
|
||||
return b.String(), false, nil
|
||||
}
|
||||
|
||||
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
|
||||
|
@ -679,7 +683,7 @@ func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, table
|
|||
return db.HasRecords(queryer, ctx, query, args...)
|
||||
}
|
||||
|
||||
func seqName(tableName string) string {
|
||||
func OracleSeqName(tableName string) string {
|
||||
return "SEQ_" + strings.ToUpper(tableName)
|
||||
}
|
||||
|
||||
|
@ -746,7 +750,7 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
|
|||
if pkName != "" && pkName == col.Name {
|
||||
col.IsPrimaryKey = true
|
||||
|
||||
has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", seqName(tableName))
|
||||
has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", OracleSeqName(tableName))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"xorm.io/xorm"
|
||||
"xorm.io/xorm/dialects"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -886,10 +887,11 @@ func TestAfterLoadProcessor(t *testing.T) {
|
|||
|
||||
type AfterInsertStruct struct {
|
||||
Id int64
|
||||
Dialect dialects.Dialect `xorm:"-"`
|
||||
}
|
||||
|
||||
func (a *AfterInsertStruct) AfterInsert() {
|
||||
if a.Id == 0 {
|
||||
if a.Dialect.Features().SupportReturnIDWhenInsert && a.Id == 0 {
|
||||
panic("a.Id")
|
||||
}
|
||||
}
|
||||
|
@ -899,6 +901,8 @@ func TestAfterInsert(t *testing.T) {
|
|||
|
||||
assertSync(t, new(AfterInsertStruct))
|
||||
|
||||
_, err := testEngine.Insert(&AfterInsertStruct{})
|
||||
_, err := testEngine.Insert(&AfterInsertStruct{
|
||||
Dialect: testEngine.Dialect(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
|
|
@ -44,7 +44,6 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
|||
}
|
||||
|
||||
var hasInsertColumns = len(colNames) > 0
|
||||
<<<<<<< HEAD
|
||||
var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
|
||||
if needSeq {
|
||||
for _, col := range colNames {
|
||||
|
@ -57,10 +56,6 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
|||
|
||||
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE &&
|
||||
statement.dialect.URI().DBType != schemas.DAMENG {
|
||||
=======
|
||||
|
||||
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE {
|
||||
>>>>>>> d24e7cb (Fix insert)
|
||||
if statement.dialect.URI().DBType == schemas.MYSQL {
|
||||
if _, err := buf.WriteString(" VALUES ()"); err != nil {
|
||||
return nil, err
|
||||
|
@ -173,14 +168,20 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
|||
}
|
||||
}
|
||||
|
||||
if len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.POSTGRES ||
|
||||
statement.dialect.URI().DBType == schemas.ORACLE) {
|
||||
if len(table.AutoIncrement) > 0 {
|
||||
if statement.dialect.URI().DBType == schemas.POSTGRES {
|
||||
if _, err := buf.WriteString(" RETURNING "); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} /* else if statement.dialect.URI().DBType == schemas.ORACLE {
|
||||
if _, err := buf.WriteString(fmt.Sprintf("; select %s.currval from dual",
|
||||
dialects.OracleSeqName(tableName))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}*/
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
|
|
Loading…
Reference in New Issue