From b5665ba1a7f0a317da6237389e637f1f9cbdfb0a Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 12 Aug 2021 09:20:43 +0800 Subject: [PATCH] Fix test --- dialects/dialect.go | 12 ++---------- dialects/mssql.go | 6 +++--- dialects/oracle.go | 18 ++++++++++-------- dialects/postgres.go | 6 ++++++ dialects/sqlite3.go | 6 ++++++ internal/statements/insert.go | 14 ++++++-------- session_insert.go | 1 - session_schema.go | 5 ++++- 8 files changed, 37 insertions(+), 31 deletions(-) diff --git a/dialects/dialect.go b/dialects/dialect.go index 3e3a3467..a336a087 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -171,9 +171,9 @@ func (db *Base) DropSequenceSQL(seqName string) (string, error) { } // DropTableSQL returns drop table SQL -func (db *Base) DropTableSQL(tableName, autoincrCol string) ([]string, bool) { +func (db *Base) DropTableSQL(tableName, autoincrCol string) (string, bool) { quote := db.dialect.Quoter().Quote - return []string{fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName))}, true + return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true } // HasRecords returns true if the SQL has records returned @@ -333,16 +333,8 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) } } -<<<<<<< HEAD if !col.DefaultIsEmpty { -<<<<<<< HEAD if _, err := bd.WriteString(" DEFAULT "); err != nil { -======= -======= - if col.Default != "" { ->>>>>>> 98251fc (Fix test) - if _, err := bd.WriteString("DEFAULT "); err != nil { ->>>>>>> ab9b694 (Fix test) return "", err } if col.Default == "" { diff --git a/dialects/mssql.go b/dialects/mssql.go index 1a8206ff..6244ce6d 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -421,10 +421,10 @@ func (db *mssql) AutoIncrStr() string { return "IDENTITY" } -func (db *mssql) DropTableSQL(tableName, autoincrCol string) ([]string, bool) { - return []string{fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ +func (db *mssql) DropTableSQL(tableName, autoincrCol string) (string, bool) { + return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ "object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+ - "DROP TABLE \"%s\"", tableName, tableName)}, true + "DROP TABLE \"%s\"", tableName, tableName), true } func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string { diff --git a/dialects/oracle.go b/dialects/oracle.go index d2561e0c..5137f6bd 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -598,6 +598,14 @@ func (db *oracle) ColumnTypeKind(t string) int { } } +func (db *oracle) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) { + var cnt int + if err := queryer.QueryRowContext(ctx, "SELECT COUNT(*) FROM user_sequences WHERE sequence_name = :1", seqName).Scan(&cnt); err != nil { + return false, err + } + return cnt > 0, nil +} + func (db *oracle) AutoIncrStr() string { return "AUTO_INCREMENT" } @@ -607,14 +615,8 @@ func (db *oracle) IsReserved(name string) bool { return ok } -func (db *oracle) DropTableSQL(tableName, autoincrCol string) ([]string, bool) { - var sqls = []string{ - fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)), - } - if autoincrCol != "" { - sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", OracleSeqName(tableName))) - } - return sqls, false +func (db *oracle) DropTableSQL(tableName, autoincrCol string) (string, bool) { + return fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)), false } func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { diff --git a/dialects/postgres.go b/dialects/postgres.go index 822d3a70..3fe71dcc 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -876,6 +876,12 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { } } +func (db *postgres) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: IncrAutoincrMode, + } +} + func (db *postgres) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 4ff9a39e..731b87f1 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -207,6 +207,12 @@ func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { } } +func (db *sqlite3) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: IncrAutoincrMode, + } +} + func (db *sqlite3) SQLType(c *schemas.Column) string { switch t := c.SQLType.Name; t { case schemas.Bool: diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 15978a43..8583b3b8 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -168,14 +168,12 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) } } - if len(table.AutoIncrement) > 0 { - if statement.dialect.URI().DBType == schemas.POSTGRES { - if _, err := buf.WriteString(" RETURNING "); err != nil { - return nil, err - } - if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { - return nil, err - } + if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { + if _, err := buf.WriteString(" RETURNING "); err != nil { + return nil, err + } + if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { + return nil, err } } diff --git a/session_insert.go b/session_insert.go index 09c4b0f7..f6846a9d 100644 --- a/session_insert.go +++ b/session_insert.go @@ -91,7 +91,6 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, e colMultiPlaces []string args []interface{} cols []*schemas.Column - insertCnt int ) for i := 0; i < size; i++ { diff --git a/session_schema.go b/session_schema.go index bbb5d0be..ef0850cb 100644 --- a/session_schema.go +++ b/session_schema.go @@ -174,7 +174,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { } } - sqlStrs, checkIfExist := session.engine.dialect.DropTableSQL(tableName, autoIncrementCol) + sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(tableName, autoIncrementCol) if !checkIfExist { exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) if err != nil { @@ -189,7 +189,10 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { if _, err := session.exec(sqlStr); err != nil { return err } +<<<<<<< HEAD +======= +>>>>>>> 1805a60 (Fix test) if session.engine.dialect.Features().AutoincrMode == dialects.IncrAutoincrMode { return nil }