diff --git a/session.go b/session.go index 9a82b1aa..90cce292 100644 --- a/session.go +++ b/session.go @@ -269,8 +269,8 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) statement := session.Statement defer session.Statement.Init() sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - if sliceValue.Kind() != reflect.Slice { - return errors.New("needs a pointer to a slice") + if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { + return errors.New("needs a pointer to a slice or a map") } sliceElementType := sliceValue.Type().Elem() @@ -290,13 +290,27 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return err } - for _, results := range resultsSlice { + for i, results := range resultsSlice { newValue := reflect.New(sliceElementType) err := session.scanMapIntoStruct(newValue.Interface(), results) if err != nil { return err } - sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + if sliceValue.Kind() == reflect.Slice { + sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + } else if sliceValue.Kind() == reflect.Map { + var key int64 + if table.PrimaryKey != "" { + x, err := strconv.ParseInt(string(results[table.PrimaryKey]), 10, 64) + if err != nil { + return errors.New("pk " + table.PrimaryKey + " as int64: " + err.Error()) + } + key = x + } else { + key = int64(i) + } + sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(newValue.Interface()))) + } } return nil } diff --git a/xorm_test.go b/xorm_test.go index 381d771c..81af8189 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -186,6 +186,16 @@ func find(t *testing.T) { fmt.Println(users) } +func findMap(t *testing.T) { + users := make(map[int64]Userinfo) + + err := engine.Find(&users) + if err != nil { + t.Error(err) + } + fmt.Println(users) +} + func count(t *testing.T) { user := Userinfo{Departname: "dev"} total, err := engine.Count(&user) @@ -430,6 +440,7 @@ func TestMysql(t *testing.T) { delete(t) get(t) find(t) + findMap(t) count(t) where(t) in(t) @@ -460,6 +471,7 @@ func TestSqlite(t *testing.T) { delete(t) get(t) find(t) + findMap(t) count(t) where(t) in(t)