diff --git a/dialects/dialect.go b/dialects/dialect.go index c6ce3653..98ad73ae 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -43,8 +43,10 @@ const ( SequenceAutoincrMode ) +// DialectFeatures represents the features that the dialect supports type DialectFeatures struct { - AutoincrMode int // 0 autoincrement column, 1 sequence + AutoincrMode int // 0 autoincrement column, 1 sequence + SupportReturnIDWhenInsert bool } // Dialect represents a kind of database diff --git a/dialects/oracle.go b/dialects/oracle.go index 4f3975d0..4965eab6 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -542,7 +542,8 @@ func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.V func (db *oracle) Features() *DialectFeatures { return &DialectFeatures{ - AutoincrMode: SequenceAutoincrMode, + AutoincrMode: SequenceAutoincrMode, + SupportReturnIDWhenInsert: false, } } @@ -553,8 +554,9 @@ func (db *oracle) SQLType(c *schemas.Column) string { res = "NUMBER" case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: return schemas.Blob - case schemas.Time, schemas.DateTime, schemas.TimeStamp: + case schemas.Date, schemas.Time, schemas.DateTime, schemas.TimeStamp: res = schemas.Date + return res case schemas.TimeStampz: res = "TIMESTAMP" case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal: @@ -607,42 +609,44 @@ func (db *oracle) DropTableSQL(tableName, autoincrCol string) ([]string, bool) { fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)), } if autoincrCol != "" { - sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", seqName(tableName))) + sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", OracleSeqName(tableName))) } return sqls, false } func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { - var sql = "CREATE TABLE " if tableName == "" { tableName = table.Name } quoter := db.Quoter() - sql += quoter.Quote(tableName) + " (" + var b strings.Builder + b.WriteString("CREATE TABLE ") + quoter.QuoteTo(&b, tableName) + b.WriteString(" (") pkList := table.PrimaryKeys - for _, colName := range table.ColumnsSeq() { + for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - /*if col.IsPrimaryKey && len(pkList) == 1 { - sql += col.String(b.dialect) - } else {*/ s, _ := ColumnString(db, col, false) - sql += s - // } - sql = strings.TrimSpace(sql) - sql += ", " + b.WriteString(s) + if i != len(table.ColumnsSeq())-1 { + b.WriteString(", ") + } } if len(pkList) > 0 { - sql += "PRIMARY KEY ( " - sql += quoter.Join(pkList, ",") - sql += " ), " + if len(table.ColumnsSeq()) > 0 { + b.WriteString(", ") + } + b.WriteString("PRIMARY KEY (") + quoter.JoinWrite(&b, pkList, ",") + b.WriteString(")") } + b.WriteString(")") - sql = sql[:len(sql)-2] + ")" - return sql, false, nil + return b.String(), false, nil } func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { @@ -679,7 +683,7 @@ func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, table return db.HasRecords(queryer, ctx, query, args...) } -func seqName(tableName string) string { +func OracleSeqName(tableName string) string { return "SEQ_" + strings.ToUpper(tableName) } @@ -746,7 +750,7 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam if pkName != "" && pkName == col.Name { col.IsPrimaryKey = true - has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", seqName(tableName)) + has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", OracleSeqName(tableName)) if err != nil { return nil, nil, err } diff --git a/integrations/processors_test.go b/integrations/processors_test.go index b32f6fbb..33c71edc 100644 --- a/integrations/processors_test.go +++ b/integrations/processors_test.go @@ -10,6 +10,7 @@ import ( "testing" "xorm.io/xorm" + "xorm.io/xorm/dialects" "github.com/stretchr/testify/assert" ) @@ -885,11 +886,12 @@ func TestAfterLoadProcessor(t *testing.T) { } type AfterInsertStruct struct { - Id int64 + Id int64 + Dialect dialects.Dialect `xorm:"-"` } func (a *AfterInsertStruct) AfterInsert() { - if a.Id == 0 { + if a.Dialect.Features().SupportReturnIDWhenInsert && a.Id == 0 { panic("a.Id") } } @@ -899,6 +901,8 @@ func TestAfterInsert(t *testing.T) { assertSync(t, new(AfterInsertStruct)) - _, err := testEngine.Insert(&AfterInsertStruct{}) + _, err := testEngine.Insert(&AfterInsertStruct{ + Dialect: testEngine.Dialect(), + }) assert.NoError(t, err) } diff --git a/internal/statements/insert.go b/internal/statements/insert.go index d773f64b..127a4449 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -44,7 +44,6 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) } var hasInsertColumns = len(colNames) > 0 -<<<<<<< HEAD var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) if needSeq { for _, col := range colNames { @@ -57,10 +56,6 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE && statement.dialect.URI().DBType != schemas.DAMENG { -======= - - if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE { ->>>>>>> d24e7cb (Fix insert) if statement.dialect.URI().DBType == schemas.MYSQL { if _, err := buf.WriteString(" VALUES ()"); err != nil { return nil, err @@ -173,14 +168,20 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) } } - if len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.POSTGRES || - statement.dialect.URI().DBType == schemas.ORACLE) { - if _, err := buf.WriteString(" RETURNING "); err != nil { - return nil, err - } - if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { - return nil, err - } + if len(table.AutoIncrement) > 0 { + if statement.dialect.URI().DBType == schemas.POSTGRES { + if _, err := buf.WriteString(" RETURNING "); err != nil { + return nil, err + } + if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { + return nil, err + } + } /* else if statement.dialect.URI().DBType == schemas.ORACLE { + if _, err := buf.WriteString(fmt.Sprintf("; select %s.currval from dual", + dialects.OracleSeqName(tableName))); err != nil { + return nil, err + } + }*/ } return buf, nil