diff --git a/names/table_name.go b/names/table_name.go index 6dd4e552..0afb1ae3 100644 --- a/names/table_name.go +++ b/names/table_name.go @@ -6,6 +6,7 @@ package names import ( "reflect" + "sync" ) // TableName table name interface to define customerize table name @@ -15,23 +16,40 @@ type TableName interface { var ( tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() + tvCache sync.Map ) func GetTableName(mapper Mapper, v reflect.Value) string { - if t, ok := v.Interface().(TableName); ok { - return t.TableName() - } if v.Type().Implements(tpTableName) { return v.Interface().(TableName).TableName() } + if v.Kind() == reflect.Ptr { v = v.Elem() - if t, ok := v.Interface().(TableName); ok { - return t.TableName() - } if v.Type().Implements(tpTableName) { return v.Interface().(TableName).TableName() } + } else if v.CanAddr() { + v1 := v.Addr() + if v1.Type().Implements(tpTableName) { + return v1.Interface().(TableName).TableName() + } + } else { + name, ok := tvCache.Load(v.Type()) + if ok { + if name.(string) != "" { + return name.(string) + } + } else { + v2 := reflect.New(v.Type()) + if v2.Type().Implements(tpTableName) { + tableName := v2.Interface().(TableName).TableName() + tvCache.Store(v.Type(), tableName) + return tableName + } + + tvCache.Store(v.Type(), "") + } } return mapper.Obj2Table(v.Type().Name()) diff --git a/names/table_name_test.go b/names/table_name_test.go index 1f20bfaa..76da4135 100644 --- a/names/table_name_test.go +++ b/names/table_name_test.go @@ -5,6 +5,7 @@ package names import ( + "fmt" "reflect" "testing" "time" @@ -43,8 +44,10 @@ func (MyGetCustomTableImpletation) TableName() string { type TestTableNameStruct struct{} +const getTestTableName = "my_test_table_name_struct" + func (t *TestTableNameStruct) TableName() string { - return "my_test_table_name_struct" + return getTestTableName } func TestGetTableName(t *testing.T) { @@ -85,13 +88,18 @@ func TestGetTableName(t *testing.T) { }, { SnakeMapper{}, - reflect.ValueOf(MyGetCustomTableImpletation{}), - getCustomTableName, + reflect.ValueOf(new(TestTableNameStruct)), + new(TestTableNameStruct).TableName(), }, { SnakeMapper{}, reflect.ValueOf(new(TestTableNameStruct)), - new(TestTableNameStruct).TableName(), + getTestTableName, + }, + { + SnakeMapper{}, + reflect.ValueOf(TestTableNameStruct{}), + getTestTableName, }, } @@ -99,3 +107,34 @@ func TestGetTableName(t *testing.T) { assert.EqualValues(t, kase.expectedTableName, GetTableName(kase.mapper, kase.v)) } } + +type OAuth2Application struct { +} + +// TableName sets the table name to `oauth2_application` +func (app *OAuth2Application) TableName() string { + return "oauth2_application" +} + +func TestGonicMapperCustomTable(t *testing.T) { + assert.EqualValues(t, "oauth2_application", + GetTableName(LintGonicMapper, reflect.ValueOf(new(OAuth2Application)))) + assert.EqualValues(t, "oauth2_application", + GetTableName(LintGonicMapper, reflect.ValueOf(OAuth2Application{}))) +} + +type MyTable struct { + Idx int +} + +func (t *MyTable) TableName() string { + return fmt.Sprintf("mytable_%d", t.Idx) +} + +func TestMyTable(t *testing.T) { + var table MyTable + for i := 0; i < 10; i++ { + table.Idx = i + assert.EqualValues(t, fmt.Sprintf("mytable_%d", i), GetTableName(SameMapper{}, reflect.ValueOf(&table))) + } +}