Support for belongs to relationships
This commit is contained in:
parent
e884f059a4
commit
8826dcbb1d
|
@ -147,6 +147,13 @@ func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
|
||||||
b.WriteString(")")
|
b.WriteString(")")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, col := range table.FKColumns() {
|
||||||
|
b.WriteString(", FOREIGN KEY (")
|
||||||
|
b.WriteString(col.Name)
|
||||||
|
b.WriteString(") REFERENCES ")
|
||||||
|
b.WriteString(col.Reference)
|
||||||
|
}
|
||||||
|
|
||||||
b.WriteString(")")
|
b.WriteString(")")
|
||||||
|
|
||||||
return b.String(), false, nil
|
return b.String(), false, nil
|
||||||
|
|
|
@ -663,6 +663,11 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
|
||||||
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true)
|
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true)
|
||||||
b.WriteString(s)
|
b.WriteString(s)
|
||||||
|
|
||||||
|
if col.Reference != "" {
|
||||||
|
b.WriteString(" FOREIGN KEY REFERENCES ")
|
||||||
|
b.WriteString(col.Reference)
|
||||||
|
}
|
||||||
|
|
||||||
if i != len(table.ColumnsSeq())-1 {
|
if i != len(table.ColumnsSeq())-1 {
|
||||||
b.WriteString(", ")
|
b.WriteString(", ")
|
||||||
}
|
}
|
||||||
|
|
|
@ -683,6 +683,13 @@ func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
|
||||||
b.WriteString(")")
|
b.WriteString(")")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, col := range table.FKColumns() {
|
||||||
|
b.WriteString(", FOREIGN KEY (")
|
||||||
|
b.WriteString(col.Name)
|
||||||
|
b.WriteString(") REFERENCES ")
|
||||||
|
b.WriteString(col.Reference)
|
||||||
|
}
|
||||||
|
|
||||||
b.WriteString(")")
|
b.WriteString(")")
|
||||||
|
|
||||||
if table.StoreEngine != "" {
|
if table.StoreEngine != "" {
|
||||||
|
|
|
@ -644,6 +644,9 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
|
||||||
sql += s
|
sql += s
|
||||||
// }
|
// }
|
||||||
sql = strings.TrimSpace(sql)
|
sql = strings.TrimSpace(sql)
|
||||||
|
if col.Reference != "" {
|
||||||
|
sql += " FOREIGN KEY REFERENCES "+col.Reference
|
||||||
|
}
|
||||||
sql += ", "
|
sql += ", "
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ type Column struct {
|
||||||
Length2 int64
|
Length2 int64
|
||||||
Nullable bool
|
Nullable bool
|
||||||
Default string
|
Default string
|
||||||
|
Reference string
|
||||||
Indexes map[string]int
|
Indexes map[string]int
|
||||||
IsPrimaryKey bool
|
IsPrimaryKey bool
|
||||||
IsAutoIncrement bool
|
IsAutoIncrement bool
|
||||||
|
@ -60,6 +61,7 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int64, nullab
|
||||||
Length2: len2,
|
Length2: len2,
|
||||||
Nullable: nullable,
|
Nullable: nullable,
|
||||||
Default: "",
|
Default: "",
|
||||||
|
Reference: "",
|
||||||
Indexes: make(map[string]int),
|
Indexes: make(map[string]int),
|
||||||
IsPrimaryKey: false,
|
IsPrimaryKey: false,
|
||||||
IsAutoIncrement: false,
|
IsAutoIncrement: false,
|
||||||
|
|
|
@ -19,6 +19,7 @@ type Table struct {
|
||||||
columns []*Column
|
columns []*Column
|
||||||
Indexes map[string]*Index
|
Indexes map[string]*Index
|
||||||
PrimaryKeys []string
|
PrimaryKeys []string
|
||||||
|
ForeignKeys []string
|
||||||
AutoIncrement string
|
AutoIncrement string
|
||||||
Created map[string]bool
|
Created map[string]bool
|
||||||
Updated string
|
Updated string
|
||||||
|
@ -45,6 +46,7 @@ func NewTable(name string, t reflect.Type) *Table {
|
||||||
Indexes: make(map[string]*Index),
|
Indexes: make(map[string]*Index),
|
||||||
Created: make(map[string]bool),
|
Created: make(map[string]bool),
|
||||||
PrimaryKeys: make([]string, 0),
|
PrimaryKeys: make([]string, 0),
|
||||||
|
ForeignKeys: make([]string, 0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,6 +93,15 @@ func (table *Table) PKColumns() []*Column {
|
||||||
return columns
|
return columns
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FKColumns reprents all foreign key columns
|
||||||
|
func (table *Table) FKColumns() []*Column {
|
||||||
|
columns := make([]*Column, len(table.ForeignKeys))
|
||||||
|
for i, name := range table.ForeignKeys {
|
||||||
|
columns[i] = table.GetColumn(name)
|
||||||
|
}
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
// ColumnType returns a column's type
|
// ColumnType returns a column's type
|
||||||
func (table *Table) ColumnType(name string) reflect.Type {
|
func (table *Table) ColumnType(name string) reflect.Type {
|
||||||
t, _ := table.Type.FieldByName(name)
|
t, _ := table.Type.FieldByName(name)
|
||||||
|
@ -131,6 +142,9 @@ func (table *Table) AddColumn(col *Column) {
|
||||||
if col.IsPrimaryKey {
|
if col.IsPrimaryKey {
|
||||||
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
|
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
|
||||||
}
|
}
|
||||||
|
if col.Reference != "" {
|
||||||
|
table.ForeignKeys = append(table.ForeignKeys, col.Name)
|
||||||
|
}
|
||||||
if col.IsAutoIncrement {
|
if col.IsAutoIncrement {
|
||||||
table.AutoIncrement = col.Name
|
table.AutoIncrement = col.Name
|
||||||
}
|
}
|
||||||
|
|
|
@ -534,6 +534,39 @@ func TestParseWithOnlyToDB(t *testing.T) {
|
||||||
assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType)
|
assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseWithBelongsTo(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type VanillaStruct struct {
|
||||||
|
ID int64
|
||||||
|
Name string
|
||||||
|
CreatedAt time.Time `db:"created"`
|
||||||
|
UpdatedAt time.Time `db:"updated"`
|
||||||
|
DeletedAt time.Time `db:"deleted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StructWithBelongsTo struct {
|
||||||
|
Name string
|
||||||
|
Vanilla VanillaStruct `db:"belongsto"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithBelongsTo)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_belongs_to", table.Name)
|
||||||
|
assert.EqualValues(t, 2, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||||
|
assert.EqualValues(t, "vanilla", table.Columns()[1].Name)
|
||||||
|
assert.EqualValues(t, "vanilla_struct(id)", table.Columns()[1].Reference)
|
||||||
|
assert.True(t, table.Columns()[0].Nullable)
|
||||||
|
assert.True(t, table.Columns()[1].Nullable)
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseWithJSON(t *testing.T) {
|
func TestParseWithJSON(t *testing.T) {
|
||||||
parser := NewParser(
|
parser := NewParser(
|
||||||
"db",
|
"db",
|
||||||
|
|
41
tags/tag.go
41
tags/tag.go
|
@ -124,6 +124,8 @@ var defaultTagHandlers = map[string]Handler{
|
||||||
"EXTENDS": ExtendsTagHandler,
|
"EXTENDS": ExtendsTagHandler,
|
||||||
"UNSIGNED": UnsignedTagHandler,
|
"UNSIGNED": UnsignedTagHandler,
|
||||||
"COLLATE": CollateTagHandler,
|
"COLLATE": CollateTagHandler,
|
||||||
|
|
||||||
|
"BELONGSTO": BelongsToTagHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -269,6 +271,45 @@ func UniqueTagHandler(ctx *Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BelongsToTagHandler describes belongs tag handler
|
||||||
|
func BelongsToTagHandler(ctx *Context) error {
|
||||||
|
fieldValue := ctx.fieldValue
|
||||||
|
colName := ""
|
||||||
|
if len(ctx.params) > 0 {
|
||||||
|
colName = ctx.params[0]
|
||||||
|
}
|
||||||
|
switch fieldValue.Kind() {
|
||||||
|
case reflect.Pointer:
|
||||||
|
f := fieldValue.Type().Elem()
|
||||||
|
if f.Kind() == reflect.Struct {
|
||||||
|
fieldPtr := fieldValue
|
||||||
|
fieldValue = fieldValue.Elem()
|
||||||
|
if !fieldValue.IsValid() || fieldPtr.IsNil() {
|
||||||
|
fieldValue = reflect.New(f).Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case reflect.Struct:
|
||||||
|
parentTable, err := ctx.parser.ParseWithCache(fieldValue)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tableName := parentTable.Name
|
||||||
|
if len(ctx.params) > 1 {
|
||||||
|
tableName = ctx.params[1]
|
||||||
|
}
|
||||||
|
for _, col := range parentTable.Columns() {
|
||||||
|
if len(colName) == 0 && col.IsPrimaryKey || colName == col.Name {
|
||||||
|
ctx.col.Reference = tableName+"("+col.Name+")"
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// TODO: warning
|
||||||
|
}
|
||||||
|
return ErrIgnoreField
|
||||||
|
}
|
||||||
|
|
||||||
// UnsignedTagHandler represents the column is unsigned
|
// UnsignedTagHandler represents the column is unsigned
|
||||||
func UnsignedTagHandler(ctx *Context) error {
|
func UnsignedTagHandler(ctx *Context) error {
|
||||||
ctx.isUnsigned = true
|
ctx.isUnsigned = true
|
||||||
|
|
|
@ -90,6 +90,14 @@ func TestSplitTag(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"belongsto(foo, bar)", []tag{
|
||||||
|
{
|
||||||
|
name: "belongsto",
|
||||||
|
params: []string{"foo", "bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, kase := range cases {
|
for _, kase := range cases {
|
||||||
|
|
Loading…
Reference in New Issue