fix tablename bug (#887)

* fix tablename bug

* fix test
This commit is contained in:
Lunny Xiao 2018-04-11 18:09:16 +08:00 committed by GitHub
parent 5c2af83817
commit bfdf773629
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 88 additions and 11 deletions

View File

@ -45,16 +45,17 @@ func (session *Session) tbNameNoSchema(table *core.Table) string {
} }
func (engine *Engine) tbNameForMap(v reflect.Value) string { func (engine *Engine) tbNameForMap(v reflect.Value) string {
t := v.Type() if v.Type().Implements(tpTableName) {
if tb, ok := v.Interface().(TableName); ok { return v.Interface().(TableName).TableName()
return tb.TableName()
} }
if v.CanAddr() { if v.Kind() == reflect.Ptr {
if tb, ok := v.Addr().Interface().(TableName); ok { v = v.Elem()
return tb.TableName() if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName()
} }
} }
return engine.TableMapper.Obj2Table(t.Name())
return engine.TableMapper.Obj2Table(v.Type().Name())
} }
func (engine *Engine) tbNameNoSchema(tablename interface{}) string { func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
@ -97,6 +98,9 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
return tablename.(TableName).TableName() return tablename.(TableName).TableName()
case string: case string:
return tablename.(string) return tablename.(string)
case reflect.Value:
v := tablename.(reflect.Value)
return engine.tbNameForMap(v)
default: default:
v := rValue(tablename) v := rValue(tablename)
t := v.Type() t := v.Type()

View File

@ -75,7 +75,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct { if sliceElementType.Elem().Kind() == reflect.Struct {
pv := reflect.New(sliceElementType.Elem()) pv := reflect.New(sliceElementType.Elem())
if err := session.statement.setRefValue(pv.Elem()); err != nil { if err := session.statement.setRefValue(pv); err != nil {
return err return err
} }
} else { } else {
@ -83,7 +83,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
} else if sliceElementType.Kind() == reflect.Struct { } else if sliceElementType.Kind() == reflect.Struct {
pv := reflect.New(sliceElementType) pv := reflect.New(sliceElementType)
if err := session.statement.setRefValue(pv.Elem()); err != nil { if err := session.statement.setRefValue(pv); err != nil {
return err return err
} }
} else { } else {

View File

@ -584,3 +584,76 @@ func TestFindAndCountOneFunc(t *testing.T) {
assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
} }
type FindMapDevice struct {
Deviceid string `xorm:"pk"`
Status int
}
func (device *FindMapDevice) TableName() string {
return "devices"
}
func TestFindMapStringId(t *testing.T) {
assert.NoError(t, prepareEngine())
assertSync(t, new(FindMapDevice))
cnt, err := testEngine.Insert(&FindMapDevice{
Deviceid: "1",
Status: 1,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
deviceIDs := []string{"1"}
deviceMaps := make(map[string]*FindMapDevice, len(deviceIDs))
err = testEngine.
Where("status = ?", 1).
In("deviceid", deviceIDs).
Find(&deviceMaps)
assert.NoError(t, err)
deviceMaps2 := make(map[string]FindMapDevice, len(deviceIDs))
err = testEngine.
Where("status = ?", 1).
In("deviceid", deviceIDs).
Find(&deviceMaps2)
assert.NoError(t, err)
devices := make([]*FindMapDevice, 0, len(deviceIDs))
err = testEngine.Find(&devices)
assert.NoError(t, err)
devices2 := make([]FindMapDevice, 0, len(deviceIDs))
err = testEngine.Find(&devices2)
assert.NoError(t, err)
var device FindMapDevice
has, err := testEngine.Get(&device)
assert.NoError(t, err)
assert.True(t, has)
has, err = testEngine.Exist(&FindMapDevice{})
assert.NoError(t, err)
assert.True(t, has)
cnt, err = testEngine.Count(new(FindMapDevice))
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
cnt, err = testEngine.ID("1").Update(&FindMapDevice{
Status: 2,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
sum, err := testEngine.SumInt(new(FindMapDevice), "status")
assert.NoError(t, err)
assert.EqualValues(t, 2, sum)
cnt, err = testEngine.ID("1").Delete(new(FindMapDevice))
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
}

View File

@ -66,7 +66,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, errors.New("could not insert a empty slice") return 0, errors.New("could not insert a empty slice")
} }
if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil { if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil {
return 0, err return 0, err
} }

View File

@ -208,7 +208,7 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
if err != nil { if err != nil {
return err return err
} }
statement.tableName = statement.Engine.TableName(v.Interface(), true) statement.tableName = statement.Engine.TableName(v, true)
return nil return nil
} }