From 3538ce1752c88a8715f1a3b7fd5a238c93aa0331 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 3 Jul 2017 17:10:32 +0800 Subject: [PATCH] improve belongs_to --- session.go | 44 +++++++++++++++++++++++++++++++-------- session_associate_test.go | 2 ++ 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/session.go b/session.go index 82741b06..3044f6f3 100644 --- a/session.go +++ b/session.go @@ -17,6 +17,12 @@ import ( "github.com/go-xorm/core" ) +type loadClosure struct { + Func func(core.PK, *reflect.Value) error + pk core.PK + fieldValue *reflect.Value +} + // Session keep a pointer to sql.DB and provides all execution of all // kind of database operations. type Session struct { @@ -658,14 +664,24 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b if fieldValue.IsNil() { fieldValue.Set(reflect.New(fieldValue.Type())) } - if err = session.getByPK(pk, fieldValue); err != nil { - return nil, err - } + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + return session.getByPK(pk, fieldValue) + }, + session: session, + bean: fieldValue, + }) } else { v := fieldValue.Addr() - if err = session.getByPK(pk, &v); err != nil { - return nil, err - } + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + return session.getByPK(pk, fieldValue) + }, + session: session, + bean: &v, + }) } hasAssigned = true } else if col.AssociateType == core.AssociateBelongsTo { @@ -696,13 +712,22 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } pk[0] = reflect.ValueOf(pk[0]).Elem().Interface() - if err = session.getByPK(pk, fieldValue); err != nil { - return nil, err - } + fmt.Println("=====", fieldValue, fieldValue.IsNil()) + fmt.Printf("%#v", fieldValue) + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + return session.getByPK(pk, fieldValue) + }, + session: session, + bean: fieldValue, + }) + hasAssigned = true } else if col.AssociateType == core.AssociateBelongsTo { hasAssigned = true if fieldValue.IsNil() { + // FIXME: find id column structInter := reflect.New(fieldValue.Type().Elem()) fieldValue.Set(structInter) } @@ -871,6 +896,7 @@ func (session *Session) getByPK(pk core.PK, fieldValue *reflect.Value) error { if has { if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { fieldValue.Set(structInter) + fmt.Println("333", fieldValue.IsNil()) } } else { return errors.New("cascade obj is not exist") diff --git a/session_associate_test.go b/session_associate_test.go index 021277d9..58f83cd1 100644 --- a/session_associate_test.go +++ b/session_associate_test.go @@ -215,6 +215,8 @@ func TestBelongsTo_FindPtr(t *testing.T) { err = testEngine.Cascade().Find(&noses2) assert.NoError(t, err) assert.Equal(t, 2, len(noses2)) + assert.NotNil(t, noses2[0].Face) + assert.NotNil(t, noses2[1].Face) assert.Equal(t, face1.Id, noses2[0].Face.Id) assert.Equal(t, face2.Id, noses2[1].Face.Id) assert.Equal(t, "face1", noses2[0].Face.Name)