More improvements

This commit is contained in:
Lunny Xiao 2020-02-27 21:32:59 +08:00
parent 28a2613fab
commit 063c39001b
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
6 changed files with 69 additions and 93 deletions

View File

@ -68,6 +68,7 @@ type Dialect interface {
CreateIndexSQL(tableName string, index *schemas.Index) string
DropIndexSQL(tableName string, index *schemas.Index) string
AddColumnSQL(tableName string, col *schemas.Column) string
ModifyColumnSQL(tableName string, col *schemas.Column) string
ForUpdateSQL(query string) string
@ -94,55 +95,6 @@ type Base struct {
uri *URI
}
// String generate column description string according dialect
func String(d Dialect, col *schemas.Column) string {
sql := d.Quoter().Quote(col.Name) + " "
sql += d.SQLType(col) + " "
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
if col.IsAutoIncrement {
sql += d.AutoIncrStr() + " "
}
}
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
if d.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
}
return sql
}
// StringNoPk generate column description string according dialect without primary keys
func StringNoPk(d Dialect, col *schemas.Column) string {
sql := d.Quoter().Quote(col.Name) + " "
sql += d.SQLType(col) + " "
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
if d.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
}
return sql
}
func (b *Base) DB() *core.DB {
return b.db
}
@ -165,6 +117,55 @@ func (b *Base) DBType() DBType {
return b.uri.DBType
}
// String generate column description string according dialect
func (b *Base) String(col *schemas.Column) string {
sql := b.dialect.Quoter().Quote(col.Name) + " "
sql += b.dialect.SQLType(col) + " "
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
if col.IsAutoIncrement {
sql += b.dialect.AutoIncrStr() + " "
}
}
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
if b.dialect.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
}
return sql
}
// StringNoPk generate column description string according dialect without primary keys
func (b *Base) StringNoPk(col *schemas.Column) string {
sql := b.dialect.Quoter().Quote(col.Name) + " "
sql += b.dialect.SQLType(col) + " "
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
if b.dialect.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
}
return sql
}
func (b *Base) FormatBytes(bs []byte) string {
return fmt.Sprintf("0x%x", bs)
}
@ -222,29 +223,15 @@ func (db *Base) IsColumnExist(ctx context.Context, tableName, colName string) (b
return db.HasRecords(ctx, query, db.uri.DBName, tableName, colName)
}
/*
func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error {
sql, args := db.dialect.TableCheckSQL(tableName)
rows, err := db.DB().Query(sql, args...)
if db.Logger != nil {
db.Logger.Info("[sql]", sql, args)
func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
quoter := db.dialect.Quoter()
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName),
db.String(col))
if db.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
if err != nil {
return err
}
defer rows.Close()
if rows.Next() {
return nil
}
sql = db.dialect.CreateTableSQL(table, tableName, storeEngine, charset)
_, err = db.DB().Exec(sql)
if db.Logger != nil {
db.Logger.Info("[sql]", sql)
}
return err
}*/
return sql
}
func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
quoter := db.dialect.Quoter()
@ -271,7 +258,7 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
}
func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, StringNoPk(db.dialect, col))
return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, db.StringNoPk(col))
}
func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
@ -291,12 +278,12 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += String(b.dialect, col)
sql += b.String(col)
} else {
sql += StringNoPk(b.dialect, col)
sql += b.StringNoPk(col)
}
sql = strings.TrimSpace(sql)
if b.DriverName() == schemas.MYSQL && len(col.Comment) > 0 {
if b.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ", "

View File

@ -511,9 +511,9 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += String(db, col)
sql += db.String(col)
} else {
sql += StringNoPk(db, col)
sql += db.StringNoPk(col)
}
sql = strings.TrimSpace(sql)
sql += ", "

View File

@ -524,9 +524,9 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += String(db, col)
sql += db.String(col)
} else {
sql += StringNoPk(db, col)
sql += db.StringNoPk(col)
}
sql = strings.TrimSpace(sql)
if len(col.Comment) > 0 {

View File

@ -593,7 +593,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c
/*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect)
} else {*/
sql += StringNoPk(db, col)
sql += db.StringNoPk(col)
// }
sql = strings.TrimSpace(sql)
sql += ", "

View File

@ -201,8 +201,8 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo
func (session *Session) addColumn(colName string) error {
col := session.statement.RefTable.GetColumn(colName)
sql, args := session.statement.genAddColumnStr(col)
_, err := session.exec(sql, args...)
sql := session.statement.dialect.AddColumnSQL(session.statement.TableName(), col)
_, err := session.exec(sql)
return err
}

View File

@ -902,17 +902,6 @@ func (statement *Statement) genDelIndexSQL() []string {
return sqls
}
func (statement *Statement) genAddColumnStr(col *schemas.Column) (string, []interface{}) {
quote := statement.quote
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
dialects.String(statement.dialect, col))
if statement.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ";"
return sql, []interface{}{}
}
func (statement *Statement) buildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)