diff --git a/dialects/dialect.go b/dialects/dialect.go index 8e512c4f..d77d30a1 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -147,6 +147,13 @@ func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table b.WriteString(")") } + for _, col := range table.FKColumns() { + b.WriteString(", FOREIGN KEY (") + b.WriteString(col.Name) + b.WriteString(") REFERENCES ") + b.WriteString(col.Reference) + } + b.WriteString(")") return b.String(), false, nil diff --git a/dialects/mssql.go b/dialects/mssql.go index 13399ed2..f306c0ad 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -663,6 +663,11 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true) b.WriteString(s) + if col.Reference != "" { + b.WriteString(" FOREIGN KEY REFERENCES ") + b.WriteString(col.Reference) + } + if i != len(table.ColumnsSeq())-1 { b.WriteString(", ") } diff --git a/dialects/mysql.go b/dialects/mysql.go index 2c061a14..366b9d68 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -683,6 +683,13 @@ func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table b.WriteString(")") } + for _, col := range table.FKColumns() { + b.WriteString(", FOREIGN KEY (") + b.WriteString(col.Name) + b.WriteString(") REFERENCES ") + b.WriteString(col.Reference) + } + b.WriteString(")") if table.StoreEngine != "" { diff --git a/dialects/oracle.go b/dialects/oracle.go index ac0fb944..2f588468 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -644,6 +644,9 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl sql += s // } sql = strings.TrimSpace(sql) + if col.Reference != "" { + sql += " FOREIGN KEY REFERENCES "+col.Reference + } sql += ", " } diff --git a/schemas/column.go b/schemas/column.go index 08d34b91..721fd514 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -30,6 +30,7 @@ type Column struct { Length2 int64 Nullable bool Default string + Reference string Indexes map[string]int IsPrimaryKey bool IsAutoIncrement bool @@ -60,6 +61,7 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int64, nullab Length2: len2, Nullable: nullable, Default: "", + Reference: "", Indexes: make(map[string]int), IsPrimaryKey: false, IsAutoIncrement: false, diff --git a/schemas/table.go b/schemas/table.go index 5c38cc70..0f1872ad 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -19,6 +19,7 @@ type Table struct { columns []*Column Indexes map[string]*Index PrimaryKeys []string + ForeignKeys []string AutoIncrement string Created map[string]bool Updated string @@ -45,6 +46,7 @@ func NewTable(name string, t reflect.Type) *Table { Indexes: make(map[string]*Index), Created: make(map[string]bool), PrimaryKeys: make([]string, 0), + ForeignKeys: make([]string, 0), } } @@ -91,6 +93,15 @@ func (table *Table) PKColumns() []*Column { return columns } +// FKColumns reprents all foreign key columns +func (table *Table) FKColumns() []*Column { + columns := make([]*Column, len(table.ForeignKeys)) + for i, name := range table.ForeignKeys { + columns[i] = table.GetColumn(name) + } + return columns +} + // ColumnType returns a column's type func (table *Table) ColumnType(name string) reflect.Type { t, _ := table.Type.FieldByName(name) @@ -131,6 +142,9 @@ func (table *Table) AddColumn(col *Column) { if col.IsPrimaryKey { table.PrimaryKeys = append(table.PrimaryKeys, col.Name) } + if col.Reference != "" { + table.ForeignKeys = append(table.ForeignKeys, col.Name) + } if col.IsAutoIncrement { table.AutoIncrement = col.Name } diff --git a/tags/parser_test.go b/tags/parser_test.go index 434cfc07..70fbae99 100644 --- a/tags/parser_test.go +++ b/tags/parser_test.go @@ -534,6 +534,39 @@ func TestParseWithOnlyToDB(t *testing.T) { assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType) } +func TestParseWithBelongsTo(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type VanillaStruct struct { + ID int64 + Name string + CreatedAt time.Time `db:"created"` + UpdatedAt time.Time `db:"updated"` + DeletedAt time.Time `db:"deleted"` + } + + type StructWithBelongsTo struct { + Name string + Vanilla VanillaStruct `db:"belongsto"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithBelongsTo))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_belongs_to", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "vanilla", table.Columns()[1].Name) + assert.EqualValues(t, "vanilla_struct(id)", table.Columns()[1].Reference) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) +} + func TestParseWithJSON(t *testing.T) { parser := NewParser( "db", diff --git a/tags/tag.go b/tags/tag.go index 024c9c18..075a0381 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -124,6 +124,8 @@ var defaultTagHandlers = map[string]Handler{ "EXTENDS": ExtendsTagHandler, "UNSIGNED": UnsignedTagHandler, "COLLATE": CollateTagHandler, + + "BELONGSTO": BelongsToTagHandler, } func init() { @@ -269,6 +271,45 @@ func UniqueTagHandler(ctx *Context) error { return nil } +// BelongsToTagHandler describes belongs tag handler +func BelongsToTagHandler(ctx *Context) error { + fieldValue := ctx.fieldValue + colName := "" + if len(ctx.params) > 0 { + colName = ctx.params[0] + } + switch fieldValue.Kind() { + case reflect.Pointer: + f := fieldValue.Type().Elem() + if f.Kind() == reflect.Struct { + fieldPtr := fieldValue + fieldValue = fieldValue.Elem() + if !fieldValue.IsValid() || fieldPtr.IsNil() { + fieldValue = reflect.New(f).Elem() + } + } + fallthrough + case reflect.Struct: + parentTable, err := ctx.parser.ParseWithCache(fieldValue) + if err != nil { + return err + } + tableName := parentTable.Name + if len(ctx.params) > 1 { + tableName = ctx.params[1] + } + for _, col := range parentTable.Columns() { + if len(colName) == 0 && col.IsPrimaryKey || colName == col.Name { + ctx.col.Reference = tableName+"("+col.Name+")" + return nil + } + } + default: + // TODO: warning + } + return ErrIgnoreField +} + // UnsignedTagHandler represents the column is unsigned func UnsignedTagHandler(ctx *Context) error { ctx.isUnsigned = true diff --git a/tags/tag_test.go b/tags/tag_test.go index 6c456f2a..f9ec3acb 100644 --- a/tags/tag_test.go +++ b/tags/tag_test.go @@ -90,6 +90,14 @@ func TestSplitTag(t *testing.T) { }, }, }, + { + "belongsto(foo, bar)", []tag{ + { + name: "belongsto", + params: []string{"foo", "bar"}, + }, + }, + }, } for _, kase := range cases {