diff --git a/rows.go b/rows.go index ef3e42b6..a42eedb9 100644 --- a/rows.go +++ b/rows.go @@ -129,7 +129,9 @@ func (rows *Rows) Scan(beans ...interface{}) error { return err } - if err := rows.session.scan(rows.rows, rows.session.statement.RefTable, beanKind, beans, types, fields); err != nil { + columnsSchema := ParseColumnsSchema(fields, types, rows.session.statement.RefTable) + + if err := rows.session.scan(rows.rows, rows.session.statement.RefTable, beanKind, beans, columnsSchema, types, fields); err != nil { return err } diff --git a/session.go b/session.go index af6e4921..14d0781e 100644 --- a/session.go +++ b/session.go @@ -16,8 +16,6 @@ import ( "io" "reflect" "strconv" - "strings" - "xorm.io/xorm/contexts" "xorm.io/xorm/convert" "xorm.io/xorm/core" @@ -395,10 +393,10 @@ func (session *Session) doPrepareTx(sqlStr string) (stmt *core.Stmt, err error) return } -func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) { - col := table.GetColumnIdx(colName, idx) +func getField(dataStruct *reflect.Value, table *schemas.Table, field *QueryedField) (*schemas.Column, *reflect.Value, error) { + col := field.ColumnSchema if col == nil { - return nil, nil, ErrFieldIsNotExist{colName, table.Name} + return nil, nil, ErrFieldIsNotExist{field.FieldName, table.Name} } fieldValue, err := col.ValueOfV(dataStruct) @@ -406,10 +404,10 @@ func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, i return nil, nil, err } if fieldValue == nil { - return nil, nil, ErrFieldIsNotValid{colName, table.Name} + return nil, nil, ErrFieldIsNotValid{field.FieldName, table.Name} } if !fieldValue.IsValid() || !fieldValue.CanSet() { - return nil, nil, ErrFieldIsNotValid{colName, table.Name} + return nil, nil, ErrFieldIsNotValid{field.FieldName, table.Name} } return col, fieldValue, nil @@ -418,7 +416,7 @@ func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, i // Cell cell is a result of one column field type Cell *interface{} -func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sql.ColumnType, +func (session *Session) rows2Beans(rows *core.Rows, columnsSchema *ColumnsSchema, fields []string, types []*sql.ColumnType, table *schemas.Table, newElemFunc func([]string) reflect.Value, sliceValueSetFunc func(*reflect.Value, schemas.PK) error, ) error { @@ -432,7 +430,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sq if err != nil { return err } - pk, err := session.slice2Bean(scanResults, fields, bean, &dataStruct, table) + pk, err := session.slice2Bean(scanResults, columnsSchema, fields, bean, &dataStruct, table) if err != nil { return err } @@ -705,28 +703,16 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return convert.AssignValue(fieldValue.Addr(), scanResult) } -func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { +func (session *Session) slice2Bean(scanResults []interface{}, columnsSchema *ColumnsSchema, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { defer func() { executeAfterSet(bean, fields, scanResults) }() buildAfterProcessors(session, bean) - tempMap := make(map[string]int) var pk schemas.PK - for i, colName := range fields { - var idx int - lKey := strings.ToLower(colName) - var ok bool - - if idx, ok = tempMap[lKey]; !ok { - idx = 0 - } else { - idx++ - } - tempMap[lKey] = idx - - col, fieldValue, err := getField(dataStruct, table, colName, idx) + for i, field := range columnsSchema.Fields { + col, fieldValue, err := getField(dataStruct, table, field) if _, ok := err.(ErrFieldIsNotExist); ok { continue } else if err != nil { @@ -800,3 +786,7 @@ func (session *Session) NoVersionCheck() *Session { session.statement.CheckVersion = false return session } + +func SetDefaultJSONHandler(jsonHandler json.Interface) { + json.DefaultJSONHandler = jsonHandler +} diff --git a/session_find.go b/session_find.go index d9444aee..1026910c 100644 --- a/session_find.go +++ b/session_find.go @@ -5,8 +5,10 @@ package xorm import ( + "database/sql" "errors" "reflect" + "strings" "xorm.io/builder" "xorm.io/xorm/caches" @@ -161,6 +163,64 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) return session.noCacheFind(table, sliceValue, sqlStr, args...) } +type QueryedField struct { + FieldName string + LowerFieldName string + ColumnType *sql.ColumnType + TempIndex int + ColumnSchema *schemas.Column +} + +type ColumnsSchema struct { + Fields []*QueryedField + FieldNames []string + Types []*sql.ColumnType +} + +func (columnsSchema *ColumnsSchema) ParseTableSchema(table *schemas.Table) { + for _, field := range columnsSchema.Fields { + field.ColumnSchema = table.GetColumnIdx(field.FieldName, field.TempIndex) + } +} + +func ParseColumnsSchema(fieldNames []string, types []*sql.ColumnType, table *schemas.Table) *ColumnsSchema { + var columnsSchema ColumnsSchema + + fields := make([]*QueryedField, 0, len(fieldNames)) + + for i, fieldName := range fieldNames { + field := &QueryedField{ + FieldName: fieldName, + LowerFieldName: strings.ToLower(fieldName), + ColumnType: types[i], + } + fields = append(fields, field) + } + + columnsSchema.Fields = fields + + tempMap := make(map[string]int) + for _, field := range fields { + var idx int + var ok bool + + if idx, ok = tempMap[field.LowerFieldName]; !ok { + idx = 0 + } else { + idx++ + } + + tempMap[field.LowerFieldName] = idx + field.TempIndex = idx + } + + if table != nil { + columnsSchema.ParseTableSchema(table) + } + + return &columnsSchema +} + func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { elemType := containerValue.Type().Elem() var isPointer bool @@ -238,7 +298,10 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect if err != nil { return err } - err = session.rows2Beans(rows, fields, types, tb, newElemFunc, containerValueSetFunc) + + columnsSchema := ParseColumnsSchema(fields, types, tb) + + err = session.rows2Beans(rows, columnsSchema, fields, types, tb, newElemFunc, containerValueSetFunc) rows.Close() if err != nil { return err diff --git a/session_get.go b/session_get.go index 96e362e9..0d590330 100644 --- a/session_get.go +++ b/session_get.go @@ -164,7 +164,9 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, return true, err } - if err := session.scan(rows, table, beanKind, beans, types, fields); err != nil { + columnsSchema := ParseColumnsSchema(fields, types, table) + + if err := session.scan(rows, table, beanKind, beans, columnsSchema, types, fields); err != nil { return true, err } rows.Close() @@ -172,7 +174,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, return true, session.executeProcessors() } -func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKind reflect.Kind, beans []interface{}, types []*sql.ColumnType, fields []string) error { +func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKind reflect.Kind, beans []interface{}, columnsSchema *ColumnsSchema, types []*sql.ColumnType, fields []string) error { if len(beans) == 1 { bean := beans[0] switch firstBeanKind { @@ -186,7 +188,7 @@ func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKin } dataStruct := utils.ReflectValue(bean) - _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) + _, err = session.slice2Bean(scanResults, columnsSchema, fields, bean, &dataStruct, table) return err case reflect.Slice: return session.getSlice(rows, types, fields, bean)