Join add TableName interface support (#874)

* Join add TableName interface support

* add some tests

* Join add struct support

* more tests
This commit is contained in:
Lunny Xiao 2018-04-08 21:54:24 +08:00 committed by GitHub
parent 149e6abf07
commit 9242b921d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 97 additions and 19 deletions

View File

@ -120,10 +120,8 @@ func TestFind2(t *testing.T) {
assertSync(t, new(Userinfo)) assertSync(t, new(Userinfo))
err := testEngine.Find(&users) err := testEngine.Find(&users)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
for _, user := range users { for _, user := range users {
fmt.Println(user) fmt.Println(user)
} }
@ -139,13 +137,15 @@ type TeamUser struct {
TeamId int64 TeamId int64
} }
func (TeamUser) TableName() string {
return "team_user"
}
func TestFind3(t *testing.T) { func TestFind3(t *testing.T) {
var teamUser = new(TeamUser)
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
err := testEngine.Sync2(new(Team), new(TeamUser)) err := testEngine.Sync2(new(Team), teamUser)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err.Error())
}
var teams []Team var teams []Team
err = testEngine.Cols("`team`.id"). err = testEngine.Cols("`team`.id").
@ -153,10 +153,47 @@ func TestFind3(t *testing.T) {
And("`team_user`.uid=?", 2). And("`team_user`.uid=?", 2).
Join("INNER", "`team_user`", "`team_user`.team_id=`team`.id"). Join("INNER", "`team_user`", "`team_user`.team_id=`team`.id").
Find(&teams) Find(&teams)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err.Error()) teams = make([]Team, 0)
} err = testEngine.Cols("`team`.id").
Where("`team_user`.org_id=?", 1).
And("`team_user`.uid=?", 2).
Join("INNER", teamUser, "`team_user`.team_id=`team`.id").
Find(&teams)
assert.NoError(t, err)
teams = make([]Team, 0)
err = testEngine.Cols("`team`.id").
Where("`team_user`.org_id=?", 1).
And("`team_user`.uid=?", 2).
Join("INNER", []interface{}{teamUser}, "`team_user`.team_id=`team`.id").
Find(&teams)
assert.NoError(t, err)
teams = make([]Team, 0)
err = testEngine.Cols("`team`.id").
Where("`tu`.org_id=?", 1).
And("`tu`.uid=?", 2).
Join("INNER", []string{"team_user", "tu"}, "`tu`.team_id=`team`.id").
Find(&teams)
assert.NoError(t, err)
teams = make([]Team, 0)
err = testEngine.Cols("`team`.id").
Where("`tu`.org_id=?", 1).
And("`tu`.uid=?", 2).
Join("INNER", []interface{}{"team_user", "tu"}, "`tu`.team_id=`team`.id").
Find(&teams)
assert.NoError(t, err)
teams = make([]Team, 0)
err = testEngine.Cols("`team`.id").
Where("`tu`.org_id=?", 1).
And("`tu`.uid=?", 2).
Join("INNER", []interface{}{teamUser, "tu"}, "`tu`.team_id=`team`.id").
Find(&teams)
assert.NoError(t, err)
} }
func TestFindMap(t *testing.T) { func TestFindMap(t *testing.T) {

View File

@ -255,3 +255,27 @@ func TestJSONString(t *testing.T) {
assert.EqualValues(t, 1, len(jss)) assert.EqualValues(t, 1, len(jss))
assert.EqualValues(t, `["1","2"]`, jss[0].Content) assert.EqualValues(t, `["1","2"]`, jss[0].Content)
} }
func TestGetActionMapping(t *testing.T) {
assert.NoError(t, prepareEngine())
type ActionMapping struct {
ActionId string `xorm:"pk"`
ActionName string `xorm:"index"`
ScriptId string `xorm:"unique"`
RollbackId string `xorm:"unique"`
Env string
Tags string
Description string
UpdateTime time.Time `xorm:"updated"`
DeleteTime time.Time `xorm:"deleted"`
}
assertSync(t, new(ActionMapping))
var valuesSlice = make([]string, 2)
_, err := testEngine.Table(new(ActionMapping)).
Cols("script_id", "rollback_id").
ID(1).Get(&valuesSlice)
assert.NoError(t, err)
}

View File

@ -766,12 +766,19 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
var table string var table string
if l > 0 { if l > 0 {
f := t[0] f := t[0]
v := rValue(f) switch f.(type) {
t := v.Type() case string:
if t.Kind() == reflect.String {
table = f.(string) table = f.(string)
} else if t.Kind() == reflect.Struct { case TableName:
table = statement.Engine.tbName(v) table = f.(TableName).TableName()
default:
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(&buf, statement.Engine.tbName(v))
} else {
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", f)))
}
} }
} }
if l > 1 { if l > 1 {
@ -780,8 +787,18 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
} else if l == 1 { } else if l == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(table)) fmt.Fprintf(&buf, statement.Engine.Quote(table))
} }
case TableName:
fmt.Fprintf(&buf, tablename.(TableName).TableName())
case string:
fmt.Fprintf(&buf, tablename.(string))
default: default:
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(&buf, statement.Engine.tbName(v))
} else {
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
}
} }
fmt.Fprintf(&buf, " ON %v", condition) fmt.Fprintf(&buf, " ON %v", condition)