From 23a143717cf1e19f35fa5f2d014b9ed821182aa0 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 16 Jun 2021 11:15:13 +0800 Subject: [PATCH] detect the scan type according type --- convert.go | 4 +++- schemas/column.go | 5 +++-- session.go | 11 +++++------ session_get.go | 32 ++++++++++++++++++++++++++------ tags/parser.go | 1 - 5 files changed, 37 insertions(+), 16 deletions(-) diff --git a/convert.go b/convert.go index 081e799f..81c52f86 100644 --- a/convert.go +++ b/convert.go @@ -611,8 +611,10 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { } return v, nil } + default: + return vv.Interface(), nil } - return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) + return nil, fmt.Errorf("asKind unsupported type: %v, %v", tp, vv) } func asBool(bs []byte) (bool, error) { diff --git a/schemas/column.go b/schemas/column.go index 4bbb6c2d..78a28cfd 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -22,8 +22,9 @@ const ( type Column struct { Name string TableName string - FieldName string // Available only when parsed from a struct - FieldIndex []int // Available only when parsed from a struct + FieldName string // Available only when parsed from a struct + FieldIndex []int // Available only when parsed from a struct + Type reflect.Type // Available only when parsed from a struct SQLType SQLType IsJSON bool Length int diff --git a/session.go b/session.go index 8a645815..6b309659 100644 --- a/session.go +++ b/session.go @@ -449,11 +449,6 @@ func (session *Session) row2Slice(rows *core.Rows, types []*sql.ColumnType, fiel return scanResults, nil } -var ( - scannerTypePlaceHolder sql.Scanner - scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem() -) - // convertAssign converts an interface src to dst reflect.Value fieldValue func (session *Session) convertAssign(fieldValue *reflect.Value, columnName string, src interface{}, table *schemas.Table, pk *schemas.PK, idx int) error { if fieldValue == nil { @@ -469,6 +464,10 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri return nil } + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } + fmt.Printf("----- %v <------ %v \n", fieldValue.Type(), rawValue.Type()) if fieldValue.Type() == rawValue.Type() { fieldValue.Set(rawValue) @@ -518,7 +517,7 @@ func (session *Session) convertAssign(fieldValue *reflect.Value, columnName stri case *sql.NullString: if t.Valid { if fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + //fieldValue.Set(reflect.New(fieldValue.Type().Elem())) structConvert = fieldValue.Interface().(convert.Conversion) } if err := structConvert.FromDB([]byte(t.String)); err != nil { diff --git a/session_get.go b/session_get.go index bb848c06..d8836274 100644 --- a/session_get.go +++ b/session_get.go @@ -298,16 +298,36 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields } func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { - fields, err := rows.Columns() + var scanResults = make([]interface{}, 0, len(types)) + for i, tp := range types { + col := table.GetColumn(fields[i]) + if col == nil { + return true, fmt.Errorf("cannot find column named %v from columns %v", fields[i], table.ColumnsSeq()) + } + if col.Type.Implements(scannerType) { + scanResults = append(scanResults, &sql.RawBytes{}) + } else if col.Type.Implements(conversionType) { + scanResults = append(scanResults, &sql.RawBytes{}) + } else { + v, err := session.engine.driver.GenScanResult(tp.DatabaseTypeName()) + if err != nil { + return true, err + } + scanResults = append(scanResults, v) + } + } + + for _, closure := range session.beforeClosures { + closure(bean) + } + + err := session.engine.scan(rows, types, scanResults...) if err != nil { - // WARN: Alougth rows return true, but get fields failed return true, err } - scanResults, err := session.row2Slice(rows, fields, bean) - if err != nil { - return false, err - } + executeBeforeSet(bean, fields, scanResults) + // close it before convert data rows.Close() diff --git a/tags/parser.go b/tags/parser.go index fc624d9f..b793a8f1 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -271,7 +271,6 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { if t.Kind() == reflect.Ptr { t = t.Elem() v = v.Elem() - fmt.Println("======3333", v) } if t.Kind() != reflect.Struct { return nil, ErrUnsupportedType