From 79554f640b2bccbd85eb3c41e1b1dd45c7915497 Mon Sep 17 00:00:00 2001 From: Diego Sogari Date: Sat, 1 Apr 2023 12:14:53 -0300 Subject: [PATCH 1/5] Add support for association preloading --- engine.go | 10 ++ preload.go | 194 +++++++++++++++++++++++++++++ preload_test.go | 276 +++++++++++++++++++++++++++++++++++++++++ schemas/association.go | 167 +++++++++++++++++++++++++ schemas/column.go | 8 +- session.go | 15 +++ session_find.go | 39 +++++- session_get.go | 31 ++++- tags/parser.go | 4 +- tags/tag.go | 198 +++++++++++++++++++++++++---- 10 files changed, 912 insertions(+), 30 deletions(-) create mode 100644 preload.go create mode 100644 preload_test.go create mode 100644 schemas/association.go diff --git a/engine.go b/engine.go index 389819e7..350f8666 100644 --- a/engine.go +++ b/engine.go @@ -1448,3 +1448,13 @@ func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interf return result, nil } + +func (engine *Engine) Preloads(preloads ...*Preload) *Session { + session := engine.NewSession() + session.isAutoClose = true + return session.Preloads(preloads...) +} + +func (engine *Engine) Preload(path string) *Preload { + return NewPreload(path) +} diff --git a/preload.go b/preload.go new file mode 100644 index 00000000..34653bda --- /dev/null +++ b/preload.go @@ -0,0 +1,194 @@ +package xorm + +import ( + "fmt" + "reflect" + "strings" + "xorm.io/builder" + "xorm.io/xorm/schemas" +) + +type Preload struct { + path []string + cols []string + cond builder.Cond + noPrune bool +} + +func NewPreload(path string) *Preload { + return &Preload{ + path: strings.Split(path, "."), + cond: builder.NewCond(), + } +} + +func (p *Preload) Cols(cols ...string) *Preload { + p.cols = append(p.cols, cols...) + return p +} + +func (p *Preload) Where(cond builder.Cond) *Preload { + p.cond = p.cond.And(cond) + return p +} + +func (p *Preload) NoPrune() *Preload { + p.noPrune = true + return p +} + +type PreloadNode struct { + preload *Preload + children map[string]*PreloadNode + association *schemas.Association + ExtraCols []string +} + +func NewPreloadNode() *PreloadNode { + return &PreloadNode{ + children: make(map[string]*PreloadNode), + } +} + +func (pn *PreloadNode) Add(preload *Preload) error { + return pn.add(preload, 0) +} + +func (pn *PreloadNode) add(preload *Preload, index int) error { + if index == len(preload.path) { + if pn.preload != nil { + return fmt.Errorf("preload: duplicated path: %s", strings.Join(preload.path, ",")) + } + pn.preload = preload + return nil + } + child, ok := pn.children[preload.path[index]] + if !ok { + child = NewPreloadNode() + pn.children[preload.path[index]] = child + } + return child.add(preload, index+1) +} + +func (pn *PreloadNode) Validate(table *schemas.Table) error { + if pn.preload != nil { + for _, col := range pn.preload.cols { + if table.GetColumn(col) == nil { + return fmt.Errorf("preload: missing col %s in table %s", col, table.Name) + } + } + } + for name, node := range pn.children { + column := table.GetColumn(name) + if column == nil { + return fmt.Errorf("preload: missing field %s in struct %s", name, table.Type.Name()) + } + if column.Association == nil { + return fmt.Errorf("preload: missing association in field %s", name) + } + if column.Association.JoinTable == nil && len(column.Association.SourceCol) > 0 { + pn.ExtraCols = append(pn.ExtraCols, column.Association.SourceCol) + } + if len(column.Association.TargetCol) > 0 { + pn.ExtraCols = append(pn.ExtraCols, table.PrimaryKeys[0]) // pk might be added many times, but that's ok + } + if column.Association.JoinTable == nil && len(column.Association.TargetCol) > 0 { + node.ExtraCols = append(node.ExtraCols, column.Association.TargetCol) + } + node.association = column.Association + if err := node.Validate(column.Association.RefTable); err != nil { + return err + } + } + return nil +} + +func (pn *PreloadNode) Compute(session *Session, ownMap reflect.Value) error { + for _, node := range pn.children { + if err := node.compute(session, ownMap, reflect.Value{}); err != nil { + return err + } + } + return nil +} + +func (pn *PreloadNode) compute(session *Session, ownMap, pruneMap reflect.Value) error { + // non-root node: pn.association is not nil + if err := pn.association.ValidateOwnMap(ownMap); err != nil { + return err + } + + var joinMap reflect.Value + cond := pn.association.GetCond(ownMap) + if pn.association.JoinTable != nil { + var err error + cond, joinMap, err = pn.preloadJoin(session, cond) + if err != nil { + return err + } + } + + refMap := pn.association.MakeRefMap() + preloadSession := session.Engine().Cols(pn.ExtraCols...).Where(cond) + if pn.preload != nil { + preloadSession.Cols(pn.preload.cols...).Where(pn.preload.cond) + } + if err := preloadSession.Find(refMap.Interface()); err != nil { + return err + } + + var refPruneMap reflect.Value + if len(pn.children) > 0 && !(pn.preload != nil && (len(pn.preload.cols) > 0 || pn.preload.noPrune)) { + refPruneMap = reflect.MakeMap(reflect.MapOf(refMap.Type().Key(), reflect.TypeOf(true))) + refIter := refMap.MapRange() + for refIter.Next() { + refPruneMap.SetMapIndex(refIter.Key(), reflect.ValueOf(true)) + } + } + + for _, node := range pn.children { + if err := node.compute(session, refMap, refPruneMap); err != nil { + return err + } + } + + if refPruneMap.IsValid() { + pruneIter := refPruneMap.MapRange() + for pruneIter.Next() { + refMap.SetMapIndex(pruneIter.Key(), reflect.Value{}) + } + } + + pn.association.Link(ownMap, refMap, pruneMap, joinMap) + return nil +} + +func (pn *PreloadNode) preloadJoin(session *Session, cond builder.Cond) (builder.Cond, reflect.Value, error) { + joinSlicePtr := pn.association.MakeJoinSlice() + if err := session.Engine(). + Table(pn.association.JoinTable.Name).Where(cond). + Cols(pn.association.SourceCol, pn.association.TargetCol). + Find(joinSlicePtr.Interface()); err != nil { + return nil, reflect.Value{}, err + } + joinSlice := joinSlicePtr.Elem() + + joinMap := pn.association.MakeJoinMap() + for i := 0; i < joinSlice.Len(); i++ { + entry := joinSlice.Index(i) + pkSlice := joinMap.MapIndex(entry.Field(1)) + if !pkSlice.IsValid() { + pkSlice = reflect.MakeSlice(reflect.SliceOf(pn.association.OwnPkType), 0, 0) + } + joinMap.SetMapIndex(entry.Field(1), reflect.Append(pkSlice, entry.Field(0))) + } + + var refPks []interface{} + iter := joinMap.MapRange() + joinMap.MapKeys() + for iter.Next() { + refPks = append(refPks, iter.Key().Interface()) + } + cond = builder.In(pn.association.RefTable.PrimaryKeys[0], refPks) + return cond, joinMap, nil +} diff --git a/preload_test.go b/preload_test.go new file mode 100644 index 00000000..9d87ad72 --- /dev/null +++ b/preload_test.go @@ -0,0 +1,276 @@ +package xorm + +import ( + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "sort" + "testing" + "xorm.io/builder" +) + +// https://gitea.com/xorm/xorm/issues/2240 + +type Employee struct { + Id int64 + Name string + BuddyId *int64 + ManagerId *int64 + Buddy *Employee `xorm:"belongs_to(buddy_id)"` + Apprentice *Employee `xorm:"has_one(buddy_id)"` + Manager *Employee `xorm:"belongs_to(manager_id)"` + Subordinates []*Employee `xorm:"has_many(manager_id)"` + Indications []*Employee `xorm:"many_to_many(employee_indication, indicator_id, indicated_id)"` + IndicatedBy []*Employee `xorm:"many_to_many(employee_indication, indicated_id, indicator_id)"` +} + +func TestPreload(t *testing.T) { + engine, err := NewEngine("sqlite3", ":memory:") + require.NoError(t, err) + + const sql = ` +create table employee ( + id integer primary key autoincrement, + name text not null, + buddy_id integer references employee(id) check (buddy_id <> id) unique, + manager_id integer references employee(id) check (manager_id <> id) +); + +create table employee_indication ( + indicator_id integer not null references employee(id), + indicated_id integer not null references employee(id), + primary key (indicator_id, indicated_id), + check (indicator_id <> indicated_id) +); + +insert into employee (name) values ('John'), ('Bob'); +insert into employee (name,manager_id) values ('Alice',1), ('Riya',2); +insert into employee (name,manager_id,buddy_id) values ('Emilie',1,3), ('Cynthia',2,4); +insert into employee_indication values (1,2), (1,3), (2,3), (2,4), (2,5), (3,5), (3,6); + +-- John manages Alice and Emilie +-- Bob manages Riya and Cynthia +-- Alice is buddy of Emilie +-- Riya is buddy of Cynthia +-- John indicated Bob and Alice +-- Bob indicated Alice, Riya and Emilie +-- Alice indicated Emilie and Cynthia +` + _, err = engine.Exec(sql) + require.NoError(t, err) + + var employee Employee + _, err = engine.Preloads( + engine.Preload("Indications.Buddy").Cols("name"), + engine.Preload("Indications").NoPrune(), + ).Cols("name").Where(builder.Eq{"id": 2}).Get(&employee) + require.NoError(t, err) + + // order is not preserved when preloading, so we compensate for it in the test + sort.Slice(employee.Indications, func(i, j int) bool { + return employee.Indications[i].Id < employee.Indications[j].Id + }) + + assert.Equal(t, Employee{ + Id: 2, + Name: "Bob", + Indications: []*Employee{ + {Id: 3}, + {Id: 4}, + { + Id: 5, + BuddyId: &[]int64{3}[0], + Buddy: &Employee{ + Id: 3, + Name: "Alice", + }, + }, + }, + }, employee) + + var employees []*Employee + err = engine.Preloads( + // 1. preload the names of all subordinates of this employee's manager + engine.Preload("Manager.Subordinates").Cols("name"), + // 2. preload the name of the buddy of each employee indicated by this employee + engine.Preload("Indications.Buddy").Cols("name"), + // 3. preload the names of all employees who indicated this employee's apprentice, except non-subordinates + engine.Preload("Apprentice.IndicatedBy").Cols("name").Where(builder.NotNull{"manager_id"}), + // 4. preload the names of: + // all employees who don't have a maanger and were indicated by: + engine.Preload("Subordinates.IndicatedBy.Indications").Where(builder.IsNull{"manager_id"}).Cols("name"), + // employees whose name is 4 letters long and indicated: + engine.Preload("Subordinates.IndicatedBy").Where(builder.Like{"name", "____"}), + // this employee's subordinates who don't have a buddy + engine.Preload("Subordinates").Where(builder.IsNull{"buddy_id"}), + // 0. find the names of all employees + ).Cols("name").Find(&employees) + require.NoError(t, err) + + // order is not preserved when preloading, so we compensate for it in the test + for k := 2; k < 6; k++ { + sort.Slice(employees[k].Manager.Subordinates, func(i, j int) bool { + return employees[k].Manager.Subordinates[i].Id < employees[k].Manager.Subordinates[j].Id + }) + } + sort.Slice(employees[2].Indications, func(i, j int) bool { + return employees[2].Indications[i].Id < employees[2].Indications[j].Id + }) + + expected := []*Employee{ + { + Id: 1, + Name: "John", + Subordinates: []*Employee{ + { + Id: 3, + ManagerId: &[]int64{1}[0], + IndicatedBy: []*Employee{ + { + Id: 1, + Indications: []*Employee{ + { + Id: 2, + Name: "Bob", + }, + }, + }, + }, + }, + }, + }, + { + Id: 2, + Name: "Bob", + Indications: []*Employee{ + { + Id: 5, + BuddyId: &[]int64{3}[0], + Buddy: &Employee{ + Id: 3, + Name: "Alice", + }, + }, + }, + }, + { + Id: 3, + Name: "Alice", + ManagerId: &[]int64{1}[0], + Manager: &Employee{ + Id: 1, + Subordinates: []*Employee{ + { + Id: 3, + Name: "Alice", + ManagerId: &[]int64{1}[0], + }, + { + Id: 5, + Name: "Emilie", + ManagerId: &[]int64{1}[0], + }, + }, + }, + Apprentice: &Employee{ + Id: 5, + BuddyId: &[]int64{3}[0], + IndicatedBy: []*Employee{ + { + Id: 3, + Name: "Alice", + }, + }, + }, + Indications: []*Employee{ + { + Id: 5, + BuddyId: &[]int64{3}[0], + Buddy: &Employee{ + Id: 3, + Name: "Alice", + }, + }, + { + Id: 6, + BuddyId: &[]int64{4}[0], + Buddy: &Employee{ + Id: 4, + Name: "Riya", + }, + }, + }, + }, + { + Id: 4, + Name: "Riya", + ManagerId: &[]int64{2}[0], + Manager: &Employee{ + Id: 2, + Subordinates: []*Employee{ + { + Id: 4, + Name: "Riya", + ManagerId: &[]int64{2}[0], + }, + { + Id: 6, + Name: "Cynthia", + ManagerId: &[]int64{2}[0], + }, + }, + }, + Apprentice: &Employee{ + Id: 6, + BuddyId: &[]int64{4}[0], + IndicatedBy: []*Employee{ + { + Id: 3, + Name: "Alice", + }, + }, + }, + }, + { + Id: 5, + Name: "Emilie", + ManagerId: &[]int64{1}[0], + Manager: &Employee{ + Id: 1, + Subordinates: []*Employee{ + { + Id: 3, + Name: "Alice", + ManagerId: &[]int64{1}[0], + }, + { + Id: 5, + Name: "Emilie", + ManagerId: &[]int64{1}[0], + }, + }, + }, + }, + { + Id: 6, + Name: "Cynthia", + ManagerId: &[]int64{2}[0], + Manager: &Employee{ + Id: 2, + Subordinates: []*Employee{ + { + Id: 4, + Name: "Riya", + ManagerId: &[]int64{2}[0], + }, + { + Id: 6, + Name: "Cynthia", + ManagerId: &[]int64{2}[0], + }, + }, + }, + }, + } + assert.Equal(t, expected, employees) +} diff --git a/schemas/association.go b/schemas/association.go new file mode 100644 index 00000000..60cfdcee --- /dev/null +++ b/schemas/association.go @@ -0,0 +1,167 @@ +package schemas + +import ( + "fmt" + "reflect" + "xorm.io/builder" +) + +type Association struct { + OwnTable *Table + OwnColumn *Column + OwnPkType reflect.Type + RefTable *Table + RefPkType reflect.Type + JoinTable *Table // many_to_many + SourceCol string // belongs_to, many_to_many + TargetCol string // has_one, has_many, many_to_many +} + +func (a *Association) MakeJoinSlice() reflect.Value { + return reflect.New(reflect.SliceOf(a.JoinTable.Type)) +} + +func (a *Association) MakeJoinMap() reflect.Value { + return reflect.MakeMap(reflect.MapOf(a.RefPkType, reflect.SliceOf(a.OwnPkType))) +} + +func (a *Association) MakeRefMap() reflect.Value { + return reflect.MakeMap(reflect.MapOf(a.RefPkType, reflect.PointerTo(a.RefTable.Type))) +} + +func (a *Association) ValidateOwnMap(ownMap reflect.Value) error { + if ownMap.Type() != reflect.MapOf(a.OwnPkType, reflect.PointerTo(a.OwnTable.Type)) { + return fmt.Errorf("wrong map type: %v", ownMap.Type()) + } + return nil +} + +func (a *Association) GetCond(ownMap reflect.Value) builder.Cond { + if a.JoinTable != nil { + return a.condManyToMany(ownMap) + } + if len(a.SourceCol) > 0 { + return a.condBelongsTo(ownMap) + } + return a.condHasOneOrMany(ownMap) +} + +func (a *Association) condBelongsTo(ownMap reflect.Value) builder.Cond { + pkMap := make(map[interface{}]bool) + fkCol := a.OwnTable.GetColumn(a.SourceCol) + iter := ownMap.MapRange() + for iter.Next() { + structPtr := iter.Value() + fk, _ := fkCol.ValueOfV(&structPtr) + if fk.Type().Kind() == reflect.Pointer { + if fk.IsNil() { + continue + } + *fk = fk.Elem() + } + // don't check fk.IsZero(), because it might be a valid fk value + pkMap[fk.Interface()] = true + } + pks := make([]interface{}, 0, len(pkMap)) + for pk := range pkMap { + pks = append(pks, pk) + } + return builder.In(a.RefTable.PrimaryKeys[0], pks) +} + +func (a *Association) condHasOneOrMany(ownMap reflect.Value) builder.Cond { + var pks []interface{} + iter := ownMap.MapRange() + for iter.Next() { + pks = append(pks, iter.Key().Interface()) + } + return builder.In(a.TargetCol, pks) +} + +func (a *Association) condManyToMany(ownMap reflect.Value) builder.Cond { + var pks []interface{} + iter := ownMap.MapRange() + for iter.Next() { + pks = append(pks, iter.Key().Interface()) + } + return builder.In(a.SourceCol, pks) +} + +func (a *Association) Link(ownMap, refMap, pruneMap, joinMap reflect.Value) { + if a.JoinTable != nil { + a.linkManyToMany(ownMap, refMap, pruneMap, joinMap) + } else if len(a.SourceCol) > 0 { + a.linkBelongsTo(ownMap, refMap, pruneMap) + } else { + a.linkHasOneOrMany(ownMap, refMap, pruneMap) + } +} + +func (a *Association) linkBelongsTo(ownMap, refMap, pruneMap reflect.Value) { + fkCol := a.OwnTable.GetColumn(a.SourceCol) + iter := ownMap.MapRange() + for iter.Next() { + structPtr := iter.Value() + fk, _ := fkCol.ValueOfV(&structPtr) + if fk.Type().Kind() == reflect.Pointer { + if fk.IsNil() { + continue + } + *fk = fk.Elem() + } + // don't check fk.IsZero(), because it might be a valid fk value + refStructPtr := refMap.MapIndex(*fk) + if refStructPtr.IsValid() { + refField, _ := a.OwnColumn.ValueOfV(&structPtr) + refField.Set(refStructPtr) + if pruneMap.IsValid() { + pruneMap.SetMapIndex(iter.Key(), reflect.Value{}) // do not prune this key + } + } + } +} + +func (a *Association) linkHasOneOrMany(ownMap, refMap, pruneMap reflect.Value) { + hasMany := a.OwnColumn.FieldType.Kind() == reflect.Slice + fkCol := a.RefTable.GetColumn(a.TargetCol) + iter := refMap.MapRange() + for iter.Next() { + refStructPtr := iter.Value() + fk, _ := fkCol.ValueOfV(&refStructPtr) + if fk.Type().Kind() == reflect.Pointer { + if fk.IsNil() { + continue + } + *fk = fk.Elem() + } + // don't check fk.IsZero(), because it might be a valid fk value + structPtr := ownMap.MapIndex(*fk) // structPtr should be valid at this point + refField, _ := a.OwnColumn.ValueOfV(&structPtr) + if hasMany { + refField.Set(reflect.Append(*refField, refStructPtr)) + } else { + refField.Set(refStructPtr) + } + if pruneMap.IsValid() { + pruneMap.SetMapIndex(*fk, reflect.Value{}) // do not prune this key + } + } +} + +func (a *Association) linkManyToMany(ownMap, refMap, pruneMap, joinMap reflect.Value) { + iter := refMap.MapRange() + for iter.Next() { + refStructPtr := iter.Value() + refPk := iter.Key() // refPk should not be a pointer + pkSlice := joinMap.MapIndex(refPk) // pkSlice should be valid at this point + for i := 0; i < pkSlice.Len(); i++ { + pk := pkSlice.Index(i) + structPtr := ownMap.MapIndex(pk) // structPtr should be valid at this point + refField, _ := a.OwnColumn.ValueOfV(&structPtr) + refField.Set(reflect.Append(*refField, refStructPtr)) + if pruneMap.IsValid() { + pruneMap.SetMapIndex(pk, reflect.Value{}) // do not prune this key + } + } + } +} diff --git a/schemas/column.go b/schemas/column.go index 001769cd..25def641 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -21,9 +21,9 @@ const ( // Column defines database column type Column struct { Name string - TableName string - FieldName string // Available only when parsed from a struct - FieldIndex []int // Available only when parsed from a struct + FieldName string // Available only when parsed from a struct + FieldIndex []int // Available only when parsed from a struct + FieldType reflect.Type // Available only when parsed from a struct SQLType SQLType IsJSON bool Length int64 @@ -45,6 +45,7 @@ type Column struct { DisableTimeZone bool TimeZone *time.Location // column specified time zone Comment string + Association *Association } // NewColumn creates a new column @@ -52,7 +53,6 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int64, nullab return &Column{ Name: name, IsJSON: sqlType.IsJson(), - TableName: "", FieldName: fieldName, SQLType: sqlType, Length: len1, diff --git a/session.go b/session.go index e1a16e5b..56644bd1 100644 --- a/session.go +++ b/session.go @@ -87,6 +87,7 @@ type Session struct { ctx context.Context sessionType sessionType + preloadNode *PreloadNode } func newSessionID() string { @@ -800,3 +801,17 @@ func (session *Session) NoVersionCheck() *Session { session.statement.CheckVersion = false return session } + +// Preloads adds preloads +func (session *Session) Preloads(preloads ...*Preload) *Session { + if session.preloadNode == nil { + session.preloadNode = NewPreloadNode() + } + for _, preload := range preloads { + if err := session.preloadNode.Add(preload); err != nil { + session.statement.LastError = err + break + } + } + return session +} diff --git a/session_find.go b/session_find.go index 2270454b..5ef68e1a 100644 --- a/session_find.go +++ b/session_find.go @@ -91,6 +91,12 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) sliceElementType := sliceValue.Type().Elem() + if session.preloadNode != nil { + if sliceElementType.Kind() != reflect.Ptr || sliceElementType.Elem().Kind() != reflect.Struct { + return errors.New("preloading requires a pointer to a slice or a map of struct pointer") + } + } + tp := tpStruct if session.statement.RefTable == nil { if sliceElementType.Kind() == reflect.Ptr { @@ -141,6 +147,18 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } } + if session.preloadNode != nil { + if err := session.preloadNode.Validate(table); err != nil { + return err + } + // we need the columns required for the preloads + if !session.statement.ColumnMap.IsEmpty() { + for _, k := range session.preloadNode.ExtraCols { + session.statement.ColumnMap.Add(k) + } + } + } + sqlStr, args, err := session.statement.GenFindSQL(autoCond) if err != nil { return err @@ -158,7 +176,26 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } } - return session.noCacheFind(table, sliceValue, sqlStr, args...) + err = session.noCacheFind(table, sliceValue, sqlStr, args...) + if err != nil { + return err + } + + if session.preloadNode != nil { + if isSlice { + // convert to a map before preloading + originalSliceValue := sliceValue + pkColumn := table.PKColumns()[0] + sliceValue = reflect.MakeMap(reflect.MapOf(pkColumn.FieldType, sliceElementType)) + for i := 0; i < originalSliceValue.Len(); i++ { + element := originalSliceValue.Index(i) + pkValue, _ := pkColumn.ValueOfV(&element) + sliceValue.SetMapIndex(*pkValue, element) + } + } + return session.preloadNode.Compute(session, sliceValue) + } + return nil } func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { diff --git a/session_get.go b/session_get.go index 9bb92a8b..7f257c4f 100644 --- a/session_get.go +++ b/session_get.go @@ -71,12 +71,28 @@ func (session *Session) get(beans ...interface{}) (bool, error) { if err := session.statement.SetRefBean(beans[0]); err != nil { return false, err } + } else if session.preloadNode != nil { + return false, errors.New("preloading requires a pointer to struct") } var sqlStr string var args []interface{} var err error + table := session.statement.RefTable + + if session.preloadNode != nil { + if err := session.preloadNode.Validate(table); err != nil { + return false, err + } + // we need the columns required for the preloads + if !session.statement.ColumnMap.IsEmpty() { + for _, k := range session.preloadNode.ExtraCols { + session.statement.ColumnMap.Add(k) + } + } + } + if session.statement.RawSQL == "" { if len(session.statement.TableName()) == 0 { return false, ErrTableNotFound @@ -91,8 +107,6 @@ func (session *Session) get(beans ...interface{}) (bool, error) { args = session.statement.RawParams } - table := session.statement.RefTable - if session.statement.ColumnMap.IsEmpty() && session.canCache() && isStruct { if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && !session.statement.GetUnscoped() { @@ -122,6 +136,19 @@ func (session *Session) get(beans ...interface{}) (bool, error) { return has, err } + if session.preloadNode != nil { + // convert to a map before preloading + pkColumn := table.PKColumns()[0] + sliceValue := reflect.MakeMap(reflect.MapOf(pkColumn.FieldType, beanValue.Type())) + dataStruct := beanValue.Elem() + pkValue, _ := pkColumn.ValueOfV(&dataStruct) + sliceValue.SetMapIndex(*pkValue, beanValue) + + if err := session.preloadNode.Compute(session, sliceValue); err != nil { + return false, err + } + } + if context != nil && isStruct { context.Put(fmt.Sprintf("%v-%v", sqlStr, args), beans[0]) } diff --git a/tags/parser.go b/tags/parser.go index 028f8d0b..0176ab65 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -84,7 +84,7 @@ func (parser *Parser) SetIdentifier(identifier string) { // ParseWithCache parse a struct with cache func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) { - t := v.Type() + t := reflect.Indirect(v).Type() tableI, ok := parser.tableCache.Load(t) if ok { return tableI.(*schemas.Table), nil @@ -172,6 +172,7 @@ func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructFi field.Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true) col.FieldIndex = []int{fieldIndex} + col.FieldType = field.Type if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { col.IsAutoIncrement = true @@ -185,6 +186,7 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, f col := &schemas.Column{ FieldName: field.Name, FieldIndex: []int{fieldIndex}, + FieldType: field.Type, Nullable: true, IsPrimaryKey: false, IsAutoIncrement: false, diff --git a/tags/tag.go b/tags/tag.go index 41d525e1..67a0aa74 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -5,6 +5,7 @@ package tags import ( + "errors" "fmt" "reflect" "strconv" @@ -101,28 +102,32 @@ type Handler func(ctx *Context) error // defaultTagHandlers enumerates all the default tag handler var defaultTagHandlers = map[string]Handler{ - "-": IgnoreHandler, - "<-": OnlyFromDBTagHandler, - "->": OnlyToDBTagHandler, - "PK": PKTagHandler, - "NULL": NULLTagHandler, - "NOT": NotTagHandler, - "AUTOINCR": AutoIncrTagHandler, - "DEFAULT": DefaultTagHandler, - "CREATED": CreatedTagHandler, - "UPDATED": UpdatedTagHandler, - "DELETED": DeletedTagHandler, - "VERSION": VersionTagHandler, - "UTC": UTCTagHandler, - "LOCAL": LocalTagHandler, - "NOTNULL": NotNullTagHandler, - "INDEX": IndexTagHandler, - "UNIQUE": UniqueTagHandler, - "CACHE": CacheTagHandler, - "NOCACHE": NoCacheTagHandler, - "COMMENT": CommentTagHandler, - "EXTENDS": ExtendsTagHandler, - "UNSIGNED": UnsignedTagHandler, + "-": IgnoreHandler, + "<-": OnlyFromDBTagHandler, + "->": OnlyToDBTagHandler, + "PK": PKTagHandler, + "NULL": NULLTagHandler, + "NOT": NotTagHandler, + "AUTOINCR": AutoIncrTagHandler, + "DEFAULT": DefaultTagHandler, + "CREATED": CreatedTagHandler, + "UPDATED": UpdatedTagHandler, + "DELETED": DeletedTagHandler, + "VERSION": VersionTagHandler, + "UTC": UTCTagHandler, + "LOCAL": LocalTagHandler, + "NOTNULL": NotNullTagHandler, + "INDEX": IndexTagHandler, + "UNIQUE": UniqueTagHandler, + "CACHE": CacheTagHandler, + "NOCACHE": NoCacheTagHandler, + "COMMENT": CommentTagHandler, + "EXTENDS": ExtendsTagHandler, + "UNSIGNED": UnsignedTagHandler, + "BELONGS_TO": BelongsToTagHandler, + "HAS_ONE": HasOneTagHandler, + "HAS_MANY": HasManyTagHandler, + "MANY_TO_MANY": ManyToManyTagHandler, } func init() { @@ -396,3 +401,152 @@ func NoCacheTagHandler(ctx *Context) error { } return nil } + +func isStructPtr(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + +func isStructPtrSlice(t reflect.Type) bool { + return t.Kind() == reflect.Slice && isStructPtr(t.Elem()) +} + +func getRefTableAndPks(ctx *Context, refType reflect.Type) ([]*schemas.Column, *schemas.Table, []*schemas.Column, error) { + pks := ctx.table.PKColumns() + if len(pks) != 1 { + return nil, nil, nil, errors.New("a single-column primary key must be declared before the association") + } + + refTable := ctx.table // self-reference case + refPks := pks + if refTable.Type != refType { + var err error + refTable, err = ctx.parser.ParseWithCache(reflect.New(refType).Elem()) + if err != nil { + return nil, nil, nil, err + } + + refPks = refTable.PKColumns() + if len(refPks) != 1 { + return nil, nil, nil, errors.New("only single-column primary key is supported in associations") + } + } + return pks, refTable, refPks, nil +} + +// BelongsToTagHandler describes belongs_to tag handler +func BelongsToTagHandler(ctx *Context) error { + if !isStructPtr(ctx.fieldValue.Type()) { + return errors.New("tag belongs_to can only be applied to struct pointer field") + } + if len(ctx.params) != 1 { + return errors.New("tag belongs_to requires one parameter") + } + pks, refTable, refPks, err := getRefTableAndPks(ctx, ctx.fieldValue.Type().Elem()) + if err != nil { + return err + } + + ctx.col.Association = &schemas.Association{ + OwnTable: ctx.table, + OwnColumn: ctx.col, + OwnPkType: pks[0].FieldType, + RefTable: refTable, + RefPkType: refPks[0].FieldType, + SourceCol: ctx.params[0], + } + ctx.col.Name = ctx.col.FieldName + return nil +} + +// HasOneTagHandler describes has_one tag handler +func HasOneTagHandler(ctx *Context) error { + if !isStructPtr(ctx.fieldValue.Type()) { + return errors.New("tag has_one can only be applied to struct pointer field") + } + if len(ctx.params) != 1 { + return errors.New("tag has_one requires one parameter") + } + pks, refTable, refPks, err := getRefTableAndPks(ctx, ctx.fieldValue.Type().Elem()) + if err != nil { + return err + } + + ctx.col.Association = &schemas.Association{ + OwnTable: ctx.table, + OwnColumn: ctx.col, + OwnPkType: pks[0].FieldType, + RefTable: refTable, + RefPkType: refPks[0].FieldType, + TargetCol: ctx.params[0], + } + ctx.col.Name = ctx.col.FieldName + return nil +} + +// HasManyTagHandler describes has_many tag handler +func HasManyTagHandler(ctx *Context) error { + if !isStructPtrSlice(ctx.fieldValue.Type()) { + return errors.New("tag has_many can only be applied to slice of struct pointer field") + } + if len(ctx.params) != 1 { + return errors.New("tag has_many requires one parameter") + } + pks, refTable, refPks, err := getRefTableAndPks(ctx, ctx.fieldValue.Type().Elem().Elem()) + if err != nil { + return err + } + + ctx.col.Association = &schemas.Association{ + OwnTable: ctx.table, + OwnColumn: ctx.col, + OwnPkType: pks[0].FieldType, + RefTable: refTable, + RefPkType: refPks[0].FieldType, + TargetCol: ctx.params[0], + } + ctx.col.Name = ctx.col.FieldName + return nil +} + +// ManyToManyTagHandler describes many_to_many tag handler +func ManyToManyTagHandler(ctx *Context) error { + if !isStructPtrSlice(ctx.fieldValue.Type()) { + return errors.New("tag many_to_many can only be applied to slice of struct pointer field") + } + if len(ctx.params) != 3 { + return errors.New("tag many_to_many requires 3 parameters") + } + pks, refTable, refPks, err := getRefTableAndPks(ctx, ctx.fieldValue.Type().Elem().Elem()) + if err != nil { + return err + } + + joinType := reflect.StructOf([]reflect.StructField{ + { + Name: ctx.parser.GetColumnMapper().Table2Obj(ctx.params[1]), + Type: pks[0].FieldType, + }, + { + Name: ctx.parser.GetColumnMapper().Table2Obj(ctx.params[2]), + Type: refPks[0].FieldType, + }, + }) + joinTable, err := ctx.parser.ParseWithCache(reflect.New(joinType).Elem()) + if err != nil { + return err + } + joinTable.Name = ctx.params[0] + + ctx.col.Association = &schemas.Association{ + OwnTable: ctx.table, + OwnColumn: ctx.col, + OwnPkType: pks[0].FieldType, + RefTable: refTable, + RefPkType: refPks[0].FieldType, + JoinTable: joinTable, + SourceCol: ctx.params[1], + TargetCol: ctx.params[2], + } + ctx.col.Name = ctx.col.FieldName + return nil +} From d76d6f0aa93a14c10965d8b02aee312fc6ec6ec2 Mon Sep 17 00:00:00 2001 From: Diego Sogari Date: Sat, 1 Apr 2023 13:03:28 -0300 Subject: [PATCH 2/5] Add code documentation for association preloading --- engine.go | 2 + preload.go | 109 +++++++++++++++++++++++------------------ schemas/association.go | 15 +++++- session.go | 4 +- session_find.go | 2 +- session_get.go | 2 +- 6 files changed, 81 insertions(+), 53 deletions(-) diff --git a/engine.go b/engine.go index 350f8666..3542e9be 100644 --- a/engine.go +++ b/engine.go @@ -1449,12 +1449,14 @@ func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interf return result, nil } +// Preloads builds the preloads func (engine *Engine) Preloads(preloads ...*Preload) *Session { session := engine.NewSession() session.isAutoClose = true return session.Preloads(preloads...) } +// Preload creates a preload func (engine *Engine) Preload(path string) *Preload { return NewPreload(path) } diff --git a/preload.go b/preload.go index 34653bda..dddc6a35 100644 --- a/preload.go +++ b/preload.go @@ -8,6 +8,7 @@ import ( "xorm.io/xorm/schemas" ) +// Preload is the representation of an association preload type Preload struct { path []string cols []string @@ -15,70 +16,79 @@ type Preload struct { noPrune bool } +// NewPreload creates a new preload with the specified path func NewPreload(path string) *Preload { return &Preload{ - path: strings.Split(path, "."), + path: strings.Split(path, "."), // list of association names composing the path cond: builder.NewCond(), } } +// Cols sets column selection for this preload func (p *Preload) Cols(cols ...string) *Preload { p.cols = append(p.cols, cols...) return p } +// Where sets the where condition for this preload func (p *Preload) Where(cond builder.Cond) *Preload { p.cond = p.cond.And(cond) return p } +// NoPrune sets a flag to avoid pruning empty associations func (p *Preload) NoPrune() *Preload { p.noPrune = true return p } -type PreloadNode struct { +// PreloadTreeNode is a tree node for the association preloads +type PreloadTreeNode struct { preload *Preload - children map[string]*PreloadNode + children map[string]*PreloadTreeNode association *schemas.Association - ExtraCols []string + extraCols []string } -func NewPreloadNode() *PreloadNode { - return &PreloadNode{ - children: make(map[string]*PreloadNode), +// NewPreloadTeeeNode creates a new preload tree node +func NewPreloadTeeeNode() *PreloadTreeNode { + return &PreloadTreeNode{ + children: make(map[string]*PreloadTreeNode), } } -func (pn *PreloadNode) Add(preload *Preload) error { - return pn.add(preload, 0) +// Add adds a node to the preload tree +func (node *PreloadTreeNode) Add(preload *Preload) error { + return node.add(preload, 0) } -func (pn *PreloadNode) add(preload *Preload, index int) error { - if index == len(preload.path) { - if pn.preload != nil { +// add adds a node to the preload tree in a recursion level +func (node *PreloadTreeNode) add(preload *Preload, level int) error { + if level == len(preload.path) { + if node.preload != nil { return fmt.Errorf("preload: duplicated path: %s", strings.Join(preload.path, ",")) } - pn.preload = preload + node.preload = preload return nil } - child, ok := pn.children[preload.path[index]] + child, ok := node.children[preload.path[level]] if !ok { - child = NewPreloadNode() - pn.children[preload.path[index]] = child + child = NewPreloadTeeeNode() + node.children[preload.path[level]] = child } - return child.add(preload, index+1) + return child.add(preload, level+1) } -func (pn *PreloadNode) Validate(table *schemas.Table) error { - if pn.preload != nil { - for _, col := range pn.preload.cols { +// Validate validates a preload tree node against a table schema and sets the association +func (node *PreloadTreeNode) Validate(table *schemas.Table) error { + if node.preload != nil { + for _, col := range node.preload.cols { if table.GetColumn(col) == nil { return fmt.Errorf("preload: missing col %s in table %s", col, table.Name) } } } - for name, node := range pn.children { + for name, child := range node.children { column := table.GetColumn(name) if column == nil { return fmt.Errorf("preload: missing field %s in struct %s", name, table.Type.Name()) @@ -87,24 +97,25 @@ func (pn *PreloadNode) Validate(table *schemas.Table) error { return fmt.Errorf("preload: missing association in field %s", name) } if column.Association.JoinTable == nil && len(column.Association.SourceCol) > 0 { - pn.ExtraCols = append(pn.ExtraCols, column.Association.SourceCol) + node.extraCols = append(node.extraCols, column.Association.SourceCol) } if len(column.Association.TargetCol) > 0 { - pn.ExtraCols = append(pn.ExtraCols, table.PrimaryKeys[0]) // pk might be added many times, but that's ok + node.extraCols = append(node.extraCols, table.PrimaryKeys[0]) // pk might be added many times, but that's ok } if column.Association.JoinTable == nil && len(column.Association.TargetCol) > 0 { - node.ExtraCols = append(node.ExtraCols, column.Association.TargetCol) + child.extraCols = append(child.extraCols, column.Association.TargetCol) } - node.association = column.Association - if err := node.Validate(column.Association.RefTable); err != nil { + child.association = column.Association + if err := child.Validate(column.Association.RefTable); err != nil { return err } } return nil } -func (pn *PreloadNode) Compute(session *Session, ownMap reflect.Value) error { - for _, node := range pn.children { +// Compute preloads the associations contained in the preload tree +func (node *PreloadTreeNode) Compute(session *Session, ownMap reflect.Value) error { + for _, node := range node.children { if err := node.compute(session, ownMap, reflect.Value{}); err != nil { return err } @@ -112,33 +123,34 @@ func (pn *PreloadNode) Compute(session *Session, ownMap reflect.Value) error { return nil } -func (pn *PreloadNode) compute(session *Session, ownMap, pruneMap reflect.Value) error { - // non-root node: pn.association is not nil - if err := pn.association.ValidateOwnMap(ownMap); err != nil { +// compute preloads the association contained in a preload tree node +func (node *PreloadTreeNode) compute(session *Session, ownMap, pruneMap reflect.Value) error { + // non-root node: association is not nil + if err := node.association.ValidateOwnMap(ownMap); err != nil { return err } var joinMap reflect.Value - cond := pn.association.GetCond(ownMap) - if pn.association.JoinTable != nil { + cond := node.association.GetCond(ownMap) + if node.association.JoinTable != nil { var err error - cond, joinMap, err = pn.preloadJoin(session, cond) + cond, joinMap, err = node.preloadJoin(session, cond) if err != nil { return err } } - refMap := pn.association.MakeRefMap() - preloadSession := session.Engine().Cols(pn.ExtraCols...).Where(cond) - if pn.preload != nil { - preloadSession.Cols(pn.preload.cols...).Where(pn.preload.cond) + refMap := node.association.MakeRefMap() + preloadSession := session.Engine().Cols(node.extraCols...).Where(cond) + if node.preload != nil { + preloadSession.Cols(node.preload.cols...).Where(node.preload.cond) } if err := preloadSession.Find(refMap.Interface()); err != nil { return err } var refPruneMap reflect.Value - if len(pn.children) > 0 && !(pn.preload != nil && (len(pn.preload.cols) > 0 || pn.preload.noPrune)) { + if len(node.children) > 0 && !(node.preload != nil && (len(node.preload.cols) > 0 || node.preload.noPrune)) { refPruneMap = reflect.MakeMap(reflect.MapOf(refMap.Type().Key(), reflect.TypeOf(true))) refIter := refMap.MapRange() for refIter.Next() { @@ -146,7 +158,7 @@ func (pn *PreloadNode) compute(session *Session, ownMap, pruneMap reflect.Value) } } - for _, node := range pn.children { + for _, node := range node.children { if err := node.compute(session, refMap, refPruneMap); err != nil { return err } @@ -159,26 +171,27 @@ func (pn *PreloadNode) compute(session *Session, ownMap, pruneMap reflect.Value) } } - pn.association.Link(ownMap, refMap, pruneMap, joinMap) + node.association.Link(ownMap, refMap, pruneMap, joinMap) return nil } -func (pn *PreloadNode) preloadJoin(session *Session, cond builder.Cond) (builder.Cond, reflect.Value, error) { - joinSlicePtr := pn.association.MakeJoinSlice() +// preloadJoin obtains a join condition and a join map for a many-to-many association +func (node *PreloadTreeNode) preloadJoin(session *Session, cond builder.Cond) (builder.Cond, reflect.Value, error) { + joinSlicePtr := node.association.NewJoinSlice() if err := session.Engine(). - Table(pn.association.JoinTable.Name).Where(cond). - Cols(pn.association.SourceCol, pn.association.TargetCol). + Table(node.association.JoinTable.Name).Where(cond). + Cols(node.association.SourceCol, node.association.TargetCol). Find(joinSlicePtr.Interface()); err != nil { return nil, reflect.Value{}, err } joinSlice := joinSlicePtr.Elem() - joinMap := pn.association.MakeJoinMap() + joinMap := node.association.MakeJoinMap() for i := 0; i < joinSlice.Len(); i++ { entry := joinSlice.Index(i) pkSlice := joinMap.MapIndex(entry.Field(1)) if !pkSlice.IsValid() { - pkSlice = reflect.MakeSlice(reflect.SliceOf(pn.association.OwnPkType), 0, 0) + pkSlice = reflect.MakeSlice(reflect.SliceOf(node.association.OwnPkType), 0, 0) } joinMap.SetMapIndex(entry.Field(1), reflect.Append(pkSlice, entry.Field(0))) } @@ -189,6 +202,6 @@ func (pn *PreloadNode) preloadJoin(session *Session, cond builder.Cond) (builder for iter.Next() { refPks = append(refPks, iter.Key().Interface()) } - cond = builder.In(pn.association.RefTable.PrimaryKeys[0], refPks) + cond = builder.In(node.association.RefTable.PrimaryKeys[0], refPks) return cond, joinMap, nil } diff --git a/schemas/association.go b/schemas/association.go index 60cfdcee..56cbca6d 100644 --- a/schemas/association.go +++ b/schemas/association.go @@ -6,6 +6,7 @@ import ( "xorm.io/builder" ) +// Association is the representation of an association type Association struct { OwnTable *Table OwnColumn *Column @@ -17,18 +18,22 @@ type Association struct { TargetCol string // has_one, has_many, many_to_many } -func (a *Association) MakeJoinSlice() reflect.Value { +// NewJoinSlice creates a slice to hold the intermediate result of a many-to-many association query +func (a *Association) NewJoinSlice() reflect.Value { return reflect.New(reflect.SliceOf(a.JoinTable.Type)) } +// MakeJoinMap creates a map to hold the intermediate result of a many-to-many association func (a *Association) MakeJoinMap() reflect.Value { return reflect.MakeMap(reflect.MapOf(a.RefPkType, reflect.SliceOf(a.OwnPkType))) } +// MakeRefMap creates a map to hold the result of an association query func (a *Association) MakeRefMap() reflect.Value { return reflect.MakeMap(reflect.MapOf(a.RefPkType, reflect.PointerTo(a.RefTable.Type))) } +// ValidateOwnMap validates the type of the owner map (parent of an association) func (a *Association) ValidateOwnMap(ownMap reflect.Value) error { if ownMap.Type() != reflect.MapOf(a.OwnPkType, reflect.PointerTo(a.OwnTable.Type)) { return fmt.Errorf("wrong map type: %v", ownMap.Type()) @@ -36,6 +41,7 @@ func (a *Association) ValidateOwnMap(ownMap reflect.Value) error { return nil } +// GetCond gets a where condition to use in an association query func (a *Association) GetCond(ownMap reflect.Value) builder.Cond { if a.JoinTable != nil { return a.condManyToMany(ownMap) @@ -46,6 +52,7 @@ func (a *Association) GetCond(ownMap reflect.Value) builder.Cond { return a.condHasOneOrMany(ownMap) } +// condBelongsTo gets a where condition to use in a belongs-to association query func (a *Association) condBelongsTo(ownMap reflect.Value) builder.Cond { pkMap := make(map[interface{}]bool) fkCol := a.OwnTable.GetColumn(a.SourceCol) @@ -69,6 +76,7 @@ func (a *Association) condBelongsTo(ownMap reflect.Value) builder.Cond { return builder.In(a.RefTable.PrimaryKeys[0], pks) } +// condHasOneOrMany gets a where condition to use in a has-one or has-many association query func (a *Association) condHasOneOrMany(ownMap reflect.Value) builder.Cond { var pks []interface{} iter := ownMap.MapRange() @@ -78,6 +86,7 @@ func (a *Association) condHasOneOrMany(ownMap reflect.Value) builder.Cond { return builder.In(a.TargetCol, pks) } +// condHasOneOrMany gets a where condition to use in a many-to-many association query func (a *Association) condManyToMany(ownMap reflect.Value) builder.Cond { var pks []interface{} iter := ownMap.MapRange() @@ -87,6 +96,7 @@ func (a *Association) condManyToMany(ownMap reflect.Value) builder.Cond { return builder.In(a.SourceCol, pks) } +// Link links the owner (parent) values with the referenced association values func (a *Association) Link(ownMap, refMap, pruneMap, joinMap reflect.Value) { if a.JoinTable != nil { a.linkManyToMany(ownMap, refMap, pruneMap, joinMap) @@ -97,6 +107,7 @@ func (a *Association) Link(ownMap, refMap, pruneMap, joinMap reflect.Value) { } } +// linkBelongsTo links the owner (parent) values with the referenced belongs-to association values func (a *Association) linkBelongsTo(ownMap, refMap, pruneMap reflect.Value) { fkCol := a.OwnTable.GetColumn(a.SourceCol) iter := ownMap.MapRange() @@ -121,6 +132,7 @@ func (a *Association) linkBelongsTo(ownMap, refMap, pruneMap reflect.Value) { } } +// linkBelongsTo links the owner (parent) values with the referenced has-one or has-many association values func (a *Association) linkHasOneOrMany(ownMap, refMap, pruneMap reflect.Value) { hasMany := a.OwnColumn.FieldType.Kind() == reflect.Slice fkCol := a.RefTable.GetColumn(a.TargetCol) @@ -148,6 +160,7 @@ func (a *Association) linkHasOneOrMany(ownMap, refMap, pruneMap reflect.Value) { } } +// linkManyToMany links the owner (parent) values with the referenced many-to-many association values func (a *Association) linkManyToMany(ownMap, refMap, pruneMap, joinMap reflect.Value) { iter := refMap.MapRange() for iter.Next() { diff --git a/session.go b/session.go index 56644bd1..a268de75 100644 --- a/session.go +++ b/session.go @@ -87,7 +87,7 @@ type Session struct { ctx context.Context sessionType sessionType - preloadNode *PreloadNode + preloadNode *PreloadTreeNode } func newSessionID() string { @@ -805,7 +805,7 @@ func (session *Session) NoVersionCheck() *Session { // Preloads adds preloads func (session *Session) Preloads(preloads ...*Preload) *Session { if session.preloadNode == nil { - session.preloadNode = NewPreloadNode() + session.preloadNode = NewPreloadTeeeNode() } for _, preload := range preloads { if err := session.preloadNode.Add(preload); err != nil { diff --git a/session_find.go b/session_find.go index 5ef68e1a..c00832ab 100644 --- a/session_find.go +++ b/session_find.go @@ -153,7 +153,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } // we need the columns required for the preloads if !session.statement.ColumnMap.IsEmpty() { - for _, k := range session.preloadNode.ExtraCols { + for _, k := range session.preloadNode.extraCols { session.statement.ColumnMap.Add(k) } } diff --git a/session_get.go b/session_get.go index 7f257c4f..80f62098 100644 --- a/session_get.go +++ b/session_get.go @@ -87,7 +87,7 @@ func (session *Session) get(beans ...interface{}) (bool, error) { } // we need the columns required for the preloads if !session.statement.ColumnMap.IsEmpty() { - for _, k := range session.preloadNode.ExtraCols { + for _, k := range session.preloadNode.extraCols { session.statement.ColumnMap.Add(k) } } From 4f8ee6f9191cbb8997f83e5beedf837e64c2dbcb Mon Sep 17 00:00:00 2001 From: Diego Sogari Date: Sat, 1 Apr 2023 13:50:32 -0300 Subject: [PATCH 3/5] Fix go vet errors --- schemas/association.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/schemas/association.go b/schemas/association.go index 56cbca6d..147dcfe0 100644 --- a/schemas/association.go +++ b/schemas/association.go @@ -30,12 +30,12 @@ func (a *Association) MakeJoinMap() reflect.Value { // MakeRefMap creates a map to hold the result of an association query func (a *Association) MakeRefMap() reflect.Value { - return reflect.MakeMap(reflect.MapOf(a.RefPkType, reflect.PointerTo(a.RefTable.Type))) + return reflect.MakeMap(reflect.MapOf(a.RefPkType, reflect.PtrTo(a.RefTable.Type))) } // ValidateOwnMap validates the type of the owner map (parent of an association) func (a *Association) ValidateOwnMap(ownMap reflect.Value) error { - if ownMap.Type() != reflect.MapOf(a.OwnPkType, reflect.PointerTo(a.OwnTable.Type)) { + if ownMap.Type() != reflect.MapOf(a.OwnPkType, reflect.PtrTo(a.OwnTable.Type)) { return fmt.Errorf("wrong map type: %v", ownMap.Type()) } return nil @@ -60,7 +60,7 @@ func (a *Association) condBelongsTo(ownMap reflect.Value) builder.Cond { for iter.Next() { structPtr := iter.Value() fk, _ := fkCol.ValueOfV(&structPtr) - if fk.Type().Kind() == reflect.Pointer { + if fk.Type().Kind() == reflect.Ptr { if fk.IsNil() { continue } @@ -114,7 +114,7 @@ func (a *Association) linkBelongsTo(ownMap, refMap, pruneMap reflect.Value) { for iter.Next() { structPtr := iter.Value() fk, _ := fkCol.ValueOfV(&structPtr) - if fk.Type().Kind() == reflect.Pointer { + if fk.Type().Kind() == reflect.Ptr { if fk.IsNil() { continue } @@ -140,7 +140,7 @@ func (a *Association) linkHasOneOrMany(ownMap, refMap, pruneMap reflect.Value) { for iter.Next() { refStructPtr := iter.Value() fk, _ := fkCol.ValueOfV(&refStructPtr) - if fk.Type().Kind() == reflect.Pointer { + if fk.Type().Kind() == reflect.Ptr { if fk.IsNil() { continue } From 17358d66b63f4995cc0278b7e139a4e1e363e637 Mon Sep 17 00:00:00 2001 From: Diego Sogari Date: Sat, 1 Apr 2023 14:14:53 -0300 Subject: [PATCH 4/5] Remove the `NoPrune` method --- preload.go | 15 ++++----------- preload_test.go | 2 +- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/preload.go b/preload.go index dddc6a35..9185aa7c 100644 --- a/preload.go +++ b/preload.go @@ -10,10 +10,9 @@ import ( // Preload is the representation of an association preload type Preload struct { - path []string - cols []string - cond builder.Cond - noPrune bool + path []string + cols []string + cond builder.Cond } // NewPreload creates a new preload with the specified path @@ -36,12 +35,6 @@ func (p *Preload) Where(cond builder.Cond) *Preload { return p } -// NoPrune sets a flag to avoid pruning empty associations -func (p *Preload) NoPrune() *Preload { - p.noPrune = true - return p -} - // PreloadTreeNode is a tree node for the association preloads type PreloadTreeNode struct { preload *Preload @@ -150,7 +143,7 @@ func (node *PreloadTreeNode) compute(session *Session, ownMap, pruneMap reflect. } var refPruneMap reflect.Value - if len(node.children) > 0 && !(node.preload != nil && (len(node.preload.cols) > 0 || node.preload.noPrune)) { + if len(node.children) > 0 && !(node.preload != nil && len(node.preload.cols) > 0) { refPruneMap = reflect.MakeMap(reflect.MapOf(refMap.Type().Key(), reflect.TypeOf(true))) refIter := refMap.MapRange() for refIter.Next() { diff --git a/preload_test.go b/preload_test.go index 9d87ad72..e7d05764 100644 --- a/preload_test.go +++ b/preload_test.go @@ -62,7 +62,7 @@ insert into employee_indication values (1,2), (1,3), (2,3), (2,4), (2,5), (3,5), var employee Employee _, err = engine.Preloads( engine.Preload("Indications.Buddy").Cols("name"), - engine.Preload("Indications").NoPrune(), + engine.Preload("Indications").Cols("id"), ).Cols("name").Where(builder.Eq{"id": 2}).Get(&employee) require.NoError(t, err) From 89a266addf2f3f87b448cd9cd37db6e25eab3c9d Mon Sep 17 00:00:00 2001 From: Diego Sogari Date: Sun, 9 Apr 2023 15:27:49 -0300 Subject: [PATCH 5/5] Allow selection of all columns in a leaf preload by not specifying any column --- internal/statements/select.go | 4 ++++ preload.go | 9 +++++++-- preload_test.go | 4 +++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/internal/statements/select.go b/internal/statements/select.go index 2bd2e94d..f2f4ba4e 100644 --- a/internal/statements/select.go +++ b/internal/statements/select.go @@ -98,6 +98,10 @@ func (statement *Statement) genColumnStr() string { continue } + if col.Association != nil { + continue + } + if buf.Len() != 0 { buf.WriteString(", ") } diff --git a/preload.go b/preload.go index 9185aa7c..eaf4743d 100644 --- a/preload.go +++ b/preload.go @@ -134,9 +134,14 @@ func (node *PreloadTreeNode) compute(session *Session, ownMap, pruneMap reflect. } refMap := node.association.MakeRefMap() - preloadSession := session.Engine().Cols(node.extraCols...).Where(cond) + preloadSession := session.Engine().Where(cond) if node.preload != nil { - preloadSession.Cols(node.preload.cols...).Where(node.preload.cond) + if len(node.preload.cols) > 0 { + preloadSession.Cols(node.extraCols...).Cols(node.preload.cols...) + } + preloadSession.Where(node.preload.cond) + } else { + preloadSession.Cols(node.extraCols...) } if err := preloadSession.Find(refMap.Interface()); err != nil { return err diff --git a/preload_test.go b/preload_test.go index e7d05764..ba94f2a5 100644 --- a/preload_test.go +++ b/preload_test.go @@ -124,10 +124,12 @@ insert into employee_indication values (1,2), (1,3), (2,3), (2,4), (2,5), (3,5), Subordinates: []*Employee{ { Id: 3, + Name: "Alice", ManagerId: &[]int64{1}[0], IndicatedBy: []*Employee{ { - Id: 1, + Id: 1, + Name: "John", Indications: []*Employee{ { Id: 2,