diff --git a/dialects/dialect.go b/dialects/dialect.go index fc11eac1..b6c0853a 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -103,6 +103,39 @@ func (db *Base) URI() *URI { return db.uri } +// CreateTableSQL implements Dialect +func (db *Base) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { + if tableName == "" { + tableName = table.Name + } + + quoter := db.dialect.Quoter() + var b strings.Builder + b.WriteString("CREATE TABLE IF NOT EXISTS ") + quoter.QuoteTo(&b, tableName) + b.WriteString(" (") + + for i, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + b.WriteString(s) + + if i != len(table.ColumnsSeq())-1 { + b.WriteString(", ") + } + } + + if len(table.PrimaryKeys) > 1 { + b.WriteString(", PRIMARY KEY (") + b.WriteString(quoter.Join(table.PrimaryKeys, ",")) + b.WriteString(")") + } + + b.WriteString(")") + + return []string{b.String()}, false +} + // DropTableSQL returns drop table SQL func (db *Base) DropTableSQL(tableName string) (string, bool) { quote := db.dialect.Quoter().Quote diff --git a/dialects/mssql.go b/dialects/mssql.go index 2121e71d..ab010eb0 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -626,34 +626,37 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { - var sql string if tableName == "" { tableName = table.Name } - sql = "IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = '" + tableName + "' ) CREATE TABLE " + quoter := db.dialect.Quoter() + var b strings.Builder + b.WriteString("IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = '") + quoter.QuoteTo(&b, tableName) + b.WriteString("' ) CREATE TABLE ") + quoter.QuoteTo(&b, tableName) + b.WriteString(" (") - sql += db.Quoter().Quote(tableName) + " (" - - pkList := table.PrimaryKeys - - for _, colName := range table.ColumnsSeq() { + for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) - sql += s - sql = strings.TrimSpace(sql) - sql += ", " + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + b.WriteString(s) + + if i != len(table.ColumnsSeq())-1 { + b.WriteString(", ") + } } - if len(pkList) > 1 { - sql += "PRIMARY KEY ( " - sql += strings.Join(pkList, ",") - sql += " ), " + if len(table.PrimaryKeys) > 1 { + b.WriteString(", PRIMARY KEY (") + b.WriteString(quoter.Join(table.PrimaryKeys, ",")) + b.WriteString(")") } - sql = sql[:len(sql)-2] + ")" - sql += ";" - return []string{sql}, true + b.WriteString(")") + + return []string{b.String()}, true } func (db *mssql) ForUpdateSQL(query string) string { diff --git a/dialects/mysql.go b/dialects/mysql.go index 21128527..0489904a 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -626,42 +626,43 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName } func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { - var sql = "CREATE TABLE IF NOT EXISTS " if tableName == "" { tableName = table.Name } - quoter := db.Quoter() + quoter := db.dialect.Quoter() + var b strings.Builder + b.WriteString("CREATE TABLE IF NOT EXISTS ") + quoter.QuoteTo(&b, tableName) + b.WriteString(" (") - sql += quoter.Quote(tableName) - sql += " (" + for i, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + b.WriteString(s) - if len(table.ColumnsSeq()) > 0 { - pkList := table.PrimaryKeys - - for _, colName := range table.ColumnsSeq() { - col := table.GetColumn(colName) - s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) - sql += s - sql = strings.TrimSpace(sql) - if len(col.Comment) > 0 { - sql += " COMMENT '" + col.Comment + "'" - } - sql += ", " + if len(col.Comment) > 0 { + b.WriteString(" COMMENT '") + b.WriteString(col.Comment) + b.WriteString("'") } - if len(pkList) > 1 { - sql += "PRIMARY KEY ( " - sql += quoter.Join(pkList, ",") - sql += " ), " + if i != len(table.ColumnsSeq())-1 { + b.WriteString(", ") } - - sql = sql[:len(sql)-2] } - sql += ")" + + if len(table.PrimaryKeys) > 1 { + b.WriteString(", PRIMARY KEY (") + b.WriteString(quoter.Join(table.PrimaryKeys, ",")) + b.WriteString(")") + } + + b.WriteString(")") if table.StoreEngine != "" { - sql += " ENGINE=" + table.StoreEngine + b.WriteString(" ENGINE=") + b.WriteString(table.StoreEngine) } var charset = table.Charset @@ -669,13 +670,15 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin charset = db.URI().Charset } if len(charset) != 0 { - sql += " DEFAULT CHARSET " + charset + b.WriteString(" DEFAULT CHARSET ") + b.WriteString(charset) } if db.rowFormat != "" { - sql += " ROW_FORMAT=" + db.rowFormat + b.WriteString(" ROW_FORMAT=") + b.WriteString(db.rowFormat) } - return []string{sql}, true + return []string{b.String()}, true } func (db *mysql) Filters() []Filter { diff --git a/dialects/postgres.go b/dialects/postgres.go index 96ebfc85..6b5a8b2f 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -965,41 +965,6 @@ func (db *postgres) AutoIncrStr() string { return "" } -func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { - var sql string - sql = "CREATE TABLE IF NOT EXISTS " - if tableName == "" { - tableName = table.Name - } - - quoter := db.Quoter() - sql += quoter.Quote(tableName) - sql += " (" - - if len(table.ColumnsSeq()) > 0 { - pkList := table.PrimaryKeys - - for _, colName := range table.ColumnsSeq() { - col := table.GetColumn(colName) - s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) - sql += s - sql = strings.TrimSpace(sql) - sql += ", " - } - - if len(pkList) > 1 { - sql += "PRIMARY KEY ( " - sql += quoter.Join(pkList, ",") - sql += " ), " - } - - sql = sql[:len(sql)-2] - } - sql += ")" - - return []string{sql}, true -} - func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { if len(db.getSchema()) == 0 { args := []interface{}{tableName, idxName} diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index ac17fd92..4eba8dad 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -285,41 +285,6 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { - var sql string - sql = "CREATE TABLE IF NOT EXISTS " - if tableName == "" { - tableName = table.Name - } - - quoter := db.Quoter() - sql += quoter.Quote(tableName) - sql += " (" - - if len(table.ColumnsSeq()) > 0 { - pkList := table.PrimaryKeys - - for _, colName := range table.ColumnsSeq() { - col := table.GetColumn(colName) - s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) - sql += s - sql = strings.TrimSpace(sql) - sql += ", " - } - - if len(pkList) > 1 { - sql += "PRIMARY KEY ( " - sql += quoter.Join(pkList, ",") - sql += " ), " - } - - sql = sql[:len(sql)-2] - } - sql += ")" - - return []string{sql}, true -} - func (db *sqlite3) ForUpdateSQL(query string) string { return query }