This commit is contained in:
Lunny Xiao 2021-07-22 15:17:38 +08:00
parent b882579026
commit f2a1e6ea2f
2 changed files with 53 additions and 20 deletions

View File

@ -169,7 +169,7 @@ func TestBelongsTo_Find(t *testing.T) {
assert.Equal(t, "face1", noses2[0].Face.Name) assert.Equal(t, "face1", noses2[0].Face.Name)
assert.Equal(t, "face2", noses2[1].Face.Name) assert.Equal(t, "face2", noses2[1].Face.Name)
err = testEngine.Load(noses1, "face") err = testEngine.Load(noses1, "face_id")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "face1", noses1[0].Face.Name) assert.Equal(t, "face1", noses1[0].Face.Name)
assert.Equal(t, "face2", noses1[1].Face.Name) assert.Equal(t, "face2", noses1[1].Face.Name)

View File

@ -6,6 +6,7 @@ package xorm
import ( import (
"errors" "errors"
"fmt"
"reflect" "reflect"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
@ -56,44 +57,76 @@ func (session *Session) loadFindSlice(v reflect.Value, cols ...string) error {
return err return err
} }
var pks = make(map[*schemas.Column][]interface{}) type Va struct {
v reflect.Value
pk []interface{}
col *schemas.Column
}
var pks = make(map[*schemas.Column]*Va)
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
ev := v.Index(i) ev := v.Index(i)
fmt.Println("1====", ev.Interface(), tb.Name, len(tb.Columns()))
for _, col := range tb.Columns() { for _, col := range tb.Columns() {
fmt.Println("====", cols, col.Name)
if len(cols) > 0 && !isStringInSlice(col.Name, cols) { if len(cols) > 0 && !isStringInSlice(col.Name, cols) {
continue continue
} }
if col.AssociateTable != nil { fmt.Println("3------", col.Name, col.AssociateTable)
if col.AssociateType == schemas.AssociateBelongsTo {
colV, err := col.ValueOfV(&ev)
if err != nil {
return err
}
vv := colV.Interface() if col.AssociateTable == nil || col.AssociateType != schemas.AssociateBelongsTo {
/*var colPtr reflect.Value continue
if colV.Kind() == reflect.Ptr { }
colPtr = *colV
} else {
colPtr = colV.Addr()
}*/
if !utils.IsZero(vv) { colV, err := col.ValueOfV(&ev)
pks[col] = append(pks[col], vv) if err != nil {
return err
}
pkCols := col.AssociateTable.PKColumns()
pkV, err := pkCols[0].ValueOfV(colV)
if err != nil {
return err
}
vv := pkV.Interface()
fmt.Println("2====", vv)
if !utils.IsZero(vv) {
va, ok := pks[col]
if !ok {
va = &Va{
v: ev,
col: pkCols[0],
} }
pks[col] = va
} }
va.pk = append(va.pk, vv)
} }
} }
} }
for col, pk := range pks { for col, va := range pks {
slice := reflect.MakeSlice(col.FieldType, 0, len(pk)) slice := reflect.MakeSlice(reflect.SliceOf(col.FieldType), 0, len(va.pk))
err = session.In(col.Name, pk...).find(slice.Addr().Interface()) err = session.In(va.col.Name, va.pk...).find(slice.Interface())
if err != nil { if err != nil {
return err return err
} }
/*vv, err := col.ValueOfV(&va.v)
if err != nil {
return err
}
vv.Set()
for i := 0; i < slice.Len(); i++ {
va.col.ValueOfV(slice.Index(i))
}*/
} }
return nil return nil
} }