multiple In() supports
This commit is contained in:
parent
82bdc0ec5a
commit
bb6a9c24fa
57
base_test.go
57
base_test.go
|
@ -520,6 +520,63 @@ func in(engine *Engine, t *testing.T) {
|
|||
panic(err)
|
||||
}
|
||||
fmt.Println(users)
|
||||
|
||||
err = engine.In("(id)", 1).In("(id)", 2).In("departname", "dev").Find(&users)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println(users)
|
||||
|
||||
cnt, err := engine.In("(id)", 4).Update(&Userinfo{Departname: "dev-"})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
if cnt != 1 {
|
||||
err = errors.New("update records not 1")
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
user := new(Userinfo)
|
||||
has, err := engine.Id(4).Get(user)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
if !has {
|
||||
err = errors.New("get record not 1")
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
if user.Departname != "dev-" {
|
||||
err = errors.New("update not success")
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
cnt, err = engine.In("(id)", 4).Update(&Userinfo{Departname: "dev"})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
if cnt != 1 {
|
||||
err = errors.New("update records not 1")
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
cnt, err = engine.In("(id)", 5).Delete(&Userinfo{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
if cnt != 1 {
|
||||
err = errors.New("deleted records not 1")
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func limit(engine *Engine, t *testing.T) {
|
||||
|
|
40
session.go
40
session.go
|
@ -1015,6 +1015,9 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
|
|||
if columnStr == "" {
|
||||
columnStr = session.Statement.genColumnStr()
|
||||
}
|
||||
|
||||
session.Statement.attachInSql()
|
||||
|
||||
sql = session.Statement.genSelectSql(columnStr)
|
||||
args = append(session.Statement.Params, session.Statement.BeanArgs...)
|
||||
} else {
|
||||
|
@ -2508,7 +2511,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
}
|
||||
}
|
||||
|
||||
var sql string
|
||||
var sql, inSql string
|
||||
var inArgs []interface{}
|
||||
if table.Version != "" && session.Statement.checkVersion {
|
||||
if condition != "" {
|
||||
condition = fmt.Sprintf("WHERE (%v) AND %v = ?", condition,
|
||||
|
@ -2516,6 +2520,15 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
} else {
|
||||
condition = fmt.Sprintf("WHERE %v = ?", session.Engine.Quote(table.Version))
|
||||
}
|
||||
inSql, inArgs = session.Statement.genInSql()
|
||||
if len(inSql) > 0 {
|
||||
if condition != "" {
|
||||
condition += " AND " + inSql
|
||||
} else {
|
||||
condition = "WHERE " + inSql
|
||||
}
|
||||
}
|
||||
|
||||
sql = fmt.Sprintf("UPDATE %v SET %v, %v %v",
|
||||
session.Engine.Quote(session.Statement.TableName()),
|
||||
strings.Join(colNames, ", "),
|
||||
|
@ -2527,13 +2540,24 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
if condition != "" {
|
||||
condition = "WHERE " + condition
|
||||
}
|
||||
inSql, inArgs = session.Statement.genInSql()
|
||||
if len(inSql) > 0 {
|
||||
if condition != "" {
|
||||
condition += " AND " + inSql
|
||||
} else {
|
||||
condition = "WHERE " + inSql
|
||||
}
|
||||
}
|
||||
|
||||
sql = fmt.Sprintf("UPDATE %v SET %v %v",
|
||||
session.Engine.Quote(session.Statement.TableName()),
|
||||
strings.Join(colNames, ", "),
|
||||
condition)
|
||||
}
|
||||
|
||||
args = append(append(args, st.Params...), condiArgs...)
|
||||
args = append(args, st.Params...)
|
||||
args = append(args, inArgs...)
|
||||
args = append(args, condiArgs...)
|
||||
|
||||
res, err := session.exec(sql, args...)
|
||||
if err != nil {
|
||||
|
@ -2660,10 +2684,18 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
|
|||
if session.Statement.WhereStr != "" {
|
||||
condition = session.Statement.WhereStr
|
||||
if len(colNames) > 0 {
|
||||
condition += " and " + strings.Join(colNames, " and ")
|
||||
condition += " AND " + strings.Join(colNames, " AND ")
|
||||
}
|
||||
} else {
|
||||
condition = strings.Join(colNames, " and ")
|
||||
condition = strings.Join(colNames, " AND ")
|
||||
}
|
||||
inSql, inArgs := session.Statement.genInSql()
|
||||
if len(inSql) > 0 {
|
||||
if len(condition) > 0 {
|
||||
condition += " AND "
|
||||
}
|
||||
condition += inSql
|
||||
args = append(args, inArgs...)
|
||||
}
|
||||
if len(condition) == 0 {
|
||||
return 0, ErrNeedDeletedCond
|
||||
|
|
51
statement.go
51
statement.go
|
@ -39,6 +39,7 @@ type Statement struct {
|
|||
allUseBool bool
|
||||
checkVersion bool
|
||||
boolColumnMap map[string]bool
|
||||
inColumns map[string][]interface{}
|
||||
}
|
||||
|
||||
// init
|
||||
|
@ -67,6 +68,7 @@ func (statement *Statement) Init() {
|
|||
statement.allUseBool = false
|
||||
statement.boolColumnMap = make(map[string]bool)
|
||||
statement.checkVersion = true
|
||||
statement.inColumns = make(map[string][]interface{})
|
||||
}
|
||||
|
||||
// add the raw sql statement
|
||||
|
@ -406,17 +408,45 @@ func (statement *Statement) Id(id int64) *Statement {
|
|||
|
||||
// Generate "Where column IN (?) " statment
|
||||
func (statement *Statement) In(column string, args ...interface{}) *Statement {
|
||||
inStr := fmt.Sprintf("%v IN (%v)", column, strings.Join(makeArray("?", len(args)), ","))
|
||||
if statement.WhereStr == "" {
|
||||
statement.WhereStr = inStr
|
||||
statement.Params = args
|
||||
k := strings.ToLower(column)
|
||||
if params, ok := statement.inColumns[k]; ok {
|
||||
statement.inColumns[k] = append(params, args...)
|
||||
} else {
|
||||
statement.WhereStr = statement.WhereStr + " AND " + inStr
|
||||
statement.Params = append(statement.Params, args...)
|
||||
statement.inColumns[k] = args
|
||||
}
|
||||
return statement
|
||||
}
|
||||
|
||||
func (statement *Statement) genInSql() (string, []interface{}) {
|
||||
if len(statement.inColumns) == 0 {
|
||||
return "", []interface{}{}
|
||||
}
|
||||
|
||||
inStrs := make([]string, 0, len(statement.inColumns))
|
||||
args := make([]interface{}, 0)
|
||||
for column, params := range statement.inColumns {
|
||||
inStrs = append(inStrs, fmt.Sprintf("(%v IN (%v))", statement.Engine.Quote(column),
|
||||
strings.Join(makeArray("?", len(params)), ",")))
|
||||
args = append(args, params...)
|
||||
}
|
||||
|
||||
if len(statement.inColumns) == 1 {
|
||||
return inStrs[0], args
|
||||
}
|
||||
return fmt.Sprintf("(%v)", strings.Join(inStrs, " AND ")), args
|
||||
}
|
||||
|
||||
func (statement *Statement) attachInSql() {
|
||||
inSql, inArgs := statement.genInSql()
|
||||
if len(inSql) > 0 {
|
||||
if statement.ConditionStr != "" {
|
||||
statement.ConditionStr += " AND "
|
||||
}
|
||||
statement.ConditionStr += inSql
|
||||
statement.Params = append(statement.Params, inArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
func col2NewCols(columns ...string) []string {
|
||||
newColumns := make([]string, 0)
|
||||
for _, col := range columns {
|
||||
|
@ -608,7 +638,7 @@ func (s *Statement) genDropSQL() string {
|
|||
}
|
||||
|
||||
// !nashtsai! REVIEW, Statement is a huge struct why is this method not passing *Statement?
|
||||
func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) {
|
||||
func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) {
|
||||
table := statement.Engine.autoMap(bean)
|
||||
statement.RefTable = table
|
||||
|
||||
|
@ -647,7 +677,7 @@ func (s *Statement) genAddUniqueStr(uqeName string, cols []string) (string, []in
|
|||
return sql, []interface{}{}
|
||||
}
|
||||
|
||||
func (statement Statement) genCountSql(bean interface{}) (string, []interface{}) {
|
||||
func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) {
|
||||
table := statement.Engine.autoMap(bean)
|
||||
statement.RefTable = table
|
||||
|
||||
|
@ -663,7 +693,7 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{})
|
|||
return statement.genSelectSql(fmt.Sprintf("COUNT(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...)
|
||||
}
|
||||
|
||||
func (statement Statement) genSelectSql(columnStr string) (a string) {
|
||||
func (statement *Statement) genSelectSql(columnStr string) (a string) {
|
||||
if statement.GroupByStr != "" {
|
||||
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
|
||||
statement.GroupByStr = columnStr
|
||||
|
@ -680,11 +710,12 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
|
|||
if statement.WhereStr != "" {
|
||||
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
|
||||
if statement.ConditionStr != "" {
|
||||
a = fmt.Sprintf("%v and %v", a, statement.ConditionStr)
|
||||
a = fmt.Sprintf("%v AND %v", a, statement.ConditionStr)
|
||||
}
|
||||
} else if statement.ConditionStr != "" {
|
||||
a = fmt.Sprintf("%v WHERE %v", a, statement.ConditionStr)
|
||||
}
|
||||
|
||||
if statement.GroupByStr != "" {
|
||||
a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue