diff --git a/dialects/dameng.go b/dialects/dameng.go index a33809de..5ecea321 100644 --- a/dialects/dameng.go +++ b/dialects/dameng.go @@ -18,7 +18,6 @@ import ( "gitee.com/travelliu/dm" "xorm.io/xorm/core" "xorm.io/xorm/internal/convert" - "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -529,7 +528,7 @@ func (db *dameng) Init(uri *URI) error { } func (db *dameng) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { - rows, err := queryer.QueryContext(ctx, "select * from v$version where banner like 'Oracle%'") + rows, err := queryer.QueryContext(ctx, "SELECT * FROM V$VERSION") // select id_code if err != nil { return nil, err } @@ -553,7 +552,8 @@ func (db *dameng) Version(ctx context.Context, queryer core.Queryer) (*schemas.V func (db *dameng) Features() *DialectFeatures { return &DialectFeatures{ - AutoincrMode: SequenceAutoincrMode, + AutoincrMode: SequenceAutoincrMode, + SupportSequence: true, } } @@ -570,8 +570,12 @@ func (db *dameng) SQLType(c *schemas.Column) string { return "BIGINT" case schemas.Bit, schemas.Bool: return schemas.Bit - case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: - return schemas.Binary + case schemas.Binary: + if c.Length == 0 { + return schemas.Binary + "(MAX)" + } + case schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: + return schemas.VarBinary case schemas.Date: return schemas.Date case schemas.Time: @@ -635,7 +639,7 @@ func (db *dameng) DropTableSQL(tableName string) (string, bool) { return fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)), false } -func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) ([]string, bool, error) { +func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { if tableName == "" { tableName = table.Name } @@ -667,38 +671,7 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl } b.WriteString(")") - var seqName = utils.SeqName(tableName) - if table.AutoIncrColumn() != nil { - var cnt int - rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM user_sequences WHERE sequence_name = ?", seqName) - if err != nil { - return nil, false, err - } - defer rows.Close() - if !rows.Next() { - if rows.Err() != nil { - return nil, false, rows.Err() - } - return nil, false, errors.New("query sequence failed") - } - - if err := rows.Scan(&cnt); err != nil { - return nil, false, err - } - - if cnt == 0 { - var sql2 = fmt.Sprintf(`CREATE sequence %s - minvalue 1 - nomaxvalue - start with 1 - increment by 1 - nocycle - nocache`, seqName) - return []string{b.String(), sql2}, false, nil - } - } - - return []string{b.String()}, false, nil + return b.String(), false, nil } func (db *dameng) SetQuotePolicy(quotePolicy QuotePolicy) { @@ -728,6 +701,26 @@ func (db *dameng) IsTableExist(queryer core.Queryer, ctx context.Context, tableN return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = ?`, tableName) } +func (db *dameng) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) { + var cnt int + rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM user_sequences WHERE sequence_name = ?", seqName) + if err != nil { + return false, err + } + defer rows.Close() + if !rows.Next() { + if rows.Err() != nil { + return false, rows.Err() + } + return false, errors.New("query sequence failed") + } + + if err := rows.Scan(&cnt); err != nil { + return false, err + } + return cnt > 0, nil +} + func (db *dameng) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { args := []interface{}{tableName, colName} query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + @@ -839,7 +832,8 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam col.Name = strings.Trim(colName.String, `" `) if colDefault.valid { col.Default = colDefault.data - col.DefaultIsEmpty = false + } else { + col.DefaultIsEmpty = true } if nullable.String == "Y" { @@ -1052,12 +1046,21 @@ func (d *damengDriver) GenScanResult(colType string) (interface{}, error) { case "NUMBER": var s sql.NullString return &s, nil - case "DATE": - var s sql.NullTime + case "BIGINT": + var s sql.NullInt64 + return &s, nil + case "INTEGER": + var s sql.NullInt32 + return &s, nil + case "DATE", "TIMESTAMP": + var s sql.NullString return &s, nil case "BLOB": var r sql.RawBytes return &r, nil + case "FLOAT": + var s sql.NullFloat64 + return &s, nil default: var r sql.RawBytes return &r, nil diff --git a/dialects/dialect.go b/dialects/dialect.go index e9b9a5c8..f5b04368 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -44,7 +44,8 @@ const ( ) type DialectFeatures struct { - AutoincrMode int // 0 autoincrement column, 1 sequence + AutoincrMode int // 0 autoincrement column, 1 sequence + SupportSequence bool } // Dialect represents a kind of database @@ -71,9 +72,13 @@ type Dialect interface { GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) - CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) ([]string, bool, error) + CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) DropTableSQL(tableName string) (string, bool) + CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) + IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) + DropSequenceSQL(seqName string) (string, error) + GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error) AddColumnSQL(tableName string, col *schemas.Column) string @@ -146,6 +151,24 @@ func (db *Base) CreateTableSQL(table *schemas.Table, tableName string) ([]string return []string{b.String()}, false } +func (db *Base) CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) { + return fmt.Sprintf(`CREATE SEQUENCE %s + minvalue 1 + nomaxvalue + start with 1 + increment by 1 + nocycle + nocache`, seqName), nil +} + +func (db *Base) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) { + return false, fmt.Errorf("unsupported sequence feature") +} + +func (db *Base) DropSequenceSQL(seqName string) (string, error) { + return fmt.Sprintf("DROP SEQUENCE %s", seqName), nil +} + // DropTableSQL returns drop table SQL func (db *Base) DropTableSQL(tableName string) (string, bool) { quote := db.dialect.Quoter().Quote @@ -309,7 +332,7 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) } } - if col.Default != "" { + if !col.DefaultIsEmpty { if _, err := bd.WriteString(" DEFAULT "); err != nil { return "", err } diff --git a/dialects/mysql.go b/dialects/mysql.go index f3e2adc8..5f3e17ec 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -684,11 +684,15 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin b.WriteString(" ROW_FORMAT=") b.WriteString(db.rowFormat) } +<<<<<<< HEAD <<<<<<< HEAD return []string{b.String()}, true ======= return []string{sql}, true, nil >>>>>>> 4dbe145 (fix insert) +======= + return sql, true, nil +>>>>>>> 21b6352 (Fix more bugs) } func (db *mysql) Filters() []Filter { diff --git a/dialects/oracle.go b/dialects/oracle.go index 63caa646..04652bd6 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -605,7 +605,7 @@ func (db *oracle) DropTableSQL(tableName string) (string, bool) { return fmt.Sprintf("DROP TABLE `%s`", tableName), false } -func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) ([]string, bool, error) { +func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { var sql = "CREATE TABLE " if tableName == "" { tableName = table.Name @@ -635,7 +635,7 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl } sql = sql[:len(sql)-2] + ")" - return []string{sql}, false, nil + return sql, false, nil } func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { diff --git a/engine.go b/engine.go index e80e6b89..b7c45563 100644 --- a/engine.go +++ b/engine.go @@ -504,16 +504,26 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w } } - sqls, _, err := dstDialect.CreateTableSQL(ctx, engine.db, dstTable, dstTableName) + sqlstr, _, err := dstDialect.CreateTableSQL(ctx, engine.db, dstTable, dstTableName) if err != nil { return err } - for _, s := range sqls { - _, err = io.WriteString(w, s+";\n") + _, err = io.WriteString(w, sqlstr+";\n") + if err != nil { + return err + } + + if dstTable.AutoIncrement != "" && dstDialect.Features().SupportSequence { + sqlstr, err = dstDialect.CreateSequenceSQL(ctx, engine.db, utils.SeqName(dstTableName)) + if err != nil { + return err + } + _, err = io.WriteString(w, sqlstr+";\n") if err != nil { return err } } + if len(dstTable.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL { fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", dstTable.Name) } diff --git a/integrations/session_find_test.go b/integrations/session_find_test.go index 4f9436c9..9b503f25 100644 --- a/integrations/session_find_test.go +++ b/integrations/session_find_test.go @@ -991,7 +991,7 @@ func TestMoreExtends(t *testing.T) { books = make([]MoreExtendsBooksExtend, 0, len(books)) err = testEngine.Table("more_extends_books"). Alias("m"). - Select("m.*, `more_extends_users`.*"). + Select("`m`.*, `more_extends_users`.*"). Join("INNER", "more_extends_users", "`m`.`user_id` = `more_extends_users`.`id`"). Where("`m`.`name` LIKE ?", "abc"). Limit(10, 10). diff --git a/session_schema.go b/session_schema.go index 7bbe75f8..ba622b82 100644 --- a/session_schema.go +++ b/session_schema.go @@ -43,17 +43,26 @@ func (session *Session) createTable(bean interface{}) error { session.statement.RefTable.StoreEngine = session.statement.StoreEngine session.statement.RefTable.Charset = session.statement.Charset - sqlStrs, _, err := session.engine.dialect.CreateTableSQL(context.Background(), session.engine.db, session.statement.RefTable, session.statement.TableName()) + tableName := session.statement.TableName() + refTable := session.statement.RefTable + sqlStr, _, err := session.engine.dialect.CreateTableSQL(context.Background(), session.engine.db, refTable, tableName) if err != nil { return err } + if _, err := session.exec(sqlStr); err != nil { + return err + } - for _, s := range sqlStrs { - _, err := session.exec(s) + if refTable.AutoIncrement != "" && session.engine.dialect.Features().SupportSequence { + sqlStr, err = session.engine.dialect.CreateSequenceSQL(context.Background(), session.engine.db, utils.SeqName(tableName)) if err != nil { return err } + if _, err := session.exec(sqlStr); err != nil { + return err + } } + return nil } @@ -148,11 +157,32 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { checkIfExist = exist } - if checkIfExist { - _, err := session.exec(sqlStr) + if !checkIfExist { + return nil + } + if _, err := session.exec(sqlStr); err != nil { return err } - return nil + + if !session.engine.dialect.Features().SupportSequence { + return nil + } + + var seqName = utils.SeqName(tableName) + exist, err := session.engine.dialect.IsSequenceExist(session.ctx, session.getQueryer(), seqName) + if err != nil { + return err + } + if !exist { + return nil + } + + sqlStr, err = session.engine.dialect.DropSequenceSQL(seqName) + if err != nil { + return err + } + _, err = session.exec(sqlStr) + return err } // IsTableExist if a table is exist