refactor and add setjson function (#1997)

Fix #1992

Reviewed-on: https://gitea.com/xorm/xorm/pulls/1997
Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
Co-committed-by: Lunny Xiao <xiaolunwen@gmail.com>
This commit is contained in:
Lunny Xiao 2021-07-19 13:43:53 +08:00
parent 5950824e37
commit 86775af2ec
17 changed files with 148 additions and 184 deletions

View File

@ -193,6 +193,8 @@ func asFloat64(src interface{}) (float64, error) {
return float64(v.Int32), nil return float64(v.Int32), nil
case *sql.NullInt64: case *sql.NullInt64:
return float64(v.Int64), nil return float64(v.Int64), nil
case *sql.NullFloat64:
return v.Float64, nil
} }
rv := reflect.ValueOf(src) rv := reflect.ValueOf(src)
@ -717,6 +719,8 @@ func convertAssignV(dv reflect.Value, src interface{}) error {
func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
switch tp.Kind() { switch tp.Kind() {
case reflect.Ptr:
return asKind(vv.Elem(), tp.Elem())
case reflect.Int64: case reflect.Int64:
return vv.Int(), nil return vv.Int(), nil
case reflect.Int: case reflect.Int:

View File

@ -8,11 +8,16 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"time" "time"
"xorm.io/xorm/internal/utils"
) )
// String2Time converts a string to time with original location // String2Time converts a string to time with original location
func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) { func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) {
if len(s) == 19 { if len(s) == 19 {
if s == utils.ZeroTime0 || s == utils.ZeroTime1 {
return &time.Time{}, nil
}
dt, err := time.ParseInLocation("2006-01-02 15:04:05", s, originalLocation) dt, err := time.ParseInLocation("2006-01-02 15:04:05", s, originalLocation)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -118,6 +118,9 @@ func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query stri
defer rows.Close() defer rows.Close()
if rows.Next() { if rows.Next() {
if rows.Err() != nil {
return true, rows.Err()
}
return true, nil return true, nil
} }
return false, nil return false, nil

View File

@ -456,6 +456,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
cols := make(map[string]*schemas.Column) cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, nil, rows.Err()
}
var name, ctype, vdefault string var name, ctype, vdefault string
var maxLen, precision, scale int var maxLen, precision, scale int
var nullable, isPK, defaultIsNull, isIncrement bool var nullable, isPK, defaultIsNull, isIncrement bool
@ -524,6 +527,9 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
tables := make([]*schemas.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
table := schemas.NewEmptyTable() table := schemas.NewEmptyTable()
var name string var name string
err = rows.Scan(&name) err = rows.Scan(&name)
@ -558,6 +564,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
indexes := make(map[string]*schemas.Index, 0) indexes := make(map[string]*schemas.Index, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
var indexType int var indexType int
var indexName, colName, isUnique string var indexName, colName, isUnique string

View File

@ -405,6 +405,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
cols := make(map[string]*schemas.Column) cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, nil, rows.Err()
}
col := new(schemas.Column) col := new(schemas.Column)
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
@ -519,6 +522,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
tables := make([]*schemas.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
table := schemas.NewEmptyTable() table := schemas.NewEmptyTable()
var name, engine string var name, engine string
var autoIncr, comment *string var autoIncr, comment *string
@ -566,6 +572,9 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName
indexes := make(map[string]*schemas.Index, 0) indexes := make(map[string]*schemas.Index, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
var indexType int var indexType int
var indexName, colName, nonUnique string var indexName, colName, nonUnique string
err = rows.Scan(&indexName, &nonUnique, &colName) err = rows.Scan(&indexName, &nonUnique, &colName)

View File

@ -677,6 +677,9 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
cols := make(map[string]*schemas.Column) cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, nil, rows.Err()
}
col := new(schemas.Column) col := new(schemas.Column)
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
@ -772,6 +775,9 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem
tables := make([]*schemas.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
table := schemas.NewEmptyTable() table := schemas.NewEmptyTable()
err = rows.Scan(&table.Name) err = rows.Scan(&table.Name)
if err != nil { if err != nil {
@ -796,6 +802,9 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam
indexes := make(map[string]*schemas.Index, 0) indexes := make(map[string]*schemas.Index, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
var indexType int var indexType int
var indexName, colName, uniqueness string var indexName, colName, uniqueness string

View File

@ -810,7 +810,7 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas
var version string var version string
if !rows.Next() { if !rows.Next() {
return nil, errors.New("Unknow version") return nil, errors.New("unknow version")
} }
if err := rows.Scan(&version); err != nil { if err := rows.Scan(&version); err != nil {
@ -1098,6 +1098,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, nil, rows.Err()
}
col := new(schemas.Column) col := new(schemas.Column)
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
@ -1192,7 +1195,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
} }
} }
if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok { if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, fmt.Errorf("Unknown colType: %s - %s", dataType, col.SQLType.Name) return nil, nil, fmt.Errorf("unknown colType: %s - %s", dataType, col.SQLType.Name)
} }
col.Length = maxLen col.Length = maxLen
@ -1200,13 +1203,13 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
if !col.DefaultIsEmpty { if !col.DefaultIsEmpty {
if col.SQLType.IsText() { if col.SQLType.IsText() {
if strings.HasSuffix(col.Default, "::character varying") { if strings.HasSuffix(col.Default, "::character varying") {
col.Default = strings.TrimRight(col.Default, "::character varying") col.Default = strings.TrimSuffix(col.Default, "::character varying")
} else if !strings.HasPrefix(col.Default, "'") { } else if !strings.HasPrefix(col.Default, "'") {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
} }
} else if col.SQLType.IsTime() { } else if col.SQLType.IsTime() {
if strings.HasSuffix(col.Default, "::timestamp without time zone") { if strings.HasSuffix(col.Default, "::timestamp without time zone") {
col.Default = strings.TrimRight(col.Default, "::timestamp without time zone") col.Default = strings.TrimSuffix(col.Default, "::timestamp without time zone")
} }
} }
} }
@ -1234,6 +1237,9 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch
tables := make([]*schemas.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
table := schemas.NewEmptyTable() table := schemas.NewEmptyTable()
var name string var name string
err = rows.Scan(&name) err = rows.Scan(&name)
@ -1259,7 +1265,7 @@ func getIndexColName(indexdef string) []string {
func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") s := "SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1"
if len(db.getSchema()) != 0 { if len(db.getSchema()) != 0 {
args = append(args, db.getSchema()) args = append(args, db.getSchema())
s = s + " AND schemaname=$2" s = s + " AND schemaname=$2"
@ -1271,8 +1277,11 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN
} }
defer rows.Close() defer rows.Close()
indexes := make(map[string]*schemas.Index, 0) indexes := make(map[string]*schemas.Index)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
var indexType int var indexType int
var indexName, indexdef string var indexName, indexdef string
var colNames []string var colNames []string
@ -1450,6 +1459,9 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri
} }
defer rows.Close() defer rows.Close()
if rows.Next() { if rows.Next() {
if rows.Err() != nil {
return "", rows.Err()
}
var defaultSchema string var defaultSchema string
if err = rows.Scan(&defaultSchema); err != nil { if err = rows.Scan(&defaultSchema); err != nil {
return "", err return "", err
@ -1458,5 +1470,5 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri
return strings.TrimSpace(parts[len(parts)-1]), nil return strings.TrimSpace(parts[len(parts)-1]), nil
} }
return "", errors.New("No default schema") return "", errors.New("no default schema")
} }

View File

@ -415,12 +415,14 @@ func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableNa
defer rows.Close() defer rows.Close()
var name string var name string
for rows.Next() { if rows.Next() {
if rows.Err() != nil {
return nil, nil, rows.Err()
}
err = rows.Scan(&name) err = rows.Scan(&name)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
break
} }
if name == "" { if name == "" {
@ -496,8 +498,11 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa
} }
defer rows.Close() defer rows.Close()
indexes := make(map[string]*schemas.Index, 0) indexes := make(map[string]*schemas.Index)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
var tmpSQL sql.NullString var tmpSQL sql.NullString
err = rows.Scan(&tmpSQL) err = rows.Scan(&tmpSQL)
if err != nil { if err != nil {

View File

@ -551,6 +551,9 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
sess := engine.NewSession() sess := engine.NewSession()
defer sess.Close() defer sess.Close()
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return rows.Err()
}
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
if err != nil { if err != nil {
return err return err

View File

@ -286,6 +286,9 @@ func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
return nil, err return nil, err
} }
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
result, err := row2mapBytes(rows, types, fields) result, err := row2mapBytes(rows, types, fields)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -364,25 +364,24 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt,
return return
} }
func (session *Session) getField(dataStruct *reflect.Value, key string, table *schemas.Table, idx int) (*reflect.Value, error) { func (session *Session) getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) {
var col *schemas.Column var col = table.GetColumnIdx(colName, idx)
if col = table.GetColumnIdx(key, idx); col == nil { if col == nil {
return nil, ErrFieldIsNotExist{key, table.Name} return nil, nil, ErrFieldIsNotExist{colName, table.Name}
} }
fieldValue, err := col.ValueOfV(dataStruct) fieldValue, err := col.ValueOfV(dataStruct)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if fieldValue == nil { if fieldValue == nil {
return nil, ErrFieldIsNotValid{key, table.Name} return nil, nil, ErrFieldIsNotValid{colName, table.Name}
} }
if !fieldValue.IsValid() || !fieldValue.CanSet() { if !fieldValue.IsValid() || !fieldValue.CanSet() {
return nil, ErrFieldIsNotValid{key, table.Name} return nil, nil, ErrFieldIsNotValid{colName, table.Name}
} }
return fieldValue, nil return col, fieldValue, nil
} }
// Cell cell is a result of one column field // Cell cell is a result of one column field
@ -392,6 +391,9 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sq
table *schemas.Table, newElemFunc func([]string) reflect.Value, table *schemas.Table, newElemFunc func([]string) reflect.Value,
sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error {
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return rows.Err()
}
var newValue = newElemFunc(fields) var newValue = newElemFunc(fields)
bean := newValue.Interface() bean := newValue.Interface()
dataStruct := newValue.Elem() dataStruct := newValue.Elem()
@ -435,6 +437,36 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql
return scanResults, nil return scanResults, nil
} }
func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, scanResult interface{}) error {
bs, ok := asBytes(scanResult)
if !ok {
return fmt.Errorf("unsupported database data type: %#v", scanResult)
}
if len(bs) == 0 {
return nil
}
if fieldType.Kind() == reflect.String {
fieldValue.SetString(string(bs))
return nil
}
if fieldValue.CanAddr() {
err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return err
}
} else {
x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil {
return err
}
fieldValue.Set(x.Elem())
}
return nil
}
func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value,
scanResult interface{}, table *schemas.Table) error { scanResult interface{}, table *schemas.Table) error {
v, ok := scanResult.(*interface{}) v, ok := scanResult.(*interface{})
@ -445,12 +477,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
return nil return nil
} }
rawValue := reflect.Indirect(reflect.ValueOf(scanResult))
// if row is null then ignore
if rawValue.Interface() == nil {
return nil
}
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
data, ok := asBytes(scanResult) data, ok := asBytes(scanResult)
@ -477,40 +503,11 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
return structConvert.FromDB(data) return structConvert.FromDB(data)
} }
rawValueType := reflect.TypeOf(rawValue.Interface()) vv := reflect.ValueOf(scanResult)
vv := reflect.ValueOf(rawValue.Interface())
fieldType := fieldValue.Type() fieldType := fieldValue.Type()
if col.IsJSON { if col.IsJSON {
var bs []byte return session.setJSON(fieldValue, fieldType, scanResult)
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(schemas.BytesType) {
bs = vv.Bytes()
} else {
return fmt.Errorf("unsupported database data type: %s %v", col.Name, rawValueType.Kind())
}
if len(bs) > 0 {
if fieldType.Kind() == reflect.String {
fieldValue.SetString(string(bs))
return nil
}
if fieldValue.CanAddr() {
err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return err
}
} else {
x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil {
return err
}
fieldValue.Set(x.Elem())
}
}
return nil
} }
switch fieldType.Kind() { switch fieldType.Kind() {
@ -529,30 +526,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
} }
return nil return nil
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
// TODO: reimplement this return session.setJSON(fieldValue, fieldType, scanResult)
var bs []byte
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(schemas.BytesType) {
bs = vv.Bytes()
}
if len(bs) > 0 {
if fieldValue.CanAddr() {
err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil {
return err
}
} else {
x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil {
return err
}
fieldValue.Set(x.Elem())
}
}
return nil
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
bs, ok := asBytes(scanResult) bs, ok := asBytes(scanResult)
if ok && fieldType.Elem().Kind() == reflect.Uint8 { if ok && fieldType.Elem().Kind() == reflect.Uint8 {
@ -602,33 +576,11 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType))
return nil return nil
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
err := nulVal.Scan(vv.Interface()) err := nulVal.Scan(scanResult)
if err == nil { if err == nil {
return nil return nil
} }
session.engine.logger.Errorf("sql.Sanner error: %v", err) session.engine.logger.Errorf("sql.Sanner error: %v", err)
} else if col.IsJSON {
if rawValueType.Kind() == reflect.String {
x := reflect.New(fieldType)
if len([]byte(vv.String())) > 0 {
err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil {
return err
}
fieldValue.Set(x.Elem())
}
return nil
} else if rawValueType.Kind() == reflect.Slice {
x := reflect.New(fieldType)
if len(vv.Bytes()) > 0 {
err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface())
if err != nil {
return err
}
fieldValue.Set(x.Elem())
}
return nil
}
} else if session.statement.UseCascade { } else if session.statement.UseCascade {
table, err := session.engine.tagParser.ParseWithCache(*fieldValue) table, err := session.engine.tagParser.ParseWithCache(*fieldValue)
if err != nil { if err != nil {
@ -639,7 +591,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
return errors.New("unsupported non or composited primary key cascade") return errors.New("unsupported non or composited primary key cascade")
} }
var pk = make(schemas.PK, len(table.PrimaryKeys)) var pk = make(schemas.PK, len(table.PrimaryKeys))
pk[0], err = asKind(vv, rawValueType) pk[0], err = asKind(vv, reflect.TypeOf(scanResult))
if err != nil { if err != nil {
return err return err
} }
@ -675,9 +627,9 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
var tempMap = make(map[string]int) var tempMap = make(map[string]int)
var pk schemas.PK var pk schemas.PK
for ii, key := range fields { for i, colName := range fields {
var idx int var idx int
var lKey = strings.ToLower(key) var lKey = strings.ToLower(colName)
var ok bool var ok bool
if idx, ok = tempMap[lKey]; !ok { if idx, ok = tempMap[lKey]; !ok {
@ -685,13 +637,9 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
} else { } else {
idx = idx + 1 idx = idx + 1
} }
tempMap[lKey] = idx tempMap[lKey] = idx
col := table.GetColumnIdx(key, idx)
var scanResult = scanResults[ii] col, fieldValue, err := session.getField(dataStruct, table, colName, idx)
fieldValue, err := session.getField(dataStruct, key, table, idx)
if err != nil { if err != nil {
if _, ok := err.(ErrFieldIsNotValid); !ok { if _, ok := err.(ErrFieldIsNotValid); !ok {
session.engine.logger.Warnf("%v", err) session.engine.logger.Warnf("%v", err)
@ -702,11 +650,11 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
continue continue
} }
if err := session.convertBeanField(col, fieldValue, scanResult, table); err != nil { if err := session.convertBeanField(col, fieldValue, scanResults[i], table); err != nil {
return nil, err return nil, err
} }
if col.IsPrimaryKey { if col.IsPrimaryKey {
pk = append(pk, scanResult) pk = append(pk, scanResults[i])
} }
} }
return pk, nil return pk, nil

View File

@ -1,70 +0,0 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"fmt"
"strconv"
"strings"
"time"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
func (session *Session) str2Time(col *schemas.Column, data string) (outTime time.Time, outErr error) {
sdata := strings.TrimSpace(data)
var x time.Time
var err error
var parseLoc = session.engine.DatabaseTZ
if col.TimeZone != nil {
parseLoc = col.TimeZone
}
if sdata == utils.ZeroTime0 || sdata == utils.ZeroTime1 {
} else if !strings.ContainsAny(sdata, "- :") { // !nashtsai! has only found that mymysql driver is using this for time type column
// time stamp
sd, err := strconv.ParseInt(sdata, 10, 64)
if err == nil {
x = time.Unix(sd, 0)
}
} else if len(sdata) > 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc)
session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.Name, x, sdata)
if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc)
}
if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc)
}
} else if len(sdata) == 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc)
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
} else if col.SQLType.Name == schemas.Time {
if strings.Contains(sdata, " ") {
ssd := strings.Split(sdata, " ")
sdata = ssd[1]
}
sdata = strings.TrimSpace(sdata)
if session.engine.dialect.URI().DBType == schemas.MYSQL && len(sdata) > 8 {
sdata = sdata[len(sdata)-8:]
}
st := fmt.Sprintf("2006-01-02 %v", sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc)
} else {
outErr = fmt.Errorf("unsupported time format %v", sdata)
return
}
if err != nil {
outErr = fmt.Errorf("unsupported time format %v: %v", sdata, err)
return
}
outTime = x.In(session.engine.TZLocation)
return
}

View File

@ -255,6 +255,9 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
} }
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return rows.Err()
}
var newValue = newElemFunc(fields) var newValue = newElemFunc(fields)
bean := newValue.Interface() bean := newValue.Interface()
@ -322,6 +325,9 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
var i int var i int
ids = make([]schemas.PK, 0) ids = make([]schemas.PK, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return rows.Err()
}
i++ i++
if i > 500 { if i > 500 {
session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache")

View File

@ -313,9 +313,12 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
defer rows.Close() defer rows.Close()
if rows.Next() { if rows.Next() {
if rows.Err() != nil {
return true, rows.Err()
}
err = rows.ScanSlice(&res) err = rows.ScanSlice(&res)
if err != nil { if err != nil {
return false, err return true, err
} }
} else { } else {
return false, ErrCacheFailed return false, ErrCacheFailed

View File

@ -43,6 +43,9 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error {
i := 0 i := 0
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return rows.Err()
}
b := reflect.New(rows.beanType).Interface() b := reflect.New(rows.beanType).Interface()
err = rows.Scan(b) err = rows.Scan(b)
if err != nil { if err != nil {

View File

@ -33,6 +33,9 @@ func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string
} }
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
result, err := session.engine.row2mapStr(rows, types, fields) result, err := session.engine.row2mapStr(rows, types, fields)
if err != nil { if err != nil {
return nil, err return nil, err
@ -54,6 +57,9 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri
} }
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
record, err := session.engine.row2sliceStr(rows, types, fields) record, err := session.engine.row2sliceStr(rows, types, fields)
if err != nil { if err != nil {
return nil, err return nil, err
@ -114,6 +120,9 @@ func (session *Session) rows2Interfaces(rows *core.Rows) (resultsSlice []map[str
return nil, err return nil, err
} }
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return nil, rows.Err()
}
result, err := session.engine.row2mapInterface(rows, types, fields) result, err := session.engine.row2mapInterface(rows, types, fields)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -59,6 +59,9 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
ids = make([]schemas.PK, 0) ids = make([]schemas.PK, 0)
for rows.Next() { for rows.Next() {
if rows.Err() != nil {
return rows.Err()
}
var res = make([]string, len(table.PrimaryKeys)) var res = make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res) err = rows.ScanSlice(&res)
if err != nil { if err != nil {