diff --git a/integrations/session_associate_test.go b/integrations/session_associate_test.go new file mode 100644 index 00000000..d34549d9 --- /dev/null +++ b/integrations/session_associate_test.go @@ -0,0 +1,232 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBelongsTo_Get(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Face1 struct { + Id int64 + Name string + } + + type Nose1 struct { + Id int64 + Face Face1 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose1), new(Face1)) + assert.NoError(t, err) + + var face = Face1{ + Name: "face1", + } + _, err = testEngine.Insert(&face) + assert.NoError(t, err) + + var cfgFace Face1 + has, err := testEngine.Get(&cfgFace) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, face, cfgFace) + + var nose = Nose1{Face: face} + _, err = testEngine.Insert(&nose) + assert.NoError(t, err) + + var cfgNose Nose1 + has, err = testEngine.Get(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "", cfgNose.Face.Name) + + err = testEngine.Load(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "face1", cfgNose.Face.Name) + + var cfgNose2 Nose1 + has, err = testEngine.Cascade().Get(&cfgNose2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose2.Id) + assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) + assert.Equal(t, "face1", cfgNose2.Face.Name) +} + +func TestBelongsTo_GetPtr(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Face2 struct { + Id int64 + Name string + } + + type Nose2 struct { + Id int64 + Face *Face2 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose2), new(Face2)) + assert.NoError(t, err) + + var face = Face2{ + Name: "face1", + } + _, err = testEngine.Insert(&face) + assert.NoError(t, err) + + var cfgFace Face2 + has, err := testEngine.Get(&cfgFace) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, face, cfgFace) + + var nose = Nose2{Face: &face} + _, err = testEngine.Insert(&nose) + assert.NoError(t, err) + + var cfgNose Nose2 + has, err = testEngine.Get(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + + err = testEngine.Load(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "face1", cfgNose.Face.Name) + + var cfgNose2 Nose2 + has, err = testEngine.Cascade().Get(&cfgNose2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose2.Id) + assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) + assert.Equal(t, "face1", cfgNose2.Face.Name) +} + +func TestBelongsTo_Find(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Face3 struct { + Id int64 + Name string + } + + type Nose3 struct { + Id int64 + Face Face3 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose3), new(Face3)) + assert.NoError(t, err) + + var face1 = Face3{ + Name: "face1", + } + var face2 = Face3{ + Name: "face2", + } + _, err = testEngine.Insert(&face1, &face2) + assert.NoError(t, err) + + var noses = []Nose3{ + {Face: face1}, + {Face: face2}, + } + _, err = testEngine.Insert(&noses) + assert.NoError(t, err) + + var noses1 []Nose3 + err = testEngine.Find(&noses1) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses1)) + assert.Equal(t, face1.Id, noses1[0].Face.Id) + assert.Equal(t, face2.Id, noses1[1].Face.Id) + assert.Equal(t, "", noses1[0].Face.Name) + assert.Equal(t, "", noses1[1].Face.Name) + + var noses2 []Nose3 + err = testEngine.Cascade().Find(&noses2) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses2)) + assert.Equal(t, face1.Id, noses2[0].Face.Id) + assert.Equal(t, face2.Id, noses2[1].Face.Id) + assert.Equal(t, "face1", noses2[0].Face.Name) + assert.Equal(t, "face2", noses2[1].Face.Name) + + err = testEngine.Load(noses1, "face") + assert.NoError(t, err) + assert.Equal(t, "face1", noses1[0].Face.Name) + assert.Equal(t, "face2", noses1[1].Face.Name) +} + +func TestBelongsTo_FindPtr(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Face4 struct { + Id int64 + Name string + } + + type Nose4 struct { + Id int64 + Face *Face4 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose4), new(Face4)) + assert.NoError(t, err) + + var face1 = Face4{ + Name: "face1", + } + var face2 = Face4{ + Name: "face2", + } + _, err = testEngine.Insert(&face1, &face2) + assert.NoError(t, err) + + var noses = []Nose4{ + {Face: &face1}, + {Face: &face2}, + } + _, err = testEngine.Insert(&noses) + assert.NoError(t, err) + + var noses1 []Nose4 + err = testEngine.Find(&noses1) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses1)) + assert.Equal(t, face1.Id, noses1[0].Face.Id) + assert.Equal(t, face2.Id, noses1[1].Face.Id) + assert.Equal(t, "", noses1[0].Face.Name) + assert.Equal(t, "", noses1[1].Face.Name) + + var noses2 []Nose4 + err = testEngine.Cascade().Find(&noses2) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses2)) + assert.NotNil(t, noses2[0].Face) + assert.NotNil(t, noses2[1].Face) + assert.Equal(t, face1.Id, noses2[0].Face.Id) + assert.Equal(t, face2.Id, noses2[1].Face.Id) + assert.Equal(t, "face1", noses2[0].Face.Name) + assert.Equal(t, "face2", noses2[1].Face.Name) + + err = testEngine.Load(noses2, "face") + assert.NoError(t, err) +} diff --git a/internal/statements/associate.go b/internal/statements/associate.go new file mode 100644 index 00000000..5659ddc9 --- /dev/null +++ b/internal/statements/associate.go @@ -0,0 +1,14 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +type cascadeMode int + +const ( + cascadeCompitable cascadeMode = iota // load field beans with another SQL with no + cascadeEager // load field beans with another SQL + cascadeJoin // load field beans with join + cascadeLazy // don't load anything +) diff --git a/session_associate.go b/session_associate.go new file mode 100644 index 00000000..62e924a3 --- /dev/null +++ b/session_associate.go @@ -0,0 +1,145 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "errors" + "reflect" + + "github.com/go-xorm/core" +) + +// Load loads associated fields from database +func (session *Session) Load(beanOrSlices interface{}, cols ...string) error { + v := reflect.ValueOf(beanOrSlices) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() == reflect.Slice { + return session.loadFind(beanOrSlices, cols...) + } else if v.Kind() == reflect.Struct { + return session.loadGet(beanOrSlices, cols...) + } + return errors.New("unsupported load type, must struct or slice") +} + +// loadFind load 's belongs to tag field immedicatlly +func (session *Session) loadFind(slices interface{}, cols ...string) error { + v := reflect.ValueOf(slices) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Slice { + return errors.New("only slice is supported") + } + + if v.Len() <= 0 { + return nil + } + + vv := v.Index(0) + if vv.Kind() == reflect.Ptr { + vv = vv.Elem() + } + tb, err := session.engine.autoMapType(vv) + if err != nil { + return err + } + + var pks = make(map[*core.Column][]interface{}) + for i := 0; i < v.Len(); i++ { + ev := v.Index(i) + + for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + if col.AssociateTable != nil { + if col.AssociateType == core.AssociateBelongsTo { + colV, err := col.ValueOfV(&ev) + if err != nil { + return err + } + + pk, err := session.engine.idOfV(*colV) + if err != nil { + return err + } + /*var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + }*/ + + if !isZero(pk[0]) { + pks[col] = append(pks[col], pk[0]) + } + } + } + } + } + + for col, pk := range pks { + slice := reflect.MakeSlice(col.FieldType, 0, len(pk)) + err = session.In(col.Name, pk...).find(slice.Addr().Interface()) + if err != nil { + return err + } + } + return nil +} + +// loadGet load bean's belongs to tag field immedicatlly +func (session *Session) loadGet(bean interface{}, cols ...string) error { + if session.isAutoClose { + defer session.Close() + } + + v := rValue(bean) + tb, err := session.engine.autoMapType(v) + if err != nil { + return err + } + + for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + if col.AssociateTable != nil { + if col.AssociateType == core.AssociateBelongsTo { + colV, err := col.ValueOfV(&v) + if err != nil { + return err + } + + pk, err := session.engine.idOfV(*colV) + if err != nil { + return err + } + var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + } + + if !isZero(pk[0]) && session.cascadeLevel > 0 { + has, err := session.ID(pk).NoAutoCondition().get(colPtr.Interface()) + if err != nil { + return err + } + if !has { + return errors.New("load bean does not exist") + } + session.cascadeLevel-- + } + } + } + } + return nil +} diff --git a/tags/tag.go b/tags/tag.go index 4e1f1ce7..cb5dde79 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -5,6 +5,7 @@ package tags import ( + "errors" "fmt" "reflect" "strconv" @@ -102,28 +103,29 @@ type Handler func(ctx *Context) error var ( // defaultTagHandlers enumerates all the default tag handler defaultTagHandlers = map[string]Handler{ - "-": IgnoreHandler, - "<-": OnlyFromDBTagHandler, - "->": OnlyToDBTagHandler, - "PK": PKTagHandler, - "NULL": NULLTagHandler, - "NOT": NotTagHandler, - "AUTOINCR": AutoIncrTagHandler, - "DEFAULT": DefaultTagHandler, - "CREATED": CreatedTagHandler, - "UPDATED": UpdatedTagHandler, - "DELETED": DeletedTagHandler, - "VERSION": VersionTagHandler, - "UTC": UTCTagHandler, - "LOCAL": LocalTagHandler, - "NOTNULL": NotNullTagHandler, - "INDEX": IndexTagHandler, - "UNIQUE": UniqueTagHandler, - "CACHE": CacheTagHandler, - "NOCACHE": NoCacheTagHandler, - "COMMENT": CommentTagHandler, - "EXTENDS": ExtendsTagHandler, - "UNSIGNED": UnsignedTagHandler, + "-": IgnoreHandler, + "<-": OnlyFromDBTagHandler, + "->": OnlyToDBTagHandler, + "PK": PKTagHandler, + "NULL": NULLTagHandler, + "NOT": NotTagHandler, + "AUTOINCR": AutoIncrTagHandler, + "DEFAULT": DefaultTagHandler, + "CREATED": CreatedTagHandler, + "UPDATED": UpdatedTagHandler, + "DELETED": DeletedTagHandler, + "VERSION": VersionTagHandler, + "UTC": UTCTagHandler, + "LOCAL": LocalTagHandler, + "NOTNULL": NotNullTagHandler, + "INDEX": IndexTagHandler, + "UNIQUE": UniqueTagHandler, + "CACHE": CacheTagHandler, + "NOCACHE": NoCacheTagHandler, + "COMMENT": CommentTagHandler, + "EXTENDS": ExtendsTagHandler, + "UNSIGNED": UnsignedTagHandler, + "BELONGS_TO": BelongsToTagHandler, } ) @@ -398,3 +400,44 @@ func NoCacheTagHandler(ctx *Context) error { } return nil } + +// BelongsToTagHandler describes belongs_to tag handler +func BelongsToTagHandler(ctx *Context) error { + if !isStruct(ctx.fieldValue.Type()) { + return errors.New("Tag belongs_to cannot be applied on non-struct field") + } + + ctx.col.AssociateType = core.AssociateBelongsTo + var t reflect.Value + if ctx.fieldValue.Kind() == reflect.Struct { + t = ctx.fieldValue + } else { + if ctx.fieldValue.Type().Kind() == reflect.Ptr && ctx.fieldValue.Type().Elem().Kind() == reflect.Struct { + if ctx.fieldValue.IsNil() { + t = reflect.New(ctx.fieldValue.Type().Elem()).Elem() + } else { + t = ctx.fieldValue + } + } else { + return errors.New("Only struct or ptr to struct field could add belongs_to flag") + } + } + + belongsT, err := ctx.engine.mapType(ctx.parsingTables, t) + if err != nil { + return err + } + pks := belongsT.PKColumns() + if len(pks) != 1 { + panic("unsupported non or composited primary key cascade") + return errors.New("blongs_to only should be as a tag of table has one primary key") + } + + ctx.col.AssociateTable = belongsT + ctx.col.SQLType = pks[0].SQLType + + if len(ctx.col.Name) == 0 { + ctx.col.Name = ctx.engine.ColumnMapper.Obj2Table(ctx.col.FieldName) + "_id" + } + return nil +}