From 68f18c80e2f031a99744275088b25339b7410cc6 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 22 Jul 2021 11:25:37 +0800 Subject: [PATCH] Compile pass --- integrations/session_tag_test.go | 63 ++++++++++++++++++++++++++++++++ schemas/associate.go | 12 ++++++ schemas/column.go | 9 ++++- session.go | 33 +++++++++++++++-- session_associate.go | 42 +++++++++++---------- tags/tag.go | 19 +++++++--- 6 files changed, 148 insertions(+), 30 deletions(-) create mode 100644 integrations/session_tag_test.go create mode 100644 schemas/associate.go diff --git a/integrations/session_tag_test.go b/integrations/session_tag_test.go new file mode 100644 index 00000000..0e273b3e --- /dev/null +++ b/integrations/session_tag_test.go @@ -0,0 +1,63 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtendsTag(t *testing.T) { + assert.NoError(t, prepareEngine()) + + table := testEngine.TableInfo(new(Userdetail)) + assert.NotNil(t, table) + assert.EqualValues(t, 3, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "intro", table.ColumnsSeq()[1]) + assert.EqualValues(t, "profile", table.ColumnsSeq()[2]) + + table = testEngine.TableInfo(new(Userinfo)) + assert.NotNil(t, table) + assert.EqualValues(t, 8, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "username", table.ColumnsSeq()[1]) + assert.EqualValues(t, "departname", table.ColumnsSeq()[2]) + assert.EqualValues(t, "created", table.ColumnsSeq()[3]) + assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4]) + assert.EqualValues(t, "height", table.ColumnsSeq()[5]) + assert.EqualValues(t, "avatar", table.ColumnsSeq()[6]) + assert.EqualValues(t, "is_man", table.ColumnsSeq()[7]) + + table = testEngine.TableInfo(new(UserAndDetail)) + assert.NotNil(t, table) + assert.EqualValues(t, 11, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "username", table.ColumnsSeq()[1]) + assert.EqualValues(t, "departname", table.ColumnsSeq()[2]) + assert.EqualValues(t, "created", table.ColumnsSeq()[3]) + assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4]) + assert.EqualValues(t, "height", table.ColumnsSeq()[5]) + assert.EqualValues(t, "avatar", table.ColumnsSeq()[6]) + assert.EqualValues(t, "is_man", table.ColumnsSeq()[7]) + assert.EqualValues(t, "id", table.ColumnsSeq()[8]) + assert.EqualValues(t, "intro", table.ColumnsSeq()[9]) + assert.EqualValues(t, "profile", table.ColumnsSeq()[10]) + + cols := table.Columns() + assert.EqualValues(t, 11, len(cols)) + assert.EqualValues(t, "Userinfo.Uid", cols[0].FieldName) + assert.EqualValues(t, "Userinfo.Username", cols[1].FieldName) + assert.EqualValues(t, "Userinfo.Departname", cols[2].FieldName) + assert.EqualValues(t, "Userinfo.Created", cols[3].FieldName) + assert.EqualValues(t, "Userinfo.Detail", cols[4].FieldName) + assert.EqualValues(t, "Userinfo.Height", cols[5].FieldName) + assert.EqualValues(t, "Userinfo.Avatar", cols[6].FieldName) + assert.EqualValues(t, "Userinfo.IsMan", cols[7].FieldName) + assert.EqualValues(t, "Userdetail.Id", cols[8].FieldName) + assert.EqualValues(t, "Userdetail.Intro", cols[9].FieldName) + assert.EqualValues(t, "Userdetail.Profile", cols[10].FieldName) +} diff --git a/schemas/associate.go b/schemas/associate.go new file mode 100644 index 00000000..3dfd85f8 --- /dev/null +++ b/schemas/associate.go @@ -0,0 +1,12 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +type AssociateType int + +const ( + AssociateNone AssociateType = iota + AssociateBelongsTo +) diff --git a/schemas/column.go b/schemas/column.go index 4bbb6c2d..7ded9677 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -22,8 +22,9 @@ const ( type Column struct { Name string TableName string - FieldName string // Available only when parsed from a struct - FieldIndex []int // Available only when parsed from a struct + FieldName string // Available only when parsed from a struct + FieldIndex []int // Available only when parsed from a struct + FieldType reflect.Type // Available only when parsed from a struct SQLType SQLType IsJSON bool Length int @@ -45,6 +46,8 @@ type Column struct { DisableTimeZone bool TimeZone *time.Location // column specified time zone Comment string + AssociateType + AssociateTable *Table } // NewColumn creates a new column @@ -71,6 +74,8 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable DefaultIsEmpty: true, // default should be no default EnumOptions: make(map[string]int), Comment: "", + AssociateType: AssociateNone, + AssociateTable: nil, } } diff --git a/session.go b/session.go index f5b45a73..6891864f 100644 --- a/session.go +++ b/session.go @@ -54,6 +54,22 @@ const ( groupSession sessionType = true ) +type cascadeMode int + +const ( + cascadeCompitable cascadeMode = iota // load field beans with another SQL with no + cascadeEager // load field beans with another SQL + cascadeJoin // load field beans with join + cascadeLazy // don't load anything +) + +type loadClosure struct { + Func func(schemas.PK, *reflect.Value) error + pk schemas.PK + fieldValue *reflect.Value + loaded bool +} + // Session keep a pointer to sql.DB and provides all execution of all // kind of database operations. type Session struct { @@ -86,6 +102,9 @@ type Session struct { ctx context.Context sessionType sessionType + + cascadeMode cascadeMode + cascadeLevel int // load level } func newSessionID() string { @@ -134,7 +153,9 @@ func newSession(engine *Engine) *Session { lastSQL: "", lastSQLArgs: make([]interface{}, 0), - sessionType: engineSession, + sessionType: engineSession, + cascadeMode: cascadeCompitable, + cascadeLevel: 2, } if engine.logSessionID { session.ctx = context.WithValue(session.ctx, log.SessionKey, session) @@ -241,7 +262,7 @@ func (session *Session) Alias(alias string) *Session { // NoCascade indicate that no cascade load child object func (session *Session) NoCascade() *Session { - session.statement.UseCascade = false + session.cascadeMode = cascadeLazy return session } @@ -296,9 +317,15 @@ func (session *Session) Charset(charset string) *Session { // Cascade indicates if loading sub Struct func (session *Session) Cascade(trueOrFalse ...bool) *Session { + var mode = cascadeEager if len(trueOrFalse) >= 1 { - session.statement.UseCascade = trueOrFalse[0] + if trueOrFalse[0] { + mode = cascadeEager + } else { + mode = cascadeLazy + } } + session.cascadeMode = mode return session } diff --git a/session_associate.go b/session_associate.go index 62e924a3..7e6041d2 100644 --- a/session_associate.go +++ b/session_associate.go @@ -8,7 +8,8 @@ import ( "errors" "reflect" - "github.com/go-xorm/core" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" ) // Load loads associated fields from database @@ -25,6 +26,15 @@ func (session *Session) Load(beanOrSlices interface{}, cols ...string) error { return errors.New("unsupported load type, must struct or slice") } +func isStringInSlice(s string, slice []string) bool { + for _, e := range slice { + if s == e { + return true + } + } + return false +} + // loadFind load 's belongs to tag field immedicatlly func (session *Session) loadFind(slices interface{}, cols ...string) error { v := reflect.ValueOf(slices) @@ -43,12 +53,12 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error { if vv.Kind() == reflect.Ptr { vv = vv.Elem() } - tb, err := session.engine.autoMapType(vv) + tb, err := session.engine.tagParser.ParseWithCache(vv) if err != nil { return err } - var pks = make(map[*core.Column][]interface{}) + var pks = make(map[*schemas.Column][]interface{}) for i := 0; i < v.Len(); i++ { ev := v.Index(i) @@ -58,16 +68,13 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error { } if col.AssociateTable != nil { - if col.AssociateType == core.AssociateBelongsTo { + if col.AssociateType == schemas.AssociateBelongsTo { colV, err := col.ValueOfV(&ev) if err != nil { return err } - pk, err := session.engine.idOfV(*colV) - if err != nil { - return err - } + vv := colV.Interface() /*var colPtr reflect.Value if colV.Kind() == reflect.Ptr { colPtr = *colV @@ -75,8 +82,8 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error { colPtr = colV.Addr() }*/ - if !isZero(pk[0]) { - pks[col] = append(pks[col], pk[0]) + if !utils.IsZero(vv) { + pks[col] = append(pks[col], vv) } } } @@ -99,8 +106,8 @@ func (session *Session) loadGet(bean interface{}, cols ...string) error { defer session.Close() } - v := rValue(bean) - tb, err := session.engine.autoMapType(v) + v := reflect.Indirect(reflect.ValueOf(bean)) + tb, err := session.engine.tagParser.ParseWithCache(v) if err != nil { return err } @@ -111,16 +118,13 @@ func (session *Session) loadGet(bean interface{}, cols ...string) error { } if col.AssociateTable != nil { - if col.AssociateType == core.AssociateBelongsTo { + if col.AssociateType == schemas.AssociateBelongsTo { colV, err := col.ValueOfV(&v) if err != nil { return err } - pk, err := session.engine.idOfV(*colV) - if err != nil { - return err - } + vv := colV.Interface() var colPtr reflect.Value if colV.Kind() == reflect.Ptr { colPtr = *colV @@ -128,8 +132,8 @@ func (session *Session) loadGet(bean interface{}, cols ...string) error { colPtr = colV.Addr() } - if !isZero(pk[0]) && session.cascadeLevel > 0 { - has, err := session.ID(pk).NoAutoCondition().get(colPtr.Interface()) + if !utils.IsZero(vv) && session.cascadeLevel > 0 { + has, err := session.ID(vv).NoAutoCondition().get(colPtr.Interface()) if err != nil { return err } diff --git a/tags/tag.go b/tags/tag.go index cb5dde79..c354617b 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -401,13 +401,21 @@ func NoCacheTagHandler(ctx *Context) error { return nil } +func isStruct(t reflect.Type) bool { + return t.Kind() == reflect.Struct || isPtrStruct(t) +} + +func isPtrStruct(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + // BelongsToTagHandler describes belongs_to tag handler func BelongsToTagHandler(ctx *Context) error { if !isStruct(ctx.fieldValue.Type()) { - return errors.New("Tag belongs_to cannot be applied on non-struct field") + return errors.New("tag belongs_to cannot be applied on non-struct field") } - ctx.col.AssociateType = core.AssociateBelongsTo + ctx.col.AssociateType = schemas.AssociateBelongsTo var t reflect.Value if ctx.fieldValue.Kind() == reflect.Struct { t = ctx.fieldValue @@ -419,17 +427,16 @@ func BelongsToTagHandler(ctx *Context) error { t = ctx.fieldValue } } else { - return errors.New("Only struct or ptr to struct field could add belongs_to flag") + return errors.New("only struct or ptr to struct field could add belongs_to flag") } } - belongsT, err := ctx.engine.mapType(ctx.parsingTables, t) + belongsT, err := ctx.parser.ParseWithCache(t) if err != nil { return err } pks := belongsT.PKColumns() if len(pks) != 1 { - panic("unsupported non or composited primary key cascade") return errors.New("blongs_to only should be as a tag of table has one primary key") } @@ -437,7 +444,7 @@ func BelongsToTagHandler(ctx *Context) error { ctx.col.SQLType = pks[0].SQLType if len(ctx.col.Name) == 0 { - ctx.col.Name = ctx.engine.ColumnMapper.Obj2Table(ctx.col.FieldName) + "_id" + ctx.col.Name = ctx.parser.columnMapper.Obj2Table(ctx.col.FieldName) + "_id" } return nil }