From 7b84aa150b0cf868b8e7880f93d60fa1b0324ff3 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 11 Jun 2021 13:35:50 +0800 Subject: [PATCH] Improve get bean --- convert.go | 3 +-- convert/string.go | 38 ++++++++++++++++++++++++++ dialects/mysql.go | 7 +++++ dialects/sqlite3.go | 6 +++++ engine.go | 7 +---- session_query.go | 66 --------------------------------------------- 6 files changed, 53 insertions(+), 74 deletions(-) create mode 100644 convert/string.go diff --git a/convert.go b/convert.go index f7d733ad..a2124dd3 100644 --- a/convert.go +++ b/convert.go @@ -462,7 +462,6 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve } var sv reflect.Value - switch d := dest.(type) { case *string: sv = reflect.ValueOf(src) @@ -496,7 +495,7 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, convertedLocation *time.Location) error { if dpv.Kind() != reflect.Ptr { - return errors.New("destination not a pointer") + return fmt.Errorf("destination %s not a pointer", dpv.Kind()) } if dpv.IsNil() { return errNilPtr diff --git a/convert/string.go b/convert/string.go new file mode 100644 index 00000000..a9d9ee98 --- /dev/null +++ b/convert/string.go @@ -0,0 +1,38 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "fmt" + "strconv" +) + +func ConvertAssignString(v interface{}) (string, error) { + switch vv := v.(type) { + case *sql.NullString: + if vv.Valid { + return vv.String, nil + } + return "", nil + case *int64: + if vv != nil { + return strconv.FormatInt(*vv, 10), nil + } + return "", nil + case *int8: + if vv != nil { + return strconv.FormatInt(int64(*vv), 10), nil + } + return "", nil + case *sql.RawBytes: + if vv != nil && len([]byte(*vv)) > 0 { + return string(*vv), nil + } + return "", nil + default: + return "", fmt.Errorf("unsupported type: %#v", vv) + } +} diff --git a/dialects/mysql.go b/dialects/mysql.go index 548692e0..7ba188b5 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -671,6 +671,12 @@ func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { return uri, nil } +func (b *mysqlDriver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: false, + } +} + func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) { switch colType { case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET": @@ -698,6 +704,7 @@ func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) { var r sql.RawBytes return &r, nil default: + fmt.Printf("unknow mysql database type: %v\n", colType) var r sql.RawBytes return &r, nil } diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 1bc0b218..a43cccaf 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -551,6 +551,12 @@ func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil } +func (b *sqlite3Driver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: false, + } +} + func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { switch colType { case "TEXT": diff --git a/engine.go b/engine.go index 1064e8e1..c40eb042 100644 --- a/engine.go +++ b/engine.go @@ -82,12 +82,7 @@ func newEngine(driverName, dataSourceName string, dialect dialects.Dialect, db * dataSourceName: dataSourceName, db: db, logSessionID: false, - } - - if dialect.URI().DBType == schemas.SQLITE { - engine.DatabaseTZ = time.UTC - } else { - engine.DatabaseTZ = time.Local + DatabaseTZ: time.Local, } logger := log.NewSimpleLogger(os.Stdout) diff --git a/session_query.go b/session_query.go index f16a498d..9594da25 100644 --- a/session_query.go +++ b/session_query.go @@ -5,19 +5,8 @@ package xorm import ( -<<<<<<< HEAD -======= -<<<<<<< HEAD -======= "database/sql" - "errors" ->>>>>>> 6e19325 (refactor driver) - "fmt" - "reflect" - "strconv" - "time" ->>>>>>> 634f82a (refactor driver) "xorm.io/xorm/core" ) @@ -35,55 +24,6 @@ func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, er return session.queryBytes(sqlStr, args...) } -<<<<<<< HEAD -======= -func value2String(rawValue *reflect.Value) (str string, err error) { - aa := reflect.TypeOf((*rawValue).Interface()) - vv := reflect.ValueOf((*rawValue).Interface()) - switch aa.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str = strconv.FormatInt(vv.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str = strconv.FormatUint(vv.Uint(), 10) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) - case reflect.String: - str = vv.String() - case reflect.Array, reflect.Slice: - switch aa.Elem().Kind() { - case reflect.Uint8: - data := rawValue.Interface().([]byte) - str = string(data) - if str == "\x00" { - str = "0" - } - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - // time type - case reflect.Struct: - if aa.ConvertibleTo(schemas.TimeType) { - str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) - } else { - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - case reflect.Bool: - str = strconv.FormatBool(vv.Bool()) - case reflect.Complex128, reflect.Complex64: - str = fmt.Sprintf("%v", vv.Complex()) - /* TODO: unsupported types below - case reflect.Map: - case reflect.Ptr: - case reflect.Uintptr: - case reflect.UnsafePointer: - case reflect.Chan, reflect.Func, reflect.Interface: - */ - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - return -} - // genRowsScanResults according func (session *Session) genRowsScanResults(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { var scanResults = make([]interface{}, len(types)) @@ -97,7 +37,6 @@ func (session *Session) genRowsScanResults(rows *core.Rows, types []*sql.ColumnT return scanResults, nil } ->>>>>>> 634f82a (refactor driver) func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { fields, err := rows.Columns() if err != nil { @@ -107,11 +46,6 @@ func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string if err != nil { return nil, err } - types, err := rows.ColumnTypes() - if err != nil { - return nil, err - } - for rows.Next() { result, err := row2mapStr(rows, types, fields) if err != nil {