add Scan features to Get method

This commit is contained in:
Lunny Xiao 2017-04-01 10:09:00 +08:00
parent 6687a2b4e8
commit 5ebae720bd
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
6 changed files with 141 additions and 19 deletions

View File

@ -145,6 +145,18 @@ has, err := engine.Get(&user)
// SELECT * FROM user LIMIT 1 // SELECT * FROM user LIMIT 1
has, err := engine.Where("name = ?", name).Desc("id").Get(&user) has, err := engine.Where("name = ?", name).Desc("id").Get(&user)
// SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1 // SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1
var name string
has, err := engine.Where("id = ?", id).Cols("name").Get(&name)
// SELECT name FROM user WHERE id = ?
var id int64
has, err := engine.Where("name = ?", name).Cols("id").Get(&id)
// SELECT id FROM user WHERE name = ?
var valuesMap = make(map[string]string)
has, err := engine.Where("id = ?", id).Get(&valuesMap)
// SELECT * FROM user WHERE id = ?
var valuesSlice = make([]interface{}, len(cols))
has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice)
// SELECT col1, col2, col3 FROM user WHERE id = ?
``` ```
* Query multiple records from database, also you can use join and extends * Query multiple records from database, also you can use join and extends

View File

@ -149,6 +149,18 @@ has, err := engine.Get(&user)
// SELECT * FROM user LIMIT 1 // SELECT * FROM user LIMIT 1
has, err := engine.Where("name = ?", name).Desc("id").Get(&user) has, err := engine.Where("name = ?", name).Desc("id").Get(&user)
// SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1 // SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1
var name string
has, err := engine.Where("id = ?", id).Cols("name").Get(&name)
// SELECT name FROM user WHERE id = ?
var id int64
has, err := engine.Where("name = ?", name).Cols("id").Get(&id)
// SELECT id FROM user WHERE name = ?
var valuesMap = make(map[string]string)
has, err := engine.Where("id = ?", id).Get(&valuesMap)
// SELECT * FROM user WHERE id = ?
var valuesSlice = make([]interface{}, len(cols))
has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice)
// SELECT col1, col2, col3 FROM user WHERE id = ?
``` ```
* 查询多条记录当然可以使用Join和extends来组合使用 * 查询多条记录当然可以使用Join和extends来组合使用

View File

@ -22,12 +22,7 @@ func (session *Session) Get(bean interface{}) (bool, error) {
beanValue := reflect.ValueOf(bean) beanValue := reflect.ValueOf(bean)
if beanValue.Kind() != reflect.Ptr { if beanValue.Kind() != reflect.Ptr {
return false, errors.New("needs a pointer to a struct") return false, errors.New("needs a pointer")
}
// FIXME: remove this after support non-struct Get
if beanValue.Elem().Kind() != reflect.Struct {
return false, errors.New("needs a pointer to a struct")
} }
if beanValue.Elem().Kind() == reflect.Struct { if beanValue.Elem().Kind() == reflect.Struct {
@ -48,7 +43,7 @@ func (session *Session) Get(bean interface{}) (bool, error) {
args = session.Statement.RawParams args = session.Statement.RawParams
} }
if session.canCache() { if session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil &&
!session.Statement.unscoped { !session.Statement.unscoped {
has, err := session.cacheGet(bean, sqlStr, args...) has, err := session.cacheGet(bean, sqlStr, args...)
@ -62,9 +57,10 @@ func (session *Session) Get(bean interface{}) (bool, error) {
} }
func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
session.queryPreprocess(&sqlStr, args...)
var rawRows *core.Rows var rawRows *core.Rows
var err error var err error
session.queryPreprocess(&sqlStr, args...)
if session.IsAutoCommit { if session.IsAutoCommit {
_, rawRows, err = session.innerQuery(sqlStr, args...) _, rawRows, err = session.innerQuery(sqlStr, args...)
} else { } else {
@ -77,14 +73,13 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlS
defer rawRows.Close() defer rawRows.Close()
if rawRows.Next() { if rawRows.Next() {
fields, err := rawRows.Columns()
if err != nil {
// WARN: Alougth rawRows return true, but get fields failed
return true, err
}
switch beanKind { switch beanKind {
case reflect.Struct: case reflect.Struct:
fields, err := rawRows.Columns()
if err != nil {
// WARN: Alougth rawRows return true, but get fields failed
return true, err
}
dataStruct := rValue(bean) dataStruct := rValue(bean)
session.Statement.setRefValue(dataStruct) session.Statement.setRefValue(dataStruct)
_, err = session.row2Bean(rawRows, fields, len(fields), bean, &dataStruct, session.Statement.RefTable) _, err = session.row2Bean(rawRows, fields, len(fields), bean, &dataStruct, session.Statement.RefTable)

91
session_get_test.go Normal file
View File

@ -0,0 +1,91 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestGetVar(t *testing.T) {
assert.NoError(t, prepareEngine())
type GetVar struct {
Id int64 `xorm:"autoincr pk"`
Msg string `xorm:"varchar(255)"`
Age int
Money float32
Created time.Time `xorm:"created"`
}
assert.NoError(t, testEngine.Sync2(new(GetVar)))
var data = GetVar{
Msg: "hi",
Age: 28,
Money: 1.5,
}
_, err := testEngine.InsertOne(data)
assert.NoError(t, err)
var msg string
has, err := testEngine.Table("get_var").Cols("msg").Get(&msg)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, "hi", msg)
var age int
has, err = testEngine.Table("get_var").Cols("age").Get(&age)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, 28, age)
var money float64
has, err = testEngine.Table("get_var").Cols("money").Get(&money)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))
var valuesString = make(map[string]string)
has, err = testEngine.Table("get_var").Get(&valuesString)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, 5, len(valuesString))
assert.Equal(t, "1", valuesString["id"])
assert.Equal(t, "hi", valuesString["msg"])
assert.Equal(t, "28", valuesString["age"])
assert.Equal(t, "1.5", valuesString["money"])
var valuesInter = make(map[string]interface{})
has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, 5, len(valuesInter))
assert.EqualValues(t, 1, valuesInter["id"])
assert.Equal(t, "hi", fmt.Sprintf("%s", valuesInter["msg"]))
assert.EqualValues(t, 28, valuesInter["age"])
assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"]))
var valuesSliceString = make([]string, 5)
has, err = testEngine.Table("get_var").Get(&valuesSliceString)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, "1", valuesSliceString[0])
assert.Equal(t, "hi", valuesSliceString[1])
assert.Equal(t, "28", valuesSliceString[2])
assert.Equal(t, "1.5", valuesSliceString[3])
var valuesSliceInter = make([]interface{}, 5)
has, err = testEngine.Table("get_var").Get(&valuesSliceInter)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 1, valuesSliceInter[0])
assert.Equal(t, "hi", fmt.Sprintf("%s", valuesSliceInter[1]))
assert.EqualValues(t, 28, valuesSliceInter[2])
assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesSliceInter[3]))
}

View File

@ -1114,7 +1114,11 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e
} }
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) { func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) {
statement.setRefValue(rValue(bean)) v := rValue(bean)
isStruct := v.Kind() == reflect.Struct
if isStruct {
statement.setRefValue(v)
}
var columnStr = statement.ColumnStr var columnStr = statement.ColumnStr
if len(statement.selectStr) > 0 { if len(statement.selectStr) > 0 {
@ -1133,14 +1137,22 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{})
if len(columnStr) == 0 { if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 { if len(statement.GroupByStr) > 0 {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
} else {
columnStr = "*"
} }
} }
} }
} }
condSQL, condArgs, _ := statement.genConds(bean) if len(columnStr) == 0 {
columnStr = "*"
}
var condSQL string
var condArgs []interface{}
if isStruct {
condSQL, condArgs, _ = statement.genConds(bean)
} else {
condSQL, condArgs, _ = builder.ToSQL(statement.cond)
}
return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...) return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...)
} }

View File

@ -17,7 +17,7 @@ import (
const ( const (
// Version show the xorm's version // Version show the xorm's version
Version string = "0.6.2.0326" Version string = "0.6.2.0401"
) )
func regDrvsNDialects() bool { func regDrvsNDialects() bool {