diff --git a/dialects/mssql.go b/dialects/mssql.go index cf81d6de..08bcfae0 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -216,10 +216,12 @@ type mssql struct { Base defaultVarchar string defaultChar string + defaultSchema string } func (db *mssql) Init(uri *URI) error { db.quoter = mssqlQuoter + db.defaultSchema = "dbo" return db.Base.Init(db, uri) } @@ -249,10 +251,21 @@ func (db *mssql) SetParams(params map[string]string) { } else { db.defaultChar = "CHAR" } + + if defaultSchema, ok := params["DEFAULT_SCHEMA"]; ok && defaultSchema != "" { + db.defaultSchema = defaultSchema + } else { + db.defaultSchema = "dbo" + } + } func (db *mssql) SQLType(c *schemas.Column) string { var res string + var defaultVarchar string = schemas.Varchar + if db.defaultVarchar != "" { + defaultVarchar = db.defaultVarchar + } switch t := c.SQLType.Name; t { case schemas.Bool: res = schemas.Bit @@ -285,7 +298,7 @@ func (db *mssql) SQLType(c *schemas.Column) string { case schemas.MediumInt: res = schemas.Int case schemas.Text, schemas.MediumText, schemas.TinyText, schemas.LongText, schemas.Json: - res = db.defaultVarchar + "(MAX)" + res = defaultVarchar + "(MAX)" case schemas.Double: res = schemas.Real case schemas.Uuid: @@ -303,7 +316,7 @@ func (db *mssql) SQLType(c *schemas.Column) string { res += "(MAX)" } case schemas.Varchar: - res = db.defaultVarchar + res = defaultVarchar if c.Length == -1 { res += "(MAX)" } @@ -567,12 +580,23 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin pkList := table.PrimaryKeys + var multiCommentSql []string + for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) sql += s sql = strings.TrimSpace(sql) sql += ", " + + if col.Comment != "" { + commentSql := fmt.Sprintf("EXEC sys.sp_addextendedproperty @name=N'MS_Description', "+ + "@value=N'%s' , "+ + "@level0type=N'SCHEMA',@level0name=N'%s', "+ + "@level1type=N'TABLE',@level1name=N'%s', "+ + "@level2type=N'COLUMN',@level2name=N'%s'; ", col.Comment, db.defaultSchema, tableName, col.FieldName) + multiCommentSql = append(multiCommentSql, commentSql) + } } if len(pkList) > 1 { @@ -583,7 +607,13 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin sql = sql[:len(sql)-2] + ")" sql += ";" - return []string{sql}, true + + multiSql := []string{sql} + if len(multiCommentSql) > 0 { + multiSql = append(multiSql, multiCommentSql...) + } + + return multiSql, true } func (db *mssql) ForUpdateSQL(query string) string {