diff --git a/session.go b/session.go index b93ebfc8..918ab9ab 100644 --- a/session.go +++ b/session.go @@ -910,32 +910,27 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) } } else if sliceValue.Kind() == reflect.Map { - var key core.PK - if table.PrimaryKeys[0] != "" { - key = ids[j] - } - + var key core.PK = ids[j] + keyType := sliceValue.Type().Key() + var ikey interface{} if len(key) == 1 { - ikey, err := strconv.ParseInt(fmt.Sprintf("%v", key[0]), 10, 64) + ikey, err = Atot(fmt.Sprintf("%v", key[0]), keyType) if err != nil { return err } - if t.Kind() == reflect.Ptr { - sliceValue.SetMapIndex(reflect.ValueOf(ikey), reflect.ValueOf(bean)) - } else { - sliceValue.SetMapIndex(reflect.ValueOf(ikey), reflect.Indirect(reflect.ValueOf(bean))) - } } else { - return errors.New("table have multiple primary keys") + if keyType.Kind() != reflect.Slice { + return errors.New("table have multiple primary keys, key is not core.PK or slice") + } + ikey = key + } + + if t.Kind() == reflect.Ptr { + sliceValue.SetMapIndex(reflect.ValueOf(ikey), reflect.ValueOf(bean)) + } else { + sliceValue.SetMapIndex(reflect.ValueOf(ikey), reflect.Indirect(reflect.ValueOf(bean))) } } - /*} else { - session.Engine.LogDebug("[xorm:cacheFind] cache delete:", tableName, ides[j]) - cacher.DelBean(tableName, ids[j]) - - session.Engine.LogDebug("[xorm:cacheFind] cache clear:", tableName) - cacher.ClearIds(tableName) - }*/ } return nil @@ -1096,6 +1091,75 @@ func (session *Session) Count(bean interface{}) (int64, error) { return int64(total), err } +func Atot(s string, tp reflect.Type) (interface{}, error) { + var err error + var result interface{} + switch tp.Kind() { + case reflect.Int: + result, err = strconv.Atoi(s) + if err != nil { + return nil, errors.New("convert " + s + " as int: " + err.Error()) + } + case reflect.Int8: + x, err := strconv.Atoi(s) + if err != nil { + return nil, errors.New("convert " + s + " as int16: " + err.Error()) + } + result = int8(x) + case reflect.Int16: + x, err := strconv.Atoi(s) + if err != nil { + return nil, errors.New("convert " + s + " as int16: " + err.Error()) + } + result = int16(x) + case reflect.Int32: + x, err := strconv.Atoi(s) + if err != nil { + return nil, errors.New("convert " + s + " as int32: " + err.Error()) + } + result = int32(x) + case reflect.Int64: + result, err = strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, errors.New("convert " + s + " as int64: " + err.Error()) + } + case reflect.Uint: + x, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return nil, errors.New("convert " + s + " as uint: " + err.Error()) + } + result = uint(x) + case reflect.Uint8: + x, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return nil, errors.New("convert " + s + " as uint8: " + err.Error()) + } + result = uint8(x) + case reflect.Uint16: + x, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return nil, errors.New("convert " + s + " as uint16: " + err.Error()) + } + result = uint16(x) + case reflect.Uint32: + x, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return nil, errors.New("convert " + s + " as uint32: " + err.Error()) + } + result = uint32(x) + case reflect.Uint64: + result, err = strconv.ParseUint(s, 10, 64) + if err != nil { + return nil, errors.New("convert " + s + " as uint64: " + err.Error()) + } + case reflect.String: + result = s + default: + panic("unsupported convert type") + } + return result, nil +} + // Find retrieve records from table, condiBeans's non-empty fields // are conditions. beans could be []Struct, []*Struct, map[int64]Struct // map[int64]*Struct @@ -1259,7 +1323,9 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return err } - for i, results := range resultsSlice { + keyType := sliceValue.Type().Key() + + for _, results := range resultsSlice { var newValue reflect.Value if sliceElementType.Kind() == reflect.Ptr { newValue = reflect.New(sliceElementType.Elem()) @@ -1270,18 +1336,29 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if err != nil { return err } - var key int64 + var key interface{} // if there is only one pk, we can put the id as map key. - // TODO: should know if the column is ints if len(table.PrimaryKeys) == 1 { - x, err := strconv.ParseInt(string(results[table.PrimaryKeys[0]]), 10, 64) + key, err = Atot(string(results[table.PrimaryKeys[0]]), keyType) if err != nil { - return errors.New("pk " + table.PrimaryKeys[0] + " as int64: " + err.Error()) + return err } - key = x } else { - key = int64(i) + if keyType.Kind() != reflect.Slice { + panic("don't support multiple primary key's map has non-slice key type") + } else { + keys := core.PK{} + for _, pk := range table.PrimaryKeys { + skey, err := Atot(string(results[pk]), keyType) + if err != nil { + return err + } + keys = append(keys, skey) + } + key = keys + } } + if sliceElementType.Kind() == reflect.Ptr { sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValue.Interface())) } else { diff --git a/xorm.go b/xorm.go index b19662c7..4673d783 100644 --- a/xorm.go +++ b/xorm.go @@ -93,5 +93,5 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { // clone an engine func (engine *Engine) Clone() (*Engine, error) { - return NewEngine(engine.dialect.DriverName(), engine.dialect.DataSourceName()) + return NewEngine(engine.DriverName(), engine.DataSourceName()) }