From 4459013c0a224c55f5f2ac5f19a28f26ff134ba3 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 23 Jul 2021 15:41:12 +0800 Subject: [PATCH] refactor conversion --- convert.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++ dialects/dameng.go | 46 ++++++++++++++++++++++++++++++++++++++++--- session_get.go | 3 --- 3 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 convert.go diff --git a/convert.go b/convert.go new file mode 100644 index 00000000..371ec2bc --- /dev/null +++ b/convert.go @@ -0,0 +1,49 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "reflect" + "strconv" +) + +func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { + switch tp.Kind() { + case reflect.Ptr: + return asKind(vv.Elem(), tp.Elem()) + case reflect.Int64: + return vv.Int(), nil + case reflect.Int: + return int(vv.Int()), nil + case reflect.Int32: + return int32(vv.Int()), nil + case reflect.Int16: + return int16(vv.Int()), nil + case reflect.Int8: + return int8(vv.Int()), nil + case reflect.Uint64: + return vv.Uint(), nil + case reflect.Uint: + return uint(vv.Uint()), nil + case reflect.Uint32: + return uint32(vv.Uint()), nil + case reflect.Uint16: + return uint16(vv.Uint()), nil + case reflect.Uint8: + return uint8(vv.Uint()), nil + case reflect.String: + return vv.String(), nil + case reflect.Slice: + if tp.Elem().Kind() == reflect.Uint8 { + v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64) + if err != nil { + return nil, err + } + return v, nil + } + } + return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) +} diff --git a/dialects/dameng.go b/dialects/dameng.go index c3f7423f..6ffd7cc1 100644 --- a/dialects/dameng.go +++ b/dialects/dameng.go @@ -17,6 +17,7 @@ import ( "gitee.com/travelliu/dm" "xorm.io/xorm/core" + "xorm.io/xorm/internal/convert" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -1001,7 +1002,7 @@ type damengDriver struct { } // Features return features -func (p *damengDriver) Features() *DriverFeatures { +func (d *damengDriver) Features() *DriverFeatures { return &DriverFeatures{ SupportReturnInsertedID: false, } @@ -1009,7 +1010,7 @@ func (p *damengDriver) Features() *DriverFeatures { // Parse parse the datasource // dm://userName:password@ip:port -func (p *damengDriver) Parse(driverName, dataSourceName string) (*URI, error) { +func (d *damengDriver) Parse(driverName, dataSourceName string) (*URI, error) { u, err := url.Parse(dataSourceName) if err != nil { return nil, err @@ -1031,7 +1032,7 @@ func (p *damengDriver) Parse(driverName, dataSourceName string) (*URI, error) { }, nil } -func (g *damengDriver) GenScanResult(colType string) (interface{}, error) { +func (d *damengDriver) GenScanResult(colType string) (interface{}, error) { switch colType { case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": var s sql.NullString @@ -1050,3 +1051,42 @@ func (g *damengDriver) GenScanResult(colType string) (interface{}, error) { return &r, nil } } + +func (d *damengDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, vv ...interface{}) error { + var scanResults = make([]interface{}, 0, len(types)) + var replaces = make([]bool, 0, len(types)) + var err error + for i, v := range vv { + var replaced bool + var scanResult interface{} + switch types[i].DatabaseTypeName() { + case "CLOB": + scanResult = &dmClobScanner{} + replaced = true + default: + scanResult = v + } + + scanResults = append(scanResults, scanResult) + replaces = append(replaces, replaced) + } + + if err = rows.Scan(scanResults...); err != nil { + return err + } + + for i, replaced := range replaces { + if replaced { + switch t := scanResults[i].(type) { + case *dmClobScanner: + if err := convert.Assign(vv[i], t.data, ctx.DBLocation, ctx.UserLocation); err != nil { + return err + } + default: + return fmt.Errorf("don't support convert %T to %T", t, vv[i]) + } + } + } + + return nil +} diff --git a/session_get.go b/session_get.go index 08172524..22b116a9 100644 --- a/session_get.go +++ b/session_get.go @@ -130,9 +130,6 @@ var ( valuerTypePlaceHolder driver.Valuer valuerType = reflect.TypeOf(&valuerTypePlaceHolder).Elem() - scannerTypePlaceHolder sql.Scanner - scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem() - conversionTypePlaceHolder convert.Conversion conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem() )