Improve oracle
This commit is contained in:
parent
572e277b42
commit
3de08f5396
|
@ -43,8 +43,10 @@ const (
|
||||||
SequenceAutoincrMode
|
SequenceAutoincrMode
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DialectFeatures represents the features that the dialect supports
|
||||||
type DialectFeatures struct {
|
type DialectFeatures struct {
|
||||||
AutoincrMode int // 0 autoincrement column, 1 sequence
|
AutoincrMode int // 0 autoincrement column, 1 sequence
|
||||||
|
SupportReturnIDWhenInsert bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dialect represents a kind of database
|
// Dialect represents a kind of database
|
||||||
|
|
|
@ -543,6 +543,7 @@ func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.V
|
||||||
func (db *oracle) Features() *DialectFeatures {
|
func (db *oracle) Features() *DialectFeatures {
|
||||||
return &DialectFeatures{
|
return &DialectFeatures{
|
||||||
AutoincrMode: SequenceAutoincrMode,
|
AutoincrMode: SequenceAutoincrMode,
|
||||||
|
SupportReturnIDWhenInsert: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -553,8 +554,9 @@ func (db *oracle) SQLType(c *schemas.Column) string {
|
||||||
res = "NUMBER"
|
res = "NUMBER"
|
||||||
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
|
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
|
||||||
return schemas.Blob
|
return schemas.Blob
|
||||||
case schemas.Time, schemas.DateTime, schemas.TimeStamp:
|
case schemas.Date, schemas.Time, schemas.DateTime, schemas.TimeStamp:
|
||||||
res = schemas.Date
|
res = schemas.Date
|
||||||
|
return res
|
||||||
case schemas.TimeStampz:
|
case schemas.TimeStampz:
|
||||||
res = "TIMESTAMP"
|
res = "TIMESTAMP"
|
||||||
case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal:
|
case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal:
|
||||||
|
@ -607,42 +609,44 @@ func (db *oracle) DropTableSQL(tableName, autoincrCol string) ([]string, bool) {
|
||||||
fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)),
|
fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)),
|
||||||
}
|
}
|
||||||
if autoincrCol != "" {
|
if autoincrCol != "" {
|
||||||
sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", seqName(tableName)))
|
sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", OracleSeqName(tableName)))
|
||||||
}
|
}
|
||||||
return sqls, false
|
return sqls, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
|
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
|
||||||
var sql = "CREATE TABLE "
|
|
||||||
if tableName == "" {
|
if tableName == "" {
|
||||||
tableName = table.Name
|
tableName = table.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
quoter := db.Quoter()
|
quoter := db.Quoter()
|
||||||
sql += quoter.Quote(tableName) + " ("
|
var b strings.Builder
|
||||||
|
b.WriteString("CREATE TABLE ")
|
||||||
|
quoter.QuoteTo(&b, tableName)
|
||||||
|
b.WriteString(" (")
|
||||||
|
|
||||||
pkList := table.PrimaryKeys
|
pkList := table.PrimaryKeys
|
||||||
|
|
||||||
for _, colName := range table.ColumnsSeq() {
|
for i, colName := range table.ColumnsSeq() {
|
||||||
col := table.GetColumn(colName)
|
col := table.GetColumn(colName)
|
||||||
/*if col.IsPrimaryKey && len(pkList) == 1 {
|
|
||||||
sql += col.String(b.dialect)
|
|
||||||
} else {*/
|
|
||||||
s, _ := ColumnString(db, col, false)
|
s, _ := ColumnString(db, col, false)
|
||||||
sql += s
|
b.WriteString(s)
|
||||||
// }
|
if i != len(table.ColumnsSeq())-1 {
|
||||||
sql = strings.TrimSpace(sql)
|
b.WriteString(", ")
|
||||||
sql += ", "
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(pkList) > 0 {
|
if len(pkList) > 0 {
|
||||||
sql += "PRIMARY KEY ( "
|
if len(table.ColumnsSeq()) > 0 {
|
||||||
sql += quoter.Join(pkList, ",")
|
b.WriteString(", ")
|
||||||
sql += " ), "
|
|
||||||
}
|
}
|
||||||
|
b.WriteString("PRIMARY KEY (")
|
||||||
|
quoter.JoinWrite(&b, pkList, ",")
|
||||||
|
b.WriteString(")")
|
||||||
|
}
|
||||||
|
b.WriteString(")")
|
||||||
|
|
||||||
sql = sql[:len(sql)-2] + ")"
|
return b.String(), false, nil
|
||||||
return sql, false, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
|
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
|
||||||
|
@ -679,7 +683,7 @@ func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, table
|
||||||
return db.HasRecords(queryer, ctx, query, args...)
|
return db.HasRecords(queryer, ctx, query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func seqName(tableName string) string {
|
func OracleSeqName(tableName string) string {
|
||||||
return "SEQ_" + strings.ToUpper(tableName)
|
return "SEQ_" + strings.ToUpper(tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -746,7 +750,7 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
|
||||||
if pkName != "" && pkName == col.Name {
|
if pkName != "" && pkName == col.Name {
|
||||||
col.IsPrimaryKey = true
|
col.IsPrimaryKey = true
|
||||||
|
|
||||||
has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", seqName(tableName))
|
has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", OracleSeqName(tableName))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"xorm.io/xorm"
|
"xorm.io/xorm"
|
||||||
|
"xorm.io/xorm/dialects"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
@ -886,10 +887,11 @@ func TestAfterLoadProcessor(t *testing.T) {
|
||||||
|
|
||||||
type AfterInsertStruct struct {
|
type AfterInsertStruct struct {
|
||||||
Id int64
|
Id int64
|
||||||
|
Dialect dialects.Dialect `xorm:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AfterInsertStruct) AfterInsert() {
|
func (a *AfterInsertStruct) AfterInsert() {
|
||||||
if a.Id == 0 {
|
if a.Dialect.Features().SupportReturnIDWhenInsert && a.Id == 0 {
|
||||||
panic("a.Id")
|
panic("a.Id")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -899,6 +901,8 @@ func TestAfterInsert(t *testing.T) {
|
||||||
|
|
||||||
assertSync(t, new(AfterInsertStruct))
|
assertSync(t, new(AfterInsertStruct))
|
||||||
|
|
||||||
_, err := testEngine.Insert(&AfterInsertStruct{})
|
_, err := testEngine.Insert(&AfterInsertStruct{
|
||||||
|
Dialect: testEngine.Dialect(),
|
||||||
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,7 +44,6 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
var hasInsertColumns = len(colNames) > 0
|
var hasInsertColumns = len(colNames) > 0
|
||||||
<<<<<<< HEAD
|
|
||||||
var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
|
var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
|
||||||
if needSeq {
|
if needSeq {
|
||||||
for _, col := range colNames {
|
for _, col := range colNames {
|
||||||
|
@ -57,10 +56,6 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
||||||
|
|
||||||
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE &&
|
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE &&
|
||||||
statement.dialect.URI().DBType != schemas.DAMENG {
|
statement.dialect.URI().DBType != schemas.DAMENG {
|
||||||
=======
|
|
||||||
|
|
||||||
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE {
|
|
||||||
>>>>>>> d24e7cb (Fix insert)
|
|
||||||
if statement.dialect.URI().DBType == schemas.MYSQL {
|
if statement.dialect.URI().DBType == schemas.MYSQL {
|
||||||
if _, err := buf.WriteString(" VALUES ()"); err != nil {
|
if _, err := buf.WriteString(" VALUES ()"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -173,14 +168,20 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.POSTGRES ||
|
if len(table.AutoIncrement) > 0 {
|
||||||
statement.dialect.URI().DBType == schemas.ORACLE) {
|
if statement.dialect.URI().DBType == schemas.POSTGRES {
|
||||||
if _, err := buf.WriteString(" RETURNING "); err != nil {
|
if _, err := buf.WriteString(" RETURNING "); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
|
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
} /* else if statement.dialect.URI().DBType == schemas.ORACLE {
|
||||||
|
if _, err := buf.WriteString(fmt.Sprintf("; select %s.currval from dual",
|
||||||
|
dialects.OracleSeqName(tableName))); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}*/
|
||||||
}
|
}
|
||||||
|
|
||||||
return buf, nil
|
return buf, nil
|
||||||
|
|
Loading…
Reference in New Issue