Support for belongs to relationships

This commit is contained in:
Martin Rodriguez Reboredo 2024-02-09 13:48:13 -03:00
parent e884f059a4
commit 8826dcbb1d
No known key found for this signature in database
GPG Key ID: 85883E1B6A5B51C7
9 changed files with 120 additions and 0 deletions

View File

@ -147,6 +147,13 @@ func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
b.WriteString(")")
}
for _, col := range table.FKColumns() {
b.WriteString(", FOREIGN KEY (")
b.WriteString(col.Name)
b.WriteString(") REFERENCES ")
b.WriteString(col.Reference)
}
b.WriteString(")")
return b.String(), false, nil

View File

@ -663,6 +663,11 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true)
b.WriteString(s)
if col.Reference != "" {
b.WriteString(" FOREIGN KEY REFERENCES ")
b.WriteString(col.Reference)
}
if i != len(table.ColumnsSeq())-1 {
b.WriteString(", ")
}

View File

@ -683,6 +683,13 @@ func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
b.WriteString(")")
}
for _, col := range table.FKColumns() {
b.WriteString(", FOREIGN KEY (")
b.WriteString(col.Name)
b.WriteString(") REFERENCES ")
b.WriteString(col.Reference)
}
b.WriteString(")")
if table.StoreEngine != "" {

View File

@ -644,6 +644,9 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
sql += s
// }
sql = strings.TrimSpace(sql)
if col.Reference != "" {
sql += " FOREIGN KEY REFERENCES "+col.Reference
}
sql += ", "
}

View File

@ -30,6 +30,7 @@ type Column struct {
Length2 int64
Nullable bool
Default string
Reference string
Indexes map[string]int
IsPrimaryKey bool
IsAutoIncrement bool
@ -60,6 +61,7 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int64, nullab
Length2: len2,
Nullable: nullable,
Default: "",
Reference: "",
Indexes: make(map[string]int),
IsPrimaryKey: false,
IsAutoIncrement: false,

View File

@ -19,6 +19,7 @@ type Table struct {
columns []*Column
Indexes map[string]*Index
PrimaryKeys []string
ForeignKeys []string
AutoIncrement string
Created map[string]bool
Updated string
@ -45,6 +46,7 @@ func NewTable(name string, t reflect.Type) *Table {
Indexes: make(map[string]*Index),
Created: make(map[string]bool),
PrimaryKeys: make([]string, 0),
ForeignKeys: make([]string, 0),
}
}
@ -91,6 +93,15 @@ func (table *Table) PKColumns() []*Column {
return columns
}
// FKColumns reprents all foreign key columns
func (table *Table) FKColumns() []*Column {
columns := make([]*Column, len(table.ForeignKeys))
for i, name := range table.ForeignKeys {
columns[i] = table.GetColumn(name)
}
return columns
}
// ColumnType returns a column's type
func (table *Table) ColumnType(name string) reflect.Type {
t, _ := table.Type.FieldByName(name)
@ -131,6 +142,9 @@ func (table *Table) AddColumn(col *Column) {
if col.IsPrimaryKey {
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
}
if col.Reference != "" {
table.ForeignKeys = append(table.ForeignKeys, col.Name)
}
if col.IsAutoIncrement {
table.AutoIncrement = col.Name
}

View File

@ -534,6 +534,39 @@ func TestParseWithOnlyToDB(t *testing.T) {
assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType)
}
func TestParseWithBelongsTo(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type VanillaStruct struct {
ID int64
Name string
CreatedAt time.Time `db:"created"`
UpdatedAt time.Time `db:"updated"`
DeletedAt time.Time `db:"deleted"`
}
type StructWithBelongsTo struct {
Name string
Vanilla VanillaStruct `db:"belongsto"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithBelongsTo)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_belongs_to", table.Name)
assert.EqualValues(t, 2, len(table.Columns()))
assert.EqualValues(t, "name", table.Columns()[0].Name)
assert.EqualValues(t, "vanilla", table.Columns()[1].Name)
assert.EqualValues(t, "vanilla_struct(id)", table.Columns()[1].Reference)
assert.True(t, table.Columns()[0].Nullable)
assert.True(t, table.Columns()[1].Nullable)
}
func TestParseWithJSON(t *testing.T) {
parser := NewParser(
"db",

View File

@ -124,6 +124,8 @@ var defaultTagHandlers = map[string]Handler{
"EXTENDS": ExtendsTagHandler,
"UNSIGNED": UnsignedTagHandler,
"COLLATE": CollateTagHandler,
"BELONGSTO": BelongsToTagHandler,
}
func init() {
@ -269,6 +271,45 @@ func UniqueTagHandler(ctx *Context) error {
return nil
}
// BelongsToTagHandler describes belongs tag handler
func BelongsToTagHandler(ctx *Context) error {
fieldValue := ctx.fieldValue
colName := ""
if len(ctx.params) > 0 {
colName = ctx.params[0]
}
switch fieldValue.Kind() {
case reflect.Pointer:
f := fieldValue.Type().Elem()
if f.Kind() == reflect.Struct {
fieldPtr := fieldValue
fieldValue = fieldValue.Elem()
if !fieldValue.IsValid() || fieldPtr.IsNil() {
fieldValue = reflect.New(f).Elem()
}
}
fallthrough
case reflect.Struct:
parentTable, err := ctx.parser.ParseWithCache(fieldValue)
if err != nil {
return err
}
tableName := parentTable.Name
if len(ctx.params) > 1 {
tableName = ctx.params[1]
}
for _, col := range parentTable.Columns() {
if len(colName) == 0 && col.IsPrimaryKey || colName == col.Name {
ctx.col.Reference = tableName+"("+col.Name+")"
return nil
}
}
default:
// TODO: warning
}
return ErrIgnoreField
}
// UnsignedTagHandler represents the column is unsigned
func UnsignedTagHandler(ctx *Context) error {
ctx.isUnsigned = true

View File

@ -90,6 +90,14 @@ func TestSplitTag(t *testing.T) {
},
},
},
{
"belongsto(foo, bar)", []tag{
{
name: "belongsto",
params: []string{"foo", "bar"},
},
},
},
}
for _, kase := range cases {