diff --git a/engine.go b/engine.go index 8b60ba9b..cfb92402 100644 --- a/engine.go +++ b/engine.go @@ -64,6 +64,11 @@ func (engine *Engine) Id(id int) *Engine { return engine } +func (engine *Engine) In(column string, args ...interface{}) *Engine { + engine.Statement.In(column, args...) + return engine +} + func (engine *Engine) Limit(limit int, start ...int) *Engine { engine.Statement.Limit(limit, start...) return engine diff --git a/session.go b/session.go index 2cc6b683..ab14afab 100644 --- a/session.go +++ b/session.go @@ -37,6 +37,11 @@ func (session *Session) Id(id int) *Session { return session } +func (session *Session) In(column string, args ...interface{}) *Session { + session.Statement.In(column, args...) + return session +} + func (session *Session) Limit(limit int, start ...int) *Session { session.Statement.Limit(limit, start...) return session diff --git a/statement.go b/statement.go index 444edf89..f27f3fda 100644 --- a/statement.go +++ b/statement.go @@ -2,6 +2,7 @@ package xorm import ( "fmt" + "strings" ) type Statement struct { @@ -19,6 +20,14 @@ type Statement struct { BeanArgs []interface{} } +func MakeArray(elem string, count int) []string { + res := make([]string, count) + for i := 0; i < count; i++ { + res[i] = elem + } + return res +} + func (statement *Statement) Init() { statement.Table = nil statement.Start = 0 @@ -48,6 +57,17 @@ func (statement *Statement) Id(id int) { } } +func (statement *Statement) In(column string, args ...interface{}) { + inStr := fmt.Sprintf("%v in (%v)", column, strings.Join(MakeArray("?", len(args)), ",")) + if statement.WhereStr == "" { + statement.WhereStr = inStr + statement.Params = args + } else { + statement.WhereStr = statement.WhereStr + " and " + inStr + statement.Params = append(statement.Params, args...) + } +} + func (statement *Statement) Limit(limit int, start ...int) { statement.LimitN = limit if len(start) > 0 { diff --git a/xorm_test.go b/xorm_test.go index 34af5d6e..a5a773b0 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -190,6 +190,23 @@ func where(t *testing.T) { fmt.Println(users) } +func in(t *testing.T) { + users := make([]Userinfo, 0) + err := engine.In("id", 1, 2, 3).Find(&users) + if err != nil { + t.Error(err) + return + } + fmt.Println(users) + + err = engine.Where("id > ?", 2).In("id", 1, 2, 3).Find(&users) + if err != nil { + t.Error(err) + return + } + fmt.Println(users) +} + func limit(t *testing.T) { users := make([]Userinfo, 0) err := engine.Limit(2, 1).Find(&users) @@ -341,6 +358,7 @@ func TestMysql(t *testing.T) { find(t) count(t) where(t) + in(t) limit(t) order(t) join(t) @@ -366,6 +384,7 @@ func TestSqlite(t *testing.T) { find(t) count(t) where(t) + in(t) limit(t) order(t) join(t)