diff --git a/dialects/dialect.go b/dialects/dialect.go index 5249e52c..26d6521a 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -68,6 +68,7 @@ type Dialect interface { CreateIndexSQL(tableName string, index *schemas.Index) string DropIndexSQL(tableName string, index *schemas.Index) string + AddColumnSQL(tableName string, col *schemas.Column) string ModifyColumnSQL(tableName string, col *schemas.Column) string ForUpdateSQL(query string) string @@ -94,55 +95,6 @@ type Base struct { uri *URI } -// String generate column description string according dialect -func String(d Dialect, col *schemas.Column) string { - sql := d.Quoter().Quote(col.Name) + " " - - sql += d.SQLType(col) + " " - - if col.IsPrimaryKey { - sql += "PRIMARY KEY " - if col.IsAutoIncrement { - sql += d.AutoIncrStr() + " " - } - } - - if col.Default != "" { - sql += "DEFAULT " + col.Default + " " - } - - if d.ShowCreateNull() { - if col.Nullable { - sql += "NULL " - } else { - sql += "NOT NULL " - } - } - - return sql -} - -// StringNoPk generate column description string according dialect without primary keys -func StringNoPk(d Dialect, col *schemas.Column) string { - sql := d.Quoter().Quote(col.Name) + " " - - sql += d.SQLType(col) + " " - - if col.Default != "" { - sql += "DEFAULT " + col.Default + " " - } - - if d.ShowCreateNull() { - if col.Nullable { - sql += "NULL " - } else { - sql += "NOT NULL " - } - } - - return sql -} - func (b *Base) DB() *core.DB { return b.db } @@ -165,6 +117,55 @@ func (b *Base) DBType() DBType { return b.uri.DBType } +// String generate column description string according dialect +func (b *Base) String(col *schemas.Column) string { + sql := b.dialect.Quoter().Quote(col.Name) + " " + + sql += b.dialect.SQLType(col) + " " + + if col.IsPrimaryKey { + sql += "PRIMARY KEY " + if col.IsAutoIncrement { + sql += b.dialect.AutoIncrStr() + " " + } + } + + if col.Default != "" { + sql += "DEFAULT " + col.Default + " " + } + + if b.dialect.ShowCreateNull() { + if col.Nullable { + sql += "NULL " + } else { + sql += "NOT NULL " + } + } + + return sql +} + +// StringNoPk generate column description string according dialect without primary keys +func (b *Base) StringNoPk(col *schemas.Column) string { + sql := b.dialect.Quoter().Quote(col.Name) + " " + + sql += b.dialect.SQLType(col) + " " + + if col.Default != "" { + sql += "DEFAULT " + col.Default + " " + } + + if b.dialect.ShowCreateNull() { + if col.Nullable { + sql += "NULL " + } else { + sql += "NOT NULL " + } + } + + return sql +} + func (b *Base) FormatBytes(bs []byte) string { return fmt.Sprintf("0x%x", bs) } @@ -222,29 +223,15 @@ func (db *Base) IsColumnExist(ctx context.Context, tableName, colName string) (b return db.HasRecords(ctx, query, db.uri.DBName, tableName, colName) } -/* -func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error { - sql, args := db.dialect.TableCheckSQL(tableName) - rows, err := db.DB().Query(sql, args...) - if db.Logger != nil { - db.Logger.Info("[sql]", sql, args) +func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { + quoter := db.dialect.Quoter() + sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), + db.String(col)) + if db.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 { + sql += " COMMENT '" + col.Comment + "'" } - if err != nil { - return err - } - defer rows.Close() - - if rows.Next() { - return nil - } - - sql = db.dialect.CreateTableSQL(table, tableName, storeEngine, charset) - _, err = db.DB().Exec(sql) - if db.Logger != nil { - db.Logger.Info("[sql]", sql) - } - return err -}*/ + return sql +} func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { quoter := db.dialect.Quoter() @@ -271,7 +258,7 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { } func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { - return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, StringNoPk(db.dialect, col)) + return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, db.StringNoPk(col)) } func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string { @@ -291,12 +278,12 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) if col.IsPrimaryKey && len(pkList) == 1 { - sql += String(b.dialect, col) + sql += b.String(col) } else { - sql += StringNoPk(b.dialect, col) + sql += b.StringNoPk(col) } sql = strings.TrimSpace(sql) - if b.DriverName() == schemas.MYSQL && len(col.Comment) > 0 { + if b.DBType() == schemas.MYSQL && len(col.Comment) > 0 { sql += " COMMENT '" + col.Comment + "'" } sql += ", " diff --git a/dialects/mssql.go b/dialects/mssql.go index c6046676..83844f4e 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -511,9 +511,9 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) if col.IsPrimaryKey && len(pkList) == 1 { - sql += String(db, col) + sql += db.String(col) } else { - sql += StringNoPk(db, col) + sql += db.StringNoPk(col) } sql = strings.TrimSpace(sql) sql += ", " diff --git a/dialects/mysql.go b/dialects/mysql.go index cda2543c..62fc6eb1 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -524,9 +524,9 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) if col.IsPrimaryKey && len(pkList) == 1 { - sql += String(db, col) + sql += db.String(col) } else { - sql += StringNoPk(db, col) + sql += db.StringNoPk(col) } sql = strings.TrimSpace(sql) if len(col.Comment) > 0 { diff --git a/dialects/oracle.go b/dialects/oracle.go index 5b157727..1247d7a4 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -593,7 +593,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c /*if col.IsPrimaryKey && len(pkList) == 1 { sql += col.String(b.dialect) } else {*/ - sql += StringNoPk(db, col) + sql += db.StringNoPk(col) // } sql = strings.TrimSpace(sql) sql += ", " diff --git a/session_schema.go b/session_schema.go index 7ed61b68..05b24c91 100644 --- a/session_schema.go +++ b/session_schema.go @@ -201,8 +201,8 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo func (session *Session) addColumn(colName string) error { col := session.statement.RefTable.GetColumn(colName) - sql, args := session.statement.genAddColumnStr(col) - _, err := session.exec(sql, args...) + sql := session.statement.dialect.AddColumnSQL(session.statement.TableName(), col) + _, err := session.exec(sql) return err } diff --git a/statement.go b/statement.go index c07ddfe9..b1593621 100644 --- a/statement.go +++ b/statement.go @@ -902,17 +902,6 @@ func (statement *Statement) genDelIndexSQL() []string { return sqls } -func (statement *Statement) genAddColumnStr(col *schemas.Column) (string, []interface{}) { - quote := statement.quote - sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()), - dialects.String(statement.dialect, col)) - if statement.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 { - sql += " COMMENT '" + col.Comment + "'" - } - sql += ";" - return sql, []interface{}{} -} - func (statement *Statement) buildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)