diff --git a/names/table_name.go b/names/table_name.go index cc0e9274..32ec2961 100644 --- a/names/table_name.go +++ b/names/table_name.go @@ -14,8 +14,13 @@ type TableName interface { TableName() string } +type TableComment interface { + TableComment() string +} + var ( tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() + tpTableComment = reflect.TypeOf((*TableComment)(nil)).Elem() tvCache sync.Map ) @@ -55,3 +60,12 @@ func GetTableName(mapper Mapper, v reflect.Value) string { return mapper.Obj2Table(v.Type().Name()) } + +// GetTableComment returns table comment +func GetTableComment(v reflect.Value) string { + if v.Type().Implements(tpTableComment) { + return v.Interface().(TableComment).TableComment() + } + + return "" +} diff --git a/tags/parser.go b/tags/parser.go index 5f816cf3..83026862 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -90,14 +90,6 @@ func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) { return nil, err } - // if bean has Comment Method, then set table.Comment - if _, ok := t.MethodByName("Comment"); ok { - tableCommentFn := v.MethodByName("Comment") - if tableCommentFn.Type().String() == "func() string" { - table.Comment = fmt.Sprintf("%s", tableCommentFn.Call(nil)[0]) - } - } - parser.tableCache.Store(t, table) if parser.cacherMgr.GetDefaultCacher() != nil { @@ -324,6 +316,7 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { table := schemas.NewEmptyTable() table.Type = t table.Name = names.GetTableName(parser.tableMapper, v) + table.Comment = names.GetTableComment(v) for i := 0; i < t.NumField(); i++ { col, err := parser.parseField(table, i, t.Field(i), v.Field(i)) diff --git a/tags/parser_test.go b/tags/parser_test.go index 70c57692..b43a7af1 100644 --- a/tags/parser_test.go +++ b/tags/parser_test.go @@ -26,6 +26,10 @@ func (p ParseTableName2) TableName() string { return "p_parseTableName" } +func (p ParseTableName2) TableComment() string { + return "p2_testTableComment" +} + func TestParseTableName(t *testing.T) { parser := NewParser( "xorm", @@ -47,6 +51,27 @@ func TestParseTableName(t *testing.T) { assert.EqualValues(t, "p_parseTableName", table.Name) } +func TestParseTableComment(t *testing.T) { + parser := NewParser( + "xorm", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + table, err := parser.Parse(reflect.ValueOf(new(ParseTableName1))) + assert.NoError(t, err) + assert.EqualValues(t, "", table.Comment) + + table, err = parser.Parse(reflect.ValueOf(new(ParseTableName2))) + assert.NoError(t, err) + assert.EqualValues(t, "p2_testTableComment", table.Comment) + + table, err = parser.Parse(reflect.ValueOf(ParseTableName2{})) + assert.NoError(t, err) + assert.EqualValues(t, "p2_testTableComment", table.Comment) +} + func TestUnexportField(t *testing.T) { parser := NewParser( "xorm",