diff --git a/dialects/dialect.go b/dialects/dialect.go index 806d8949..35139817 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -51,7 +51,7 @@ type Dialect interface { GetTables(ctx context.Context) ([]*schemas.Table, error) IsTableExist(ctx context.Context, tableName string) (bool, error) - CreateTableSQL(table *schemas.Table, tableName string) (string, bool) + CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) DropTableSQL(tableName string) (string, bool) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) diff --git a/dialects/mssql.go b/dialects/mssql.go index 0d857f32..6ba2cd97 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -486,7 +486,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? return indexes, nil } -func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { +func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { var sql string if tableName == "" { tableName = table.Name @@ -517,7 +517,7 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) (string, sql = sql[:len(sql)-2] + ")" sql += ";" - return sql, true + return []string{sql}, true } func (db *mssql) ForUpdateSQL(query string) string { diff --git a/dialects/mysql.go b/dialects/mysql.go index 3c8d3c2a..78acf1d0 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -507,7 +507,7 @@ func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]* return indexes, nil } -func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { +func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { var sql = "CREATE TABLE IF NOT EXISTS " if tableName == "" { tableName = table.Name @@ -560,7 +560,7 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) (string, if db.rowFormat != "" { sql += " ROW_FORMAT=" + db.rowFormat } - return sql, true + return []string{sql}, true } func (db *mysql) Filters() []Filter { diff --git a/dialects/oracle.go b/dialects/oracle.go index fd2f0fd1..045ad99b 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -556,7 +556,7 @@ func (db *oracle) DropTableSQL(tableName string) (string, bool) { return fmt.Sprintf("DROP TABLE `%s`", tableName), false } -func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { +func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { var sql = "CREATE TABLE " if tableName == "" { tableName = table.Name @@ -585,7 +585,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) (string } sql = sql[:len(sql)-2] + ")" - return sql, false + return []string{sql}, false } func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { diff --git a/dialects/postgres.go b/dialects/postgres.go index 69100627..e393452f 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -893,7 +893,7 @@ func (db *postgres) AutoIncrStr() string { return "" } -func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { +func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { var sql string sql = "CREATE TABLE IF NOT EXISTS " if tableName == "" { @@ -928,7 +928,7 @@ func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) (stri } sql += ")" - return sql, true + return []string{sql}, true } func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 4bd3147d..710babe6 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -244,7 +244,7 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { +func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { var sql string sql = "CREATE TABLE IF NOT EXISTS " if tableName == "" { @@ -279,7 +279,7 @@ func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) (strin } sql += ")" - return sql, true + return []string{sql}, true } func (db *sqlite3) ForUpdateSQL(query string) string { diff --git a/engine.go b/engine.go index 8d77aeef..d94591e1 100644 --- a/engine.go +++ b/engine.go @@ -380,10 +380,12 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } } - s, _ := dialect.CreateTableSQL(table, "") - _, err = io.WriteString(w, s+";\n") - if err != nil { - return err + sqls, _ := dialect.CreateTableSQL(table, "") + for _, s := range sqls { + _, err = io.WriteString(w, s+";\n") + if err != nil { + return err + } } for _, index := range table.Indexes { _, err = io.WriteString(w, dialect.CreateIndexSQL(table.Name, index)+";\n") diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 0b670da5..e8675443 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -640,7 +640,7 @@ func (statement *Statement) genColumnStr() string { return buf.String() } -func (statement *Statement) GenCreateTableSQL() string { +func (statement *Statement) GenCreateTableSQL() []string { statement.RefTable.StoreEngine = statement.StoreEngine statement.RefTable.Charset = statement.Charset s, _ := statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName()) diff --git a/session_schema.go b/session_schema.go index 047be23d..ca4e2d75 100644 --- a/session_schema.go +++ b/session_schema.go @@ -37,9 +37,14 @@ func (session *Session) createTable(bean interface{}) error { return err } - sqlStr := session.statement.GenCreateTableSQL() - _, err := session.exec(sqlStr) - return err + sqlStrs := session.statement.GenCreateTableSQL() + for _, s := range sqlStrs { + _, err := session.exec(s) + if err != nil { + return err + } + } + return nil } // CreateIndexes create indexes