return sqls for create table (#1580)

return sqls for create table

Reviewed-on: https://gitea.com/xorm/xorm/pulls/1580
This commit is contained in:
Lunny Xiao 2020-03-07 12:06:28 +00:00
parent ccf65397e8
commit 257653726e
9 changed files with 26 additions and 19 deletions

View File

@ -51,7 +51,7 @@ type Dialect interface {
GetTables(ctx context.Context) ([]*schemas.Table, error) GetTables(ctx context.Context) ([]*schemas.Table, error)
IsTableExist(ctx context.Context, tableName string) (bool, error) IsTableExist(ctx context.Context, tableName string) (bool, error)
CreateTableSQL(table *schemas.Table, tableName string) (string, bool) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool)
DropTableSQL(tableName string) (string, bool) DropTableSQL(tableName string) (string, bool)
GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)

View File

@ -486,7 +486,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
return indexes, nil return indexes, nil
} }
func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
var sql string var sql string
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
@ -517,7 +517,7 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) (string,
sql = sql[:len(sql)-2] + ")" sql = sql[:len(sql)-2] + ")"
sql += ";" sql += ";"
return sql, true return []string{sql}, true
} }
func (db *mssql) ForUpdateSQL(query string) string { func (db *mssql) ForUpdateSQL(query string) string {

View File

@ -507,7 +507,7 @@ func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*
return indexes, nil return indexes, nil
} }
func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
var sql = "CREATE TABLE IF NOT EXISTS " var sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
@ -560,7 +560,7 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) (string,
if db.rowFormat != "" { if db.rowFormat != "" {
sql += " ROW_FORMAT=" + db.rowFormat sql += " ROW_FORMAT=" + db.rowFormat
} }
return sql, true return []string{sql}, true
} }
func (db *mysql) Filters() []Filter { func (db *mysql) Filters() []Filter {

View File

@ -556,7 +556,7 @@ func (db *oracle) DropTableSQL(tableName string) (string, bool) {
return fmt.Sprintf("DROP TABLE `%s`", tableName), false return fmt.Sprintf("DROP TABLE `%s`", tableName), false
} }
func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
var sql = "CREATE TABLE " var sql = "CREATE TABLE "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
@ -585,7 +585,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) (string
} }
sql = sql[:len(sql)-2] + ")" sql = sql[:len(sql)-2] + ")"
return sql, false return []string{sql}, false
} }
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {

View File

@ -893,7 +893,7 @@ func (db *postgres) AutoIncrStr() string {
return "" return ""
} }
func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
var sql string var sql string
sql = "CREATE TABLE IF NOT EXISTS " sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" { if tableName == "" {
@ -928,7 +928,7 @@ func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) (stri
} }
sql += ")" sql += ")"
return sql, true return []string{sql}, true
} }
func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {

View File

@ -244,7 +244,7 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName))
} }
func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) { func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
var sql string var sql string
sql = "CREATE TABLE IF NOT EXISTS " sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" { if tableName == "" {
@ -279,7 +279,7 @@ func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) (strin
} }
sql += ")" sql += ")"
return sql, true return []string{sql}, true
} }
func (db *sqlite3) ForUpdateSQL(query string) string { func (db *sqlite3) ForUpdateSQL(query string) string {

View File

@ -380,10 +380,12 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return err return err
} }
} }
s, _ := dialect.CreateTableSQL(table, "") sqls, _ := dialect.CreateTableSQL(table, "")
_, err = io.WriteString(w, s+";\n") for _, s := range sqls {
if err != nil { _, err = io.WriteString(w, s+";\n")
return err if err != nil {
return err
}
} }
for _, index := range table.Indexes { for _, index := range table.Indexes {
_, err = io.WriteString(w, dialect.CreateIndexSQL(table.Name, index)+";\n") _, err = io.WriteString(w, dialect.CreateIndexSQL(table.Name, index)+";\n")

View File

@ -640,7 +640,7 @@ func (statement *Statement) genColumnStr() string {
return buf.String() return buf.String()
} }
func (statement *Statement) GenCreateTableSQL() string { func (statement *Statement) GenCreateTableSQL() []string {
statement.RefTable.StoreEngine = statement.StoreEngine statement.RefTable.StoreEngine = statement.StoreEngine
statement.RefTable.Charset = statement.Charset statement.RefTable.Charset = statement.Charset
s, _ := statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName()) s, _ := statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName())

View File

@ -37,9 +37,14 @@ func (session *Session) createTable(bean interface{}) error {
return err return err
} }
sqlStr := session.statement.GenCreateTableSQL() sqlStrs := session.statement.GenCreateTableSQL()
_, err := session.exec(sqlStr) for _, s := range sqlStrs {
return err _, err := session.exec(s)
if err != nil {
return err
}
}
return nil
} }
// CreateIndexes create indexes // CreateIndexes create indexes