Compile pass

This commit is contained in:
Lunny Xiao 2021-07-22 11:25:37 +08:00
parent 06d4d50e82
commit 68f18c80e2
6 changed files with 148 additions and 30 deletions

View File

@ -0,0 +1,63 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package integrations
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestExtendsTag(t *testing.T) {
assert.NoError(t, prepareEngine())
table := testEngine.TableInfo(new(Userdetail))
assert.NotNil(t, table)
assert.EqualValues(t, 3, len(table.ColumnsSeq()))
assert.EqualValues(t, "id", table.ColumnsSeq()[0])
assert.EqualValues(t, "intro", table.ColumnsSeq()[1])
assert.EqualValues(t, "profile", table.ColumnsSeq()[2])
table = testEngine.TableInfo(new(Userinfo))
assert.NotNil(t, table)
assert.EqualValues(t, 8, len(table.ColumnsSeq()))
assert.EqualValues(t, "id", table.ColumnsSeq()[0])
assert.EqualValues(t, "username", table.ColumnsSeq()[1])
assert.EqualValues(t, "departname", table.ColumnsSeq()[2])
assert.EqualValues(t, "created", table.ColumnsSeq()[3])
assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4])
assert.EqualValues(t, "height", table.ColumnsSeq()[5])
assert.EqualValues(t, "avatar", table.ColumnsSeq()[6])
assert.EqualValues(t, "is_man", table.ColumnsSeq()[7])
table = testEngine.TableInfo(new(UserAndDetail))
assert.NotNil(t, table)
assert.EqualValues(t, 11, len(table.ColumnsSeq()))
assert.EqualValues(t, "id", table.ColumnsSeq()[0])
assert.EqualValues(t, "username", table.ColumnsSeq()[1])
assert.EqualValues(t, "departname", table.ColumnsSeq()[2])
assert.EqualValues(t, "created", table.ColumnsSeq()[3])
assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4])
assert.EqualValues(t, "height", table.ColumnsSeq()[5])
assert.EqualValues(t, "avatar", table.ColumnsSeq()[6])
assert.EqualValues(t, "is_man", table.ColumnsSeq()[7])
assert.EqualValues(t, "id", table.ColumnsSeq()[8])
assert.EqualValues(t, "intro", table.ColumnsSeq()[9])
assert.EqualValues(t, "profile", table.ColumnsSeq()[10])
cols := table.Columns()
assert.EqualValues(t, 11, len(cols))
assert.EqualValues(t, "Userinfo.Uid", cols[0].FieldName)
assert.EqualValues(t, "Userinfo.Username", cols[1].FieldName)
assert.EqualValues(t, "Userinfo.Departname", cols[2].FieldName)
assert.EqualValues(t, "Userinfo.Created", cols[3].FieldName)
assert.EqualValues(t, "Userinfo.Detail", cols[4].FieldName)
assert.EqualValues(t, "Userinfo.Height", cols[5].FieldName)
assert.EqualValues(t, "Userinfo.Avatar", cols[6].FieldName)
assert.EqualValues(t, "Userinfo.IsMan", cols[7].FieldName)
assert.EqualValues(t, "Userdetail.Id", cols[8].FieldName)
assert.EqualValues(t, "Userdetail.Intro", cols[9].FieldName)
assert.EqualValues(t, "Userdetail.Profile", cols[10].FieldName)
}

12
schemas/associate.go Normal file
View File

@ -0,0 +1,12 @@
// Copyright 2021 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
type AssociateType int
const (
AssociateNone AssociateType = iota
AssociateBelongsTo
)

View File

@ -22,8 +22,9 @@ const (
type Column struct {
Name string
TableName string
FieldName string // Available only when parsed from a struct
FieldIndex []int // Available only when parsed from a struct
FieldName string // Available only when parsed from a struct
FieldIndex []int // Available only when parsed from a struct
FieldType reflect.Type // Available only when parsed from a struct
SQLType SQLType
IsJSON bool
Length int
@ -45,6 +46,8 @@ type Column struct {
DisableTimeZone bool
TimeZone *time.Location // column specified time zone
Comment string
AssociateType
AssociateTable *Table
}
// NewColumn creates a new column
@ -71,6 +74,8 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable
DefaultIsEmpty: true, // default should be no default
EnumOptions: make(map[string]int),
Comment: "",
AssociateType: AssociateNone,
AssociateTable: nil,
}
}

View File

@ -54,6 +54,22 @@ const (
groupSession sessionType = true
)
type cascadeMode int
const (
cascadeCompitable cascadeMode = iota // load field beans with another SQL with no
cascadeEager // load field beans with another SQL
cascadeJoin // load field beans with join
cascadeLazy // don't load anything
)
type loadClosure struct {
Func func(schemas.PK, *reflect.Value) error
pk schemas.PK
fieldValue *reflect.Value
loaded bool
}
// Session keep a pointer to sql.DB and provides all execution of all
// kind of database operations.
type Session struct {
@ -86,6 +102,9 @@ type Session struct {
ctx context.Context
sessionType sessionType
cascadeMode cascadeMode
cascadeLevel int // load level
}
func newSessionID() string {
@ -134,7 +153,9 @@ func newSession(engine *Engine) *Session {
lastSQL: "",
lastSQLArgs: make([]interface{}, 0),
sessionType: engineSession,
sessionType: engineSession,
cascadeMode: cascadeCompitable,
cascadeLevel: 2,
}
if engine.logSessionID {
session.ctx = context.WithValue(session.ctx, log.SessionKey, session)
@ -241,7 +262,7 @@ func (session *Session) Alias(alias string) *Session {
// NoCascade indicate that no cascade load child object
func (session *Session) NoCascade() *Session {
session.statement.UseCascade = false
session.cascadeMode = cascadeLazy
return session
}
@ -296,9 +317,15 @@ func (session *Session) Charset(charset string) *Session {
// Cascade indicates if loading sub Struct
func (session *Session) Cascade(trueOrFalse ...bool) *Session {
var mode = cascadeEager
if len(trueOrFalse) >= 1 {
session.statement.UseCascade = trueOrFalse[0]
if trueOrFalse[0] {
mode = cascadeEager
} else {
mode = cascadeLazy
}
}
session.cascadeMode = mode
return session
}

View File

@ -8,7 +8,8 @@ import (
"errors"
"reflect"
"github.com/go-xorm/core"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
// Load loads associated fields from database
@ -25,6 +26,15 @@ func (session *Session) Load(beanOrSlices interface{}, cols ...string) error {
return errors.New("unsupported load type, must struct or slice")
}
func isStringInSlice(s string, slice []string) bool {
for _, e := range slice {
if s == e {
return true
}
}
return false
}
// loadFind load 's belongs to tag field immedicatlly
func (session *Session) loadFind(slices interface{}, cols ...string) error {
v := reflect.ValueOf(slices)
@ -43,12 +53,12 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error {
if vv.Kind() == reflect.Ptr {
vv = vv.Elem()
}
tb, err := session.engine.autoMapType(vv)
tb, err := session.engine.tagParser.ParseWithCache(vv)
if err != nil {
return err
}
var pks = make(map[*core.Column][]interface{})
var pks = make(map[*schemas.Column][]interface{})
for i := 0; i < v.Len(); i++ {
ev := v.Index(i)
@ -58,16 +68,13 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error {
}
if col.AssociateTable != nil {
if col.AssociateType == core.AssociateBelongsTo {
if col.AssociateType == schemas.AssociateBelongsTo {
colV, err := col.ValueOfV(&ev)
if err != nil {
return err
}
pk, err := session.engine.idOfV(*colV)
if err != nil {
return err
}
vv := colV.Interface()
/*var colPtr reflect.Value
if colV.Kind() == reflect.Ptr {
colPtr = *colV
@ -75,8 +82,8 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error {
colPtr = colV.Addr()
}*/
if !isZero(pk[0]) {
pks[col] = append(pks[col], pk[0])
if !utils.IsZero(vv) {
pks[col] = append(pks[col], vv)
}
}
}
@ -99,8 +106,8 @@ func (session *Session) loadGet(bean interface{}, cols ...string) error {
defer session.Close()
}
v := rValue(bean)
tb, err := session.engine.autoMapType(v)
v := reflect.Indirect(reflect.ValueOf(bean))
tb, err := session.engine.tagParser.ParseWithCache(v)
if err != nil {
return err
}
@ -111,16 +118,13 @@ func (session *Session) loadGet(bean interface{}, cols ...string) error {
}
if col.AssociateTable != nil {
if col.AssociateType == core.AssociateBelongsTo {
if col.AssociateType == schemas.AssociateBelongsTo {
colV, err := col.ValueOfV(&v)
if err != nil {
return err
}
pk, err := session.engine.idOfV(*colV)
if err != nil {
return err
}
vv := colV.Interface()
var colPtr reflect.Value
if colV.Kind() == reflect.Ptr {
colPtr = *colV
@ -128,8 +132,8 @@ func (session *Session) loadGet(bean interface{}, cols ...string) error {
colPtr = colV.Addr()
}
if !isZero(pk[0]) && session.cascadeLevel > 0 {
has, err := session.ID(pk).NoAutoCondition().get(colPtr.Interface())
if !utils.IsZero(vv) && session.cascadeLevel > 0 {
has, err := session.ID(vv).NoAutoCondition().get(colPtr.Interface())
if err != nil {
return err
}

View File

@ -401,13 +401,21 @@ func NoCacheTagHandler(ctx *Context) error {
return nil
}
func isStruct(t reflect.Type) bool {
return t.Kind() == reflect.Struct || isPtrStruct(t)
}
func isPtrStruct(t reflect.Type) bool {
return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
}
// BelongsToTagHandler describes belongs_to tag handler
func BelongsToTagHandler(ctx *Context) error {
if !isStruct(ctx.fieldValue.Type()) {
return errors.New("Tag belongs_to cannot be applied on non-struct field")
return errors.New("tag belongs_to cannot be applied on non-struct field")
}
ctx.col.AssociateType = core.AssociateBelongsTo
ctx.col.AssociateType = schemas.AssociateBelongsTo
var t reflect.Value
if ctx.fieldValue.Kind() == reflect.Struct {
t = ctx.fieldValue
@ -419,17 +427,16 @@ func BelongsToTagHandler(ctx *Context) error {
t = ctx.fieldValue
}
} else {
return errors.New("Only struct or ptr to struct field could add belongs_to flag")
return errors.New("only struct or ptr to struct field could add belongs_to flag")
}
}
belongsT, err := ctx.engine.mapType(ctx.parsingTables, t)
belongsT, err := ctx.parser.ParseWithCache(t)
if err != nil {
return err
}
pks := belongsT.PKColumns()
if len(pks) != 1 {
panic("unsupported non or composited primary key cascade")
return errors.New("blongs_to only should be as a tag of table has one primary key")
}
@ -437,7 +444,7 @@ func BelongsToTagHandler(ctx *Context) error {
ctx.col.SQLType = pks[0].SQLType
if len(ctx.col.Name) == 0 {
ctx.col.Name = ctx.engine.ColumnMapper.Obj2Table(ctx.col.FieldName) + "_id"
ctx.col.Name = ctx.parser.columnMapper.Obj2Table(ctx.col.FieldName) + "_id"
}
return nil
}