diff --git a/names/table_name.go b/names/table_name.go index 5826a638..2af7e77a 100644 --- a/names/table_name.go +++ b/names/table_name.go @@ -22,6 +22,7 @@ var ( tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() tpTableComment = reflect.TypeOf((*TableComment)(nil)).Elem() tvCache sync.Map + tcCache sync.Map ) // GetTableName returns table name @@ -67,5 +68,33 @@ func GetTableComment(v reflect.Value) string { return v.Interface().(TableComment).TableComment() } + if v.Kind() == reflect.Ptr {//如果是指针 + v = v.Elem() + if v.Type().Implements(tpTableComment) { + return v.Interface().(TableComment).TableComment() + } + } else if v.CanAddr() {//如果可以用地址访问 + v1 := v.Addr() + if v1.Type().Implements(tpTableComment) { + return v1.Interface().(TableComment).TableComment() + } + } else { + comment, ok := tcCache.Load(v.Type()) + if ok { + if comment.(string) != "" { + return comment.(string) + } + } else { + v2 := reflect.New(v.Type()) + if v2.Type().Implements(tpTableComment) { + tableComment := v2.Interface().(TableComment).TableComment() + tcCache.Store(v.Type(), tableComment) + return tableComment + } + + tcCache.Store(v.Type(), "") + } + } + return "" } diff --git a/tags/parser_test.go b/tags/parser_test.go index b43a7af1..c65f4f2b 100644 --- a/tags/parser_test.go +++ b/tags/parser_test.go @@ -26,10 +26,21 @@ func (p ParseTableName2) TableName() string { return "p_parseTableName" } -func (p ParseTableName2) TableComment() string { - return "p2_testTableComment" +type ParseTableComment struct{} + +type ParseTableComment1 struct{} + +type ParseTableComment2 struct{} + +func (p ParseTableComment1) TableComment() string { + return "p_parseTableComment1" } +func (p *ParseTableComment2) TableComment() string { + return "p_parseTableComment2" +} + + func TestParseTableName(t *testing.T) { parser := NewParser( "xorm", @@ -59,17 +70,30 @@ func TestParseTableComment(t *testing.T) { names.SnakeMapper{}, caches.NewManager(), ) - table, err := parser.Parse(reflect.ValueOf(new(ParseTableName1))) + + table, err := parser.Parse(reflect.ValueOf(new(ParseTableComment))) + assert.NoError(t, err) + assert.EqualValues(t, "", table.Name) + + table, err = parser.Parse(reflect.ValueOf(new(ParseTableName1))) assert.NoError(t, err) assert.EqualValues(t, "", table.Comment) - table, err = parser.Parse(reflect.ValueOf(new(ParseTableName2))) + table, err = parser.Parse(reflect.ValueOf(new(ParseTableComment1))) assert.NoError(t, err) - assert.EqualValues(t, "p2_testTableComment", table.Comment) + assert.EqualValues(t, "p_parseTableComment1", table.Comment) - table, err = parser.Parse(reflect.ValueOf(ParseTableName2{})) + table, err = parser.Parse(reflect.ValueOf(ParseTableComment1{})) assert.NoError(t, err) - assert.EqualValues(t, "p2_testTableComment", table.Comment) + assert.EqualValues(t, "p_parseTableComment1", table.Comment) + + table, err = parser.Parse(reflect.ValueOf(new(ParseTableComment2))) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableComment2", table.Comment) + + table, err = parser.Parse(reflect.ValueOf(ParseTableComment2{})) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableComment2", table.Comment) } func TestUnexportField(t *testing.T) {