diff --git a/convert/time.go b/convert/time.go new file mode 100644 index 00000000..8901279b --- /dev/null +++ b/convert/time.go @@ -0,0 +1,30 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "fmt" + "time" +) + +// String2Time converts a string to time with original location +func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) { + if len(s) == 19 { + dt, err := time.ParseInLocation("2006-01-02 15:04:05", s, originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { + dt, err := time.ParseInLocation("2006-01-02T15:04:05Z", s, originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } + return nil, fmt.Errorf("unsupported convertion from %s to time", s) +} diff --git a/dialects/driver.go b/dialects/driver.go index bb46a936..c511b665 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -5,12 +5,24 @@ package dialects import ( + "database/sql" "fmt" + "time" + + "xorm.io/xorm/core" ) +// ScanContext represents a context when Scan +type ScanContext struct { + DBLocation *time.Location + UserLocation *time.Location +} + // Driver represents a database driver type Driver interface { Parse(string, string) (*URI, error) + GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface + Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error } var ( @@ -59,3 +71,9 @@ func OpenDialect(driverName, connstr string) (Dialect, error) { return dialect, nil } + +type baseDriver struct{} + +func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error { + return rows.Scan(v...) +} diff --git a/dialects/mssql.go b/dialects/mssql.go index 7e922e62..c3c15077 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "net/url" @@ -624,6 +625,7 @@ func (db *mssql) Filters() []Filter { } type odbcDriver struct { + baseDriver } func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -652,3 +654,26 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { } return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil } + +func (p *odbcDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT": + fallthrough + case "DATE", "DATETIME", "DATETIME2", "TIME": + var s sql.NullString + return &s, nil + case "FLOAT", "REAL": + var s sql.NullFloat64 + return &s, nil + case "BIGINT", "DATETIMEOFFSET": + var s sql.NullInt64 + return &s, nil + case "TINYINT", "SMALLINT", "INT": + var s sql.NullInt32 + return &s, nil + + default: + var r sql.RawBytes + return &r, nil + } +} diff --git a/dialects/mysql.go b/dialects/mysql.go index a169b901..6cf40608 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -7,6 +7,7 @@ package dialects import ( "context" "crypto/tls" + "database/sql" "errors" "fmt" "regexp" @@ -14,6 +15,7 @@ import ( "strings" "time" + "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/schemas" ) @@ -630,7 +632,119 @@ func (db *mysql) Filters() []Filter { return []Filter{} } +type mysqlDriver struct { +} + +func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { + dsnPattern := regexp.MustCompile( + `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] + `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] + `\/(?P.*?)` + // /dbname + `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] + matches := dsnPattern.FindStringSubmatch(dataSourceName) + // tlsConfigRegister := make(map[string]*tls.Config) + names := dsnPattern.SubexpNames() + + uri := &URI{DBType: schemas.MYSQL} + + for i, match := range matches { + switch names[i] { + case "dbname": + uri.DBName = match + case "params": + if len(match) > 0 { + kvs := strings.Split(match, "&") + for _, kv := range kvs { + splits := strings.Split(kv, "=") + if len(splits) == 2 { + switch splits[0] { + case "charset": + uri.Charset = splits[1] + } + } + } + } + + } + } + return uri, nil +} + +func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "VARCHAR", "TEXT": + var s sql.NullString + return &s, nil + case "BIGINT": + var s sql.NullInt64 + return &s, nil + case "TINYINT", "INT": + var s sql.NullInt32 + return &s, nil + case "FLOAT": + var s sql.NullFloat64 + return &s, nil + case "DATETIME": + var s sql.NullTime + return &s, nil + case "BIT": + var s sql.RawBytes + return &s, nil + default: + fmt.Printf("unknow mysql database type: %v\n", colType) + var r sql.RawBytes + return &r, nil + } +} + +func (p *mysqlDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, scanResults ...interface{}) error { + var v2 = make([]interface{}, 0, len(scanResults)) + var turnBackIdxes = make([]int, 0, 5) + for i, vv := range scanResults { + switch vv.(type) { + case *time.Time: + v2 = append(v2, &sql.NullString{}) + turnBackIdxes = append(turnBackIdxes, i) + case *sql.NullTime: + v2 = append(v2, &sql.NullString{}) + turnBackIdxes = append(turnBackIdxes, i) + default: + v2 = append(v2, scanResults[i]) + } + } + if err := rows.Scan(v2...); err != nil { + return err + } + for _, i := range turnBackIdxes { + switch t := scanResults[i].(type) { + case *time.Time: + var s = *(v2[i].(*sql.NullString)) + if !s.Valid { + break + } + dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation) + if err != nil { + return err + } + *t = *dt + case *sql.NullTime: + var s = *(v2[i].(*sql.NullString)) + if !s.Valid { + break + } + dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation) + if err != nil { + return err + } + t.Time = *dt + t.Valid = true + } + } + return nil +} + type mymysqlDriver struct { + mysqlDriver } func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -681,41 +795,3 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { return uri, nil } - -type mysqlDriver struct { -} - -func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { - dsnPattern := regexp.MustCompile( - `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] - `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] - `\/(?P.*?)` + // /dbname - `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] - matches := dsnPattern.FindStringSubmatch(dataSourceName) - // tlsConfigRegister := make(map[string]*tls.Config) - names := dsnPattern.SubexpNames() - - uri := &URI{DBType: schemas.MYSQL} - - for i, match := range matches { - switch names[i] { - case "dbname": - uri.DBName = match - case "params": - if len(match) > 0 { - kvs := strings.Split(match, "&") - for _, kv := range kvs { - splits := strings.Split(kv, "=") - if len(splits) == 2 { - switch splits[0] { - case "charset": - uri.Charset = splits[1] - } - } - } - } - - } - } - return uri, nil -} diff --git a/dialects/oracle.go b/dialects/oracle.go index 0b06c4c6..fe3e0a2f 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "regexp" @@ -823,6 +824,7 @@ func (db *oracle) Filters() []Filter { } type godrorDriver struct { + baseDriver } func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -848,7 +850,19 @@ func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) return db, nil } +func (p *godrorDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "VARCHAR", "TEXT": + var s sql.NullString + return &s, nil + default: + var r sql.RawBytes + return &r, nil + } +} + type oci8Driver struct { + godrorDriver } // dataSourceName=user/password@ipv4:port/dbname diff --git a/dialects/postgres.go b/dialects/postgres.go index 9acf763a..9a5beddb 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "net/url" @@ -1298,6 +1299,7 @@ func (db *postgres) Filters() []Filter { } type pqDriver struct { + baseDriver } type values map[string]string @@ -1374,6 +1376,33 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { return db, nil } +func (p *pqDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "VARCHAR", "TEXT": + var s sql.NullString + return &s, nil + case "BIGINT": + var s sql.NullInt64 + return &s, nil + case "TINYINT", "INT": + var s sql.NullInt32 + return &s, nil + case "FLOAT": + var s sql.NullFloat64 + return &s, nil + case "DATETIME": + var s sql.NullTime + return &s, nil + case "BIT": + var s sql.RawBytes + return &s, nil + default: + fmt.Printf("unknow postgres database type: %v\n", colType) + var r sql.RawBytes + return &r, nil + } +} + type pqDriverPgx struct { pqDriver } diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index a42aad48..bca0bc18 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -540,6 +540,7 @@ func (db *sqlite3) Filters() []Filter { } type sqlite3Driver struct { + baseDriver } func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -549,3 +550,30 @@ func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil } + +func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "TEXT": + var s sql.NullString + return &s, nil + case "INTEGER": + var s sql.NullInt64 + return &s, nil + case "DATETIME": + var s sql.NullTime + return &s, nil + case "REAL": + var s sql.NullFloat64 + return &s, nil + case "NUMERIC": + var s sql.NullString + return &s, nil + case "BLOB": + var s sql.RawBytes + return &s, nil + default: + fmt.Printf("====unknow handle db type: %v \n", colType) + var r sql.NullString + return &r, nil + } +} diff --git a/engine.go b/engine.go index 0eb429b1..1064e8e1 100644 --- a/engine.go +++ b/engine.go @@ -35,6 +35,7 @@ type Engine struct { cacherMgr *caches.Manager defaultContext context.Context dialect dialects.Dialect + driver dialects.Driver engineGroup *EngineGroup logger log.ContextLogger tagParser *tags.Parser @@ -72,6 +73,7 @@ func newEngine(driverName, dataSourceName string, dialect dialects.Dialect, db * engine := &Engine{ dialect: dialect, + driver: dialects.QueryDriver(driverName), TZLocation: time.Local, defaultContext: context.Background(), cacherMgr: cacherMgr, diff --git a/integrations/session_query_test.go b/integrations/session_query_test.go index 5f3a0797..2338acb0 100644 --- a/integrations/session_query_test.go +++ b/integrations/session_query_test.go @@ -132,10 +132,10 @@ func TestQueryInterface(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) - assert.EqualValues(t, 1, toInt64(records[0]["id"])) - assert.Equal(t, "hi", toString(records[0]["msg"])) + assert.EqualValues(t, int64(1), records[0]["id"]) + assert.Equal(t, "hi", records[0]["msg"]) assert.EqualValues(t, 28, toInt64(records[0]["age"])) - assert.EqualValues(t, 1.5, toFloat64(records[0]["money"])) + assert.EqualValues(t, 1.5, records[0]["money"]) } func TestQueryNoParams(t *testing.T) { diff --git a/scan.go b/scan.go index e19037a0..d5b5698d 100644 --- a/scan.go +++ b/scan.go @@ -6,8 +6,10 @@ package xorm import ( "database/sql" + "fmt" "xorm.io/xorm/core" + "xorm.io/xorm/dialects" ) func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { @@ -65,3 +67,49 @@ func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fie } return results, nil } + +func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]interface{}, error) { + var resultsMap = make(map[string]interface{}, len(fields)) + var scanResultContainers = make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) + if err != nil { + return nil, err + } + scanResultContainers[i] = scanResult + } + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResultContainers...); err != nil { + return nil, err + } + + for ii, key := range fields { + switch t := scanResultContainers[ii].(type) { + case *sql.NullInt32: + resultsMap[key] = t.Int32 + case *sql.NullInt64: + resultsMap[key] = t.Int64 + case *sql.NullFloat64: + resultsMap[key] = t.Float64 + case *sql.NullString: + resultsMap[key] = t.String + case *sql.NullTime: + if t.Valid { + resultsMap[key] = t.Time.In(engine.TZLocation).Format("2006-01-02 15:04:05") + } else { + resultsMap[key] = nil + } + case *sql.RawBytes: + if t == nil { + resultsMap[key] = nil + } else { + resultsMap[key] = []byte(*t) + } + default: + return nil, fmt.Errorf("unknow type: %v", t) + } + } + return resultsMap, nil +} diff --git a/session_query.go b/session_query.go index 379ad0e1..24ec86d3 100644 --- a/session_query.go +++ b/session_query.go @@ -157,30 +157,17 @@ func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string, return session.rows2SliceString(rows) } -func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) { - resultsMap = make(map[string]interface{}, len(fields)) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - resultsMap[key] = reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])).Interface() - } - return -} - -func rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) { +func (session *Session) rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) { fields, err := rows.Columns() if err != nil { return nil, err } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } for rows.Next() { - result, err := row2mapInterface(rows, fields) + result, err := session.engine.row2mapInterface(rows, types, fields) if err != nil { return nil, err } @@ -207,5 +194,5 @@ func (session *Session) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]i } defer rows.Close() - return rows2Interfaces(rows) + return session.rows2Interfaces(rows) }