From 07b2d15a2676a32b4368bc8069a7f65fe7ef08b0 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Wed, 8 Jan 2014 18:37:22 +0800 Subject: [PATCH] code tidy up, minor performance improvement --- engine.go | 2 +- session.go | 102 ++++++++++++++++++++++++++++++++++------------------- table.go | 1 + 3 files changed, 68 insertions(+), 37 deletions(-) diff --git a/engine.go b/engine.go index b5fd317f..1e303696 100644 --- a/engine.go +++ b/engine.go @@ -590,7 +590,7 @@ func (engine *Engine) mapType(t reflect.Type) *Table { sqlType := Type2SQLType(fieldType) col = &Column{engine.columnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true, "", make(map[string]bool), false, false, - TWOSIDES, false, false, false, false} + TWOSIDES, false, false, false, false, nil} } if col.IsAutoIncrement { col.Nullable = false diff --git a/session.go b/session.go index 7623009b..7cb7fc05 100644 --- a/session.go +++ b/session.go @@ -1092,25 +1092,53 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) fieldsCount := len(fields) - for rawRows.Next() { - var newValue reflect.Value + var newElemFunc func() reflect.Value + if sliceElementType.Kind() == reflect.Ptr { + newElemFunc = func() reflect.Value { + return reflect.New(sliceElementType.Elem()) + } + } else { + newElemFunc = func() reflect.Value { + return reflect.New(sliceElementType) + } + } + + var sliceValueSetFunc func(*reflect.Value) + + if sliceValue.Kind() == reflect.Slice { if sliceElementType.Kind() == reflect.Ptr { - newValue = reflect.New(sliceElementType.Elem()) - } else { - newValue = reflect.New(sliceElementType) - } - err := session.row2Bean(rawRows, fields, fieldsCount, newValue.Interface()) - if err != nil { - return err - } - if sliceValue.Kind() == reflect.Slice { - if sliceElementType.Kind() == reflect.Ptr { + sliceValueSetFunc = func(newValue *reflect.Value) { sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) - } else { + } + } else { + sliceValueSetFunc = func(newValue *reflect.Value) { sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) } } } + + for rawRows.Next() { + var newValue reflect.Value = newElemFunc() + // if sliceElementType.Kind() == reflect.Ptr { + // newValue = reflect.New(sliceElementType.Elem()) + // } else { + // newValue = reflect.New(sliceElementType) + // } + if sliceValueSetFunc != nil { + err := session.row2Bean(rawRows, fields, fieldsCount, newValue.Interface()) + if err != nil { + return err + } + sliceValueSetFunc(&newValue) + } + // // if sliceValue.Kind() == reflect.Slice { + // // if sliceElementType.Kind() == reflect.Ptr { + // // sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) + // // } else { + // // sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + // // } + // // } + } } else { resultsSlice, err := session.query(sqlStr, args...) if err != nil { @@ -1347,32 +1375,35 @@ func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err func (session *Session) getField(dataStruct *reflect.Value, key string, table *Table) *reflect.Value { - key = strings.ToLower(key) - if _, ok := table.Columns[key]; !ok { + //key = strings.ToLower(key) + if col, ok := table.Columns[key]; !ok { session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.ColumnsSeq)) return nil - } - col := table.Columns[key] - fieldName := col.FieldName - fieldPath := strings.Split(fieldName, ".") - var fieldValue reflect.Value - if len(fieldPath) > 2 { - session.Engine.LogError("Unsupported mutliderive", fieldName) - return nil - } else if len(fieldPath) == 2 { - parentField := dataStruct.FieldByName(fieldPath[0]) - if parentField.IsValid() { - fieldValue = parentField.FieldByName(fieldPath[1]) - } } else { - fieldValue = dataStruct.FieldByName(fieldName) + fieldName := col.FieldName + if col.fieldPath == nil { + col.fieldPath = strings.Split(fieldName, ".") + } + var fieldValue reflect.Value + fieldPathLen := len(col.fieldPath) + if fieldPathLen > 2 { + session.Engine.LogError("Unsupported mutliderive", fieldName) + return nil + } else if fieldPathLen == 2 { + parentField := dataStruct.FieldByName(col.fieldPath[0]) + if parentField.IsValid() { + fieldValue = parentField.FieldByName(col.fieldPath[1]) + } + } else { + fieldValue = dataStruct.FieldByName(fieldName) + } + if !fieldValue.IsValid() || !fieldValue.CanSet() { + session.Engine.LogWarn("table %v's column %v is not valid or cannot set", + table.Name, key) + return nil + } + return &fieldValue } - if !fieldValue.IsValid() || !fieldValue.CanSet() { - session.Engine.LogWarn("table %v's column %v is not valid or cannot set", - table.Name, key) - return nil - } - return &fieldValue } func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount int, bean interface{}) error { @@ -1395,7 +1426,6 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in for ii, key := range fields { if fieldValue := session.getField(&dataStruct, key, table); fieldValue != nil { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) //if row is null then ignore diff --git a/table.go b/table.go index a1ae7cdc..0667a12e 100644 --- a/table.go +++ b/table.go @@ -275,6 +275,7 @@ type Column struct { IsUpdated bool IsCascade bool IsVersion bool + fieldPath []string } // generate column description string according dialect