diff --git a/rows.go b/rows.go index e464b101..81b3a29f 100644 --- a/rows.go +++ b/rows.go @@ -129,9 +129,12 @@ func (rows *Rows) Scan(beans ...interface{}) error { return err } - allColumn := ParseQueryRows(fields, types) + columnsSchema, parseError := ParseColumnsSchema(fields, types, rows.session.statement.RefTable) + if parseError != nil { + return parseError + } - if err := rows.session.scan(rows.rows, rows.session.statement.RefTable, beanKind, beans, allColumn, types, fields); err != nil { + if err := rows.session.scan(rows.rows, rows.session.statement.RefTable, beanKind, beans, columnsSchema, types, fields); err != nil { return err } diff --git a/session.go b/session.go index a9b9f863..5e3d0ac9 100644 --- a/session.go +++ b/session.go @@ -416,7 +416,7 @@ func getField(dataStruct *reflect.Value, table *schemas.Table, field *QueryedFie // Cell cell is a result of one column field type Cell *interface{} -func (session *Session) rows2Beans(rows *core.Rows, allColumn *AllColumn, fields []string, types []*sql.ColumnType, +func (session *Session) rows2Beans(rows *core.Rows, columnsSchema *ColumnsSchema, fields []string, types []*sql.ColumnType, table *schemas.Table, newElemFunc func([]string) reflect.Value, sliceValueSetFunc func(*reflect.Value, schemas.PK) error, ) error { @@ -426,11 +426,11 @@ func (session *Session) rows2Beans(rows *core.Rows, allColumn *AllColumn, fields dataStruct := newValue.Elem() // handle beforeClosures - scanResults, err := session.row2Slice(rows, allColumn, fields, types, bean) + scanResults, err := session.row2Slice(rows, fields, types, bean) if err != nil { return err } - pk, err := session.slice2Bean(scanResults, allColumn, fields, bean, &dataStruct, table) + pk, err := session.slice2Bean(scanResults, columnsSchema, fields, bean, &dataStruct, table) if err != nil { return err } @@ -445,7 +445,7 @@ func (session *Session) rows2Beans(rows *core.Rows, allColumn *AllColumn, fields return rows.Err() } -func (session *Session) row2Slice(rows *core.Rows, allColumn *AllColumn, fields []string, types []*sql.ColumnType, bean interface{}) ([]interface{}, error) { +func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql.ColumnType, bean interface{}) ([]interface{}, error) { for _, closure := range session.beforeClosures { closure(bean) } @@ -703,7 +703,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return convert.AssignValue(fieldValue.Addr(), scanResult) } -func (session *Session) slice2Bean(scanResults []interface{}, allColum *AllColumn, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { +func (session *Session) slice2Bean(scanResults []interface{}, columnsSchema *ColumnsSchema, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { defer func() { executeAfterSet(bean, fields, scanResults) }() @@ -711,7 +711,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, allColum *AllColum buildAfterProcessors(session, bean) var pk schemas.PK - for i, field := range allColum.Fields { + for i, field := range columnsSchema.Fields { col, fieldValue, err := getField(dataStruct, table, field) if _, ok := err.(ErrFieldIsNotExist); ok { continue diff --git a/session_find.go b/session_find.go index 0dcb2da7..a64dd90f 100644 --- a/session_find.go +++ b/session_find.go @@ -171,14 +171,14 @@ type QueryedField struct { ColumnSchema *schemas.Column } -type AllColumn struct { +type ColumnsSchema struct { Fields []*QueryedField FieldNames []string Types []*sql.ColumnType } -func (allColumn *AllColumn) ParseTableSchema(table *schemas.Table) error { - for _, field := range allColumn.Fields { +func (columnsSchema *ColumnsSchema) ParseTableSchema(table *schemas.Table) error { + for _, field := range columnsSchema.Fields { col := table.GetColumnIdx(field.FieldName, field.TempIndex) if col == nil { return ErrFieldIsNotExist{FieldName: field.FieldName, TableName: table.Name} @@ -190,8 +190,8 @@ func (allColumn *AllColumn) ParseTableSchema(table *schemas.Table) error { return nil } -func ParseQueryRows(fieldNames []string, types []*sql.ColumnType) *AllColumn { - var allColumn AllColumn +func ParseColumnsSchema(fieldNames []string, types []*sql.ColumnType, table *schemas.Table) (*ColumnsSchema, error) { + var columnsSchema ColumnsSchema fields := make([]*QueryedField, 0, len(fieldNames)) @@ -204,7 +204,7 @@ func ParseQueryRows(fieldNames []string, types []*sql.ColumnType) *AllColumn { fields = append(fields, field) } - allColumn.Fields = fields + columnsSchema.Fields = fields tempMap := make(map[string]int) for _, field := range fields { @@ -221,7 +221,12 @@ func ParseQueryRows(fieldNames []string, types []*sql.ColumnType) *AllColumn { field.TempIndex = idx } - return &allColumn + err := columnsSchema.ParseTableSchema(table) + if err != nil { + return nil, err + } + + return &columnsSchema, nil } func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { @@ -251,8 +256,6 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return err } - allColumn := ParseQueryRows(fields, types) - newElemFunc := func(fields []string) reflect.Value { return utils.New(elemType, len(fields), len(fields)) } @@ -304,12 +307,12 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return err } - parseTableSchemaError := allColumn.ParseTableSchema(tb) - if parseTableSchemaError != nil { - return parseTableSchemaError + columnsSchema, parseError := ParseColumnsSchema(fields, types, tb) + if parseError != nil { + return parseError } - err = session.rows2Beans(rows, allColumn, fields, types, tb, newElemFunc, containerValueSetFunc) + err = session.rows2Beans(rows, columnsSchema, fields, types, tb, newElemFunc, containerValueSetFunc) rows.Close() if err != nil { return err diff --git a/session_get.go b/session_get.go index cea6ce5b..74a14479 100644 --- a/session_get.go +++ b/session_get.go @@ -164,9 +164,12 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, return true, err } - allColumn := ParseQueryRows(fields, types) + columnsSchema, parseError := ParseColumnsSchema(fields, types, table) + if parseError != nil { + return true, parseError + } - if err := session.scan(rows, table, beanKind, beans, allColumn, types, fields); err != nil { + if err := session.scan(rows, table, beanKind, beans, columnsSchema, types, fields); err != nil { return true, err } rows.Close() @@ -174,7 +177,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, return true, session.executeProcessors() } -func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKind reflect.Kind, beans []interface{}, allColumn *AllColumn, types []*sql.ColumnType, fields []string) error { +func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKind reflect.Kind, beans []interface{}, columnsSchema *ColumnsSchema, types []*sql.ColumnType, fields []string) error { if len(beans) == 1 { bean := beans[0] switch firstBeanKind { @@ -182,13 +185,13 @@ func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKin if !isScannableStruct(bean, len(types)) { break } - scanResults, err := session.row2Slice(rows, allColumn, fields, types, bean) + scanResults, err := session.row2Slice(rows, fields, types, bean) if err != nil { return err } dataStruct := utils.ReflectValue(bean) - _, err = session.slice2Bean(scanResults, allColumn, fields, bean, &dataStruct, table) + _, err = session.slice2Bean(scanResults, columnsSchema, fields, bean, &dataStruct, table) return err case reflect.Slice: return session.getSlice(rows, types, fields, bean)