From 5ebae720bd8ed2ddea750c5a51c20a00c3d8900f Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 1 Apr 2017 10:09:00 +0800 Subject: [PATCH] add Scan features to Get method --- README.md | 12 ++++++ README_CN.md | 12 ++++++ session_get.go | 23 +++++------- session_get_test.go | 91 +++++++++++++++++++++++++++++++++++++++++++++ statement.go | 20 ++++++++-- xorm.go | 2 +- 6 files changed, 141 insertions(+), 19 deletions(-) create mode 100644 session_get_test.go diff --git a/README.md b/README.md index bdcbab68..8a209823 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,18 @@ has, err := engine.Get(&user) // SELECT * FROM user LIMIT 1 has, err := engine.Where("name = ?", name).Desc("id").Get(&user) // SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1 +var name string +has, err := engine.Where("id = ?", id).Cols("name").Get(&name) +// SELECT name FROM user WHERE id = ? +var id int64 +has, err := engine.Where("name = ?", name).Cols("id").Get(&id) +// SELECT id FROM user WHERE name = ? +var valuesMap = make(map[string]string) +has, err := engine.Where("id = ?", id).Get(&valuesMap) +// SELECT * FROM user WHERE id = ? +var valuesSlice = make([]interface{}, len(cols)) +has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice) +// SELECT col1, col2, col3 FROM user WHERE id = ? ``` * Query multiple records from database, also you can use join and extends diff --git a/README_CN.md b/README_CN.md index fafef1a5..f4c7e219 100644 --- a/README_CN.md +++ b/README_CN.md @@ -149,6 +149,18 @@ has, err := engine.Get(&user) // SELECT * FROM user LIMIT 1 has, err := engine.Where("name = ?", name).Desc("id").Get(&user) // SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1 +var name string +has, err := engine.Where("id = ?", id).Cols("name").Get(&name) +// SELECT name FROM user WHERE id = ? +var id int64 +has, err := engine.Where("name = ?", name).Cols("id").Get(&id) +// SELECT id FROM user WHERE name = ? +var valuesMap = make(map[string]string) +has, err := engine.Where("id = ?", id).Get(&valuesMap) +// SELECT * FROM user WHERE id = ? +var valuesSlice = make([]interface{}, len(cols)) +has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice) +// SELECT col1, col2, col3 FROM user WHERE id = ? ``` * 查询多条记录,当然可以使用Join和extends来组合使用 diff --git a/session_get.go b/session_get.go index 0c78ed94..bf61963c 100644 --- a/session_get.go +++ b/session_get.go @@ -22,12 +22,7 @@ func (session *Session) Get(bean interface{}) (bool, error) { beanValue := reflect.ValueOf(bean) if beanValue.Kind() != reflect.Ptr { - return false, errors.New("needs a pointer to a struct") - } - - // FIXME: remove this after support non-struct Get - if beanValue.Elem().Kind() != reflect.Struct { - return false, errors.New("needs a pointer to a struct") + return false, errors.New("needs a pointer") } if beanValue.Elem().Kind() == reflect.Struct { @@ -48,7 +43,7 @@ func (session *Session) Get(bean interface{}) (bool, error) { args = session.Statement.RawParams } - if session.canCache() { + if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && !session.Statement.unscoped { has, err := session.cacheGet(bean, sqlStr, args...) @@ -62,9 +57,10 @@ func (session *Session) Get(bean interface{}) (bool, error) { } func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { + session.queryPreprocess(&sqlStr, args...) + var rawRows *core.Rows var err error - session.queryPreprocess(&sqlStr, args...) if session.IsAutoCommit { _, rawRows, err = session.innerQuery(sqlStr, args...) } else { @@ -77,14 +73,13 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlS defer rawRows.Close() if rawRows.Next() { - fields, err := rawRows.Columns() - if err != nil { - // WARN: Alougth rawRows return true, but get fields failed - return true, err - } - switch beanKind { case reflect.Struct: + fields, err := rawRows.Columns() + if err != nil { + // WARN: Alougth rawRows return true, but get fields failed + return true, err + } dataStruct := rValue(bean) session.Statement.setRefValue(dataStruct) _, err = session.row2Bean(rawRows, fields, len(fields), bean, &dataStruct, session.Statement.RefTable) diff --git a/session_get_test.go b/session_get_test.go new file mode 100644 index 00000000..4c25dbd4 --- /dev/null +++ b/session_get_test.go @@ -0,0 +1,91 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestGetVar(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type GetVar struct { + Id int64 `xorm:"autoincr pk"` + Msg string `xorm:"varchar(255)"` + Age int + Money float32 + Created time.Time `xorm:"created"` + } + + assert.NoError(t, testEngine.Sync2(new(GetVar))) + + var data = GetVar{ + Msg: "hi", + Age: 28, + Money: 1.5, + } + _, err := testEngine.InsertOne(data) + assert.NoError(t, err) + + var msg string + has, err := testEngine.Table("get_var").Cols("msg").Get(&msg) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, "hi", msg) + + var age int + has, err = testEngine.Table("get_var").Cols("age").Get(&age) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, 28, age) + + var money float64 + has, err = testEngine.Table("get_var").Cols("money").Get(&money) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) + + var valuesString = make(map[string]string) + has, err = testEngine.Table("get_var").Get(&valuesString) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, 5, len(valuesString)) + assert.Equal(t, "1", valuesString["id"]) + assert.Equal(t, "hi", valuesString["msg"]) + assert.Equal(t, "28", valuesString["age"]) + assert.Equal(t, "1.5", valuesString["money"]) + + var valuesInter = make(map[string]interface{}) + has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, 5, len(valuesInter)) + assert.EqualValues(t, 1, valuesInter["id"]) + assert.Equal(t, "hi", fmt.Sprintf("%s", valuesInter["msg"])) + assert.EqualValues(t, 28, valuesInter["age"]) + assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"])) + + var valuesSliceString = make([]string, 5) + has, err = testEngine.Table("get_var").Get(&valuesSliceString) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, "1", valuesSliceString[0]) + assert.Equal(t, "hi", valuesSliceString[1]) + assert.Equal(t, "28", valuesSliceString[2]) + assert.Equal(t, "1.5", valuesSliceString[3]) + + var valuesSliceInter = make([]interface{}, 5) + has, err = testEngine.Table("get_var").Get(&valuesSliceInter) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 1, valuesSliceInter[0]) + assert.Equal(t, "hi", fmt.Sprintf("%s", valuesSliceInter[1])) + assert.EqualValues(t, 28, valuesSliceInter[2]) + assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesSliceInter[3])) +} diff --git a/statement.go b/statement.go index 82101ff2..71e9a7fe 100644 --- a/statement.go +++ b/statement.go @@ -1114,7 +1114,11 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e } func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) { - statement.setRefValue(rValue(bean)) + v := rValue(bean) + isStruct := v.Kind() == reflect.Struct + if isStruct { + statement.setRefValue(v) + } var columnStr = statement.ColumnStr if len(statement.selectStr) > 0 { @@ -1133,14 +1137,22 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) if len(columnStr) == 0 { if len(statement.GroupByStr) > 0 { columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) - } else { - columnStr = "*" } } } } - condSQL, condArgs, _ := statement.genConds(bean) + if len(columnStr) == 0 { + columnStr = "*" + } + + var condSQL string + var condArgs []interface{} + if isStruct { + condSQL, condArgs, _ = statement.genConds(bean) + } else { + condSQL, condArgs, _ = builder.ToSQL(statement.cond) + } return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...) } diff --git a/xorm.go b/xorm.go index 2cfbe9ec..0d9debdf 100644 --- a/xorm.go +++ b/xorm.go @@ -17,7 +17,7 @@ import ( const ( // Version show the xorm's version - Version string = "0.6.2.0326" + Version string = "0.6.2.0401" ) func regDrvsNDialects() bool {