more improvements

This commit is contained in:
Lunny Xiao 2021-06-14 16:09:03 +08:00
parent 1f771b1764
commit baa3fdb549
12 changed files with 600 additions and 395 deletions

View File

@ -42,6 +42,13 @@ func asString(src interface{}) string {
return string(v) return string(v)
case *sql.NullString: case *sql.NullString:
return v.String return v.String
case *sql.RawBytes:
return string(*v)
case *sql.NullBool:
if v.Valid {
return strconv.FormatBool(v.Bool)
}
return ""
case *sql.NullInt32: case *sql.NullInt32:
return fmt.Sprintf("%d", v.Int32) return fmt.Sprintf("%d", v.Int32)
case *sql.NullInt64: case *sql.NullInt64:
@ -236,6 +243,7 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
// An error is returned if the copy would result in loss of information. // An error is returned if the copy would result in loss of information.
// dest should be a pointer type. // dest should be a pointer type.
func convertAssign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error { func convertAssign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error {
fmt.Printf("======= %#v ------> %#v \n", src, dest)
// Common cases, without reflect. // Common cases, without reflect.
switch s := src.(type) { switch s := src.(type) {
case string: case string:
@ -253,6 +261,17 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve
*d = []byte(s) *d = []byte(s)
return nil return nil
} }
case int64:
switch d := dest.(type) {
case *uint64:
*d = uint64(s)
return nil
}
fmt.Println("======", src, dest)
case int, int32, int16, int8:
fmt.Println("22222222")
case uint, uint32, uint64, uint8, uint16:
fmt.Println("3333333")
case []byte: case []byte:
switch d := dest.(type) { switch d := dest.(type) {
case *string: case *string:
@ -274,7 +293,6 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve
*d = cloneBytes(s) *d = cloneBytes(s)
return nil return nil
} }
case time.Time: case time.Time:
switch d := dest.(type) { switch d := dest.(type) {
case *string: case *string:
@ -527,6 +545,7 @@ func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, conver
} }
dv.Set(reflect.New(dv.Type().Elem())) dv.Set(reflect.New(dv.Type().Elem()))
fmt.Println("333333")
return convertAssign(dv.Interface(), src, originalLocation, convertedLocation) return convertAssign(dv.Interface(), src, originalLocation, convertedLocation)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i64, err := asInt64(src) i64, err := asInt64(src)
@ -592,7 +611,6 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
} }
return v, nil return v, nil
} }
} }
return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv)
} }

View File

@ -62,3 +62,52 @@ func Interface2String(v interface{}) (string, error) {
return "", fmt.Errorf("convert assign string unsupported type: %#v", vv) return "", fmt.Errorf("convert assign string unsupported type: %#v", vv)
} }
} }
func Interface2Interface(v interface{}) (interface{}, error) {
if v == nil {
return nil, nil
}
switch vv := v.(type) {
case *int64:
return *vv, nil
case *int8:
return *vv, nil
case *sql.NullString:
if vv.Valid {
return vv.String, nil
}
return "", nil
case *sql.RawBytes:
if len([]byte(*vv)) > 0 {
return []byte(*vv), nil
}
return nil, nil
case *sql.NullInt32:
if vv.Valid {
return vv.Int32, nil
}
return 0, nil
case *sql.NullInt64:
if vv.Valid {
return vv.Int64, nil
}
return 0, nil
case *sql.NullFloat64:
if vv.Valid {
return vv.Float64, nil
}
return 0, nil
case *sql.NullBool:
if vv.Valid {
return vv.Bool, nil
}
return nil, nil
case *sql.NullTime:
if vv.Valid {
return vv.Time.Format("2006-01-02 15:04:05"), nil
}
return "", nil
default:
return "", fmt.Errorf("convert assign string unsupported type: %#v", vv)
}
}

View File

@ -1201,7 +1201,7 @@ func TestTagTime(t *testing.T) {
has, err = testEngine.Table("tag_u_t_c_struct").Cols("created").Get(&tm) has, err = testEngine.Table("tag_u_t_c_struct").Cols("created").Get(&tm)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), assert.EqualValues(t, s.Created.Format("2006-01-02 15:04:05"),
strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1))
} }

View File

@ -109,6 +109,7 @@ func TestGetBytes(t *testing.T) {
type ConvString string type ConvString string
func (s *ConvString) FromDB(data []byte) error { func (s *ConvString) FromDB(data []byte) error {
fmt.Printf("======= %#v,,,,,, %#v\n", s, data)
*s = ConvString("prefix---" + string(data)) *s = ConvString("prefix---" + string(data))
return nil return nil
} }

View File

@ -4,6 +4,8 @@
package xorm package xorm
import "reflect"
// BeforeInsertProcessor executed before an object is initially persisted to the database // BeforeInsertProcessor executed before an object is initially persisted to the database
type BeforeInsertProcessor interface { type BeforeInsertProcessor interface {
BeforeInsert() BeforeInsert()
@ -94,7 +96,8 @@ func executeBeforeClosures(session *Session, bean interface{}) {
func executeBeforeSet(bean interface{}, fields []string, scanResults []interface{}) { func executeBeforeSet(bean interface{}, fields []string, scanResults []interface{}) {
if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet { if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet {
for ii, key := range fields { for ii, key := range fields {
b.BeforeSet(key, Cell(scanResults[ii].(*interface{})))
b.BeforeSet(key, Cell(reflect.ValueOf(scanResults[ii]).Elem().Interface()))
} }
} }
} }
@ -102,7 +105,7 @@ func executeBeforeSet(bean interface{}, fields []string, scanResults []interface
func executeAfterSet(bean interface{}, fields []string, scanResults []interface{}) { func executeAfterSet(bean interface{}, fields []string, scanResults []interface{}) {
if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet {
for ii, key := range fields { for ii, key := range fields {
b.AfterSet(key, Cell(scanResults[ii].(*interface{}))) b.AfterSet(key, Cell(reflect.ValueOf(scanResults[ii]).Elem().Interface()))
} }
} }
} }

View File

@ -130,7 +130,12 @@ func (rows *Rows) Scan(bean interface{}) error {
return err return err
} }
scanResults, err := rows.session.row2Slice(rows.rows, fields, bean) types, err := rows.rows.ColumnTypes()
if err != nil {
return err
}
scanResults, err := rows.session.row2Slice(rows.rows, types, fields, bean)
if err != nil { if err != nil {
return err return err
} }

View File

@ -5,6 +5,7 @@
package schemas package schemas
import ( import (
"database/sql"
"reflect" "reflect"
"sort" "sort"
"strings" "strings"
@ -222,24 +223,28 @@ var (
// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison // !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison
var ( var (
emptyString string emptyString string
boolDefault bool boolDefault bool
byteDefault byte byteDefault byte
complex64Default complex64 complex64Default complex64
complex128Default complex128 complex128Default complex128
float32Default float32 float32Default float32
float64Default float64 float64Default float64
int64Default int64 int64Default int64
uint64Default uint64 uint64Default uint64
int32Default int32 int32Default int32
uint32Default uint32 uint32Default uint32
int16Default int16 int16Default int16
uint16Default uint16 uint16Default uint16
int8Default int8 int8Default int8
uint8Default uint8 uint8Default uint8
intDefault int intDefault int
uintDefault uint uintDefault uint
timeDefault time.Time timeDefault time.Time
nullInt64Default sql.NullInt64
nullFloat64Default sql.NullFloat64
nullBoolDefault sql.NullBool
nullTimeDefault sql.NullTime
) )
// enumerates all types // enumerates all types
@ -268,6 +273,11 @@ var (
BytesType = reflect.SliceOf(ByteType) BytesType = reflect.SliceOf(ByteType)
TimeType = reflect.TypeOf(timeDefault) TimeType = reflect.TypeOf(timeDefault)
NullInt64Type = reflect.TypeOf(nullInt64Default)
NullFloat64Type = reflect.TypeOf(nullFloat64Default)
NullBoolType = reflect.TypeOf(nullBoolDefault)
NullTimeType = reflect.TypeOf(nullTimeDefault)
) )
// enumerates all types // enumerates all types
@ -295,6 +305,9 @@ var (
PtrByteType = reflect.PtrTo(ByteType) PtrByteType = reflect.PtrTo(ByteType)
PtrTimeType = reflect.PtrTo(TimeType) PtrTimeType = reflect.PtrTo(TimeType)
PtrNullInt64Type = reflect.PtrTo(NullInt64Type)
PtrNullFloat64Type = reflect.PtrTo(NullFloat64Type)
) )
// Type2SQLType generate SQLType acorrding Go's type // Type2SQLType generate SQLType acorrding Go's type
@ -327,6 +340,14 @@ func Type2SQLType(t reflect.Type) (st SQLType) {
case reflect.Struct: case reflect.Struct:
if t.ConvertibleTo(TimeType) { if t.ConvertibleTo(TimeType) {
st = SQLType{DateTime, 0, 0} st = SQLType{DateTime, 0, 0}
} else if t.ConvertibleTo(NullInt64Type) {
st = SQLType{BigInt, 0, 0}
} else if t.ConvertibleTo(NullFloat64Type) {
st = SQLType{Double, 0, 0}
} else if t.ConvertibleTo(NullBoolType) {
st = SQLType{Bool, 0, 0}
} else if t.ConvertibleTo(NullTimeType) {
st = SQLType{DateTime, 0, 0}
} else { } else {
// TODO need to handle association struct // TODO need to handle association struct
st = SQLType{Text, 0, 0} st = SQLType{Text, 0, 0}

View File

@ -15,6 +15,7 @@ import (
"hash/crc32" "hash/crc32"
"io" "io"
"reflect" "reflect"
"strconv"
"strings" "strings"
"time" "time"
@ -387,9 +388,9 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s
} }
// Cell cell is a result of one column field // Cell cell is a result of one column field
type Cell *interface{} type Cell interface{}
func (session *Session) rows2Beans(rows *core.Rows, fields []string, func (session *Session) rows2Beans(rows *core.Rows, types []*sql.ColumnType, fields []string,
table *schemas.Table, newElemFunc func([]string) reflect.Value, table *schemas.Table, newElemFunc func([]string) reflect.Value,
sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error {
for rows.Next() { for rows.Next() {
@ -398,7 +399,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string,
dataStruct := newValue.Elem() dataStruct := newValue.Elem()
// handle beforeClosures // handle beforeClosures
scanResults, err := session.row2Slice(rows, fields, bean) scanResults, err := session.row2Slice(rows, types, fields, bean)
if err != nil { if err != nil {
return err return err
} }
@ -417,17 +418,29 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string,
return nil return nil
} }
func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) { func (session *Session) genScanResultsByTypes(types []*sql.ColumnType) ([]interface{}, error) {
scanResults := make([]interface{}, len(types))
for i := 0; i < len(types); i++ {
result, err := session.engine.driver.GenScanResult(types[i].DatabaseTypeName())
if err != nil {
return nil, err
}
scanResults[i] = result
}
return scanResults, nil
}
func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) ([]interface{}, error) {
for _, closure := range session.beforeClosures { for _, closure := range session.beforeClosures {
closure(bean) closure(bean)
} }
scanResults := make([]interface{}, len(fields)) scanResults, err := session.genScanResultsByTypes(types)
for i := 0; i < len(fields); i++ { if err != nil {
var cell interface{} return nil, err
scanResults[i] = &cell
} }
if err := rows.Scan(scanResults...); err != nil {
if err := session.engine.scan(rows, types, scanResults...); err != nil {
return nil, err return nil, err
} }
@ -436,206 +449,370 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa
return scanResults, nil return scanResults, nil
} }
func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { var (
defer func() { scannerTypePlaceHolder sql.Scanner
executeAfterSet(bean, fields, scanResults) scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem()
}() )
buildAfterProcessors(session, bean) // convertAssign converts an interface src to dst reflect.Value fieldValue
func (session *Session) convertAssign(fieldValue *reflect.Value, columnName string, src interface{}, table *schemas.Table, pk *schemas.PK, idx int) error {
if fieldValue == nil {
return nil
}
// if row is null then ignore
if src == nil {
return nil
}
var tempMap = make(map[string]int) rawValue := reflect.Indirect(reflect.ValueOf(src))
var pk schemas.PK if rawValue.Interface() == nil {
for ii, key := range fields { return nil
var idx int }
var ok bool
var lKey = strings.ToLower(key)
if idx, ok = tempMap[lKey]; !ok {
idx = 0
} else {
idx = idx + 1
}
tempMap[lKey] = idx
fieldValue, err := session.getField(dataStruct, key, table, idx) fmt.Printf("----- %v <------ %v \n", fieldValue.Type(), rawValue.Type())
if err != nil { if fieldValue.Type() == rawValue.Type() {
if !strings.Contains(err.Error(), "is not valid") { fieldValue.Set(rawValue)
session.engine.logger.Warnf("%v", err) return nil
} }
continue
}
if fieldValue == nil {
continue
}
rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii]))
// if row is null then ignore if fieldValue.CanAddr() {
if rawValue.Interface() == nil { if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
continue fmt.Printf("%s, ===========00000000000 %#v <----- %#v \n", columnName, fieldValue.Addr().Interface(), src)
return scanner.Scan(src)
} }
if fieldValue.CanAddr() { if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { switch t := src.(type) {
if data, err := value2Bytes(&rawValue); err == nil { case *sql.RawBytes:
if err := structConvert.FromDB(data); err != nil { if err := structConvert.FromDB([]byte(*t)); err != nil {
return nil, err return err
}
} else {
return nil, err
} }
continue case *sql.NullString:
if t.Valid {
if err := structConvert.FromDB([]byte(t.String)); err != nil {
return err
}
}
default:
return fmt.Errorf("unsupported type: %#v on column %v", t, columnName)
} }
return nil
}
}
if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok {
if scanner, ok := fieldValue.Interface().(sql.Scanner); ok {
fmt.Println("===========111111111111")
return scanner.Scan(src)
} }
if _, ok := fieldValue.Interface().(convert.Conversion); ok { switch t := src.(type) {
if data, err := value2Bytes(&rawValue); err == nil { case *sql.RawBytes:
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
structConvert = fieldValue.Interface().(convert.Conversion)
}
if err := structConvert.FromDB([]byte(*t)); err != nil {
return err
}
case *sql.NullString:
if t.Valid {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem())) fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
structConvert = fieldValue.Interface().(convert.Conversion)
}
if err := structConvert.FromDB([]byte(t.String)); err != nil {
return err
} }
fieldValue.Interface().(convert.Conversion).FromDB(data)
} else {
return nil, err
} }
continue default:
return fmt.Errorf("unsupported type: %#v", t)
} }
return nil
}
rawValueType := reflect.TypeOf(rawValue.Interface()) rawValueType := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface())
col := table.GetColumnIdx(key, idx) col := table.GetColumnIdx(columnName, idx)
if col.IsPrimaryKey { if col.IsPrimaryKey {
pk = append(pk, rawValue.Interface()) *pk = append(*pk, rawValue.Interface())
} }
fieldType := fieldValue.Type() fieldType := fieldValue.Type()
hasAssigned := false hasAssigned := false
if col.IsJSON { if col.IsJSON {
var bs []byte var bs []byte
switch t := src.(type) {
case *sql.NullString:
if t.Valid {
bs = []byte(t.String)
}
case *sql.RawBytes:
bs = *t
default:
if rawValueType.Kind() == reflect.String { if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String()) bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(schemas.BytesType) { } else if rawValueType.ConvertibleTo(schemas.BytesType) {
bs = vv.Bytes() bs = vv.Bytes()
} else { } else {
return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) return fmt.Errorf("unsupported database data type: %s %v ---> %#v", columnName, rawValueType.Kind(), src)
} }
hasAssigned = true
if len(bs) > 0 {
if fieldType.Kind() == reflect.String {
fieldValue.SetString(string(bs))
continue
}
if fieldValue.CanAddr() {
err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return nil, err
}
} else {
x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
}
}
continue
} }
switch fieldType.Kind() { if len(bs) > 0 {
case reflect.Complex64, reflect.Complex128: if fieldType.Kind() == reflect.String {
// TODO: reimplement this fieldValue.SetString(string(bs))
var bs []byte return nil
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(schemas.BytesType) {
bs = vv.Bytes()
} }
if fieldValue.CanAddr() {
err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return err
}
} else {
x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil {
return err
}
fieldValue.Set(x.Elem())
}
}
return nil
}
var kind = fieldType.Kind()
if reflect.Ptr == kind {
kind = fieldType.Elem().Kind()
}
switch kind {
case reflect.Complex64, reflect.Complex128:
// TODO: reimplement this
var bs []byte
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(schemas.BytesType) {
bs = vv.Bytes()
}
if len(bs) > 0 {
if fieldValue.CanAddr() {
err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return err
}
} else {
x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil {
return err
}
fieldValue.Set(x.Elem())
}
}
return nil
case reflect.Slice, reflect.Array:
switch t := src.(type) {
case *sql.NullString:
hasAssigned = true hasAssigned = true
if len(bs) > 0 { fmt.Printf("====== %#v <-------- %#v \n", fieldValue.Interface(), t)
if fieldValue.CanAddr() { if t.Valid {
err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) if fieldType.Elem().Kind() == reflect.Uint8 {
if err != nil { fieldValue.SetBytes([]byte(t.String))
return nil, err
}
} else {
x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
} }
} }
}
switch rawValueType.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
switch rawValueType.Kind() { switch rawValueType.Elem().Kind() {
case reflect.Slice, reflect.Array: case reflect.Uint8:
switch rawValueType.Elem().Kind() { if fieldType.Elem().Kind() == reflect.Uint8 {
case reflect.Uint8: hasAssigned = true
if fieldType.Elem().Kind() == reflect.Uint8 { if col.SQLType.IsText() {
hasAssigned = true x := reflect.New(fieldType)
if col.SQLType.IsText() { err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface())
x := reflect.New(fieldType) if err != nil {
err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) return err
if err != nil { }
return nil, err fieldValue.Set(x.Elem())
} else {
if fieldValue.Len() > 0 {
for i := 0; i < fieldValue.Len(); i++ {
if i < vv.Len() {
fieldValue.Index(i).Set(vv.Index(i))
}
} }
fieldValue.Set(x.Elem())
} else { } else {
if fieldValue.Len() > 0 { for i := 0; i < vv.Len(); i++ {
for i := 0; i < fieldValue.Len(); i++ { fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i)))
if i < vv.Len() {
fieldValue.Index(i).Set(vv.Index(i))
}
}
} else {
for i := 0; i < vv.Len(); i++ {
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i)))
}
} }
} }
} }
} }
} }
case reflect.String: }
if rawValueType.Kind() == reflect.String { case reflect.String:
hasAssigned = true switch t := src.(type) {
fieldValue.SetString(vv.String()) case *sql.NullString:
if t.Valid {
fieldValue.SetString(t.String)
} }
case reflect.Bool: return nil
if rawValueType.Kind() == reflect.Bool { case sql.NullString:
hasAssigned = true if t.Valid {
fieldValue.SetBool(vv.Bool()) fieldValue.SetString(t.String)
} }
return nil
case *sql.NullTime:
if t.Valid {
fieldValue.SetString(t.Time.In(session.engine.TZLocation).Format("2006-01-02 15:04:05"))
}
return nil
}
if rawValueType.Kind() == reflect.String {
fieldValue.SetString(vv.String())
return nil
}
case reflect.Bool:
switch t := src.(type) {
case *sql.NullBool:
if t.Valid {
fieldValue.SetBool(t.Bool)
}
return nil
case *sql.NullInt64:
if t.Valid {
fieldValue.SetBool(t.Int64 > 0)
}
return nil
}
if rawValueType.Kind() == reflect.Bool {
fieldValue.SetBool(vv.Bool())
return nil
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch t := src.(type) {
case *sql.NullInt64:
if t.Valid {
fieldValue.SetInt(t.Int64)
}
return nil
case *sql.NullString:
if t.Valid {
tv, _ := strconv.ParseInt(t.String, 10, 64)
fieldValue.SetInt(tv)
}
return nil
}
switch rawValueType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch rawValueType.Kind() { fieldValue.SetInt(vv.Int())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return nil
hasAssigned = true default:
fieldValue.SetInt(vv.Int()) return fmt.Errorf("unsupported convert %v to %v", rawValueType.Kind(), fieldType.Kind())
}
case reflect.Float32, reflect.Float64:
switch t := src.(type) {
case *sql.NullFloat64:
if t.Valid {
fieldValue.SetFloat(t.Float64)
} }
case reflect.Float32, reflect.Float64: return nil
switch rawValueType.Kind() { case *sql.NullString:
case reflect.Float32, reflect.Float64: if t.Valid {
hasAssigned = true vv, err := strconv.ParseFloat(t.String, 10)
fieldValue.SetFloat(vv.Float()) if err != nil {
} return err
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
switch rawValueType.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
hasAssigned = true
fieldValue.SetUint(vv.Uint())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
hasAssigned = true
fieldValue.SetUint(uint64(vv.Int()))
}
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.TimeType) {
dbTZ := session.engine.DatabaseTZ
if col.TimeZone != nil {
dbTZ = col.TimeZone
} }
fieldValue.SetFloat(vv)
}
return nil
}
switch rawValueType.Kind() {
case reflect.Float32, reflect.Float64:
fieldValue.SetFloat(vv.Float())
return nil
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
switch t := src.(type) {
case *sql.NullInt64:
if t.Valid {
fieldValue.SetUint(uint64(t.Int64))
}
return nil
case *sql.NullString:
if t.Valid {
vv, err := strconv.ParseUint(t.String, 10, 64)
if err != nil {
return err
}
fieldValue.SetUint(vv)
}
return nil
}
switch rawValueType.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
fieldValue.SetUint(vv.Uint())
return nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
hasAssigned = true
return nil
}
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.TimeType) {
dbTZ := session.engine.DatabaseTZ
if col.TimeZone != nil {
dbTZ = col.TimeZone
}
fmt.Printf("99999999 %#v\n", src)
switch d := src.(type) {
case []byte:
hasAssigned = true
t, err := session.byte2Time(col, d)
if err != nil {
session.engine.logger.Errorf("byte2Time error: %v", err)
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
case string:
hasAssigned = true
t, err := session.str2Time(col, d)
if err != nil {
session.engine.logger.Errorf("byte2Time error: %v", err)
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
case *sql.NullInt64:
hasAssigned = true
if d.Valid {
t := time.Unix(d.Int64, 0).In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
case *sql.NullString:
hasAssigned = true
if d.Valid {
t, err := session.str2Time(col, d.String)
if err != nil {
session.engine.logger.Errorf("byte2Time error: %v", err)
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
}
case *sql.NullTime:
hasAssigned = true
if d.Valid {
fieldValue.Set(reflect.ValueOf(d.Time).Convert(fieldType))
}
case nil:
hasAssigned = true
default:
if rawValueType == schemas.TimeType { if rawValueType == schemas.TimeType {
hasAssigned = true hasAssigned = true
@ -644,7 +821,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
z, _ := t.Zone() z, _ := t.Zone()
// set new location if database don't save timezone or give an incorrect timezone // set new location if database don't save timezone or give an incorrect timezone
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", columnName, t, z, *t.Location())
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), dbTZ) t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
} }
@ -658,216 +835,123 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else { } else {
if d, ok := vv.Interface().([]uint8); ok { return fmt.Errorf("rawValueType is %v, value is %v", rawValueType, src)
hasAssigned = true
t, err := session.byte2Time(col, d)
if err != nil {
session.engine.logger.Errorf("byte2Time error: %v", err)
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
} else if d, ok := vv.Interface().(string); ok {
hasAssigned = true
t, err := session.str2Time(col, d)
if err != nil {
session.engine.logger.Errorf("byte2Time error: %v", err)
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
}
} else {
return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface())
}
} }
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { }
// !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
if err := nulVal.Scan(src); err != nil {
return fmt.Errorf("sql.Sanner error: %v", err)
}
return nil
} else if col.IsJSON {
if rawValueType.Kind() == reflect.String {
hasAssigned = true hasAssigned = true
if err := nulVal.Scan(vv.Interface()); err != nil { x := reflect.New(fieldType)
session.engine.logger.Errorf("sql.Sanner error: %v", err) if len([]byte(vv.String())) > 0 {
hasAssigned = false err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface())
} if err != nil {
} else if col.IsJSON { return err
if rawValueType.Kind() == reflect.String {
hasAssigned = true
x := reflect.New(fieldType)
if len([]byte(vv.String())) > 0 {
err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
} }
} else if rawValueType.Kind() == reflect.Slice { fieldValue.Set(x.Elem())
hasAssigned = true }
x := reflect.New(fieldType) } else if rawValueType.Kind() == reflect.Slice {
if len(vv.Bytes()) > 0 { hasAssigned = true
err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) x := reflect.New(fieldType)
if err != nil { if len(vv.Bytes()) > 0 {
return nil, err err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface())
} if err != nil {
fieldValue.Set(x.Elem()) return err
} }
fieldValue.Set(x.Elem())
} }
} else if session.statement.UseCascade { }
table, err := session.engine.tagParser.ParseWithCache(*fieldValue) } else if session.statement.UseCascade {
if err != nil { fmt.Printf("5565666======= %#v \n", *fieldValue)
return nil, err table, err := session.engine.tagParser.ParseWithCache(*fieldValue)
} if err != nil {
return err
}
hasAssigned = true hasAssigned = true
if len(table.PrimaryKeys) != 1 { if len(table.PrimaryKeys) != 1 {
return nil, errors.New("unsupported non or composited primary key cascade") return errors.New("unsupported non or composited primary key cascade")
}
var pk = make(schemas.PK, len(table.PrimaryKeys))
switch t := src.(type) {
case int64:
pk[0] = t
case *sql.NullInt64:
if t.Valid {
pk[0] = t.Int64
} }
var pk = make(schemas.PK, len(table.PrimaryKeys)) default:
pk[0], err = asKind(vv, rawValueType) pk[0], err = asKind(vv, rawValueType)
if err != nil { if err != nil {
return nil, err return err
}
if !pk.IsZero() {
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily
structInter := reflect.New(fieldValue.Type())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return nil, err
}
if has {
fieldValue.Set(structInter.Elem())
} else {
return nil, errors.New("cascade obj is not exist")
}
} }
} }
case reflect.Ptr:
// !nashtsai! TODO merge duplicated codes above
switch fieldType {
// following types case matching ptr's native type, therefore assign ptr directly
case schemas.PtrStringType:
if rawValueType.Kind() == reflect.String {
x := vv.String()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.PtrBoolType:
if rawValueType.Kind() == reflect.Bool {
x := vv.Bool()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.PtrTimeType:
if rawValueType == schemas.PtrTimeType {
hasAssigned = true
var x = rawValue.Interface().(time.Time)
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.PtrFloat64Type:
if rawValueType.Kind() == reflect.Float64 {
x := vv.Float()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.PtrUint64Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint64(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.PtrInt64Type:
if rawValueType.Kind() == reflect.Int64 {
x := vv.Int()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.PtrFloat32Type:
if rawValueType.Kind() == reflect.Float64 {
var x = float32(vv.Float())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.PtrIntType:
if rawValueType.Kind() == reflect.Int64 {
var x = int(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.PtrInt32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.PtrInt8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.PtrInt16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.PtrUintType:
if rawValueType.Kind() == reflect.Int64 {
var x = uint(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.PtrUint32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.Uint8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.Uint16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case schemas.Complex64Type:
var x complex64
if len([]byte(vv.String())) > 0 {
err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x)
if err != nil {
return nil, err
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
case schemas.Complex128Type:
var x complex128
if len([]byte(vv.String())) > 0 {
err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x)
if err != nil {
return nil, err
}
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
} // switch fieldType
} // switch fieldType.Kind()
// !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value if !pk.IsZero() {
if !hasAssigned { // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
data, err := value2Bytes(&rawValue) // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
if err != nil { // property to be fetched lazily
return nil, err structInter := reflect.New(fieldValue.Type())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return err
}
if has {
fieldValue.Set(structInter.Elem())
} else {
return errors.New("cascade obj is not exist")
}
} }
} else if fieldType.ConvertibleTo(reflect.TypeOf(&sql.NullString{})) {
fmt.Println("=====3333======")
} else if fieldType.ConvertibleTo(reflect.TypeOf(&sql.NullInt64{})) {
fmt.Println("=====4444======")
}
case reflect.Ptr:
return errors.New("unsupported pointer to pointer")
} // switch fieldType.Kind()
if err = session.bytes2Value(col, fieldValue, data); err != nil { if !hasAssigned {
return nil, err return fmt.Errorf("unsupported convertion from %#v to %#v", src, fieldValue.Interface())
}
return nil
}
func (session *Session) slice2Bean(scanResults []interface{}, columnNames []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) {
defer func() {
executeAfterSet(bean, columnNames, scanResults)
}()
buildAfterProcessors(session, bean)
var tempMap = make(map[string]int)
var pk schemas.PK
for i, columnName := range columnNames {
var idx int
var ok bool
var lKey = strings.ToLower(columnName)
if idx, ok = tempMap[lKey]; !ok {
idx = 0
} else {
idx = idx + 1
}
tempMap[lKey] = idx
fieldValue, err := session.getField(dataStruct, columnName, table, idx)
if err != nil {
if !strings.Contains(err.Error(), "is not valid") {
session.engine.logger.Warnf("%v", err)
} }
continue
}
fmt.Printf("88888====== %#v \n ", scanResults[i])
if err := session.convertAssign(fieldValue, columnName, scanResults[i], table, &pk, idx); err != nil {
return nil, err
} }
} }
return pk, nil return pk, nil

View File

@ -172,6 +172,11 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
return err return err
} }
types, err := rows.ColumnTypes()
if err != nil {
return err
}
var newElemFunc func(fields []string) reflect.Value var newElemFunc func(fields []string) reflect.Value
elemType := containerValue.Type().Elem() elemType := containerValue.Type().Elem()
var isPointer bool var isPointer bool
@ -241,7 +246,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
if err != nil { if err != nil {
return err return err
} }
err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc) err = session.rows2Beans(rows, types, fields, tb, newElemFunc, containerValueSetFunc)
rows.Close() rows.Close()
if err != nil { if err != nil {
return err return err
@ -274,6 +279,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
} }
func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error {
fmt.Printf("----- %#v \n", pk)
cols := table.PKColumns() cols := table.PKColumns()
if len(cols) == 1 { if len(cols) == 1 {
return convertAssign(dst, pk[0], nil, nil) return convertAssign(dst, pk[0], nil, nil)

View File

@ -124,6 +124,31 @@ var (
conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem() conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem()
) )
func (session *Session) genScanResultsByBeanStruct(bean interface{}, fields ...string) ([]interface{}, error) {
structV := reflect.ValueOf(bean)
table, err := session.engine.tagParser.ParseWithCache(structV)
if err != nil {
return nil, err
}
structV = structV.Elem()
var dstResults = make([]interface{}, 0, len(fields))
for _, field := range fields {
var fieldName string
for _, col := range table.Columns() {
if col.Name == field {
fieldName = col.FieldName
break
}
}
v := structV.FieldByName(fieldName)
if !v.IsValid() {
return nil, fmt.Errorf("get field named %v failed", field)
}
dstResults = append(dstResults, v.Addr())
}
return dstResults, nil
}
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
rows, err := session.queryRows(sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)
if err != nil { if err != nil {

View File

@ -122,14 +122,6 @@ func value2String(rawValue *reflect.Value) (str string, err error) {
return return
} }
func value2Bytes(rawValue *reflect.Value) ([]byte, error) {
str, err := value2String(rawValue)
if err != nil {
return nil, err
}
return []byte(str), nil
}
func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) { func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) {
rows, err := session.queryRows(sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)
if err != nil { if err != nil {

View File

@ -271,6 +271,7 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
if t.Kind() == reflect.Ptr { if t.Kind() == reflect.Ptr {
t = t.Elem() t = t.Elem()
v = v.Elem() v = v.Elem()
fmt.Println("======3333", v)
} }
if t.Kind() != reflect.Struct { if t.Kind() != reflect.Struct {
return nil, ErrUnsupportedType return nil, ErrUnsupportedType