diff --git a/dialects/dameng.go b/dialects/dameng.go index 52b4592a..e988887d 100644 --- a/dialects/dameng.go +++ b/dialects/dameng.go @@ -559,9 +559,12 @@ func (db *dameng) SQLType(c *schemas.Column) string { res = "NUMBER" case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: return schemas.Blob - case schemas.Date, schemas.Time, schemas.DateTime, schemas.TimeStamp: - res = schemas.Date - return res + case schemas.Date: + return schemas.Date + case schemas.Time: + return schemas.Time + case schemas.DateTime, schemas.TimeStamp: + return schemas.TimeStamp case schemas.TimeStampz: res = "TIMESTAMP" case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal: @@ -613,7 +616,12 @@ func (db *dameng) DropTableSQL(tableName string) (string, bool) { return fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)), false } -func (db *dameng) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { +// SeqName returns sequence name for some table +func SeqName(tableName string) string { + return "SEQ_" + strings.ToUpper(tableName) +} + +func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) ([]string, bool, error) { if tableName == "" { tableName = table.Name } @@ -645,7 +653,38 @@ func (db *dameng) CreateTableSQL(table *schemas.Table, tableName string) ([]stri } b.WriteString(")") - return []string{b.String()}, false + var seqName = SeqName(tableName) + if table.AutoIncrColumn() != nil { + var cnt int + rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM user_sequences WHERE sequence_name = ?", seqName) + if err != nil { + return nil, false, err + } + defer rows.Close() + if !rows.Next() { + if rows.Err() != nil { + return nil, false, rows.Err() + } + return nil, false, errors.New("query sequence failed") + } + + if err := rows.Scan(&cnt); err != nil { + return nil, false, err + } + + if cnt == 0 { + var sql2 = fmt.Sprintf(`CREATE sequence %s + minvalue 1 + nomaxvalue + start with 1 + increment by 1 + nocycle + nocache`, seqName) + return []string{b.String(), sql2}, false, nil + } + } + + return []string{b.String()}, false, nil } func (db *dameng) SetQuotePolicy(quotePolicy QuotePolicy) { @@ -826,13 +865,13 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam switch dt { case "VARCHAR2": + col.SQLType = schemas.SQLType{Name: "VARCHAR2", DefaultLength: len1, DefaultLength2: len2} + case "VARCHAR": col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: len1, DefaultLength2: len2} - case "NVARCHAR2": - col.SQLType = schemas.SQLType{Name: schemas.NVarchar, DefaultLength: len1, DefaultLength2: len2} case "TIMESTAMP WITH TIME ZONE": col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} case "NUMBER": - col.SQLType = schemas.SQLType{Name: schemas.Double, DefaultLength: len1, DefaultLength2: len2} + col.SQLType = schemas.SQLType{Name: "NUMBER", DefaultLength: len1, DefaultLength2: len2} case "LONG", "LONG RAW", "NCLOB", "CLOB": col.SQLType = schemas.SQLType{Name: schemas.Text, DefaultLength: 0, DefaultLength2: 0} case "RAW": @@ -959,7 +998,7 @@ type damengDriver struct { // Features return features func (p *damengDriver) Features() *DriverFeatures { return &DriverFeatures{ - SupportReturnInsertedID: true, + SupportReturnInsertedID: false, } } diff --git a/dialects/dialect.go b/dialects/dialect.go index b6c0853a..49198a70 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -61,7 +61,7 @@ type Dialect interface { GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) - CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) + CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) ([]string, bool, error) DropTableSQL(tableName string) (string, bool) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) @@ -285,43 +285,35 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) return "", err } - if err := bd.WriteByte(' '); err != nil { - return "", err - } - if includePrimaryKey && col.IsPrimaryKey { - if _, err := bd.WriteString("PRIMARY KEY "); err != nil { + if _, err := bd.WriteString(" PRIMARY KEY"); err != nil { return "", err } - if col.IsAutoIncrement { - if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil { + if err := bd.WriteByte(' '); err != nil { return "", err } - if err := bd.WriteByte(' '); err != nil { + if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil { return "", err } } } if col.Default != "" { - if _, err := bd.WriteString("DEFAULT "); err != nil { + if _, err := bd.WriteString(" DEFAULT "); err != nil { return "", err } if _, err := bd.WriteString(col.Default); err != nil { return "", err } - if err := bd.WriteByte(' '); err != nil { - return "", err - } } if col.Nullable { - if _, err := bd.WriteString("NULL "); err != nil { + if _, err := bd.WriteString(" NULL"); err != nil { return "", err } } else { - if _, err := bd.WriteString("NOT NULL "); err != nil { + if _, err := bd.WriteString(" NOT NULL"); err != nil { return "", err } } diff --git a/dialects/mysql.go b/dialects/mysql.go index 0489904a..5cf9ffc9 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -678,7 +678,11 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin b.WriteString(" ROW_FORMAT=") b.WriteString(db.rowFormat) } +<<<<<<< HEAD return []string{b.String()}, true +======= + return []string{sql}, true, nil +>>>>>>> 4dbe145 (fix insert) } func (db *mysql) Filters() []Filter { diff --git a/dialects/oracle.go b/dialects/oracle.go index 11a6653b..e3188df9 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -599,7 +599,7 @@ func (db *oracle) DropTableSQL(tableName string) (string, bool) { return fmt.Sprintf("DROP TABLE `%s`", tableName), false } -func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { +func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) ([]string, bool, error) { var sql = "CREATE TABLE " if tableName == "" { tableName = table.Name @@ -629,7 +629,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]stri } sql = sql[:len(sql)-2] + ")" - return []string{sql}, false + return []string{sql}, false, nil } func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { diff --git a/engine.go b/engine.go index 8937d0a1..ca258185 100644 --- a/engine.go +++ b/engine.go @@ -436,7 +436,7 @@ func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp .. // DumpTables dump specify tables to io.Writer func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { - return engine.dumpTables(tables, w, tp...) + return engine.dumpTables(context.Background(), tables, w, tp...) } func formatBool(s string, dstDialect dialects.Dialect) string { @@ -452,7 +452,7 @@ func formatBool(s string, dstDialect dialects.Dialect) string { } // dumpTables dump database all table structs and data to w with specify db type -func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { +func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { var dstDialect dialects.Dialect if len(tp) == 0 { dstDialect = engine.dialect @@ -504,7 +504,10 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } } - sqls, _ := dstDialect.CreateTableSQL(dstTable, dstTableName) + sqls, _, err := dstDialect.CreateTableSQL(ctx, engine.db, dstTable, dstTableName) + if err != nil { + return err + } for _, s := range sqls { _, err = io.WriteString(w, s+";\n") if err != nil { diff --git a/integrations/cache_test.go b/integrations/cache_test.go index e07d7e21..80cd45e8 100644 --- a/integrations/cache_test.go +++ b/integrations/cache_test.go @@ -166,14 +166,14 @@ func TestCacheGet(t *testing.T) { assert.NoError(t, err) var box1 MailBox3 - has, err := testEngine.Where("id = ?", inserts[0].Id).Get(&box1) + has, err := testEngine.Where(testEngine.Quote("id")+" = ?", inserts[0].Id).Get(&box1) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "user1", box1.Username) assert.EqualValues(t, "pass1", box1.Password) var box2 MailBox3 - has, err = testEngine.Where("id = ?", inserts[0].Id).Get(&box2) + has, err = testEngine.Where(testEngine.Quote("id")+" = ?", inserts[0].Id).Get(&box2) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "user1", box2.Username) diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 84547cdf..72b6e36c 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -10,6 +10,7 @@ import ( "strings" "xorm.io/builder" + "xorm.io/xorm/dialects" "xorm.io/xorm/schemas" ) @@ -42,7 +43,11 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if len(colNames) <= 0 { + var hasInsertColumns = len(colNames) > 0 + var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) + + if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE && + statement.dialect.URI().DBType != schemas.DAMENG { if statement.dialect.URI().DBType == schemas.MYSQL { if _, err := buf.WriteString(" VALUES ()"); err != nil { return "", nil, err @@ -60,6 +65,10 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } + if needSeq { + colNames = append(colNames, table.AutoIncrement) + } + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames()...), ","); err != nil { return "", nil, err } @@ -113,6 +122,18 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } + // Insert tablename (id) Values(seq_tablename.nextval) + if needSeq { + if hasInsertColumns { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + } + if _, err := buf.WriteString(dialects.SeqName(tableName) + ".nextval"); err != nil { + return "", nil, err + } + } + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 1fcc0bba..5be46ef9 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -642,14 +642,6 @@ func (statement *Statement) genColumnStr() string { return buf.String() } -// GenCreateTableSQL generated create table SQL -func (statement *Statement) GenCreateTableSQL() []string { - statement.RefTable.StoreEngine = statement.StoreEngine - statement.RefTable.Charset = statement.Charset - s, _ := statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName()) - return s -} - // GenIndexSQL generated create index SQL func (statement *Statement) GenIndexSQL() []string { var sqls []string diff --git a/schemas/type.go b/schemas/type.go index be25230d..d192bac6 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -106,12 +106,14 @@ var ( Integer = "INTEGER" BigInt = "BIGINT" UnsignedBigInt = "UNSIGNED BIGINT" + Number = "NUMBER" Enum = "ENUM" Set = "SET" Char = "CHAR" Varchar = "VARCHAR" + VARCHAR2 = "VARCHAR2" NChar = "NCHAR" NVarchar = "NVARCHAR" TinyText = "TINYTEXT" @@ -175,6 +177,7 @@ var ( Integer: NUMERIC_TYPE, BigInt: NUMERIC_TYPE, UnsignedBigInt: NUMERIC_TYPE, + Number: NUMERIC_TYPE, Enum: TEXT_TYPE, Set: TEXT_TYPE, @@ -186,6 +189,7 @@ var ( Char: TEXT_TYPE, NChar: TEXT_TYPE, Varchar: TEXT_TYPE, + VARCHAR2: TEXT_TYPE, NVarchar: TEXT_TYPE, TinyText: TEXT_TYPE, Text: TEXT_TYPE, diff --git a/session_insert.go b/session_insert.go index a8f365c7..09873b68 100644 --- a/session_insert.go +++ b/session_insert.go @@ -307,16 +307,39 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { // if there is auto increment column and driver don't support return it if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID { - var sql = sqlStr - if session.engine.dialect.URI().DBType == schemas.ORACLE { - sql = "select seq_atable.currval from dual" + var sql string + var newArgs []interface{} + var needCommit bool + if session.engine.dialect.URI().DBType == schemas.ORACLE || session.engine.dialect.URI().DBType == schemas.DAMENG { + if session.isAutoCommit { // if it's not in transaction + if err := session.Begin(); err != nil { + return 0, err + } + needCommit = true + } + _, err := session.exec(sqlStr, args...) + if err != nil { + return 0, err + } + sql = fmt.Sprintf("select %s.currval from dual", dialects.SeqName(tableName)) + } else { + sql = sqlStr + newArgs = args } - rows, err := session.queryRows(sql, args...) + var id int64 + err := session.queryRow(sql, newArgs...).Scan(&id) if err != nil { return 0, err } - defer rows.Close() + if needCommit { + if err := session.Commit(); err != nil { + return 0, err + } + } + if id == 0 { + return 0, errors.New("insert successfully but not returned id") + } defer handleAfterInsertProcessorFunc(bean) @@ -331,16 +354,6 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { } } - var id int64 - if !rows.Next() { - if rows.Err() != nil { - return 0, rows.Err() - } - return 0, errors.New("insert successfully but not returned id") - } - if err := rows.Scan(&id); err != nil { - return 1, err - } aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { session.engine.logger.Errorf("%v", err) diff --git a/session_schema.go b/session_schema.go index 2e64350f..7bbe75f8 100644 --- a/session_schema.go +++ b/session_schema.go @@ -6,6 +6,7 @@ package xorm import ( "bufio" + "context" "database/sql" "fmt" "io" @@ -40,7 +41,13 @@ func (session *Session) createTable(bean interface{}) error { return err } - sqlStrs := session.statement.GenCreateTableSQL() + session.statement.RefTable.StoreEngine = session.statement.StoreEngine + session.statement.RefTable.Charset = session.statement.Charset + sqlStrs, _, err := session.engine.dialect.CreateTableSQL(context.Background(), session.engine.db, session.statement.RefTable, session.statement.TableName()) + if err != nil { + return err + } + for _, s := range sqlStrs { _, err := session.exec(s) if err != nil {