diff --git a/session.go b/session.go index 67fedd70..a9b9f863 100644 --- a/session.go +++ b/session.go @@ -393,10 +393,10 @@ func (session *Session) doPrepareTx(sqlStr string) (stmt *core.Stmt, err error) return } -func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) { - col := table.GetColumnIdx(colName, idx) +func getField(dataStruct *reflect.Value, table *schemas.Table, field *QueryedField) (*schemas.Column, *reflect.Value, error) { + col := field.ColumnSchema if col == nil { - return nil, nil, ErrFieldIsNotExist{colName, table.Name} + return nil, nil, ErrFieldIsNotExist{field.FieldName, table.Name} } fieldValue, err := col.ValueOfV(dataStruct) @@ -404,10 +404,10 @@ func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, i return nil, nil, err } if fieldValue == nil { - return nil, nil, ErrFieldIsNotValid{colName, table.Name} + return nil, nil, ErrFieldIsNotValid{field.FieldName, table.Name} } if !fieldValue.IsValid() || !fieldValue.CanSet() { - return nil, nil, ErrFieldIsNotValid{colName, table.Name} + return nil, nil, ErrFieldIsNotValid{field.FieldName, table.Name} } return col, fieldValue, nil @@ -712,7 +712,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, allColum *AllColum var pk schemas.PK for i, field := range allColum.Fields { - col, fieldValue, err := getField(dataStruct, table, field.FieldName, field.TempIndex) + col, fieldValue, err := getField(dataStruct, table, field) if _, ok := err.(ErrFieldIsNotExist); ok { continue } else if err != nil { diff --git a/session_find.go b/session_find.go index e0b8d295..0dcb2da7 100644 --- a/session_find.go +++ b/session_find.go @@ -168,6 +168,7 @@ type QueryedField struct { LowerFieldName string ColumnType *sql.ColumnType TempIndex int + ColumnSchema *schemas.Column } type AllColumn struct { @@ -176,6 +177,19 @@ type AllColumn struct { Types []*sql.ColumnType } +func (allColumn *AllColumn) ParseTableSchema(table *schemas.Table) error { + for _, field := range allColumn.Fields { + col := table.GetColumnIdx(field.FieldName, field.TempIndex) + if col == nil { + return ErrFieldIsNotExist{FieldName: field.FieldName, TableName: table.Name} + } + + field.ColumnSchema = col + } + + return nil +} + func ParseQueryRows(fieldNames []string, types []*sql.ColumnType) *AllColumn { var allColumn AllColumn @@ -289,6 +303,12 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect if err != nil { return err } + + parseTableSchemaError := allColumn.ParseTableSchema(tb) + if parseTableSchemaError != nil { + return parseTableSchemaError + } + err = session.rows2Beans(rows, allColumn, fields, types, tb, newElemFunc, containerValueSetFunc) rows.Close() if err != nil {