From aea91cc7dede96f071514cea007bd792613ae748 Mon Sep 17 00:00:00 2001 From: fanybook Date: Fri, 12 Nov 2021 20:58:05 +0800 Subject: [PATCH] =?UTF-8?q?add=20table=20&=20column=20comment=20for=20post?= =?UTF-8?q?gres=EF=BC=88add=20table=20comment=20for=20mysql=EF=BC=89=20(#2?= =?UTF-8?q?067)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 让 postgres 支持字段注释,只在 v1.2.5 上测试过(不知道怎么 import master) 发现 master 分支好像大改了?模式表名带 schema 了 使用方式和 mysql 相同 ```go type User struct { Id int64 `xorm:"pk autoincr"` Name string `json:"name" xorm:"not null default '' varchar(50) index(name_age) comment('用户 (it''s) 1; 名')"` Salt string Age int `json:"age" xorm:"not null default 0 int(10) index(name_age) comment('年龄')"` Passwd string `xorm:"varchar(200)"` CreatedAt time.Time `xorm:"created"` UpdatedAt time.Time `xorm:"updated"` } _ = engine.Sync(new(User)) func (model User) TableComment() string { return "表注释" } ``` Co-authored-by: fanybook Co-authored-by: Lunny Xiao Reviewed-on: https://gitea.com/xorm/xorm/pulls/2067 Co-authored-by: fanybook Co-committed-by: fanybook --- dialects/dialect.go | 2 +- dialects/mysql.go | 6 +++++ dialects/postgres.go | 54 ++++++++++++++++++++++++++++++++++++++++---- names/table_name.go | 47 ++++++++++++++++++++++++++++++++++++-- session_schema.go | 2 ++ tags/parser.go | 1 + tags/parser_test.go | 44 ++++++++++++++++++++++++++++++++++++ 7 files changed, 148 insertions(+), 8 deletions(-) diff --git a/dialects/dialect.go b/dialects/dialect.go index f3aa7470..2d772411 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -206,7 +206,7 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa // AddColumnSQL returns a SQL to add a column func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { s, _ := ColumnString(db.dialect, col, true) - return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName), s) + return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s) } // CreateIndexSQL returns a SQL to create index diff --git a/dialects/mysql.go b/dialects/mysql.go index ce3bd705..9cc695ef 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -685,6 +685,12 @@ func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table b.WriteString(db.rowFormat) } + if table.Comment != "" { + b.WriteString(" COMMENT='") + b.WriteString(table.Comment) + b.WriteString("'") + } + return b.String(), true, nil } diff --git a/dialects/postgres.go b/dialects/postgres.go index 1e99cd9d..3595f6c5 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -991,13 +991,37 @@ func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tabl db.getSchema(), tableName) } -func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string { +func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string { + s, _ := ColumnString(db.dialect, col, true) + + quoter := db.dialect.Quoter() + addColumnSQL := "" + commentSQL := "; " if len(db.getSchema()) == 0 || strings.Contains(tableName, ".") { - return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", - db.quoter.Quote(tableName), db.quoter.Quote(col.Name), db.SQLType(col)) + addColumnSQL = fmt.Sprintf("ALTER TABLE %s ADD %s", quoter.Quote(tableName), s) + commentSQL += fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoter.Quote(tableName), quoter.Quote(col.Name), col.Comment) + return addColumnSQL + commentSQL } - return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", - db.quoter.Quote(db.getSchema()), db.quoter.Quote(tableName), db.quoter.Quote(col.Name), db.SQLType(col)) + + addColumnSQL = fmt.Sprintf("ALTER TABLE %s.%s ADD %s", quoter.Quote(db.getSchema()), quoter.Quote(tableName), s) + commentSQL += fmt.Sprintf("COMMENT ON COLUMN %s.%s.%s IS '%s'", quoter.Quote(db.getSchema()), quoter.Quote(tableName), quoter.Quote(col.Name), col.Comment) + return addColumnSQL + commentSQL +} + +func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string { + quoter := db.dialect.Quoter() + modifyColumnSQL := "" + commentSQL := "; " + + if len(db.getSchema()) == 0 || strings.Contains(tableName, ".") { + modifyColumnSQL = fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s TYPE %s", quoter.Quote(tableName), quoter.Quote(col.Name), db.SQLType(col)) + commentSQL += fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoter.Quote(tableName), quoter.Quote(col.Name), col.Comment) + return modifyColumnSQL + commentSQL + } + + modifyColumnSQL = fmt.Sprintf("ALTER TABLE %s.%s ALTER COLUMN %s TYPE %s", quoter.Quote(db.getSchema()), quoter.Quote(tableName), quoter.Quote(col.Name), db.SQLType(col)) + commentSQL += fmt.Sprintf("COMMENT ON COLUMN %s.%s.%s IS '%s'", quoter.Quote(db.getSchema()), quoter.Quote(tableName), quoter.Quote(col.Name), col.Comment) + return modifyColumnSQL + commentSQL } func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string { @@ -1302,6 +1326,26 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN return indexes, nil } +func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { + quoter := db.dialect.Quoter() + if len(db.getSchema()) != 0 && !strings.Contains(tableName, ".") { + tableName = fmt.Sprintf("%s.%s", db.getSchema(), tableName) + } + + createTableSQL, ok, err := db.Base.CreateTableSQL(ctx, queryer, table, tableName) + if err != nil { + return "", ok, err + } + + commentSql := "; " + if table.Comment != "" { + // support schema.table -> "schema"."table" + commentSql += fmt.Sprintf("COMMENT ON TABLE %s IS '%s'", quoter.Quote(tableName), table.Comment) + } + + return createTableSQL + commentSql, true, nil +} + func (db *postgres) Filters() []Filter { return []Filter{&SeqFilter{Prefix: "$", Start: 1}} } diff --git a/names/table_name.go b/names/table_name.go index cc0e9274..d7d71b51 100644 --- a/names/table_name.go +++ b/names/table_name.go @@ -14,9 +14,15 @@ type TableName interface { TableName() string } +type TableComment interface { + TableComment() string +} + var ( - tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() - tvCache sync.Map + tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() + tpTableComment = reflect.TypeOf((*TableComment)(nil)).Elem() + tvCache sync.Map + tcCache sync.Map ) // GetTableName returns table name @@ -55,3 +61,40 @@ func GetTableName(mapper Mapper, v reflect.Value) string { return mapper.Obj2Table(v.Type().Name()) } + +// GetTableComment returns table comment +func GetTableComment(v reflect.Value) string { + if v.Type().Implements(tpTableComment) { + return v.Interface().(TableComment).TableComment() + } + + if v.Kind() == reflect.Ptr { + v = v.Elem() + if v.Type().Implements(tpTableComment) { + return v.Interface().(TableComment).TableComment() + } + } else if v.CanAddr() { + v1 := v.Addr() + if v1.Type().Implements(tpTableComment) { + return v1.Interface().(TableComment).TableComment() + } + } else { + comment, ok := tcCache.Load(v.Type()) + if ok { + if comment.(string) != "" { + return comment.(string) + } + } else { + v2 := reflect.New(v.Type()) + if v2.Type().Implements(tpTableComment) { + tableComment := v2.Interface().(TableComment).TableComment() + tcCache.Store(v.Type(), tableComment) + return tableComment + } + + tcCache.Store(v.Type(), "") + } + } + + return "" +} diff --git a/session_schema.go b/session_schema.go index e9ed9ec5..75140426 100644 --- a/session_schema.go +++ b/session_schema.go @@ -394,6 +394,8 @@ func (session *Session) Sync(beans ...interface{}) error { _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } } + } else if col.Comment != oriCol.Comment { + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } if col.Default != oriCol.Default { diff --git a/tags/parser.go b/tags/parser.go index efee11e7..83026862 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -316,6 +316,7 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { table := schemas.NewEmptyTable() table.Type = t table.Name = names.GetTableName(parser.tableMapper, v) + table.Comment = names.GetTableComment(v) for i := 0; i < t.NumField(); i++ { col, err := parser.parseField(table, i, t.Field(i), v.Field(i)) diff --git a/tags/parser_test.go b/tags/parser_test.go index 70c57692..83c81a1e 100644 --- a/tags/parser_test.go +++ b/tags/parser_test.go @@ -26,6 +26,20 @@ func (p ParseTableName2) TableName() string { return "p_parseTableName" } +type ParseTableComment struct{} + +type ParseTableComment1 struct{} + +type ParseTableComment2 struct{} + +func (p ParseTableComment1) TableComment() string { + return "p_parseTableComment1" +} + +func (p *ParseTableComment2) TableComment() string { + return "p_parseTableComment2" +} + func TestParseTableName(t *testing.T) { parser := NewParser( "xorm", @@ -47,6 +61,36 @@ func TestParseTableName(t *testing.T) { assert.EqualValues(t, "p_parseTableName", table.Name) } +func TestParseTableComment(t *testing.T) { + parser := NewParser( + "xorm", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + + table, err := parser.Parse(reflect.ValueOf(new(ParseTableComment))) + assert.NoError(t, err) + assert.EqualValues(t, "", table.Comment) + + table, err = parser.Parse(reflect.ValueOf(new(ParseTableComment1))) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableComment1", table.Comment) + + table, err = parser.Parse(reflect.ValueOf(ParseTableComment1{})) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableComment1", table.Comment) + + table, err = parser.Parse(reflect.ValueOf(new(ParseTableComment2))) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableComment2", table.Comment) + + table, err = parser.Parse(reflect.ValueOf(ParseTableComment2{})) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableComment2", table.Comment) +} + func TestUnexportField(t *testing.T) { parser := NewParser( "xorm",