From 188da20272dcc467e096cc08797b0f6a59150685 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 9 Mar 2020 08:03:59 +0000 Subject: [PATCH] Move value2interface from session to statement package (#1587) Fix zero Fix tests Move value2interface from session to statement package Reviewed-on: https://gitea.com/xorm/xorm/pulls/1587 --- internal/statements/values.go | 151 ++++++++++++++++++++++++++++++++++ internal/utils/zero.go | 50 +++++++++-- internal/utils/zero_test.go | 73 ++++++++++++++++ session_convert.go | 136 ------------------------------ session_insert.go | 27 ++---- session_update.go | 23 +----- 6 files changed, 275 insertions(+), 185 deletions(-) create mode 100644 internal/statements/values.go create mode 100644 internal/utils/zero_test.go diff --git a/internal/statements/values.go b/internal/statements/values.go new file mode 100644 index 00000000..b545a605 --- /dev/null +++ b/internal/statements/values.go @@ -0,0 +1,151 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "time" + + "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/json" + "xorm.io/xorm/schemas" +) + +var ( + nullFloatType = reflect.TypeOf(sql.NullFloat64{}) +) + +// Value2Interface convert a field value of a struct to interface for puting into database +func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue reflect.Value) (interface{}, error) { + if fieldValue.CanAddr() { + if fieldConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + data, err := fieldConvert.ToDB() + if err != nil { + return nil, err + } + if col.SQLType.IsBlob() { + return data, nil + } + return string(data), nil + } + } + + if fieldConvert, ok := fieldValue.Interface().(convert.Conversion); ok { + data, err := fieldConvert.ToDB() + if err != nil { + return nil, err + } + if col.SQLType.IsBlob() { + return data, nil + } + return string(data), nil + } + + fieldType := fieldValue.Type() + k := fieldType.Kind() + if k == reflect.Ptr { + if fieldValue.IsNil() { + return nil, nil + } else if !fieldValue.IsValid() { + return nil, nil + } else { + // !nashtsai! deference pointer type to instance type + fieldValue = fieldValue.Elem() + fieldType = fieldValue.Type() + k = fieldType.Kind() + } + } + + switch k { + case reflect.Bool: + return fieldValue.Bool(), nil + case reflect.String: + return fieldValue.String(), nil + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + tf := dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + return tf, nil + } else if fieldType.ConvertibleTo(nullFloatType) { + t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64) + if !t.Valid { + return nil, nil + } + return t.Float64, nil + } + + if !col.SQLType.IsJson() { + // !! 增加支持driver.Valuer接口的结构,如sql.NullString + if v, ok := fieldValue.Interface().(driver.Valuer); ok { + return v.Value() + } + + fieldTable, err := statement.tagParser.ParseWithCache(fieldValue) + if err != nil { + return nil, err + } + if len(fieldTable.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) + return pkField.Interface(), nil + } + return nil, fmt.Errorf("no primary key for col %v", col.Name) + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + return string(bytes), nil + } else if col.SQLType.IsBlob() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + return bytes, nil + } + return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type()) + case reflect.Complex64, reflect.Complex128: + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + return string(bytes), nil + case reflect.Array, reflect.Slice, reflect.Map: + if !fieldValue.IsValid() { + return fieldValue.Interface(), nil + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + return string(bytes), nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (k == reflect.Slice) && + (fieldValue.Type().Elem().Kind() == reflect.Uint8) { + bytes = fieldValue.Bytes() + } else { + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + } + return bytes, nil + } + return nil, ErrUnSupportedType + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + return int64(fieldValue.Uint()), nil + default: + return fieldValue.Interface(), nil + } +} diff --git a/internal/utils/zero.go b/internal/utils/zero.go index 5415fc15..8f033c60 100644 --- a/internal/utils/zero.go +++ b/internal/utils/zero.go @@ -13,7 +13,14 @@ type Zeroable interface { IsZero() bool } +var nilTime *time.Time + +// IsZero returns false if k is nil or has a zero value func IsZero(k interface{}) bool { + if k == nil { + return true + } + switch k.(type) { case int: return k.(int) == 0 @@ -43,28 +50,57 @@ func IsZero(k interface{}) bool { return k.(bool) == false case string: return k.(string) == "" + case *time.Time: + return k.(*time.Time) == nilTime || IsTimeZero(*k.(*time.Time)) + case time.Time: + return IsTimeZero(k.(time.Time)) case Zeroable: - return k.(Zeroable).IsZero() + return k.(Zeroable) == nil || k.(Zeroable).IsZero() + case reflect.Value: // for go version less than 1.13 because reflect.Value has no method IsZero + return IsValueZero(k.(reflect.Value)) } - return false + + return IsValueZero(reflect.ValueOf(k)) } +var zeroType = reflect.TypeOf((*Zeroable)(nil)).Elem() + func IsValueZero(v reflect.Value) bool { - if IsZero(v.Interface()) { - return true - } switch v.Kind() { - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Slice: return v.IsNil() + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: + return v.Int() == 0 + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: + return v.Uint() == 0 + case reflect.String: + return v.Len() == 0 + case reflect.Ptr: + if v.IsNil() { + return true + } + return IsValueZero(v.Elem()) + case reflect.Struct: + return IsStructZero(v) + case reflect.Array: + return IsArrayZero(v) } return false } func IsStructZero(v reflect.Value) bool { - if !v.IsValid() { + if !v.IsValid() || v.NumField() == 0 { return true } + if v.Type().Implements(zeroType) { + f := v.MethodByName("IsZero") + if f.IsValid() { + res := f.Call(nil) + return len(res) == 1 && res[0].Bool() + } + } + for i := 0; i < v.NumField(); i++ { field := v.Field(i) switch field.Kind() { diff --git a/internal/utils/zero_test.go b/internal/utils/zero_test.go new file mode 100644 index 00000000..a5f4912a --- /dev/null +++ b/internal/utils/zero_test.go @@ -0,0 +1,73 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type MyInt int +type ZeroStruct struct{} + +func TestZero(t *testing.T) { + var zeroValues = []interface{}{ + int8(0), + int16(0), + int(0), + int32(0), + int64(0), + uint8(0), + uint16(0), + uint(0), + uint32(0), + uint64(0), + MyInt(0), + reflect.ValueOf(0), + nil, + time.Time{}, + &time.Time{}, + nilTime, + ZeroStruct{}, + &ZeroStruct{}, + } + + for _, v := range zeroValues { + t.Run(fmt.Sprintf("%#v", v), func(t *testing.T) { + assert.True(t, IsZero(v)) + }) + } +} + +func TestIsValueZero(t *testing.T) { + var zeroReflectValues = []reflect.Value{ + reflect.ValueOf(int8(0)), + reflect.ValueOf(int16(0)), + reflect.ValueOf(int(0)), + reflect.ValueOf(int32(0)), + reflect.ValueOf(int64(0)), + reflect.ValueOf(uint8(0)), + reflect.ValueOf(uint16(0)), + reflect.ValueOf(uint(0)), + reflect.ValueOf(uint32(0)), + reflect.ValueOf(uint64(0)), + reflect.ValueOf(MyInt(0)), + reflect.ValueOf(time.Time{}), + reflect.ValueOf(&time.Time{}), + reflect.ValueOf(nilTime), + reflect.ValueOf(ZeroStruct{}), + reflect.ValueOf(&ZeroStruct{}), + } + + for _, v := range zeroReflectValues { + t.Run(fmt.Sprintf("%#v", v), func(t *testing.T) { + assert.True(t, IsValueZero(v)) + }) + } +} diff --git a/session_convert.go b/session_convert.go index 28866d4d..a6839947 100644 --- a/session_convert.go +++ b/session_convert.go @@ -6,7 +6,6 @@ package xorm import ( "database/sql" - "database/sql/driver" "errors" "fmt" "reflect" @@ -15,7 +14,6 @@ import ( "time" "xorm.io/xorm/convert" - "xorm.io/xorm/dialects" "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" @@ -88,10 +86,6 @@ func (session *Session) byte2Time(col *schemas.Column, data []byte) (outTime tim return session.str2Time(col, string(data)) } -var ( - nullFloatType = reflect.TypeOf(sql.NullFloat64{}) -) - // convert a db data([]byte) to a field value func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Value, data []byte) error { if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { @@ -533,133 +527,3 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val return nil } - -// convert a field value of a struct to interface for put into db -func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.Value) (interface{}, error) { - if fieldValue.CanAddr() { - if fieldConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - data, err := fieldConvert.ToDB() - if err != nil { - return 0, err - } - if col.SQLType.IsBlob() { - return data, nil - } - return string(data), nil - } - } - - if fieldConvert, ok := fieldValue.Interface().(convert.Conversion); ok { - data, err := fieldConvert.ToDB() - if err != nil { - return 0, err - } - if col.SQLType.IsBlob() { - return data, nil - } - return string(data), nil - } - - fieldType := fieldValue.Type() - k := fieldType.Kind() - if k == reflect.Ptr { - if fieldValue.IsNil() { - return nil, nil - } else if !fieldValue.IsValid() { - session.engine.logger.Warnf("the field [%s] is invalid", col.FieldName) - return nil, nil - } else { - // !nashtsai! deference pointer type to instance type - fieldValue = fieldValue.Elem() - fieldType = fieldValue.Type() - k = fieldType.Kind() - } - } - - switch k { - case reflect.Bool: - return fieldValue.Bool(), nil - case reflect.String: - return fieldValue.String(), nil - case reflect.Struct: - if fieldType.ConvertibleTo(schemas.TimeType) { - t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) - tf := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, t) - return tf, nil - } else if fieldType.ConvertibleTo(nullFloatType) { - t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64) - if !t.Valid { - return nil, nil - } - return t.Float64, nil - } - - if !col.SQLType.IsJson() { - // !! 增加支持driver.Valuer接口的结构,如sql.NullString - if v, ok := fieldValue.Interface().(driver.Valuer); ok { - return v.Value() - } - - fieldTable, err := session.engine.tagParser.ParseWithCache(fieldValue) - if err != nil { - return nil, err - } - if len(fieldTable.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) - return pkField.Interface(), nil - } - return 0, fmt.Errorf("no primary key for col %v", col.Name) - } - - if col.SQLType.IsText() { - bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return 0, err - } - return string(bytes), nil - } else if col.SQLType.IsBlob() { - bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return 0, err - } - return bytes, nil - } - return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type()) - case reflect.Complex64, reflect.Complex128: - bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return 0, err - } - return string(bytes), nil - case reflect.Array, reflect.Slice, reflect.Map: - if !fieldValue.IsValid() { - return fieldValue.Interface(), nil - } - - if col.SQLType.IsText() { - bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return 0, err - } - return string(bytes), nil - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (k == reflect.Slice) && - (fieldValue.Type().Elem().Kind() == reflect.Uint8) { - bytes = fieldValue.Bytes() - } else { - bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return 0, err - } - } - return bytes, nil - } - return nil, ErrUnSupportedType - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - return int64(fieldValue.Uint()), nil - default: - return fieldValue.Interface(), nil - } -} diff --git a/session_insert.go b/session_insert.go index b2e92309..12483aa3 100644 --- a/session_insert.go +++ b/session_insert.go @@ -176,7 +176,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error setColumnInt(bean, col, 1) }) } else { - arg, err := session.value2Interface(col, fieldValue) + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return 0, err } @@ -227,7 +227,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error setColumnInt(bean, col, 1) }) } else { - arg, err := session.value2Interface(col, fieldValue) + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return 0, err } @@ -567,25 +567,8 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } fieldValue := *fieldValuePtr - if col.IsAutoIncrement { - switch fieldValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - case reflect.String: - if len(fieldValue.String()) == 0 { - continue - } - case reflect.Ptr: - if fieldValue.Pointer() == 0 { - continue - } - } + if col.IsAutoIncrement && utils.IsValueZero(fieldValue) { + continue } // !evalphobia! set fieldValue as nil when column is nullable and zero-value @@ -609,7 +592,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) } else { - arg, err := session.value2Interface(col, fieldValue) + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return colNames, args, err } diff --git a/session_update.go b/session_update.go index f60f48e3..dadfaaca 100644 --- a/session_update.go +++ b/session_update.go @@ -473,25 +473,8 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac } fieldValue := *fieldValuePtr - if col.IsAutoIncrement { - switch fieldValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - case reflect.String: - if len(fieldValue.String()) == 0 { - continue - } - case reflect.Ptr: - if fieldValue.Pointer() == 0 { - continue - } - } + if col.IsAutoIncrement && utils.IsValueZero(fieldValue) { + continue } if (col.IsDeleted && !session.statement.GetUnscoped()) || col.IsCreated { @@ -532,7 +515,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) } else { - arg, err := session.value2Interface(col, fieldValue) + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return colNames, args, err }