From d76d6f0aa93a14c10965d8b02aee312fc6ec6ec2 Mon Sep 17 00:00:00 2001 From: Diego Sogari Date: Sat, 1 Apr 2023 13:03:28 -0300 Subject: [PATCH] Add code documentation for association preloading --- engine.go | 2 + preload.go | 109 +++++++++++++++++++++++------------------ schemas/association.go | 15 +++++- session.go | 4 +- session_find.go | 2 +- session_get.go | 2 +- 6 files changed, 81 insertions(+), 53 deletions(-) diff --git a/engine.go b/engine.go index 350f8666..3542e9be 100644 --- a/engine.go +++ b/engine.go @@ -1449,12 +1449,14 @@ func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interf return result, nil } +// Preloads builds the preloads func (engine *Engine) Preloads(preloads ...*Preload) *Session { session := engine.NewSession() session.isAutoClose = true return session.Preloads(preloads...) } +// Preload creates a preload func (engine *Engine) Preload(path string) *Preload { return NewPreload(path) } diff --git a/preload.go b/preload.go index 34653bda..dddc6a35 100644 --- a/preload.go +++ b/preload.go @@ -8,6 +8,7 @@ import ( "xorm.io/xorm/schemas" ) +// Preload is the representation of an association preload type Preload struct { path []string cols []string @@ -15,70 +16,79 @@ type Preload struct { noPrune bool } +// NewPreload creates a new preload with the specified path func NewPreload(path string) *Preload { return &Preload{ - path: strings.Split(path, "."), + path: strings.Split(path, "."), // list of association names composing the path cond: builder.NewCond(), } } +// Cols sets column selection for this preload func (p *Preload) Cols(cols ...string) *Preload { p.cols = append(p.cols, cols...) return p } +// Where sets the where condition for this preload func (p *Preload) Where(cond builder.Cond) *Preload { p.cond = p.cond.And(cond) return p } +// NoPrune sets a flag to avoid pruning empty associations func (p *Preload) NoPrune() *Preload { p.noPrune = true return p } -type PreloadNode struct { +// PreloadTreeNode is a tree node for the association preloads +type PreloadTreeNode struct { preload *Preload - children map[string]*PreloadNode + children map[string]*PreloadTreeNode association *schemas.Association - ExtraCols []string + extraCols []string } -func NewPreloadNode() *PreloadNode { - return &PreloadNode{ - children: make(map[string]*PreloadNode), +// NewPreloadTeeeNode creates a new preload tree node +func NewPreloadTeeeNode() *PreloadTreeNode { + return &PreloadTreeNode{ + children: make(map[string]*PreloadTreeNode), } } -func (pn *PreloadNode) Add(preload *Preload) error { - return pn.add(preload, 0) +// Add adds a node to the preload tree +func (node *PreloadTreeNode) Add(preload *Preload) error { + return node.add(preload, 0) } -func (pn *PreloadNode) add(preload *Preload, index int) error { - if index == len(preload.path) { - if pn.preload != nil { +// add adds a node to the preload tree in a recursion level +func (node *PreloadTreeNode) add(preload *Preload, level int) error { + if level == len(preload.path) { + if node.preload != nil { return fmt.Errorf("preload: duplicated path: %s", strings.Join(preload.path, ",")) } - pn.preload = preload + node.preload = preload return nil } - child, ok := pn.children[preload.path[index]] + child, ok := node.children[preload.path[level]] if !ok { - child = NewPreloadNode() - pn.children[preload.path[index]] = child + child = NewPreloadTeeeNode() + node.children[preload.path[level]] = child } - return child.add(preload, index+1) + return child.add(preload, level+1) } -func (pn *PreloadNode) Validate(table *schemas.Table) error { - if pn.preload != nil { - for _, col := range pn.preload.cols { +// Validate validates a preload tree node against a table schema and sets the association +func (node *PreloadTreeNode) Validate(table *schemas.Table) error { + if node.preload != nil { + for _, col := range node.preload.cols { if table.GetColumn(col) == nil { return fmt.Errorf("preload: missing col %s in table %s", col, table.Name) } } } - for name, node := range pn.children { + for name, child := range node.children { column := table.GetColumn(name) if column == nil { return fmt.Errorf("preload: missing field %s in struct %s", name, table.Type.Name()) @@ -87,24 +97,25 @@ func (pn *PreloadNode) Validate(table *schemas.Table) error { return fmt.Errorf("preload: missing association in field %s", name) } if column.Association.JoinTable == nil && len(column.Association.SourceCol) > 0 { - pn.ExtraCols = append(pn.ExtraCols, column.Association.SourceCol) + node.extraCols = append(node.extraCols, column.Association.SourceCol) } if len(column.Association.TargetCol) > 0 { - pn.ExtraCols = append(pn.ExtraCols, table.PrimaryKeys[0]) // pk might be added many times, but that's ok + node.extraCols = append(node.extraCols, table.PrimaryKeys[0]) // pk might be added many times, but that's ok } if column.Association.JoinTable == nil && len(column.Association.TargetCol) > 0 { - node.ExtraCols = append(node.ExtraCols, column.Association.TargetCol) + child.extraCols = append(child.extraCols, column.Association.TargetCol) } - node.association = column.Association - if err := node.Validate(column.Association.RefTable); err != nil { + child.association = column.Association + if err := child.Validate(column.Association.RefTable); err != nil { return err } } return nil } -func (pn *PreloadNode) Compute(session *Session, ownMap reflect.Value) error { - for _, node := range pn.children { +// Compute preloads the associations contained in the preload tree +func (node *PreloadTreeNode) Compute(session *Session, ownMap reflect.Value) error { + for _, node := range node.children { if err := node.compute(session, ownMap, reflect.Value{}); err != nil { return err } @@ -112,33 +123,34 @@ func (pn *PreloadNode) Compute(session *Session, ownMap reflect.Value) error { return nil } -func (pn *PreloadNode) compute(session *Session, ownMap, pruneMap reflect.Value) error { - // non-root node: pn.association is not nil - if err := pn.association.ValidateOwnMap(ownMap); err != nil { +// compute preloads the association contained in a preload tree node +func (node *PreloadTreeNode) compute(session *Session, ownMap, pruneMap reflect.Value) error { + // non-root node: association is not nil + if err := node.association.ValidateOwnMap(ownMap); err != nil { return err } var joinMap reflect.Value - cond := pn.association.GetCond(ownMap) - if pn.association.JoinTable != nil { + cond := node.association.GetCond(ownMap) + if node.association.JoinTable != nil { var err error - cond, joinMap, err = pn.preloadJoin(session, cond) + cond, joinMap, err = node.preloadJoin(session, cond) if err != nil { return err } } - refMap := pn.association.MakeRefMap() - preloadSession := session.Engine().Cols(pn.ExtraCols...).Where(cond) - if pn.preload != nil { - preloadSession.Cols(pn.preload.cols...).Where(pn.preload.cond) + refMap := node.association.MakeRefMap() + preloadSession := session.Engine().Cols(node.extraCols...).Where(cond) + if node.preload != nil { + preloadSession.Cols(node.preload.cols...).Where(node.preload.cond) } if err := preloadSession.Find(refMap.Interface()); err != nil { return err } var refPruneMap reflect.Value - if len(pn.children) > 0 && !(pn.preload != nil && (len(pn.preload.cols) > 0 || pn.preload.noPrune)) { + if len(node.children) > 0 && !(node.preload != nil && (len(node.preload.cols) > 0 || node.preload.noPrune)) { refPruneMap = reflect.MakeMap(reflect.MapOf(refMap.Type().Key(), reflect.TypeOf(true))) refIter := refMap.MapRange() for refIter.Next() { @@ -146,7 +158,7 @@ func (pn *PreloadNode) compute(session *Session, ownMap, pruneMap reflect.Value) } } - for _, node := range pn.children { + for _, node := range node.children { if err := node.compute(session, refMap, refPruneMap); err != nil { return err } @@ -159,26 +171,27 @@ func (pn *PreloadNode) compute(session *Session, ownMap, pruneMap reflect.Value) } } - pn.association.Link(ownMap, refMap, pruneMap, joinMap) + node.association.Link(ownMap, refMap, pruneMap, joinMap) return nil } -func (pn *PreloadNode) preloadJoin(session *Session, cond builder.Cond) (builder.Cond, reflect.Value, error) { - joinSlicePtr := pn.association.MakeJoinSlice() +// preloadJoin obtains a join condition and a join map for a many-to-many association +func (node *PreloadTreeNode) preloadJoin(session *Session, cond builder.Cond) (builder.Cond, reflect.Value, error) { + joinSlicePtr := node.association.NewJoinSlice() if err := session.Engine(). - Table(pn.association.JoinTable.Name).Where(cond). - Cols(pn.association.SourceCol, pn.association.TargetCol). + Table(node.association.JoinTable.Name).Where(cond). + Cols(node.association.SourceCol, node.association.TargetCol). Find(joinSlicePtr.Interface()); err != nil { return nil, reflect.Value{}, err } joinSlice := joinSlicePtr.Elem() - joinMap := pn.association.MakeJoinMap() + joinMap := node.association.MakeJoinMap() for i := 0; i < joinSlice.Len(); i++ { entry := joinSlice.Index(i) pkSlice := joinMap.MapIndex(entry.Field(1)) if !pkSlice.IsValid() { - pkSlice = reflect.MakeSlice(reflect.SliceOf(pn.association.OwnPkType), 0, 0) + pkSlice = reflect.MakeSlice(reflect.SliceOf(node.association.OwnPkType), 0, 0) } joinMap.SetMapIndex(entry.Field(1), reflect.Append(pkSlice, entry.Field(0))) } @@ -189,6 +202,6 @@ func (pn *PreloadNode) preloadJoin(session *Session, cond builder.Cond) (builder for iter.Next() { refPks = append(refPks, iter.Key().Interface()) } - cond = builder.In(pn.association.RefTable.PrimaryKeys[0], refPks) + cond = builder.In(node.association.RefTable.PrimaryKeys[0], refPks) return cond, joinMap, nil } diff --git a/schemas/association.go b/schemas/association.go index 60cfdcee..56cbca6d 100644 --- a/schemas/association.go +++ b/schemas/association.go @@ -6,6 +6,7 @@ import ( "xorm.io/builder" ) +// Association is the representation of an association type Association struct { OwnTable *Table OwnColumn *Column @@ -17,18 +18,22 @@ type Association struct { TargetCol string // has_one, has_many, many_to_many } -func (a *Association) MakeJoinSlice() reflect.Value { +// NewJoinSlice creates a slice to hold the intermediate result of a many-to-many association query +func (a *Association) NewJoinSlice() reflect.Value { return reflect.New(reflect.SliceOf(a.JoinTable.Type)) } +// MakeJoinMap creates a map to hold the intermediate result of a many-to-many association func (a *Association) MakeJoinMap() reflect.Value { return reflect.MakeMap(reflect.MapOf(a.RefPkType, reflect.SliceOf(a.OwnPkType))) } +// MakeRefMap creates a map to hold the result of an association query func (a *Association) MakeRefMap() reflect.Value { return reflect.MakeMap(reflect.MapOf(a.RefPkType, reflect.PointerTo(a.RefTable.Type))) } +// ValidateOwnMap validates the type of the owner map (parent of an association) func (a *Association) ValidateOwnMap(ownMap reflect.Value) error { if ownMap.Type() != reflect.MapOf(a.OwnPkType, reflect.PointerTo(a.OwnTable.Type)) { return fmt.Errorf("wrong map type: %v", ownMap.Type()) @@ -36,6 +41,7 @@ func (a *Association) ValidateOwnMap(ownMap reflect.Value) error { return nil } +// GetCond gets a where condition to use in an association query func (a *Association) GetCond(ownMap reflect.Value) builder.Cond { if a.JoinTable != nil { return a.condManyToMany(ownMap) @@ -46,6 +52,7 @@ func (a *Association) GetCond(ownMap reflect.Value) builder.Cond { return a.condHasOneOrMany(ownMap) } +// condBelongsTo gets a where condition to use in a belongs-to association query func (a *Association) condBelongsTo(ownMap reflect.Value) builder.Cond { pkMap := make(map[interface{}]bool) fkCol := a.OwnTable.GetColumn(a.SourceCol) @@ -69,6 +76,7 @@ func (a *Association) condBelongsTo(ownMap reflect.Value) builder.Cond { return builder.In(a.RefTable.PrimaryKeys[0], pks) } +// condHasOneOrMany gets a where condition to use in a has-one or has-many association query func (a *Association) condHasOneOrMany(ownMap reflect.Value) builder.Cond { var pks []interface{} iter := ownMap.MapRange() @@ -78,6 +86,7 @@ func (a *Association) condHasOneOrMany(ownMap reflect.Value) builder.Cond { return builder.In(a.TargetCol, pks) } +// condHasOneOrMany gets a where condition to use in a many-to-many association query func (a *Association) condManyToMany(ownMap reflect.Value) builder.Cond { var pks []interface{} iter := ownMap.MapRange() @@ -87,6 +96,7 @@ func (a *Association) condManyToMany(ownMap reflect.Value) builder.Cond { return builder.In(a.SourceCol, pks) } +// Link links the owner (parent) values with the referenced association values func (a *Association) Link(ownMap, refMap, pruneMap, joinMap reflect.Value) { if a.JoinTable != nil { a.linkManyToMany(ownMap, refMap, pruneMap, joinMap) @@ -97,6 +107,7 @@ func (a *Association) Link(ownMap, refMap, pruneMap, joinMap reflect.Value) { } } +// linkBelongsTo links the owner (parent) values with the referenced belongs-to association values func (a *Association) linkBelongsTo(ownMap, refMap, pruneMap reflect.Value) { fkCol := a.OwnTable.GetColumn(a.SourceCol) iter := ownMap.MapRange() @@ -121,6 +132,7 @@ func (a *Association) linkBelongsTo(ownMap, refMap, pruneMap reflect.Value) { } } +// linkBelongsTo links the owner (parent) values with the referenced has-one or has-many association values func (a *Association) linkHasOneOrMany(ownMap, refMap, pruneMap reflect.Value) { hasMany := a.OwnColumn.FieldType.Kind() == reflect.Slice fkCol := a.RefTable.GetColumn(a.TargetCol) @@ -148,6 +160,7 @@ func (a *Association) linkHasOneOrMany(ownMap, refMap, pruneMap reflect.Value) { } } +// linkManyToMany links the owner (parent) values with the referenced many-to-many association values func (a *Association) linkManyToMany(ownMap, refMap, pruneMap, joinMap reflect.Value) { iter := refMap.MapRange() for iter.Next() { diff --git a/session.go b/session.go index 56644bd1..a268de75 100644 --- a/session.go +++ b/session.go @@ -87,7 +87,7 @@ type Session struct { ctx context.Context sessionType sessionType - preloadNode *PreloadNode + preloadNode *PreloadTreeNode } func newSessionID() string { @@ -805,7 +805,7 @@ func (session *Session) NoVersionCheck() *Session { // Preloads adds preloads func (session *Session) Preloads(preloads ...*Preload) *Session { if session.preloadNode == nil { - session.preloadNode = NewPreloadNode() + session.preloadNode = NewPreloadTeeeNode() } for _, preload := range preloads { if err := session.preloadNode.Add(preload); err != nil { diff --git a/session_find.go b/session_find.go index 5ef68e1a..c00832ab 100644 --- a/session_find.go +++ b/session_find.go @@ -153,7 +153,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } // we need the columns required for the preloads if !session.statement.ColumnMap.IsEmpty() { - for _, k := range session.preloadNode.ExtraCols { + for _, k := range session.preloadNode.extraCols { session.statement.ColumnMap.Add(k) } } diff --git a/session_get.go b/session_get.go index 7f257c4f..80f62098 100644 --- a/session_get.go +++ b/session_get.go @@ -87,7 +87,7 @@ func (session *Session) get(beans ...interface{}) (bool, error) { } // we need the columns required for the preloads if !session.statement.ColumnMap.IsEmpty() { - for _, k := range session.preloadNode.ExtraCols { + for _, k := range session.preloadNode.extraCols { session.statement.ColumnMap.Add(k) } }