Support for belongs to relationships
This commit is contained in:
parent
e884f059a4
commit
8826dcbb1d
|
@ -147,6 +147,13 @@ func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
|
|||
b.WriteString(")")
|
||||
}
|
||||
|
||||
for _, col := range table.FKColumns() {
|
||||
b.WriteString(", FOREIGN KEY (")
|
||||
b.WriteString(col.Name)
|
||||
b.WriteString(") REFERENCES ")
|
||||
b.WriteString(col.Reference)
|
||||
}
|
||||
|
||||
b.WriteString(")")
|
||||
|
||||
return b.String(), false, nil
|
||||
|
|
|
@ -663,6 +663,11 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
|
|||
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true)
|
||||
b.WriteString(s)
|
||||
|
||||
if col.Reference != "" {
|
||||
b.WriteString(" FOREIGN KEY REFERENCES ")
|
||||
b.WriteString(col.Reference)
|
||||
}
|
||||
|
||||
if i != len(table.ColumnsSeq())-1 {
|
||||
b.WriteString(", ")
|
||||
}
|
||||
|
|
|
@ -683,6 +683,13 @@ func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
|
|||
b.WriteString(")")
|
||||
}
|
||||
|
||||
for _, col := range table.FKColumns() {
|
||||
b.WriteString(", FOREIGN KEY (")
|
||||
b.WriteString(col.Name)
|
||||
b.WriteString(") REFERENCES ")
|
||||
b.WriteString(col.Reference)
|
||||
}
|
||||
|
||||
b.WriteString(")")
|
||||
|
||||
if table.StoreEngine != "" {
|
||||
|
|
|
@ -644,6 +644,9 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
|
|||
sql += s
|
||||
// }
|
||||
sql = strings.TrimSpace(sql)
|
||||
if col.Reference != "" {
|
||||
sql += " FOREIGN KEY REFERENCES "+col.Reference
|
||||
}
|
||||
sql += ", "
|
||||
}
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ type Column struct {
|
|||
Length2 int64
|
||||
Nullable bool
|
||||
Default string
|
||||
Reference string
|
||||
Indexes map[string]int
|
||||
IsPrimaryKey bool
|
||||
IsAutoIncrement bool
|
||||
|
@ -60,6 +61,7 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int64, nullab
|
|||
Length2: len2,
|
||||
Nullable: nullable,
|
||||
Default: "",
|
||||
Reference: "",
|
||||
Indexes: make(map[string]int),
|
||||
IsPrimaryKey: false,
|
||||
IsAutoIncrement: false,
|
||||
|
|
|
@ -19,6 +19,7 @@ type Table struct {
|
|||
columns []*Column
|
||||
Indexes map[string]*Index
|
||||
PrimaryKeys []string
|
||||
ForeignKeys []string
|
||||
AutoIncrement string
|
||||
Created map[string]bool
|
||||
Updated string
|
||||
|
@ -45,6 +46,7 @@ func NewTable(name string, t reflect.Type) *Table {
|
|||
Indexes: make(map[string]*Index),
|
||||
Created: make(map[string]bool),
|
||||
PrimaryKeys: make([]string, 0),
|
||||
ForeignKeys: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,6 +93,15 @@ func (table *Table) PKColumns() []*Column {
|
|||
return columns
|
||||
}
|
||||
|
||||
// FKColumns reprents all foreign key columns
|
||||
func (table *Table) FKColumns() []*Column {
|
||||
columns := make([]*Column, len(table.ForeignKeys))
|
||||
for i, name := range table.ForeignKeys {
|
||||
columns[i] = table.GetColumn(name)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
// ColumnType returns a column's type
|
||||
func (table *Table) ColumnType(name string) reflect.Type {
|
||||
t, _ := table.Type.FieldByName(name)
|
||||
|
@ -131,6 +142,9 @@ func (table *Table) AddColumn(col *Column) {
|
|||
if col.IsPrimaryKey {
|
||||
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
|
||||
}
|
||||
if col.Reference != "" {
|
||||
table.ForeignKeys = append(table.ForeignKeys, col.Name)
|
||||
}
|
||||
if col.IsAutoIncrement {
|
||||
table.AutoIncrement = col.Name
|
||||
}
|
||||
|
|
|
@ -534,6 +534,39 @@ func TestParseWithOnlyToDB(t *testing.T) {
|
|||
assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType)
|
||||
}
|
||||
|
||||
func TestParseWithBelongsTo(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type VanillaStruct struct {
|
||||
ID int64
|
||||
Name string
|
||||
CreatedAt time.Time `db:"created"`
|
||||
UpdatedAt time.Time `db:"updated"`
|
||||
DeletedAt time.Time `db:"deleted"`
|
||||
}
|
||||
|
||||
type StructWithBelongsTo struct {
|
||||
Name string
|
||||
Vanilla VanillaStruct `db:"belongsto"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithBelongsTo)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_belongs_to", table.Name)
|
||||
assert.EqualValues(t, 2, len(table.Columns()))
|
||||
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||
assert.EqualValues(t, "vanilla", table.Columns()[1].Name)
|
||||
assert.EqualValues(t, "vanilla_struct(id)", table.Columns()[1].Reference)
|
||||
assert.True(t, table.Columns()[0].Nullable)
|
||||
assert.True(t, table.Columns()[1].Nullable)
|
||||
}
|
||||
|
||||
func TestParseWithJSON(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
|
|
41
tags/tag.go
41
tags/tag.go
|
@ -124,6 +124,8 @@ var defaultTagHandlers = map[string]Handler{
|
|||
"EXTENDS": ExtendsTagHandler,
|
||||
"UNSIGNED": UnsignedTagHandler,
|
||||
"COLLATE": CollateTagHandler,
|
||||
|
||||
"BELONGSTO": BelongsToTagHandler,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -269,6 +271,45 @@ func UniqueTagHandler(ctx *Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// BelongsToTagHandler describes belongs tag handler
|
||||
func BelongsToTagHandler(ctx *Context) error {
|
||||
fieldValue := ctx.fieldValue
|
||||
colName := ""
|
||||
if len(ctx.params) > 0 {
|
||||
colName = ctx.params[0]
|
||||
}
|
||||
switch fieldValue.Kind() {
|
||||
case reflect.Pointer:
|
||||
f := fieldValue.Type().Elem()
|
||||
if f.Kind() == reflect.Struct {
|
||||
fieldPtr := fieldValue
|
||||
fieldValue = fieldValue.Elem()
|
||||
if !fieldValue.IsValid() || fieldPtr.IsNil() {
|
||||
fieldValue = reflect.New(f).Elem()
|
||||
}
|
||||
}
|
||||
fallthrough
|
||||
case reflect.Struct:
|
||||
parentTable, err := ctx.parser.ParseWithCache(fieldValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tableName := parentTable.Name
|
||||
if len(ctx.params) > 1 {
|
||||
tableName = ctx.params[1]
|
||||
}
|
||||
for _, col := range parentTable.Columns() {
|
||||
if len(colName) == 0 && col.IsPrimaryKey || colName == col.Name {
|
||||
ctx.col.Reference = tableName+"("+col.Name+")"
|
||||
return nil
|
||||
}
|
||||
}
|
||||
default:
|
||||
// TODO: warning
|
||||
}
|
||||
return ErrIgnoreField
|
||||
}
|
||||
|
||||
// UnsignedTagHandler represents the column is unsigned
|
||||
func UnsignedTagHandler(ctx *Context) error {
|
||||
ctx.isUnsigned = true
|
||||
|
|
|
@ -90,6 +90,14 @@ func TestSplitTag(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"belongsto(foo, bar)", []tag{
|
||||
{
|
||||
name: "belongsto",
|
||||
params: []string{"foo", "bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, kase := range cases {
|
||||
|
|
Loading…
Reference in New Issue