From 567b13889bdc855adbcb103488b52aa2b078afc1 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 29 May 2017 08:30:20 +0800 Subject: [PATCH] rename EagerLoad -> EagerGet --- engine.go | 6 +- helpers.go | 29 +++---- helpers_test.go | 21 ----- interface.go | 2 +- session.go | 14 +++- session_associate.go | 77 ++++++++++++++++++- ...gs_to_test.go => session_associate_test.go | 4 +- tag.go | 20 +++++ tag_test.go | 19 +++++ 9 files changed, 141 insertions(+), 51 deletions(-) rename blongs_to_test.go => session_associate_test.go (98%) diff --git a/engine.go b/engine.go index 1c0efc51..18422ce5 100644 --- a/engine.go +++ b/engine.go @@ -1501,11 +1501,11 @@ func (engine *Engine) Exist(bean ...interface{}) (bool, error) { return session.Exist(bean...) } -// EagerLoad loads bean's belongs to tag field immedicatlly -func (engine *Engine) EagerLoad(bean interface{}, cols ...string) error { +// EagerGet loads bean's belongs to tag field immedicatlly +func (engine *Engine) EagerGet(bean interface{}, cols ...string) error { session := engine.NewSession() defer session.Close() - return session.EagerLoad(bean, cols...) + return session.EagerGet(bean, cols...) } // Find retrieve records from table, condiBeans's non-empty fields diff --git a/helpers.go b/helpers.go index b505a995..85240b24 100644 --- a/helpers.go +++ b/helpers.go @@ -104,26 +104,6 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) { return v.Interface(), nil } -func splitTag(tag string) (tags []string) { - tag = strings.TrimSpace(tag) - var hasQuote = false - var lastIdx = 0 - for i, t := range tag { - if t == '\'' { - hasQuote = !hasQuote - } else if t == ' ' { - if lastIdx < i && !hasQuote { - tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) - lastIdx = i + 1 - } - } - } - if lastIdx < len(tag) { - tags = append(tags, strings.TrimSpace(tag[lastIdx:])) - } - return -} - type zeroable interface { IsZero() bool } @@ -479,3 +459,12 @@ func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) return false, false } + +func isStringInSlice(s string, slice []string) bool { + for _, e := range slice { + if s == e { + return true + } + } + return false +} diff --git a/helpers_test.go b/helpers_test.go index d57c54ae..31aec506 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -3,24 +3,3 @@ // license that can be found in the LICENSE file. package xorm - -import "testing" - -func TestSplitTag(t *testing.T) { - var cases = []struct { - tag string - tags []string - }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, - } - - for _, kase := range cases { - tags := splitTag(kase.tag) - if !sliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } - } -} diff --git a/interface.go b/interface.go index 4adee92e..e15126a7 100644 --- a/interface.go +++ b/interface.go @@ -28,7 +28,7 @@ type Interface interface { Delete(interface{}) (int64, error) Distinct(columns ...string) *Session DropIndexes(bean interface{}) error - EagerLoad(interface{}, ...string) error + EagerGet(interface{}, ...string) error Exec(string, ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error diff --git a/session.go b/session.go index 830011bb..82741b06 100644 --- a/session.go +++ b/session.go @@ -654,8 +654,18 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } pk[0] = reflect.ValueOf(pk[0]).Elem().Interface() - if err = session.getByPK(pk, fieldValue); err != nil { - return nil, err + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type())) + } + if err = session.getByPK(pk, fieldValue); err != nil { + return nil, err + } + } else { + v := fieldValue.Addr() + if err = session.getByPK(pk, &v); err != nil { + return nil, err + } } hasAssigned = true } else if col.AssociateType == core.AssociateBelongsTo { diff --git a/session_associate.go b/session_associate.go index 8e2c6231..e11be89f 100644 --- a/session_associate.go +++ b/session_associate.go @@ -11,8 +11,77 @@ import ( "github.com/go-xorm/core" ) -// EagerLoad load bean's belongs to tag field immedicatlly -func (session *Session) EagerLoad(bean interface{}, cols ...string) error { +// EagerFind load 's belongs to tag field immedicatlly +func (session *Session) EagerFind(slices interface{}, cols ...string) error { + /*v := reflect.ValueOf(slices) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Slice { + return errors.New("only slice is supported") + } + + if v.Len() <= 0 { + return nil + } + + vv := v.Index(0) + if vv.Kind() == reflect.Ptr { + vv = vv.Elem() + } + tb, err := session.Engine.autoMapType(vv) + if err != nil { + return err + } + + var pks = make(map[string][]core.PK) + for i := 0; i < v.Len(); i++ { + ev := v.Index(i) + + for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + if col.AssociateTable != nil { + if col.AssociateType == core.AssociateBelongsTo { + colV, err := col.ValueOfV(&ev) + if err != nil { + return err + } + + pk, err := session.Engine.idOfV(*colV) + if err != nil { + return err + } + var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + } + + if !isZero(pk[0]) { + pks[col.Name] = append(pks[col.Name], pk) + } + } + } + } + } + + for colName, pk := range pks { + slice := reflect.MakeSlice(tp, 0, len(pk)) + err = session.In("", pk).Find(slice.Addr().Interafce()) + if err != nil { + return err + } + + }*/ + return nil +} + +// EagerGet load bean's belongs to tag field immedicatlly +func (session *Session) EagerGet(bean interface{}, cols ...string) error { if session.isAutoClose { defer session.Close() } @@ -24,6 +93,10 @@ func (session *Session) EagerLoad(bean interface{}, cols ...string) error { } for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + if col.AssociateTable != nil { if col.AssociateType == core.AssociateBelongsTo { colV, err := col.ValueOfV(&v) diff --git a/blongs_to_test.go b/session_associate_test.go similarity index 98% rename from blongs_to_test.go rename to session_associate_test.go index 8a0cbfcb..021277d9 100644 --- a/blongs_to_test.go +++ b/session_associate_test.go @@ -50,7 +50,7 @@ func TestBelongsTo_Get(t *testing.T) { assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) assert.Equal(t, "", cfgNose.Face.Name) - err = testEngine.EagerLoad(&cfgNose) + err = testEngine.EagerGet(&cfgNose) assert.NoError(t, err) assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) @@ -104,7 +104,7 @@ func TestBelongsTo_GetPtr(t *testing.T) { assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) - err = testEngine.EagerLoad(&cfgNose) + err = testEngine.EagerGet(&cfgNose) assert.NoError(t, err) assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) diff --git a/tag.go b/tag.go index 835c65f3..8c7f5b95 100644 --- a/tag.go +++ b/tag.go @@ -35,6 +35,26 @@ type tagContext struct { ignoreNext bool } +func splitTag(tag string) (tags []string) { + tag = strings.TrimSpace(tag) + var hasQuote = false + var lastIdx = 0 + for i, t := range tag { + if t == '\'' { + hasQuote = !hasQuote + } else if t == ' ' { + if lastIdx < i && !hasQuote { + tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) + lastIdx = i + 1 + } + } + } + if lastIdx < len(tag) { + tags = append(tags, strings.TrimSpace(tag[lastIdx:])) + } + return +} + // tagHandler describes tag handler for XORM type tagHandler func(ctx *tagContext) error diff --git a/tag_test.go b/tag_test.go index c9b76048..f97f6a10 100644 --- a/tag_test.go +++ b/tag_test.go @@ -393,3 +393,22 @@ func TestTagTime(t *testing.T) { assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) } + +func TestSplitTag(t *testing.T) { + var cases = []struct { + tag string + tags []string + }{ + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, + {"TEXT", []string{"TEXT"}}, + {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, + {"json binary", []string{"json", "binary"}}, + } + + for _, kase := range cases { + tags := splitTag(kase.tag) + if !sliceEq(tags, kase.tags) { + t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) + } + } +}