diff --git a/dialects/dialect.go b/dialects/dialect.go index b074d485..806d8949 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -43,8 +43,6 @@ type Dialect interface { SetQuotePolicy(quotePolicy QuotePolicy) AutoIncrStr() string - SupportInsertMany() bool - SupportDropIfExists() bool GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) IndexCheckSQL(tableName, idxName string) (string, []interface{}) @@ -52,9 +50,9 @@ type Dialect interface { DropIndexSQL(tableName string, index *schemas.Index) string GetTables(ctx context.Context) ([]*schemas.Table, error) - TableCheckSQL(tableName string) (string, []interface{}) - CreateTableSQL(table *schemas.Table, tableName string) string - DropTableSQL(tableName string) string + IsTableExist(ctx context.Context, tableName string) (bool, error) + CreateTableSQL(table *schemas.Table, tableName string) (string, bool) + DropTableSQL(tableName string) (string, bool) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error) @@ -149,13 +147,9 @@ func (b *Base) FormatBytes(bs []byte) string { return fmt.Sprintf("0x%x", bs) } -func (db *Base) SupportDropIfExists() bool { - return true -} - -func (db *Base) DropTableSQL(tableName string) string { +func (db *Base) DropTableSQL(tableName string) (string, bool) { quote := db.dialect.Quoter().Quote - return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)) + return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true } func (db *Base) HasRecords(ctx context.Context, query string, args ...interface{}) (bool, error) { diff --git a/dialects/mssql.go b/dialects/mssql.go index 06ab0b78..0d857f32 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -281,10 +281,6 @@ func (db *mssql) SQLType(c *schemas.Column) string { return res } -func (db *mssql) SupportInsertMany() bool { - return true -} - func (db *mssql) IsReserved(name string) bool { _, ok := mssqlReservedWords[strings.ToUpper(name)] return ok @@ -311,10 +307,10 @@ func (db *mssql) AutoIncrStr() string { return "IDENTITY" } -func (db *mssql) DropTableSQL(tableName string) string { +func (db *mssql) DropTableSQL(tableName string) (string, bool) { return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ "object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+ - "DROP TABLE \"%s\"", tableName, tableName) + "DROP TABLE \"%s\"", tableName, tableName), true } func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { @@ -329,10 +325,9 @@ func (db *mssql) IsColumnExist(ctx context.Context, tableName, colName string) ( return db.HasRecords(ctx, query, tableName, colName) } -func (db *mssql) TableCheckSQL(tableName string) (string, []interface{}) { - args := []interface{}{} +func (db *mssql) IsTableExist(ctx context.Context, tableName string) (bool, error) { sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1" - return sql, args + return db.HasRecords(ctx, sql) } func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { @@ -491,7 +486,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? return indexes, nil } -func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) string { +func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { var sql string if tableName == "" { tableName = table.Name @@ -522,7 +517,7 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) string { sql = sql[:len(sql)-2] + ")" sql += ";" - return sql + return sql, true } func (db *mssql) ForUpdateSQL(query string) string { diff --git a/dialects/mysql.go b/dialects/mysql.go index 364f22b6..3c8d3c2a 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -270,10 +270,6 @@ func (db *mysql) SQLType(c *schemas.Column) string { return res } -func (db *mysql) SupportInsertMany() bool { - return true -} - func (db *mysql) IsReserved(name string) bool { _, ok := mysqlReservedWords[strings.ToUpper(name)] return ok @@ -290,10 +286,9 @@ func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{} return sql, args } -func (db *mysql) TableCheckSQL(tableName string) (string, []interface{}) { - args := []interface{}{db.uri.DBName, tableName} +func (db *mysql) IsTableExist(ctx context.Context, tableName string) (bool, error) { sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" - return sql, args + return db.HasRecords(ctx, sql, db.uri.DBName, tableName) } func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { @@ -512,7 +507,7 @@ func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]* return indexes, nil } -func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) string { +func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { var sql = "CREATE TABLE IF NOT EXISTS " if tableName == "" { tableName = table.Name @@ -565,7 +560,7 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) string { if db.rowFormat != "" { sql += " ROW_FORMAT=" + db.rowFormat } - return sql + return sql, true } func (db *mysql) Filters() []Filter { diff --git a/dialects/oracle.go b/dialects/oracle.go index e0d83115..fd2f0fd1 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -547,24 +547,16 @@ func (db *oracle) AutoIncrStr() string { return "AUTO_INCREMENT" } -func (db *oracle) SupportInsertMany() bool { - return true -} - func (db *oracle) IsReserved(name string) bool { _, ok := oracleReservedWords[strings.ToUpper(name)] return ok } -func (db *oracle) SupportDropIfExists() bool { - return false +func (db *oracle) DropTableSQL(tableName string) (string, bool) { + return fmt.Sprintf("DROP TABLE `%s`", tableName), false } -func (db *oracle) DropTableSQL(tableName string) string { - return fmt.Sprintf("DROP TABLE `%s`", tableName) -} - -func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) string { +func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { var sql = "CREATE TABLE " if tableName == "" { tableName = table.Name @@ -593,7 +585,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) string } sql = sql[:len(sql)-2] + ")" - return sql + return sql, false } func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { @@ -619,26 +611,15 @@ func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []interface{ `WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args } -func (db *oracle) TableCheckSQL(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return `SELECT table_name FROM user_tables WHERE table_name = :1`, args +func (db *oracle) IsTableExist(ctx context.Context, tableName string) (bool, error) { + return db.HasRecords(ctx, `SELECT table_name FROM user_tables WHERE table_name = :1`, tableName) } func (db *oracle) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { args := []interface{}{tableName, colName} query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" + " AND column_name = :2" - - rows, err := db.DB().QueryContext(ctx, query, args...) - if err != nil { - return false, err - } - defer rows.Close() - - if rows.Next() { - return true, nil - } - return false, nil + return db.HasRecords(ctx, query, args...) } func (db *oracle) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { diff --git a/dialects/postgres.go b/dialects/postgres.go index 31cd49b6..69100627 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -884,10 +884,6 @@ func (db *postgres) SQLType(c *schemas.Column) string { return res } -func (db *postgres) SupportInsertMany() bool { - return true -} - func (db *postgres) IsReserved(name string) bool { _, ok := postgresReservedWords[strings.ToUpper(name)] return ok @@ -897,7 +893,7 @@ func (db *postgres) AutoIncrStr() string { return "" } -func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) string { +func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { var sql string sql = "CREATE TABLE IF NOT EXISTS " if tableName == "" { @@ -932,7 +928,7 @@ func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) strin } sql += ")" - return sql + return sql, true } func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { @@ -946,14 +942,13 @@ func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interfac `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } -func (db *postgres) TableCheckSQL(tableName string) (string, []interface{}) { +func (db *postgres) IsTableExist(ctx context.Context, tableName string) (bool, error) { if len(db.uri.Schema) == 0 { - args := []interface{}{tableName} - return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args + return db.HasRecords(ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName) } - args := []interface{}{db.uri.Schema, tableName} - return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args + return db.HasRecords(ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`, + db.uri.Schema, tableName) } func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 212c5a8e..4bd3147d 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -211,10 +211,6 @@ func (db *sqlite3) FormatBytes(bs []byte) string { return fmt.Sprintf("X'%x'", bs) } -func (db *sqlite3) SupportInsertMany() bool { - return true -} - func (db *sqlite3) IsReserved(name string) bool { _, ok := sqlite3ReservedWords[strings.ToUpper(name)] return ok @@ -229,9 +225,8 @@ func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []interface return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args } -func (db *sqlite3) TableCheckSQL(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args +func (db *sqlite3) IsTableExist(ctx context.Context, tableName string) (bool, error) { + return db.HasRecords(ctx, "SELECT name FROM sqlite_master WHERE type='table' and name = ?", tableName) } func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { @@ -249,7 +244,7 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) string { +func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { var sql string sql = "CREATE TABLE IF NOT EXISTS " if tableName == "" { @@ -284,7 +279,7 @@ func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) string } sql += ")" - return sql + return sql, true } func (db *sqlite3) ForUpdateSQL(query string) string { diff --git a/engine.go b/engine.go index c330e9f5..8d77aeef 100644 --- a/engine.go +++ b/engine.go @@ -125,14 +125,6 @@ func (engine *Engine) SetColumnMapper(mapper names.Mapper) { engine.tagParser.SetColumnMapper(mapper) } -// SupportInsertMany If engine's database support batch insert records like -// "insert into user values (name, age), (name, age)". -// When the return is ture, then engine.Insert(&users) will -// generate batch sql and exeute. -func (engine *Engine) SupportInsertMany() bool { - return engine.dialect.SupportInsertMany() -} - // Quote Use QuoteStr quote the string sql func (engine *Engine) Quote(value string) string { value = strings.TrimSpace(value) @@ -388,7 +380,8 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } } - _, err = io.WriteString(w, dialect.CreateTableSQL(table, "")+";\n") + s, _ := dialect.CreateTableSQL(table, "") + _, err = io.WriteString(w, s+";\n") if err != nil { return err } diff --git a/internal/statements/statement.go b/internal/statements/statement.go index d6dd58b1..0b670da5 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -643,7 +643,8 @@ func (statement *Statement) genColumnStr() string { func (statement *Statement) GenCreateTableSQL() string { statement.RefTable.StoreEngine = statement.StoreEngine statement.RefTable.Charset = statement.Charset - return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName()) + s, _ := statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName()) + return s } func (statement *Statement) GenIndexSQL() []string { diff --git a/session_insert.go b/session_insert.go index e5368571..91257f0a 100644 --- a/session_insert.go +++ b/session_insert.go @@ -75,21 +75,11 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { return 0, ErrNoElementsOnSlice } - if session.engine.SupportInsertMany() { - cnt, err := session.innerInsertMulti(bean) - if err != nil { - return affected, err - } - affected += cnt - } else { - for i := 0; i < size; i++ { - cnt, err := session.innerInsert(sliceValue.Index(i).Interface()) - if err != nil { - return affected, err - } - affected += cnt - } + cnt, err := session.innerInsertMulti(bean) + if err != nil { + return affected, err } + affected += cnt } else { cnt, err := session.innerInsert(bean) if err != nil { diff --git a/session_schema.go b/session_schema.go index 6d363521..047be23d 100644 --- a/session_schema.go +++ b/session_schema.go @@ -124,18 +124,16 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { func (session *Session) dropTable(beanOrTableName interface{}) error { tableName := session.engine.TableName(beanOrTableName) - var needDrop = true - if !session.engine.dialect.SupportDropIfExists() { - sqlStr, args := session.engine.dialect.TableCheckSQL(tableName) - results, err := session.queryBytes(sqlStr, args...) + sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true)) + if !checkIfExist { + exist, err := session.engine.dialect.IsTableExist(session.ctx, tableName) if err != nil { return err } - needDrop = len(results) > 0 + checkIfExist = exist } - if needDrop { - sqlStr := session.engine.Dialect().DropTableSQL(session.engine.TableName(tableName, true)) + if checkIfExist { _, err := session.exec(sqlStr) return err } @@ -154,9 +152,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) } func (session *Session) isTableExist(tableName string) (bool, error) { - sqlStr, args := session.engine.dialect.TableCheckSQL(tableName) - results, err := session.queryBytes(sqlStr, args...) - return len(results) > 0, err + return session.engine.dialect.IsTableExist(session.ctx, tableName) } // IsTableEmpty if table have any records