From a691000f064689a7a77ff395f4520afcb4a12e8e Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 29 Jul 2019 23:32:32 +0800 Subject: [PATCH] fix error when get null var (#890) * fix error when get null var * add support get for null var * fix bug --- session_get.go | 111 +++++++++++++++++++++++++++++++++ session_get_test.go | 147 +++++++++++++++++++++++++++++++++++++++++++- tag_test.go | 2 +- 3 files changed, 258 insertions(+), 2 deletions(-) diff --git a/session_get.go b/session_get.go index a38707c8..ad2627f4 100644 --- a/session_get.go +++ b/session_get.go @@ -114,6 +114,114 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea return true, rows.Scan(&bean) case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString: return true, rows.Scan(bean) + case *string: + var res sql.NullString + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*string)) = res.String + } + return true, nil + case *int: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*int)) = int(res.Int64) + } + return true, nil + case *int8: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*int8)) = int8(res.Int64) + } + return true, nil + case *int16: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*int16)) = int16(res.Int64) + } + return true, nil + case *int32: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*int32)) = int32(res.Int64) + } + return true, nil + case *int64: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*int64)) = int64(res.Int64) + } + return true, nil + case *uint: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*uint)) = uint(res.Int64) + } + return true, nil + case *uint8: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*uint8)) = uint8(res.Int64) + } + return true, nil + case *uint16: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*uint16)) = uint16(res.Int64) + } + return true, nil + case *uint32: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*uint32)) = uint32(res.Int64) + } + return true, nil + case *uint64: + var res sql.NullInt64 + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*uint64)) = uint64(res.Int64) + } + return true, nil + case *bool: + var res sql.NullBool + if err := rows.Scan(&res); err != nil { + return true, err + } + if res.Valid { + *(bean.(*bool)) = res.Bool + } + return true, nil } switch beanKind { @@ -142,6 +250,9 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea err = rows.ScanSlice(bean) case reflect.Map: err = rows.ScanMap(bean) + case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + err = rows.Scan(&bean) default: err = rows.Scan(bean) } diff --git a/session_get_test.go b/session_get_test.go index 7bb84a8b..0679b7ab 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/core" ) func TestGetVar(t *testing.T) { @@ -56,6 +56,69 @@ func TestGetVar(t *testing.T) { assert.Equal(t, true, has) assert.EqualValues(t, 28, age2) + var age3 int8 + has, err = testEngine.Table("get_var").Cols("age").Get(&age3) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age3) + + var age4 int16 + has, err = testEngine.Table("get_var").Cols("age"). + Where("age > ?", 20). + And("age < ?", 30). + Get(&age4) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age4) + + var age5 int32 + has, err = testEngine.Table("get_var").Cols("age"). + Where("age > ?", 20). + And("age < ?", 30). + Get(&age5) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age5) + + var age6 int + has, err = testEngine.Table("get_var").Cols("age").Get(&age6) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age6) + + var age7 int64 + has, err = testEngine.Table("get_var").Cols("age"). + Where("age > ?", 20). + And("age < ?", 30). + Get(&age7) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age7) + + var age8 int8 + has, err = testEngine.Table("get_var").Cols("age").Get(&age8) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age8) + + var age9 int16 + has, err = testEngine.Table("get_var").Cols("age"). + Where("age > ?", 20). + And("age < ?", 30). + Get(&age9) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age9) + + var age10 int32 + has, err = testEngine.Table("get_var").Cols("age"). + Where("age > ?", 20). + And("age < ?", 30). + Get(&age10) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age10) + var id sql.NullInt64 has, err = testEngine.Table("get_var").Cols("id").Get(&id) assert.NoError(t, err) @@ -433,3 +496,85 @@ func TestGetCustomTableInterface(t *testing.T) { assert.NoError(t, err) assert.True(t, has) } + +func TestGetNullVar(t *testing.T) { + type TestGetNullVarStruct struct { + Id int64 + Name string + Age int + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(TestGetNullVarStruct)) + + affected, err := testEngine.Exec("insert into " + testEngine.TableName(new(TestGetNullVarStruct), true) + " (name,age) values (null,null)") + assert.NoError(t, err) + a, _ := affected.RowsAffected() + assert.EqualValues(t, 1, a) + + var name string + has, err := testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("name").Get(&name) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "", name) + + var age int + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age) + + var age2 int8 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age2) + + var age3 int16 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age3) + + var age4 int32 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age4) + + var age5 int64 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age5) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age5) + + var age6 uint + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age6) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age6) + + var age7 uint8 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age7) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age7) + + var age8 int16 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age8) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age8) + + var age9 int32 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age9) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age9) + + var age10 int64 + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age10) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, age10) +} diff --git a/tag_test.go b/tag_test.go index 8dc7fa13..cfb16b3b 100644 --- a/tag_test.go +++ b/tag_test.go @@ -11,8 +11,8 @@ import ( "testing" "time" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/core" ) type UserCU struct {