From 9242b921d8ff5e08b569c846c4c1446f3b357bed Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 8 Apr 2018 21:54:24 +0800 Subject: [PATCH] Join add TableName interface support (#874) * Join add TableName interface support * add some tests * Join add struct support * more tests --- session_find_test.go | 63 +++++++++++++++++++++++++++++++++++--------- session_get_test.go | 24 +++++++++++++++++ statement.go | 29 +++++++++++++++----- 3 files changed, 97 insertions(+), 19 deletions(-) diff --git a/session_find_test.go b/session_find_test.go index fe81597e..04fdb030 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -120,10 +120,8 @@ func TestFind2(t *testing.T) { assertSync(t, new(Userinfo)) err := testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) + for _, user := range users { fmt.Println(user) } @@ -139,13 +137,15 @@ type TeamUser struct { TeamId int64 } +func (TeamUser) TableName() string { + return "team_user" +} + func TestFind3(t *testing.T) { + var teamUser = new(TeamUser) assert.NoError(t, prepareEngine()) - err := testEngine.Sync2(new(Team), new(TeamUser)) - if err != nil { - t.Error(err) - panic(err.Error()) - } + err := testEngine.Sync2(new(Team), teamUser) + assert.NoError(t, err) var teams []Team err = testEngine.Cols("`team`.id"). @@ -153,10 +153,47 @@ func TestFind3(t *testing.T) { And("`team_user`.uid=?", 2). Join("INNER", "`team_user`", "`team_user`.team_id=`team`.id"). Find(&teams) - if err != nil { - t.Error(err) - panic(err.Error()) - } + assert.NoError(t, err) + + teams = make([]Team, 0) + err = testEngine.Cols("`team`.id"). + Where("`team_user`.org_id=?", 1). + And("`team_user`.uid=?", 2). + Join("INNER", teamUser, "`team_user`.team_id=`team`.id"). + Find(&teams) + assert.NoError(t, err) + + teams = make([]Team, 0) + err = testEngine.Cols("`team`.id"). + Where("`team_user`.org_id=?", 1). + And("`team_user`.uid=?", 2). + Join("INNER", []interface{}{teamUser}, "`team_user`.team_id=`team`.id"). + Find(&teams) + assert.NoError(t, err) + + teams = make([]Team, 0) + err = testEngine.Cols("`team`.id"). + Where("`tu`.org_id=?", 1). + And("`tu`.uid=?", 2). + Join("INNER", []string{"team_user", "tu"}, "`tu`.team_id=`team`.id"). + Find(&teams) + assert.NoError(t, err) + + teams = make([]Team, 0) + err = testEngine.Cols("`team`.id"). + Where("`tu`.org_id=?", 1). + And("`tu`.uid=?", 2). + Join("INNER", []interface{}{"team_user", "tu"}, "`tu`.team_id=`team`.id"). + Find(&teams) + assert.NoError(t, err) + + teams = make([]Team, 0) + err = testEngine.Cols("`team`.id"). + Where("`tu`.org_id=?", 1). + And("`tu`.uid=?", 2). + Join("INNER", []interface{}{teamUser, "tu"}, "`tu`.team_id=`team`.id"). + Find(&teams) + assert.NoError(t, err) } func TestFindMap(t *testing.T) { diff --git a/session_get_test.go b/session_get_test.go index 61398d1f..e27e6de9 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -255,3 +255,27 @@ func TestJSONString(t *testing.T) { assert.EqualValues(t, 1, len(jss)) assert.EqualValues(t, `["1","2"]`, jss[0].Content) } + +func TestGetActionMapping(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type ActionMapping struct { + ActionId string `xorm:"pk"` + ActionName string `xorm:"index"` + ScriptId string `xorm:"unique"` + RollbackId string `xorm:"unique"` + Env string + Tags string + Description string + UpdateTime time.Time `xorm:"updated"` + DeleteTime time.Time `xorm:"deleted"` + } + + assertSync(t, new(ActionMapping)) + + var valuesSlice = make([]string, 2) + _, err := testEngine.Table(new(ActionMapping)). + Cols("script_id", "rollback_id"). + ID(1).Get(&valuesSlice) + assert.NoError(t, err) +} diff --git a/statement.go b/statement.go index 35c4a472..02d73559 100644 --- a/statement.go +++ b/statement.go @@ -766,12 +766,19 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition var table string if l > 0 { f := t[0] - v := rValue(f) - t := v.Type() - if t.Kind() == reflect.String { + switch f.(type) { + case string: table = f.(string) - } else if t.Kind() == reflect.Struct { - table = statement.Engine.tbName(v) + case TableName: + table = f.(TableName).TableName() + default: + v := rValue(f) + t := v.Type() + if t.Kind() == reflect.Struct { + fmt.Fprintf(&buf, statement.Engine.tbName(v)) + } else { + fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", f))) + } } } if l > 1 { @@ -780,8 +787,18 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition } else if l == 1 { fmt.Fprintf(&buf, statement.Engine.Quote(table)) } + case TableName: + fmt.Fprintf(&buf, tablename.(TableName).TableName()) + case string: + fmt.Fprintf(&buf, tablename.(string)) default: - fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) + v := rValue(tablename) + t := v.Type() + if t.Kind() == reflect.Struct { + fmt.Fprintf(&buf, statement.Engine.tbName(v)) + } else { + fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) + } } fmt.Fprintf(&buf, " ON %v", condition)