From 611e378b7be40fb612a08e02430bb4aef7b8a34d Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 10 Jun 2021 23:32:10 +0800 Subject: [PATCH] refactor driver --- dialects/mysql.go | 6 +++ integrations/session_query_test.go | 2 +- session_query.go | 80 ++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 1 deletion(-) diff --git a/dialects/mysql.go b/dialects/mysql.go index a341ce05..548692e0 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -714,6 +714,9 @@ func (p *mysqlDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.Colum case *sql.NullTime: v2 = append(v2, &sql.NullString{}) turnBackIdxes = append(turnBackIdxes, i) + case bool: + v2 = append(v2, new(bool)) + turnBackIdxes = append(turnBackIdxes, i) default: v2 = append(v2, scanResults[i]) } @@ -744,6 +747,9 @@ func (p *mysqlDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.Colum } t.Time = *dt t.Valid = true + case *bool: + var s = *(v2[i].(*bool)) + *t = s } } return nil diff --git a/integrations/session_query_test.go b/integrations/session_query_test.go index ed03ff3e..1e0b0dc6 100644 --- a/integrations/session_query_test.go +++ b/integrations/session_query_test.go @@ -68,7 +68,7 @@ func TestQueryString2(t *testing.T) { assert.Equal(t, 1, len(records)) assert.Equal(t, 2, len(records[0])) assert.Equal(t, "1", records[0]["id"]) - assert.True(t, "0" == records[0]["msg"] || "false" == records[0]["msg"]) + assert.True(t, "0" == records[0]["msg"] || "false" == records[0]["msg"], records[0]) } func toString(i interface{}) string { diff --git a/session_query.go b/session_query.go index fa33496d..f16a498d 100644 --- a/session_query.go +++ b/session_query.go @@ -5,6 +5,19 @@ package xorm import ( +<<<<<<< HEAD +======= +<<<<<<< HEAD +======= + "database/sql" + "errors" +>>>>>>> 6e19325 (refactor driver) + "fmt" + "reflect" + "strconv" + "time" + +>>>>>>> 634f82a (refactor driver) "xorm.io/xorm/core" ) @@ -22,6 +35,69 @@ func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, er return session.queryBytes(sqlStr, args...) } +<<<<<<< HEAD +======= +func value2String(rawValue *reflect.Value) (str string, err error) { + aa := reflect.TypeOf((*rawValue).Interface()) + vv := reflect.ValueOf((*rawValue).Interface()) + switch aa.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(vv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str = strconv.FormatUint(vv.Uint(), 10) + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) + case reflect.String: + str = vv.String() + case reflect.Array, reflect.Slice: + switch aa.Elem().Kind() { + case reflect.Uint8: + data := rawValue.Interface().([]byte) + str = string(data) + if str == "\x00" { + str = "0" + } + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + // time type + case reflect.Struct: + if aa.ConvertibleTo(schemas.TimeType) { + str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) + } else { + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + case reflect.Bool: + str = strconv.FormatBool(vv.Bool()) + case reflect.Complex128, reflect.Complex64: + str = fmt.Sprintf("%v", vv.Complex()) + /* TODO: unsupported types below + case reflect.Map: + case reflect.Ptr: + case reflect.Uintptr: + case reflect.UnsafePointer: + case reflect.Chan, reflect.Func, reflect.Interface: + */ + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + return +} + +// genRowsScanResults according +func (session *Session) genRowsScanResults(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { + var scanResults = make([]interface{}, len(types)) + var err error + for i, t := range types { + scanResults[i], err = session.engine.driver.GenScanResult(t.DatabaseTypeName()) + if err != nil { + return nil, err + } + } + return scanResults, nil +} + +>>>>>>> 634f82a (refactor driver) func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { fields, err := rows.Columns() if err != nil { @@ -31,6 +107,10 @@ func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string if err != nil { return nil, err } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } for rows.Next() { result, err := row2mapStr(rows, types, fields)