diff --git a/convert/conversion.go b/convert/conversion.go index 096fcfaf..78a9fd78 100644 --- a/convert/conversion.go +++ b/convert/conversion.go @@ -325,6 +325,9 @@ func AssignValue(dv reflect.Value, src interface{}) error { if src == nil { return nil } + if v, ok := src.(*interface{}); ok { + return AssignValue(dv, *v) + } if dv.Type().Implements(scannerType) { return dv.Interface().(sql.Scanner).Scan(src) diff --git a/internal/utils/new.go b/internal/utils/new.go new file mode 100644 index 00000000..e3b4eae8 --- /dev/null +++ b/internal/utils/new.go @@ -0,0 +1,25 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import "reflect" + +// New creates a value according type +func New(tp reflect.Type, length, cap int) reflect.Value { + switch tp.Kind() { + case reflect.Slice: + slice := reflect.MakeSlice(tp, length, cap) + x := reflect.New(slice.Type()) + x.Elem().Set(slice) + return x + case reflect.Map: + mp := reflect.MakeMapWithSize(tp, cap) + x := reflect.New(mp.Type()) + x.Elem().Set(mp) + return x + default: + return reflect.New(tp) + } +} diff --git a/session_find.go b/session_find.go index df3bd85d..47a3d308 100644 --- a/session_find.go +++ b/session_find.go @@ -161,6 +161,16 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { + elemType := containerValue.Type().Elem() + var isPointer bool + if elemType.Kind() == reflect.Ptr { + isPointer = true + elemType = elemType.Elem() + } + if elemType.Kind() == reflect.Ptr { + return errors.New("pointer to pointer is not supported") + } + rows, err := session.queryRows(sqlStr, args...) if err != nil { return err @@ -177,31 +187,8 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return err } - var newElemFunc func(fields []string) reflect.Value - elemType := containerValue.Type().Elem() - var isPointer bool - if elemType.Kind() == reflect.Ptr { - isPointer = true - elemType = elemType.Elem() - } - if elemType.Kind() == reflect.Ptr { - return errors.New("pointer to pointer is not supported") - } - - newElemFunc = func(fields []string) reflect.Value { - switch elemType.Kind() { - case reflect.Slice: - slice := reflect.MakeSlice(elemType, len(fields), len(fields)) - x := reflect.New(slice.Type()) - x.Elem().Set(slice) - return x - case reflect.Map: - mp := reflect.MakeMap(elemType) - x := reflect.New(mp.Type()) - x.Elem().Set(mp) - return x - } - return reflect.New(elemType) + var newElemFunc = func(fields []string) reflect.Value { + return utils.New(elemType, len(fields), len(fields)) } var containerValueSetFunc func(*reflect.Value, schemas.PK) error @@ -226,10 +213,15 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect containerValueSetFunc = func(newValue *reflect.Value, pk schemas.PK) error { keyValue := reflect.New(keyType) - err := convertPKToValue(table, keyValue.Interface(), pk) - if err != nil { - return err + cols := table.PKColumns() + if len(cols) == 1 { + if err := convert.AssignValue(keyValue, pk[0]); err != nil { + return err + } + } else { + keyValue.Set(reflect.ValueOf(&pk)) } + if isPointer { containerValue.SetMapIndex(keyValue.Elem(), newValue.Elem().Addr()) } else { @@ -241,8 +233,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect if elemType.Kind() == reflect.Struct { var newValue = newElemFunc(fields) - dataStruct := utils.ReflectValue(newValue.Interface()) - tb, err := session.engine.tagParser.ParseWithCache(dataStruct) + tb, err := session.engine.tagParser.ParseWithCache(newValue) if err != nil { return err } @@ -266,7 +257,6 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect default: err = rows.Scan(bean) } - if err != nil { return err } @@ -278,16 +268,6 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return rows.Err() } -func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { - cols := table.PKColumns() - if len(cols) == 1 { - return convert.Assign(dst, pk[0], nil, nil) - } - - dst = pk - return nil -} - func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { if !session.canCache() || utils.IndexNoCase(sqlStr, "having") != -1 ||