From c8ac6bb65c52bf1371ceacf0c16b261d9908af1c Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 21 Feb 2017 11:06:40 +0800 Subject: [PATCH] Ask Get parameter is pointer and prepare for Get non-struct --- session_get.go | 49 +++++++++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/session_get.go b/session_get.go index ac0c5ebb..2b975c83 100644 --- a/session_get.go +++ b/session_get.go @@ -20,7 +20,19 @@ func (session *Session) Get(bean interface{}) (bool, error) { defer session.Close() } - session.Statement.setRefValue(rValue(bean)) + beanValue := reflect.ValueOf(bean) + if beanValue.Kind() != reflect.Ptr { + return false, errors.New("needs a pointer to a struct") + } + + // FIXME: remove this after support non-struct Get + if beanValue.Elem().Kind() != reflect.Struct { + return false, errors.New("needs a pointer to a struct") + } + + if beanValue.Elem().Kind() == reflect.Struct { + session.Statement.setRefValue(beanValue.Elem()) + } var sqlStr string var args []interface{} @@ -46,10 +58,10 @@ func (session *Session) Get(bean interface{}) (bool, error) { } } - return session.nocacheGet(bean, sqlStr, args...) + return session.nocacheGet(beanValue.Elem().Kind(), bean, sqlStr, args...) } -func (session *Session) nocacheGet(bean interface{}, sqlStr string, args ...interface{}) (bool, error) { +func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { var rawRows *core.Rows var err error session.queryPreprocess(&sqlStr, args...) @@ -66,9 +78,22 @@ func (session *Session) nocacheGet(bean interface{}, sqlStr string, args ...inte if rawRows.Next() { fields, err := rawRows.Columns() - if err == nil { - _, err = session.row2Bean(rawRows, fields, len(fields), bean) + if err != nil { + // WARN: Alougth rawRows return true, but get fields failed + return true, err } + + switch beanKind { + case reflect.Struct: + _, err = session.row2Bean(rawRows, fields, len(fields), bean) + case reflect.Slice: + err = rawRows.ScanSlice(bean) + case reflect.Map: + err = rawRows.ScanMap(bean) + default: + err = rawRows.Scan(bean) + } + return true, err } return false, nil @@ -145,20 +170,8 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } cacheBean := cacher.GetBean(tableName, sid) if cacheBean == nil { - /*newSession := session.Engine.NewSession() - defer newSession.Close() - cacheBean = reflect.New(structValue.Type()).Interface() - newSession.Id(id).NoCache() - if session.Statement.AltTableName != "" { - newSession.Table(session.Statement.AltTableName) - } - if !session.Statement.UseCascade { - newSession.NoCascade() - } - has, err = newSession.Get(cacheBean) - */ cacheBean = bean - has, err = session.nocacheGet(cacheBean, sqlStr, args...) + has, err = session.nocacheGet(reflect.Struct, cacheBean, sqlStr, args...) if err != nil || !has { return has, err }