From b88257902642a881683bf2ebc555d45fe5b71fd5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 22 Jul 2021 13:22:35 +0800 Subject: [PATCH] Fix test --- session.go | 58 ++++++++----------- session_associate.go | 130 ++++++++++++++++++++++++++++++++----------- session_get.go | 36 ++++++++++++ 3 files changed, 157 insertions(+), 67 deletions(-) diff --git a/session.go b/session.go index 6891864f..a2781d75 100644 --- a/session.go +++ b/session.go @@ -10,7 +10,6 @@ import ( "crypto/sha256" "database/sql" "encoding/hex" - "errors" "fmt" "hash/crc32" "io" @@ -631,9 +630,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } fieldValue.Set(reflect.ValueOf(v).Elem().Convert(fieldType)) return nil - } - - if fieldType.ConvertibleTo(schemas.TimeType) { + } else if fieldType.ConvertibleTo(schemas.TimeType) { dbTZ := session.engine.DatabaseTZ if col.TimeZone != nil { dbTZ = col.TimeZone @@ -647,42 +644,35 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType)) return nil } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err := nulVal.Scan(scanResult) - if err == nil { - return nil - } - session.engine.logger.Errorf("sql.Sanner error: %v", err) - } else if session.statement.UseCascade { - table, err := session.engine.tagParser.ParseWithCache(*fieldValue) - if err != nil { - return err - } - - if len(table.PrimaryKeys) != 1 { - return errors.New("unsupported non or composited primary key cascade") - } - var pk = make(schemas.PK, len(table.PrimaryKeys)) + return nulVal.Scan(scanResult) + } else if session.cascadeLevel > 0 && ((col.AssociateType == schemas.AssociateNone && + session.cascadeMode == cascadeCompitable) || + (col.AssociateType == schemas.AssociateBelongsTo && + session.cascadeMode == cascadeEager)) { + var pk = make(schemas.PK, len(col.AssociateTable.PrimaryKeys)) + var err error pk[0], err = asKind(vv, reflect.TypeOf(scanResult)) if err != nil { return err } - if !pk.IsZero() { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - fieldValue.Set(structInter.Elem()) - } else { - return errors.New("cascade obj is not exist") - } - } + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + return session.getStructByPK(pk, fieldValue) + }, + session: session, + bean: fieldValue, + }) + session.cascadeLevel-- return nil + } else if col.AssociateType == schemas.AssociateBelongsTo { + pkCols := col.AssociateTable.PKColumns() + colV, err := pkCols[0].ValueOfV(fieldValue) + if err != nil { + return err + } + return convertAssignV(*colV, scanResult) } } // switch fieldType.Kind() diff --git a/session_associate.go b/session_associate.go index 7e6041d2..2951b6fc 100644 --- a/session_associate.go +++ b/session_associate.go @@ -19,11 +19,13 @@ func (session *Session) Load(beanOrSlices interface{}, cols ...string) error { v = v.Elem() } if v.Kind() == reflect.Slice { - return session.loadFind(beanOrSlices, cols...) + return session.loadFindSlice(v, cols...) + } else if v.Kind() == reflect.Map { + return session.loadFindMap(v, cols...) } else if v.Kind() == reflect.Struct { - return session.loadGet(beanOrSlices, cols...) + return session.loadGet(v, cols...) } - return errors.New("unsupported load type, must struct or slice") + return errors.New("unsupported load type, must struct, slice or map") } func isStringInSlice(s string, slice []string) bool { @@ -36,11 +38,7 @@ func isStringInSlice(s string, slice []string) bool { } // loadFind load 's belongs to tag field immedicatlly -func (session *Session) loadFind(slices interface{}, cols ...string) error { - v := reflect.ValueOf(slices) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } +func (session *Session) loadFindSlice(v reflect.Value, cols ...string) error { if v.Kind() != reflect.Slice { return errors.New("only slice is supported") } @@ -100,13 +98,73 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error { return nil } +// loadFindMap load 's belongs to tag field immedicatlly +func (session *Session) loadFindMap(v reflect.Value, cols ...string) error { + if v.Kind() != reflect.Map { + return errors.New("only map is supported") + } + + if v.Len() <= 0 { + return nil + } + + vv := v.Index(0) + if vv.Kind() == reflect.Ptr { + vv = vv.Elem() + } + tb, err := session.engine.tagParser.ParseWithCache(vv) + if err != nil { + return err + } + + var pks = make(map[*schemas.Column][]interface{}) + for i := 0; i < v.Len(); i++ { + ev := v.Index(i) + + for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + if col.AssociateTable != nil { + if col.AssociateType == schemas.AssociateBelongsTo { + colV, err := col.ValueOfV(&ev) + if err != nil { + return err + } + + vv := colV.Interface() + /*var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + }*/ + + if !utils.IsZero(vv) { + pks[col] = append(pks[col], vv) + } + } + } + } + } + + for col, pk := range pks { + slice := reflect.MakeSlice(col.FieldType, 0, len(pk)) + err = session.In(col.Name, pk...).find(slice.Addr().Interface()) + if err != nil { + return err + } + } + return nil +} + // loadGet load bean's belongs to tag field immedicatlly -func (session *Session) loadGet(bean interface{}, cols ...string) error { +func (session *Session) loadGet(v reflect.Value, cols ...string) error { if session.isAutoClose { defer session.Close() } - v := reflect.Indirect(reflect.ValueOf(bean)) tb, err := session.engine.tagParser.ParseWithCache(v) if err != nil { return err @@ -117,32 +175,38 @@ func (session *Session) loadGet(bean interface{}, cols ...string) error { continue } - if col.AssociateTable != nil { - if col.AssociateType == schemas.AssociateBelongsTo { - colV, err := col.ValueOfV(&v) - if err != nil { - return err - } + if col.AssociateTable == nil || col.AssociateType != schemas.AssociateBelongsTo { + continue + } - vv := colV.Interface() - var colPtr reflect.Value - if colV.Kind() == reflect.Ptr { - colPtr = *colV - } else { - colPtr = colV.Addr() - } + colV, err := col.ValueOfV(&v) + if err != nil { + return err + } - if !utils.IsZero(vv) && session.cascadeLevel > 0 { - has, err := session.ID(vv).NoAutoCondition().get(colPtr.Interface()) - if err != nil { - return err - } - if !has { - return errors.New("load bean does not exist") - } - session.cascadeLevel-- - } + var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + } + + pks := col.AssociateTable.PKColumns() + pkV, err := pks[0].ValueOfV(colV) + if err != nil { + return err + } + vv := pkV.Interface() + + if !utils.IsZero(vv) && session.cascadeLevel > 0 { + has, err := session.ID(vv).NoAutoCondition().get(colPtr.Interface()) + if err != nil { + return err } + if !has { + return errors.New("load bean does not exist") + } + session.cascadeLevel-- } } return nil diff --git a/session_get.go b/session_get.go index 48616a6b..9bf916ce 100644 --- a/session_get.go +++ b/session_get.go @@ -280,6 +280,42 @@ func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields } } +func (session *Session) getStructByPK(pk schemas.PK, fieldValue *reflect.Value) error { + if pk.IsZero() { + return errors.New("getStructByPK: primary key is zero") + } + + var structInter reflect.Value + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + structInter = reflect.New(fieldValue.Type().Elem()) + } else { + structInter = *fieldValue + } + } else { + structInter = fieldValue.Addr() + } + + has, err := session.ID(pk).NoAutoCondition().get(structInter.Interface()) + if err != nil { + return err + } + if !has { + return errors.New("cascade obj is not exist") + } + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(structInter) + fmt.Println("getByPK value ptr:", fieldValue.Interface()) + return nil + } else if fieldValue.Kind() == reflect.Struct { + fieldValue.Set(structInter.Elem()) + fmt.Println("getByPK value:", fieldValue.Interface()) + return nil + } + return errors.New("set value failed") + +} + func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { // if has no reftable, then don't use cache currently if !session.canCache() {