Merge branch 'association-preloading' of gitea.com:sogari/xorm into sogari-association-preloading
This commit is contained in:
commit
e4a88e57ed
12
engine.go
12
engine.go
|
@ -1353,3 +1353,15 @@ func (engine *Engine) Transaction(f func(*Session) (any, error)) (any, error) {
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Preloads builds the preloads
|
||||||
|
func (engine *Engine) Preloads(preloads ...*Preload) *Session {
|
||||||
|
session := engine.NewSession()
|
||||||
|
session.isAutoClose = true
|
||||||
|
return session.Preloads(preloads...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preload creates a preload
|
||||||
|
func (engine *Engine) Preload(path string) *Preload {
|
||||||
|
return NewPreload(path)
|
||||||
|
}
|
||||||
|
|
39
find.go
39
find.go
|
@ -92,6 +92,12 @@ func (session *Session) find(rowsSlicePtr any, condiBean ...any) error {
|
||||||
|
|
||||||
sliceElementType := sliceValue.Type().Elem()
|
sliceElementType := sliceValue.Type().Elem()
|
||||||
|
|
||||||
|
if session.preloadNode != nil {
|
||||||
|
if sliceElementType.Kind() != reflect.Ptr || sliceElementType.Elem().Kind() != reflect.Struct {
|
||||||
|
return errors.New("preloading requires a pointer to a slice or a map of struct pointer")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tp := tpStruct
|
tp := tpStruct
|
||||||
if session.statement.RefTable == nil {
|
if session.statement.RefTable == nil {
|
||||||
if sliceElementType.Kind() == reflect.Ptr {
|
if sliceElementType.Kind() == reflect.Ptr {
|
||||||
|
@ -142,12 +148,43 @@ func (session *Session) find(rowsSlicePtr any, condiBean ...any) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if session.preloadNode != nil {
|
||||||
|
if err := session.preloadNode.Validate(table); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// we need the columns required for the preloads
|
||||||
|
if !session.statement.ColumnMap.IsEmpty() {
|
||||||
|
for _, k := range session.preloadNode.extraCols {
|
||||||
|
session.statement.ColumnMap.Add(k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sqlStr, args, err := session.statement.GenFindSQL(autoCond)
|
sqlStr, args, err := session.statement.GenFindSQL(autoCond)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return session.noCacheFind(table, sliceValue, sqlStr, args...)
|
err = session.noCacheFind(table, sliceValue, sqlStr, args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.preloadNode != nil {
|
||||||
|
if isSlice {
|
||||||
|
// convert to a map before preloading
|
||||||
|
originalSliceValue := sliceValue
|
||||||
|
pkColumn := table.PKColumns()[0]
|
||||||
|
sliceValue = reflect.MakeMap(reflect.MapOf(pkColumn.FieldType, sliceElementType))
|
||||||
|
for i := 0; i < originalSliceValue.Len(); i++ {
|
||||||
|
element := originalSliceValue.Index(i)
|
||||||
|
pkValue, _ := pkColumn.ValueOfV(&element)
|
||||||
|
sliceValue.SetMapIndex(*pkValue, element)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return session.preloadNode.Compute(session, sliceValue)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryedField struct {
|
type QueryedField struct {
|
||||||
|
|
30
get.go
30
get.go
|
@ -67,12 +67,28 @@ func (session *Session) get(beans ...any) (bool, error) {
|
||||||
if err := session.statement.SetRefBean(beans[0]); err != nil {
|
if err := session.statement.SetRefBean(beans[0]); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
} else if session.preloadNode != nil {
|
||||||
|
return false, errors.New("preloading requires a pointer to struct")
|
||||||
}
|
}
|
||||||
|
|
||||||
var sqlStr string
|
var sqlStr string
|
||||||
var args []any
|
var args []any
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
table := session.statement.RefTable
|
||||||
|
|
||||||
|
if session.preloadNode != nil {
|
||||||
|
if err := session.preloadNode.Validate(table); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
// we need the columns required for the preloads
|
||||||
|
if !session.statement.ColumnMap.IsEmpty() {
|
||||||
|
for _, k := range session.preloadNode.extraCols {
|
||||||
|
session.statement.ColumnMap.Add(k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if session.statement.RawSQL == "" {
|
if session.statement.RawSQL == "" {
|
||||||
if len(session.statement.TableName()) == 0 {
|
if len(session.statement.TableName()) == 0 {
|
||||||
return false, ErrTableNotFound
|
return false, ErrTableNotFound
|
||||||
|
@ -87,7 +103,6 @@ func (session *Session) get(beans ...any) (bool, error) {
|
||||||
args = session.statement.RawParams
|
args = session.statement.RawParams
|
||||||
}
|
}
|
||||||
|
|
||||||
table := session.statement.RefTable
|
|
||||||
context := session.statement.Context
|
context := session.statement.Context
|
||||||
if context != nil && isStruct {
|
if context != nil && isStruct {
|
||||||
res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args))
|
res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args))
|
||||||
|
@ -107,6 +122,19 @@ func (session *Session) get(beans ...any) (bool, error) {
|
||||||
return has, err
|
return has, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if session.preloadNode != nil {
|
||||||
|
// convert to a map before preloading
|
||||||
|
pkColumn := table.PKColumns()[0]
|
||||||
|
sliceValue := reflect.MakeMap(reflect.MapOf(pkColumn.FieldType, beanValue.Type()))
|
||||||
|
dataStruct := beanValue.Elem()
|
||||||
|
pkValue, _ := pkColumn.ValueOfV(&dataStruct)
|
||||||
|
sliceValue.SetMapIndex(*pkValue, beanValue)
|
||||||
|
|
||||||
|
if err := session.preloadNode.Compute(session, sliceValue); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if context != nil && isStruct {
|
if context != nil && isStruct {
|
||||||
context.Put(fmt.Sprintf("%v-%v", sqlStr, args), beans[0])
|
context.Put(fmt.Sprintf("%v-%v", sqlStr, args), beans[0])
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,6 +98,10 @@ func (statement *Statement) genColumnStr() string {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if col.Association != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if buf.Len() != 0 {
|
if buf.Len() != 0 {
|
||||||
buf.WriteString(", ")
|
buf.WriteString(", ")
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,206 @@
|
||||||
|
package xorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"xorm.io/builder"
|
||||||
|
"xorm.io/xorm/v2/schemas"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Preload is the representation of an association preload
|
||||||
|
type Preload struct {
|
||||||
|
path []string
|
||||||
|
cols []string
|
||||||
|
cond builder.Cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPreload creates a new preload with the specified path
|
||||||
|
func NewPreload(path string) *Preload {
|
||||||
|
return &Preload{
|
||||||
|
path: strings.Split(path, "."), // list of association names composing the path
|
||||||
|
cond: builder.NewCond(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cols sets column selection for this preload
|
||||||
|
func (p *Preload) Cols(cols ...string) *Preload {
|
||||||
|
p.cols = append(p.cols, cols...)
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where sets the where condition for this preload
|
||||||
|
func (p *Preload) Where(cond builder.Cond) *Preload {
|
||||||
|
p.cond = p.cond.And(cond)
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreloadTreeNode is a tree node for the association preloads
|
||||||
|
type PreloadTreeNode struct {
|
||||||
|
preload *Preload
|
||||||
|
children map[string]*PreloadTreeNode
|
||||||
|
association *schemas.Association
|
||||||
|
extraCols []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPreloadTeeeNode creates a new preload tree node
|
||||||
|
func NewPreloadTeeeNode() *PreloadTreeNode {
|
||||||
|
return &PreloadTreeNode{
|
||||||
|
children: make(map[string]*PreloadTreeNode),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a node to the preload tree
|
||||||
|
func (node *PreloadTreeNode) Add(preload *Preload) error {
|
||||||
|
return node.add(preload, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// add adds a node to the preload tree in a recursion level
|
||||||
|
func (node *PreloadTreeNode) add(preload *Preload, level int) error {
|
||||||
|
if level == len(preload.path) {
|
||||||
|
if node.preload != nil {
|
||||||
|
return fmt.Errorf("preload: duplicated path: %s", strings.Join(preload.path, ","))
|
||||||
|
}
|
||||||
|
node.preload = preload
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
child, ok := node.children[preload.path[level]]
|
||||||
|
if !ok {
|
||||||
|
child = NewPreloadTeeeNode()
|
||||||
|
node.children[preload.path[level]] = child
|
||||||
|
}
|
||||||
|
return child.add(preload, level+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates a preload tree node against a table schema and sets the association
|
||||||
|
func (node *PreloadTreeNode) Validate(table *schemas.Table) error {
|
||||||
|
if node.preload != nil {
|
||||||
|
for _, col := range node.preload.cols {
|
||||||
|
if table.GetColumn(col) == nil {
|
||||||
|
return fmt.Errorf("preload: missing col %s in table %s", col, table.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for name, child := range node.children {
|
||||||
|
column := table.GetColumn(name)
|
||||||
|
if column == nil {
|
||||||
|
return fmt.Errorf("preload: missing field %s in struct %s", name, table.Type.Name())
|
||||||
|
}
|
||||||
|
if column.Association == nil {
|
||||||
|
return fmt.Errorf("preload: missing association in field %s", name)
|
||||||
|
}
|
||||||
|
if column.Association.JoinTable == nil && len(column.Association.SourceCol) > 0 {
|
||||||
|
node.extraCols = append(node.extraCols, column.Association.SourceCol)
|
||||||
|
}
|
||||||
|
if len(column.Association.TargetCol) > 0 {
|
||||||
|
node.extraCols = append(node.extraCols, table.PrimaryKeys[0]) // pk might be added many times, but that's ok
|
||||||
|
}
|
||||||
|
if column.Association.JoinTable == nil && len(column.Association.TargetCol) > 0 {
|
||||||
|
child.extraCols = append(child.extraCols, column.Association.TargetCol)
|
||||||
|
}
|
||||||
|
child.association = column.Association
|
||||||
|
if err := child.Validate(column.Association.RefTable); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute preloads the associations contained in the preload tree
|
||||||
|
func (node *PreloadTreeNode) Compute(session *Session, ownMap reflect.Value) error {
|
||||||
|
for _, node := range node.children {
|
||||||
|
if err := node.compute(session, ownMap, reflect.Value{}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute preloads the association contained in a preload tree node
|
||||||
|
func (node *PreloadTreeNode) compute(session *Session, ownMap, pruneMap reflect.Value) error {
|
||||||
|
// non-root node: association is not nil
|
||||||
|
if err := node.association.ValidateOwnMap(ownMap); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var joinMap reflect.Value
|
||||||
|
cond := node.association.GetCond(ownMap)
|
||||||
|
if node.association.JoinTable != nil {
|
||||||
|
var err error
|
||||||
|
cond, joinMap, err = node.preloadJoin(session, cond)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
refMap := node.association.MakeRefMap()
|
||||||
|
preloadSession := session.Engine().Where(cond)
|
||||||
|
if node.preload != nil {
|
||||||
|
if len(node.preload.cols) > 0 {
|
||||||
|
preloadSession.Cols(node.extraCols...).Cols(node.preload.cols...)
|
||||||
|
}
|
||||||
|
preloadSession.Where(node.preload.cond)
|
||||||
|
} else {
|
||||||
|
preloadSession.Cols(node.extraCols...)
|
||||||
|
}
|
||||||
|
if err := preloadSession.Find(refMap.Interface()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var refPruneMap reflect.Value
|
||||||
|
if len(node.children) > 0 && !(node.preload != nil && len(node.preload.cols) > 0) {
|
||||||
|
refPruneMap = reflect.MakeMap(reflect.MapOf(refMap.Type().Key(), reflect.TypeOf(true)))
|
||||||
|
refIter := refMap.MapRange()
|
||||||
|
for refIter.Next() {
|
||||||
|
refPruneMap.SetMapIndex(refIter.Key(), reflect.ValueOf(true))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, node := range node.children {
|
||||||
|
if err := node.compute(session, refMap, refPruneMap); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if refPruneMap.IsValid() {
|
||||||
|
pruneIter := refPruneMap.MapRange()
|
||||||
|
for pruneIter.Next() {
|
||||||
|
refMap.SetMapIndex(pruneIter.Key(), reflect.Value{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
node.association.Link(ownMap, refMap, pruneMap, joinMap)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// preloadJoin obtains a join condition and a join map for a many-to-many association
|
||||||
|
func (node *PreloadTreeNode) preloadJoin(session *Session, cond builder.Cond) (builder.Cond, reflect.Value, error) {
|
||||||
|
joinSlicePtr := node.association.NewJoinSlice()
|
||||||
|
if err := session.Engine().
|
||||||
|
Table(node.association.JoinTable.Name).Where(cond).
|
||||||
|
Cols(node.association.SourceCol, node.association.TargetCol).
|
||||||
|
Find(joinSlicePtr.Interface()); err != nil {
|
||||||
|
return nil, reflect.Value{}, err
|
||||||
|
}
|
||||||
|
joinSlice := joinSlicePtr.Elem()
|
||||||
|
|
||||||
|
joinMap := node.association.MakeJoinMap()
|
||||||
|
for i := 0; i < joinSlice.Len(); i++ {
|
||||||
|
entry := joinSlice.Index(i)
|
||||||
|
pkSlice := joinMap.MapIndex(entry.Field(1))
|
||||||
|
if !pkSlice.IsValid() {
|
||||||
|
pkSlice = reflect.MakeSlice(reflect.SliceOf(node.association.OwnPkType), 0, 0)
|
||||||
|
}
|
||||||
|
joinMap.SetMapIndex(entry.Field(1), reflect.Append(pkSlice, entry.Field(0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var refPks []interface{}
|
||||||
|
iter := joinMap.MapRange()
|
||||||
|
joinMap.MapKeys()
|
||||||
|
for iter.Next() {
|
||||||
|
refPks = append(refPks, iter.Key().Interface())
|
||||||
|
}
|
||||||
|
cond = builder.In(node.association.RefTable.PrimaryKeys[0], refPks)
|
||||||
|
return cond, joinMap, nil
|
||||||
|
}
|
|
@ -0,0 +1,278 @@
|
||||||
|
package xorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
"xorm.io/builder"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://gitea.com/xorm/xorm/issues/2240
|
||||||
|
|
||||||
|
type Employee struct {
|
||||||
|
Id int64
|
||||||
|
Name string
|
||||||
|
BuddyId *int64
|
||||||
|
ManagerId *int64
|
||||||
|
Buddy *Employee `xorm:"belongs_to(buddy_id)"`
|
||||||
|
Apprentice *Employee `xorm:"has_one(buddy_id)"`
|
||||||
|
Manager *Employee `xorm:"belongs_to(manager_id)"`
|
||||||
|
Subordinates []*Employee `xorm:"has_many(manager_id)"`
|
||||||
|
Indications []*Employee `xorm:"many_to_many(employee_indication, indicator_id, indicated_id)"`
|
||||||
|
IndicatedBy []*Employee `xorm:"many_to_many(employee_indication, indicated_id, indicator_id)"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreload(t *testing.T) {
|
||||||
|
engine, err := NewEngine("sqlite3", ":memory:")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
const sql = `
|
||||||
|
create table employee (
|
||||||
|
id integer primary key autoincrement,
|
||||||
|
name text not null,
|
||||||
|
buddy_id integer references employee(id) check (buddy_id <> id) unique,
|
||||||
|
manager_id integer references employee(id) check (manager_id <> id)
|
||||||
|
);
|
||||||
|
|
||||||
|
create table employee_indication (
|
||||||
|
indicator_id integer not null references employee(id),
|
||||||
|
indicated_id integer not null references employee(id),
|
||||||
|
primary key (indicator_id, indicated_id),
|
||||||
|
check (indicator_id <> indicated_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
insert into employee (name) values ('John'), ('Bob');
|
||||||
|
insert into employee (name,manager_id) values ('Alice',1), ('Riya',2);
|
||||||
|
insert into employee (name,manager_id,buddy_id) values ('Emilie',1,3), ('Cynthia',2,4);
|
||||||
|
insert into employee_indication values (1,2), (1,3), (2,3), (2,4), (2,5), (3,5), (3,6);
|
||||||
|
|
||||||
|
-- John manages Alice and Emilie
|
||||||
|
-- Bob manages Riya and Cynthia
|
||||||
|
-- Alice is buddy of Emilie
|
||||||
|
-- Riya is buddy of Cynthia
|
||||||
|
-- John indicated Bob and Alice
|
||||||
|
-- Bob indicated Alice, Riya and Emilie
|
||||||
|
-- Alice indicated Emilie and Cynthia
|
||||||
|
`
|
||||||
|
_, err = engine.Exec(sql)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var employee Employee
|
||||||
|
_, err = engine.Preloads(
|
||||||
|
engine.Preload("Indications.Buddy").Cols("name"),
|
||||||
|
engine.Preload("Indications").Cols("id"),
|
||||||
|
).Cols("name").Where(builder.Eq{"id": 2}).Get(&employee)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// order is not preserved when preloading, so we compensate for it in the test
|
||||||
|
sort.Slice(employee.Indications, func(i, j int) bool {
|
||||||
|
return employee.Indications[i].Id < employee.Indications[j].Id
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, Employee{
|
||||||
|
Id: 2,
|
||||||
|
Name: "Bob",
|
||||||
|
Indications: []*Employee{
|
||||||
|
{Id: 3},
|
||||||
|
{Id: 4},
|
||||||
|
{
|
||||||
|
Id: 5,
|
||||||
|
BuddyId: &[]int64{3}[0],
|
||||||
|
Buddy: &Employee{
|
||||||
|
Id: 3,
|
||||||
|
Name: "Alice",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, employee)
|
||||||
|
|
||||||
|
var employees []*Employee
|
||||||
|
err = engine.Preloads(
|
||||||
|
// 1. preload the names of all subordinates of this employee's manager
|
||||||
|
engine.Preload("Manager.Subordinates").Cols("name"),
|
||||||
|
// 2. preload the name of the buddy of each employee indicated by this employee
|
||||||
|
engine.Preload("Indications.Buddy").Cols("name"),
|
||||||
|
// 3. preload the names of all employees who indicated this employee's apprentice, except non-subordinates
|
||||||
|
engine.Preload("Apprentice.IndicatedBy").Cols("name").Where(builder.NotNull{"manager_id"}),
|
||||||
|
// 4. preload the names of:
|
||||||
|
// all employees who don't have a maanger and were indicated by:
|
||||||
|
engine.Preload("Subordinates.IndicatedBy.Indications").Where(builder.IsNull{"manager_id"}).Cols("name"),
|
||||||
|
// employees whose name is 4 letters long and indicated:
|
||||||
|
engine.Preload("Subordinates.IndicatedBy").Where(builder.Like{"name", "____"}),
|
||||||
|
// this employee's subordinates who don't have a buddy
|
||||||
|
engine.Preload("Subordinates").Where(builder.IsNull{"buddy_id"}),
|
||||||
|
// 0. find the names of all employees
|
||||||
|
).Cols("name").Find(&employees)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// order is not preserved when preloading, so we compensate for it in the test
|
||||||
|
for k := 2; k < 6; k++ {
|
||||||
|
sort.Slice(employees[k].Manager.Subordinates, func(i, j int) bool {
|
||||||
|
return employees[k].Manager.Subordinates[i].Id < employees[k].Manager.Subordinates[j].Id
|
||||||
|
})
|
||||||
|
}
|
||||||
|
sort.Slice(employees[2].Indications, func(i, j int) bool {
|
||||||
|
return employees[2].Indications[i].Id < employees[2].Indications[j].Id
|
||||||
|
})
|
||||||
|
|
||||||
|
expected := []*Employee{
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Name: "John",
|
||||||
|
Subordinates: []*Employee{
|
||||||
|
{
|
||||||
|
Id: 3,
|
||||||
|
Name: "Alice",
|
||||||
|
ManagerId: &[]int64{1}[0],
|
||||||
|
IndicatedBy: []*Employee{
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Name: "John",
|
||||||
|
Indications: []*Employee{
|
||||||
|
{
|
||||||
|
Id: 2,
|
||||||
|
Name: "Bob",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 2,
|
||||||
|
Name: "Bob",
|
||||||
|
Indications: []*Employee{
|
||||||
|
{
|
||||||
|
Id: 5,
|
||||||
|
BuddyId: &[]int64{3}[0],
|
||||||
|
Buddy: &Employee{
|
||||||
|
Id: 3,
|
||||||
|
Name: "Alice",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 3,
|
||||||
|
Name: "Alice",
|
||||||
|
ManagerId: &[]int64{1}[0],
|
||||||
|
Manager: &Employee{
|
||||||
|
Id: 1,
|
||||||
|
Subordinates: []*Employee{
|
||||||
|
{
|
||||||
|
Id: 3,
|
||||||
|
Name: "Alice",
|
||||||
|
ManagerId: &[]int64{1}[0],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 5,
|
||||||
|
Name: "Emilie",
|
||||||
|
ManagerId: &[]int64{1}[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Apprentice: &Employee{
|
||||||
|
Id: 5,
|
||||||
|
BuddyId: &[]int64{3}[0],
|
||||||
|
IndicatedBy: []*Employee{
|
||||||
|
{
|
||||||
|
Id: 3,
|
||||||
|
Name: "Alice",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Indications: []*Employee{
|
||||||
|
{
|
||||||
|
Id: 5,
|
||||||
|
BuddyId: &[]int64{3}[0],
|
||||||
|
Buddy: &Employee{
|
||||||
|
Id: 3,
|
||||||
|
Name: "Alice",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 6,
|
||||||
|
BuddyId: &[]int64{4}[0],
|
||||||
|
Buddy: &Employee{
|
||||||
|
Id: 4,
|
||||||
|
Name: "Riya",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 4,
|
||||||
|
Name: "Riya",
|
||||||
|
ManagerId: &[]int64{2}[0],
|
||||||
|
Manager: &Employee{
|
||||||
|
Id: 2,
|
||||||
|
Subordinates: []*Employee{
|
||||||
|
{
|
||||||
|
Id: 4,
|
||||||
|
Name: "Riya",
|
||||||
|
ManagerId: &[]int64{2}[0],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 6,
|
||||||
|
Name: "Cynthia",
|
||||||
|
ManagerId: &[]int64{2}[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Apprentice: &Employee{
|
||||||
|
Id: 6,
|
||||||
|
BuddyId: &[]int64{4}[0],
|
||||||
|
IndicatedBy: []*Employee{
|
||||||
|
{
|
||||||
|
Id: 3,
|
||||||
|
Name: "Alice",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 5,
|
||||||
|
Name: "Emilie",
|
||||||
|
ManagerId: &[]int64{1}[0],
|
||||||
|
Manager: &Employee{
|
||||||
|
Id: 1,
|
||||||
|
Subordinates: []*Employee{
|
||||||
|
{
|
||||||
|
Id: 3,
|
||||||
|
Name: "Alice",
|
||||||
|
ManagerId: &[]int64{1}[0],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 5,
|
||||||
|
Name: "Emilie",
|
||||||
|
ManagerId: &[]int64{1}[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 6,
|
||||||
|
Name: "Cynthia",
|
||||||
|
ManagerId: &[]int64{2}[0],
|
||||||
|
Manager: &Employee{
|
||||||
|
Id: 2,
|
||||||
|
Subordinates: []*Employee{
|
||||||
|
{
|
||||||
|
Id: 4,
|
||||||
|
Name: "Riya",
|
||||||
|
ManagerId: &[]int64{2}[0],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 6,
|
||||||
|
Name: "Cynthia",
|
||||||
|
ManagerId: &[]int64{2}[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.Equal(t, expected, employees)
|
||||||
|
}
|
|
@ -0,0 +1,180 @@
|
||||||
|
package schemas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"xorm.io/builder"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Association is the representation of an association
|
||||||
|
type Association struct {
|
||||||
|
OwnTable *Table
|
||||||
|
OwnColumn *Column
|
||||||
|
OwnPkType reflect.Type
|
||||||
|
RefTable *Table
|
||||||
|
RefPkType reflect.Type
|
||||||
|
JoinTable *Table // many_to_many
|
||||||
|
SourceCol string // belongs_to, many_to_many
|
||||||
|
TargetCol string // has_one, has_many, many_to_many
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJoinSlice creates a slice to hold the intermediate result of a many-to-many association query
|
||||||
|
func (a *Association) NewJoinSlice() reflect.Value {
|
||||||
|
return reflect.New(reflect.SliceOf(a.JoinTable.Type))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeJoinMap creates a map to hold the intermediate result of a many-to-many association
|
||||||
|
func (a *Association) MakeJoinMap() reflect.Value {
|
||||||
|
return reflect.MakeMap(reflect.MapOf(a.RefPkType, reflect.SliceOf(a.OwnPkType)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeRefMap creates a map to hold the result of an association query
|
||||||
|
func (a *Association) MakeRefMap() reflect.Value {
|
||||||
|
return reflect.MakeMap(reflect.MapOf(a.RefPkType, reflect.PtrTo(a.RefTable.Type)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateOwnMap validates the type of the owner map (parent of an association)
|
||||||
|
func (a *Association) ValidateOwnMap(ownMap reflect.Value) error {
|
||||||
|
if ownMap.Type() != reflect.MapOf(a.OwnPkType, reflect.PtrTo(a.OwnTable.Type)) {
|
||||||
|
return fmt.Errorf("wrong map type: %v", ownMap.Type())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCond gets a where condition to use in an association query
|
||||||
|
func (a *Association) GetCond(ownMap reflect.Value) builder.Cond {
|
||||||
|
if a.JoinTable != nil {
|
||||||
|
return a.condManyToMany(ownMap)
|
||||||
|
}
|
||||||
|
if len(a.SourceCol) > 0 {
|
||||||
|
return a.condBelongsTo(ownMap)
|
||||||
|
}
|
||||||
|
return a.condHasOneOrMany(ownMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// condBelongsTo gets a where condition to use in a belongs-to association query
|
||||||
|
func (a *Association) condBelongsTo(ownMap reflect.Value) builder.Cond {
|
||||||
|
pkMap := make(map[interface{}]bool)
|
||||||
|
fkCol := a.OwnTable.GetColumn(a.SourceCol)
|
||||||
|
iter := ownMap.MapRange()
|
||||||
|
for iter.Next() {
|
||||||
|
structPtr := iter.Value()
|
||||||
|
fk, _ := fkCol.ValueOfV(&structPtr)
|
||||||
|
if fk.Type().Kind() == reflect.Ptr {
|
||||||
|
if fk.IsNil() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
*fk = fk.Elem()
|
||||||
|
}
|
||||||
|
// don't check fk.IsZero(), because it might be a valid fk value
|
||||||
|
pkMap[fk.Interface()] = true
|
||||||
|
}
|
||||||
|
pks := make([]interface{}, 0, len(pkMap))
|
||||||
|
for pk := range pkMap {
|
||||||
|
pks = append(pks, pk)
|
||||||
|
}
|
||||||
|
return builder.In(a.RefTable.PrimaryKeys[0], pks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// condHasOneOrMany gets a where condition to use in a has-one or has-many association query
|
||||||
|
func (a *Association) condHasOneOrMany(ownMap reflect.Value) builder.Cond {
|
||||||
|
var pks []interface{}
|
||||||
|
iter := ownMap.MapRange()
|
||||||
|
for iter.Next() {
|
||||||
|
pks = append(pks, iter.Key().Interface())
|
||||||
|
}
|
||||||
|
return builder.In(a.TargetCol, pks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// condHasOneOrMany gets a where condition to use in a many-to-many association query
|
||||||
|
func (a *Association) condManyToMany(ownMap reflect.Value) builder.Cond {
|
||||||
|
var pks []interface{}
|
||||||
|
iter := ownMap.MapRange()
|
||||||
|
for iter.Next() {
|
||||||
|
pks = append(pks, iter.Key().Interface())
|
||||||
|
}
|
||||||
|
return builder.In(a.SourceCol, pks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Link links the owner (parent) values with the referenced association values
|
||||||
|
func (a *Association) Link(ownMap, refMap, pruneMap, joinMap reflect.Value) {
|
||||||
|
if a.JoinTable != nil {
|
||||||
|
a.linkManyToMany(ownMap, refMap, pruneMap, joinMap)
|
||||||
|
} else if len(a.SourceCol) > 0 {
|
||||||
|
a.linkBelongsTo(ownMap, refMap, pruneMap)
|
||||||
|
} else {
|
||||||
|
a.linkHasOneOrMany(ownMap, refMap, pruneMap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// linkBelongsTo links the owner (parent) values with the referenced belongs-to association values
|
||||||
|
func (a *Association) linkBelongsTo(ownMap, refMap, pruneMap reflect.Value) {
|
||||||
|
fkCol := a.OwnTable.GetColumn(a.SourceCol)
|
||||||
|
iter := ownMap.MapRange()
|
||||||
|
for iter.Next() {
|
||||||
|
structPtr := iter.Value()
|
||||||
|
fk, _ := fkCol.ValueOfV(&structPtr)
|
||||||
|
if fk.Type().Kind() == reflect.Ptr {
|
||||||
|
if fk.IsNil() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
*fk = fk.Elem()
|
||||||
|
}
|
||||||
|
// don't check fk.IsZero(), because it might be a valid fk value
|
||||||
|
refStructPtr := refMap.MapIndex(*fk)
|
||||||
|
if refStructPtr.IsValid() {
|
||||||
|
refField, _ := a.OwnColumn.ValueOfV(&structPtr)
|
||||||
|
refField.Set(refStructPtr)
|
||||||
|
if pruneMap.IsValid() {
|
||||||
|
pruneMap.SetMapIndex(iter.Key(), reflect.Value{}) // do not prune this key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// linkBelongsTo links the owner (parent) values with the referenced has-one or has-many association values
|
||||||
|
func (a *Association) linkHasOneOrMany(ownMap, refMap, pruneMap reflect.Value) {
|
||||||
|
hasMany := a.OwnColumn.FieldType.Kind() == reflect.Slice
|
||||||
|
fkCol := a.RefTable.GetColumn(a.TargetCol)
|
||||||
|
iter := refMap.MapRange()
|
||||||
|
for iter.Next() {
|
||||||
|
refStructPtr := iter.Value()
|
||||||
|
fk, _ := fkCol.ValueOfV(&refStructPtr)
|
||||||
|
if fk.Type().Kind() == reflect.Ptr {
|
||||||
|
if fk.IsNil() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
*fk = fk.Elem()
|
||||||
|
}
|
||||||
|
// don't check fk.IsZero(), because it might be a valid fk value
|
||||||
|
structPtr := ownMap.MapIndex(*fk) // structPtr should be valid at this point
|
||||||
|
refField, _ := a.OwnColumn.ValueOfV(&structPtr)
|
||||||
|
if hasMany {
|
||||||
|
refField.Set(reflect.Append(*refField, refStructPtr))
|
||||||
|
} else {
|
||||||
|
refField.Set(refStructPtr)
|
||||||
|
}
|
||||||
|
if pruneMap.IsValid() {
|
||||||
|
pruneMap.SetMapIndex(*fk, reflect.Value{}) // do not prune this key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// linkManyToMany links the owner (parent) values with the referenced many-to-many association values
|
||||||
|
func (a *Association) linkManyToMany(ownMap, refMap, pruneMap, joinMap reflect.Value) {
|
||||||
|
iter := refMap.MapRange()
|
||||||
|
for iter.Next() {
|
||||||
|
refStructPtr := iter.Value()
|
||||||
|
refPk := iter.Key() // refPk should not be a pointer
|
||||||
|
pkSlice := joinMap.MapIndex(refPk) // pkSlice should be valid at this point
|
||||||
|
for i := 0; i < pkSlice.Len(); i++ {
|
||||||
|
pk := pkSlice.Index(i)
|
||||||
|
structPtr := ownMap.MapIndex(pk) // structPtr should be valid at this point
|
||||||
|
refField, _ := a.OwnColumn.ValueOfV(&structPtr)
|
||||||
|
refField.Set(reflect.Append(*refField, refStructPtr))
|
||||||
|
if pruneMap.IsValid() {
|
||||||
|
pruneMap.SetMapIndex(pk, reflect.Value{}) // do not prune this key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -21,9 +21,9 @@ const (
|
||||||
// Column defines database column
|
// Column defines database column
|
||||||
type Column struct {
|
type Column struct {
|
||||||
Name string
|
Name string
|
||||||
TableName string
|
FieldName string // Available only when parsed from a struct
|
||||||
FieldName string // Available only when parsed from a struct
|
FieldIndex []int // Available only when parsed from a struct
|
||||||
FieldIndex []int // Available only when parsed from a struct
|
FieldType reflect.Type // Available only when parsed from a struct
|
||||||
SQLType SQLType
|
SQLType SQLType
|
||||||
IsJSON bool
|
IsJSON bool
|
||||||
Length int64
|
Length int64
|
||||||
|
@ -46,6 +46,7 @@ type Column struct {
|
||||||
TimeZone *time.Location // column specified time zone
|
TimeZone *time.Location // column specified time zone
|
||||||
Comment string
|
Comment string
|
||||||
Collation string
|
Collation string
|
||||||
|
Association *Association
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewColumn creates a new column
|
// NewColumn creates a new column
|
||||||
|
@ -53,7 +54,6 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int64, nullab
|
||||||
return &Column{
|
return &Column{
|
||||||
Name: name,
|
Name: name,
|
||||||
IsJSON: sqlType.IsJson(),
|
IsJSON: sqlType.IsJson(),
|
||||||
TableName: "",
|
|
||||||
FieldName: fieldName,
|
FieldName: fieldName,
|
||||||
SQLType: sqlType,
|
SQLType: sqlType,
|
||||||
Length: len1,
|
Length: len1,
|
||||||
|
|
15
session.go
15
session.go
|
@ -85,6 +85,7 @@ type Session struct {
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
sessionType sessionType
|
sessionType sessionType
|
||||||
|
preloadNode *PreloadTreeNode
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSessionID() string {
|
func newSessionID() string {
|
||||||
|
@ -712,3 +713,17 @@ func (session *Session) NoVersionCheck() *Session {
|
||||||
func SetDefaultJSONHandler(jsonHandler json.Interface) {
|
func SetDefaultJSONHandler(jsonHandler json.Interface) {
|
||||||
json.DefaultJSONHandler = jsonHandler
|
json.DefaultJSONHandler = jsonHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Preloads adds preloads
|
||||||
|
func (session *Session) Preloads(preloads ...*Preload) *Session {
|
||||||
|
if session.preloadNode == nil {
|
||||||
|
session.preloadNode = NewPreloadTeeeNode()
|
||||||
|
}
|
||||||
|
for _, preload := range preloads {
|
||||||
|
if err := session.preloadNode.Add(preload); err != nil {
|
||||||
|
session.statement.LastError = err
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
|
@ -85,7 +85,7 @@ func (parser *Parser) SetIdentifier(identifier string) {
|
||||||
|
|
||||||
// ParseWithCache parse a struct with cache
|
// ParseWithCache parse a struct with cache
|
||||||
func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) {
|
func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) {
|
||||||
t := v.Type()
|
t := reflect.Indirect(v).Type()
|
||||||
tableI, ok := parser.tableCache.Load(t)
|
tableI, ok := parser.tableCache.Load(t)
|
||||||
if ok {
|
if ok {
|
||||||
return tableI.(*schemas.Table), nil
|
return tableI.(*schemas.Table), nil
|
||||||
|
@ -165,6 +165,7 @@ func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructFi
|
||||||
field.Name, sqlType, sqlType.DefaultLength,
|
field.Name, sqlType, sqlType.DefaultLength,
|
||||||
sqlType.DefaultLength2, true)
|
sqlType.DefaultLength2, true)
|
||||||
col.FieldIndex = []int{fieldIndex}
|
col.FieldIndex = []int{fieldIndex}
|
||||||
|
col.FieldType = field.Type
|
||||||
|
|
||||||
if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
|
if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
|
||||||
col.IsAutoIncrement = true
|
col.IsAutoIncrement = true
|
||||||
|
@ -178,6 +179,7 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, f
|
||||||
col := &schemas.Column{
|
col := &schemas.Column{
|
||||||
FieldName: field.Name,
|
FieldName: field.Name,
|
||||||
FieldIndex: []int{fieldIndex},
|
FieldIndex: []int{fieldIndex},
|
||||||
|
FieldType: field.Type,
|
||||||
Nullable: true,
|
Nullable: true,
|
||||||
IsPrimaryKey: false,
|
IsPrimaryKey: false,
|
||||||
IsAutoIncrement: false,
|
IsAutoIncrement: false,
|
||||||
|
|
196
tags/tag.go
196
tags/tag.go
|
@ -5,6 +5,7 @@
|
||||||
package tags
|
package tags
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -99,27 +100,31 @@ type Handler func(ctx *Context) error
|
||||||
|
|
||||||
// defaultTagHandlers enumerates all the default tag handler
|
// defaultTagHandlers enumerates all the default tag handler
|
||||||
var defaultTagHandlers = map[string]Handler{
|
var defaultTagHandlers = map[string]Handler{
|
||||||
"-": IgnoreHandler,
|
"-": IgnoreHandler,
|
||||||
"<-": OnlyFromDBTagHandler,
|
"<-": OnlyFromDBTagHandler,
|
||||||
"->": OnlyToDBTagHandler,
|
"->": OnlyToDBTagHandler,
|
||||||
"PK": PKTagHandler,
|
"PK": PKTagHandler,
|
||||||
"NULL": NULLTagHandler,
|
"NULL": NULLTagHandler,
|
||||||
"NOT": NotTagHandler,
|
"NOT": NotTagHandler,
|
||||||
"AUTOINCR": AutoIncrTagHandler,
|
"AUTOINCR": AutoIncrTagHandler,
|
||||||
"DEFAULT": DefaultTagHandler,
|
"DEFAULT": DefaultTagHandler,
|
||||||
"CREATED": CreatedTagHandler,
|
"CREATED": CreatedTagHandler,
|
||||||
"UPDATED": UpdatedTagHandler,
|
"UPDATED": UpdatedTagHandler,
|
||||||
"DELETED": DeletedTagHandler,
|
"DELETED": DeletedTagHandler,
|
||||||
"VERSION": VersionTagHandler,
|
"VERSION": VersionTagHandler,
|
||||||
"UTC": UTCTagHandler,
|
"UTC": UTCTagHandler,
|
||||||
"LOCAL": LocalTagHandler,
|
"LOCAL": LocalTagHandler,
|
||||||
"NOTNULL": NotNullTagHandler,
|
"NOTNULL": NotNullTagHandler,
|
||||||
"INDEX": IndexTagHandler,
|
"INDEX": IndexTagHandler,
|
||||||
"UNIQUE": UniqueTagHandler,
|
"UNIQUE": UniqueTagHandler,
|
||||||
"COMMENT": CommentTagHandler,
|
"COMMENT": CommentTagHandler,
|
||||||
"EXTENDS": ExtendsTagHandler,
|
"EXTENDS": ExtendsTagHandler,
|
||||||
"UNSIGNED": UnsignedTagHandler,
|
"UNSIGNED": UnsignedTagHandler,
|
||||||
"COLLATE": CollateTagHandler,
|
"COLLATE": CollateTagHandler,
|
||||||
|
"BELONGS_TO": BelongsToTagHandler,
|
||||||
|
"HAS_ONE": HasOneTagHandler,
|
||||||
|
"HAS_MANY": HasManyTagHandler,
|
||||||
|
"MANY_TO_MANY": ManyToManyTagHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -387,3 +392,152 @@ func ExtendsTagHandler(ctx *Context) error {
|
||||||
}
|
}
|
||||||
return ErrIgnoreField
|
return ErrIgnoreField
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isStructPtr(t reflect.Type) bool {
|
||||||
|
return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
|
||||||
|
}
|
||||||
|
|
||||||
|
func isStructPtrSlice(t reflect.Type) bool {
|
||||||
|
return t.Kind() == reflect.Slice && isStructPtr(t.Elem())
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRefTableAndPks(ctx *Context, refType reflect.Type) ([]*schemas.Column, *schemas.Table, []*schemas.Column, error) {
|
||||||
|
pks := ctx.table.PKColumns()
|
||||||
|
if len(pks) != 1 {
|
||||||
|
return nil, nil, nil, errors.New("a single-column primary key must be declared before the association")
|
||||||
|
}
|
||||||
|
|
||||||
|
refTable := ctx.table // self-reference case
|
||||||
|
refPks := pks
|
||||||
|
if refTable.Type != refType {
|
||||||
|
var err error
|
||||||
|
refTable, err = ctx.parser.ParseWithCache(reflect.New(refType).Elem())
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
refPks = refTable.PKColumns()
|
||||||
|
if len(refPks) != 1 {
|
||||||
|
return nil, nil, nil, errors.New("only single-column primary key is supported in associations")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pks, refTable, refPks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BelongsToTagHandler describes belongs_to tag handler
|
||||||
|
func BelongsToTagHandler(ctx *Context) error {
|
||||||
|
if !isStructPtr(ctx.fieldValue.Type()) {
|
||||||
|
return errors.New("tag belongs_to can only be applied to struct pointer field")
|
||||||
|
}
|
||||||
|
if len(ctx.params) != 1 {
|
||||||
|
return errors.New("tag belongs_to requires one parameter")
|
||||||
|
}
|
||||||
|
pks, refTable, refPks, err := getRefTableAndPks(ctx, ctx.fieldValue.Type().Elem())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.col.Association = &schemas.Association{
|
||||||
|
OwnTable: ctx.table,
|
||||||
|
OwnColumn: ctx.col,
|
||||||
|
OwnPkType: pks[0].FieldType,
|
||||||
|
RefTable: refTable,
|
||||||
|
RefPkType: refPks[0].FieldType,
|
||||||
|
SourceCol: ctx.params[0],
|
||||||
|
}
|
||||||
|
ctx.col.Name = ctx.col.FieldName
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasOneTagHandler describes has_one tag handler
|
||||||
|
func HasOneTagHandler(ctx *Context) error {
|
||||||
|
if !isStructPtr(ctx.fieldValue.Type()) {
|
||||||
|
return errors.New("tag has_one can only be applied to struct pointer field")
|
||||||
|
}
|
||||||
|
if len(ctx.params) != 1 {
|
||||||
|
return errors.New("tag has_one requires one parameter")
|
||||||
|
}
|
||||||
|
pks, refTable, refPks, err := getRefTableAndPks(ctx, ctx.fieldValue.Type().Elem())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.col.Association = &schemas.Association{
|
||||||
|
OwnTable: ctx.table,
|
||||||
|
OwnColumn: ctx.col,
|
||||||
|
OwnPkType: pks[0].FieldType,
|
||||||
|
RefTable: refTable,
|
||||||
|
RefPkType: refPks[0].FieldType,
|
||||||
|
TargetCol: ctx.params[0],
|
||||||
|
}
|
||||||
|
ctx.col.Name = ctx.col.FieldName
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasManyTagHandler describes has_many tag handler
|
||||||
|
func HasManyTagHandler(ctx *Context) error {
|
||||||
|
if !isStructPtrSlice(ctx.fieldValue.Type()) {
|
||||||
|
return errors.New("tag has_many can only be applied to slice of struct pointer field")
|
||||||
|
}
|
||||||
|
if len(ctx.params) != 1 {
|
||||||
|
return errors.New("tag has_many requires one parameter")
|
||||||
|
}
|
||||||
|
pks, refTable, refPks, err := getRefTableAndPks(ctx, ctx.fieldValue.Type().Elem().Elem())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.col.Association = &schemas.Association{
|
||||||
|
OwnTable: ctx.table,
|
||||||
|
OwnColumn: ctx.col,
|
||||||
|
OwnPkType: pks[0].FieldType,
|
||||||
|
RefTable: refTable,
|
||||||
|
RefPkType: refPks[0].FieldType,
|
||||||
|
TargetCol: ctx.params[0],
|
||||||
|
}
|
||||||
|
ctx.col.Name = ctx.col.FieldName
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManyToManyTagHandler describes many_to_many tag handler
|
||||||
|
func ManyToManyTagHandler(ctx *Context) error {
|
||||||
|
if !isStructPtrSlice(ctx.fieldValue.Type()) {
|
||||||
|
return errors.New("tag many_to_many can only be applied to slice of struct pointer field")
|
||||||
|
}
|
||||||
|
if len(ctx.params) != 3 {
|
||||||
|
return errors.New("tag many_to_many requires 3 parameters")
|
||||||
|
}
|
||||||
|
pks, refTable, refPks, err := getRefTableAndPks(ctx, ctx.fieldValue.Type().Elem().Elem())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
joinType := reflect.StructOf([]reflect.StructField{
|
||||||
|
{
|
||||||
|
Name: ctx.parser.GetColumnMapper().Table2Obj(ctx.params[1]),
|
||||||
|
Type: pks[0].FieldType,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: ctx.parser.GetColumnMapper().Table2Obj(ctx.params[2]),
|
||||||
|
Type: refPks[0].FieldType,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
joinTable, err := ctx.parser.ParseWithCache(reflect.New(joinType).Elem())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
joinTable.Name = ctx.params[0]
|
||||||
|
|
||||||
|
ctx.col.Association = &schemas.Association{
|
||||||
|
OwnTable: ctx.table,
|
||||||
|
OwnColumn: ctx.col,
|
||||||
|
OwnPkType: pks[0].FieldType,
|
||||||
|
RefTable: refTable,
|
||||||
|
RefPkType: refPks[0].FieldType,
|
||||||
|
JoinTable: joinTable,
|
||||||
|
SourceCol: ctx.params[1],
|
||||||
|
TargetCol: ctx.params[2],
|
||||||
|
}
|
||||||
|
ctx.col.Name = ctx.col.FieldName
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue