From 23c6999de8fe5f16a5789917554ecd5c87395290 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 11 Sep 2017 15:43:39 +0800 Subject: [PATCH] merge EagerGet and EagerFind to Load --- engine.go | 6 ++-- interface.go | 2 +- session.go | 73 ++++++++++++++++++--------------------- session_associate.go | 27 +++++++++++---- session_associate_test.go | 4 +-- session_convert.go | 14 ++++---- statement.go | 9 +++-- tag_extends_test.go | 36 +++++-------------- 8 files changed, 80 insertions(+), 91 deletions(-) diff --git a/engine.go b/engine.go index 18422ce5..732f876e 100644 --- a/engine.go +++ b/engine.go @@ -1501,11 +1501,11 @@ func (engine *Engine) Exist(bean ...interface{}) (bool, error) { return session.Exist(bean...) } -// EagerGet loads bean's belongs to tag field immedicatlly -func (engine *Engine) EagerGet(bean interface{}, cols ...string) error { +// Load loads bean's belongs to tag field immedicatlly +func (engine *Engine) Load(bean interface{}, cols ...string) error { session := engine.NewSession() defer session.Close() - return session.EagerGet(bean, cols...) + return session.Load(bean, cols...) } // Find retrieve records from table, condiBeans's non-empty fields diff --git a/interface.go b/interface.go index e15126a7..ba548b1d 100644 --- a/interface.go +++ b/interface.go @@ -28,7 +28,7 @@ type Interface interface { Delete(interface{}) (int64, error) Distinct(columns ...string) *Session DropIndexes(bean interface{}) error - EagerGet(interface{}, ...string) error + Load(interface{}, ...string) error Exec(string, ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error diff --git a/session.go b/session.go index 3044f6f3..ff312547 100644 --- a/session.go +++ b/session.go @@ -21,6 +21,7 @@ type loadClosure struct { Func func(core.PK, *reflect.Value) error pk core.PK fieldValue *reflect.Value + loaded bool } // Session keep a pointer to sql.DB and provides all execution of all @@ -57,6 +58,9 @@ type Session struct { lastSQL string lastSQLArgs []interface{} + cascadeMode cascadeMode + cascadeLevel int // load level + err error } @@ -88,6 +92,9 @@ func (session *Session) Init() { session.lastSQL = "" session.lastSQLArgs = []interface{}{} + + session.cascadeMode = cascadeCompitable + session.cascadeLevel = 2 } // Close release the connection from pool @@ -155,7 +162,7 @@ func (session *Session) Alias(alias string) *Session { // NoCascade indicate that no cascade load child object func (session *Session) NoCascade() *Session { - session.statement.cascadeMode = cascadeManually + session.cascadeMode = cascadeLazy return session } @@ -210,16 +217,16 @@ func (session *Session) Charset(charset string) *Session { // Cascade indicates if loading sub Struct func (session *Session) Cascade(trueOrFalse ...bool) *Session { - var mode = cascadeAuto + var mode = cascadeEager if len(trueOrFalse) >= 1 { if trueOrFalse[0] { - mode = cascadeAuto + mode = cascadeEager } else { - mode = cascadeManually + mode = cascadeLazy } } - session.statement.cascadeMode = mode + session.cascadeMode = mode return session } @@ -642,10 +649,10 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b session.engine.logger.Error("sql.Sanner error:", err.Error()) hasAssigned = false } - } else if (col.AssociateType == core.AssociateNone && - session.statement.cascadeMode == cascadeCompitable) || + } else if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone && + session.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && - session.statement.cascadeMode == cascadeAuto) { + session.cascadeMode == cascadeEager)) { var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) var err error rawValueType := col.AssociateTable.PKColumns()[0].FieldType @@ -660,29 +667,15 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } pk[0] = reflect.ValueOf(pk[0]).Elem().Interface() - if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type())) - } - session.afterProcessors = append(session.afterProcessors, executedProcessor{ - fun: func(session *Session, bean interface{}) error { - fieldValue := bean.(*reflect.Value) - return session.getByPK(pk, fieldValue) - }, - session: session, - bean: fieldValue, - }) - } else { - v := fieldValue.Addr() - session.afterProcessors = append(session.afterProcessors, executedProcessor{ - fun: func(session *Session, bean interface{}) error { - fieldValue := bean.(*reflect.Value) - return session.getByPK(pk, fieldValue) - }, - session: session, - bean: &v, - }) - } + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + return session.getByPK(pk, fieldValue) + }, + session: session, + bean: fieldValue, + }) + session.cascadeLevel-- hasAssigned = true } else if col.AssociateType == core.AssociateBelongsTo { hasAssigned = true @@ -694,10 +687,10 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } case reflect.Ptr: if fieldType != core.PtrTimeType && fieldType.Elem().Kind() == reflect.Struct { - if (col.AssociateType == core.AssociateNone && - session.statement.cascadeMode == cascadeCompitable) || + if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone && + session.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && - session.statement.cascadeMode == cascadeAuto) { + session.cascadeMode == cascadeEager)) { var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) var err error rawValueType := col.AssociateTable.ColumnType(col.AssociateTable.PKColumns()[0].FieldName) @@ -712,8 +705,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } pk[0] = reflect.ValueOf(pk[0]).Elem().Interface() - fmt.Println("=====", fieldValue, fieldValue.IsNil()) - fmt.Printf("%#v", fieldValue) session.afterProcessors = append(session.afterProcessors, executedProcessor{ fun: func(session *Session, bean interface{}) error { fieldValue := bean.(*reflect.Value) @@ -723,6 +714,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b bean: fieldValue, }) + session.cascadeLevel-- hasAssigned = true } else if col.AssociateType == core.AssociateBelongsTo { hasAssigned = true @@ -732,7 +724,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b fieldValue.Set(structInter) } - //fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Set(vv) err := convertAssign(fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Addr().Interface(), vv.Interface()) if err != nil { @@ -741,7 +732,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } } else { // !nashtsai! TODO merge duplicated codes above - //typeStr := fieldType.String() switch fieldType { // following types case matching ptr's native type, therefore assign ptr directly case core.PtrStringType: @@ -889,14 +879,17 @@ func (session *Session) getByPK(pk core.PK, fieldValue *reflect.Value) error { structInter = fieldValue.Addr() } - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) + has, err := session.ID(pk).NoAutoCondition().get(structInter.Interface()) if err != nil { return err } if has { if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { fieldValue.Set(structInter) - fmt.Println("333", fieldValue.IsNil()) + } else if fieldValue.Kind() == reflect.Struct { + fieldValue.Set(structInter.Elem()) + } else { + return errors.New("set value failed") } } else { return errors.New("cascade obj is not exist") diff --git a/session_associate.go b/session_associate.go index e11be89f..aa922db2 100644 --- a/session_associate.go +++ b/session_associate.go @@ -11,8 +11,22 @@ import ( "github.com/go-xorm/core" ) -// EagerFind load 's belongs to tag field immedicatlly -func (session *Session) EagerFind(slices interface{}, cols ...string) error { +// Load loads associated fields from database +func (session *Session) Load(beanOrSlices interface{}, cols ...string) error { + v := reflect.ValueOf(beanOrSlices) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() == reflect.Slice { + return session.loadFind(beanOrSlices, cols...) + } else if v.Kind() == reflect.Struct { + return session.loadGet(beanOrSlices, cols...) + } + return errors.New("unsupported load type, must struct or slice") +} + +// loadFind load 's belongs to tag field immedicatlly +func (session *Session) loadFind(slices interface{}, cols ...string) error { /*v := reflect.ValueOf(slices) if v.Kind() == reflect.Ptr { v = v.Elem() @@ -80,8 +94,8 @@ func (session *Session) EagerFind(slices interface{}, cols ...string) error { return nil } -// EagerGet load bean's belongs to tag field immedicatlly -func (session *Session) EagerGet(bean interface{}, cols ...string) error { +// loadGet load bean's belongs to tag field immedicatlly +func (session *Session) loadGet(bean interface{}, cols ...string) error { if session.isAutoClose { defer session.Close() } @@ -115,14 +129,15 @@ func (session *Session) EagerGet(bean interface{}, cols ...string) error { colPtr = colV.Addr() } - if !isZero(pk[0]) { - has, err := session.ID(pk).get(colPtr.Interface()) + if !isZero(pk[0]) && session.cascadeLevel > 0 { + has, err := session.ID(pk).NoAutoCondition().get(colPtr.Interface()) if err != nil { return err } if !has { return errors.New("load bean does not exist") } + session.cascadeLevel-- } } } diff --git a/session_associate_test.go b/session_associate_test.go index 58f83cd1..62e46cbc 100644 --- a/session_associate_test.go +++ b/session_associate_test.go @@ -50,7 +50,7 @@ func TestBelongsTo_Get(t *testing.T) { assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) assert.Equal(t, "", cfgNose.Face.Name) - err = testEngine.EagerGet(&cfgNose) + err = testEngine.Load(&cfgNose) assert.NoError(t, err) assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) @@ -104,7 +104,7 @@ func TestBelongsTo_GetPtr(t *testing.T) { assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) - err = testEngine.EagerGet(&cfgNose) + err = testEngine.Load(&cfgNose) assert.NoError(t, err) assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) diff --git a/session_convert.go b/session_convert.go index e69788b8..de7cfbec 100644 --- a/session_convert.go +++ b/session_convert.go @@ -202,10 +202,10 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } v = x fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) - } else if (col.AssociateType == core.AssociateNone && - session.statement.cascadeMode == cascadeCompitable) || + } else if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone && + session.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && - session.statement.cascadeMode == cascadeAuto) { + session.cascadeMode == cascadeEager)) { var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) // only 1 PK checked on tag parsing rawValueType := col.AssociateTable.ColumnType(col.AssociateTable.PKColumns()[0].FieldName) @@ -215,6 +215,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return err } + session.cascadeLevel-- if err = session.getByPK(pk, fieldValue); err != nil { return err } @@ -466,10 +467,10 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, v = x fieldValue.Set(reflect.ValueOf(&x)) default: - if (col.AssociateType == core.AssociateNone && - session.statement.cascadeMode == cascadeCompitable) || + if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone && + session.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && - session.statement.cascadeMode == cascadeAuto) { + session.cascadeMode == cascadeEager)) { var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) var err error // only 1 PK checked on tag parsing @@ -479,6 +480,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return err } + session.cascadeLevel-- if err = session.getByPK(pk, fieldValue); err != nil { return err } diff --git a/statement.go b/statement.go index de8a071c..b2d1e90b 100644 --- a/statement.go +++ b/statement.go @@ -36,9 +36,10 @@ type exprParam struct { type cascadeMode int const ( - cascadeCompitable cascadeMode = iota - cascadeAuto - cascadeManually + cascadeCompitable cascadeMode = iota // load field beans with another SQL with no + cascadeEager // load field beans with another SQL + cascadeJoin // load field beans with join + cascadeLazy // don't load anything ) // Statement save all the sql info for executing SQL @@ -62,7 +63,6 @@ type Statement struct { tableName string RawSQL string RawParams []interface{} - cascadeMode cascadeMode UseAutoJoin bool StoreEngine string Charset string @@ -90,7 +90,6 @@ func (statement *Statement) Init() { statement.Start = 0 statement.LimitN = 0 statement.OrderStr = "" - statement.cascadeMode = cascadeCompitable statement.JoinStr = "" statement.joinArgs = make([]interface{}, 0) statement.GroupByStr = "" diff --git a/tag_extends_test.go b/tag_extends_test.go index b70eefe3..6c72a6dc 100644 --- a/tag_extends_test.go +++ b/tag_extends_test.go @@ -169,22 +169,12 @@ func TestExtends(t *testing.T) { tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} _, err = testEngine.ID(tu5.Temp.Id).Update(tu6) - if err != nil { - t.Error(err) - panic(err) - } + assert.Error(t, err) users := make([]tempUser3, 0) err = testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - if len(users) != 1 { - err = errors.New("error get data not 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) assertSync(t, new(Userinfo), new(Userdetail)) @@ -209,21 +199,11 @@ func TestExtends(t *testing.T) { sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) b, err := testEngine.SQL(sql).NoCascade().Get(&info) - if err != nil { - t.Error(err) - panic(err) - } - if !b { - err = errors.New("should has lest one record") - t.Error(err) - panic(err) - } - fmt.Println(info) - if info.Userinfo.Uid == 0 || info.Userdetail.Id == 0 { - err = errors.New("all of the id should has value") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, b) + assert.False(t, info.Userinfo.Uid == 0 || info.Userdetail.Id == 0) + + panic("") fmt.Println("----join--info2") var info2 UserAndDetail