multiple In() supports

This commit is contained in:
Lunny Xiao 2013-12-11 16:27:33 +08:00
parent 82bdc0ec5a
commit bb6a9c24fa
3 changed files with 3311 additions and 3191 deletions

View File

@ -520,6 +520,63 @@ func in(engine *Engine, t *testing.T) {
panic(err)
}
fmt.Println(users)
err = engine.In("(id)", 1).In("(id)", 2).In("departname", "dev").Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(users)
cnt, err := engine.In("(id)", 4).Update(&Userinfo{Departname: "dev-"})
if err != nil {
t.Error(err)
panic(err)
}
if cnt != 1 {
err = errors.New("update records not 1")
t.Error(err)
panic(err)
}
user := new(Userinfo)
has, err := engine.Id(4).Get(user)
if err != nil {
t.Error(err)
panic(err)
}
if !has {
err = errors.New("get record not 1")
t.Error(err)
panic(err)
}
if user.Departname != "dev-" {
err = errors.New("update not success")
t.Error(err)
panic(err)
}
cnt, err = engine.In("(id)", 4).Update(&Userinfo{Departname: "dev"})
if err != nil {
t.Error(err)
panic(err)
}
if cnt != 1 {
err = errors.New("update records not 1")
t.Error(err)
panic(err)
}
cnt, err = engine.In("(id)", 5).Delete(&Userinfo{})
if err != nil {
t.Error(err)
panic(err)
}
if cnt != 1 {
err = errors.New("deleted records not 1")
t.Error(err)
panic(err)
}
}
func limit(engine *Engine, t *testing.T) {

View File

@ -1015,6 +1015,9 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
if columnStr == "" {
columnStr = session.Statement.genColumnStr()
}
session.Statement.attachInSql()
sql = session.Statement.genSelectSql(columnStr)
args = append(session.Statement.Params, session.Statement.BeanArgs...)
} else {
@ -2508,7 +2511,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}
var sql string
var sql, inSql string
var inArgs []interface{}
if table.Version != "" && session.Statement.checkVersion {
if condition != "" {
condition = fmt.Sprintf("WHERE (%v) AND %v = ?", condition,
@ -2516,6 +2520,15 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} else {
condition = fmt.Sprintf("WHERE %v = ?", session.Engine.Quote(table.Version))
}
inSql, inArgs = session.Statement.genInSql()
if len(inSql) > 0 {
if condition != "" {
condition += " AND " + inSql
} else {
condition = "WHERE " + inSql
}
}
sql = fmt.Sprintf("UPDATE %v SET %v, %v %v",
session.Engine.Quote(session.Statement.TableName()),
strings.Join(colNames, ", "),
@ -2527,13 +2540,24 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if condition != "" {
condition = "WHERE " + condition
}
inSql, inArgs = session.Statement.genInSql()
if len(inSql) > 0 {
if condition != "" {
condition += " AND " + inSql
} else {
condition = "WHERE " + inSql
}
}
sql = fmt.Sprintf("UPDATE %v SET %v %v",
session.Engine.Quote(session.Statement.TableName()),
strings.Join(colNames, ", "),
condition)
}
args = append(append(args, st.Params...), condiArgs...)
args = append(args, st.Params...)
args = append(args, inArgs...)
args = append(args, condiArgs...)
res, err := session.exec(sql, args...)
if err != nil {
@ -2660,10 +2684,18 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
if session.Statement.WhereStr != "" {
condition = session.Statement.WhereStr
if len(colNames) > 0 {
condition += " and " + strings.Join(colNames, " and ")
condition += " AND " + strings.Join(colNames, " AND ")
}
} else {
condition = strings.Join(colNames, " and ")
condition = strings.Join(colNames, " AND ")
}
inSql, inArgs := session.Statement.genInSql()
if len(inSql) > 0 {
if len(condition) > 0 {
condition += " AND "
}
condition += inSql
args = append(args, inArgs...)
}
if len(condition) == 0 {
return 0, ErrNeedDeletedCond

View File

@ -39,6 +39,7 @@ type Statement struct {
allUseBool bool
checkVersion bool
boolColumnMap map[string]bool
inColumns map[string][]interface{}
}
// init
@ -67,6 +68,7 @@ func (statement *Statement) Init() {
statement.allUseBool = false
statement.boolColumnMap = make(map[string]bool)
statement.checkVersion = true
statement.inColumns = make(map[string][]interface{})
}
// add the raw sql statement
@ -406,17 +408,45 @@ func (statement *Statement) Id(id int64) *Statement {
// Generate "Where column IN (?) " statment
func (statement *Statement) In(column string, args ...interface{}) *Statement {
inStr := fmt.Sprintf("%v IN (%v)", column, strings.Join(makeArray("?", len(args)), ","))
if statement.WhereStr == "" {
statement.WhereStr = inStr
statement.Params = args
k := strings.ToLower(column)
if params, ok := statement.inColumns[k]; ok {
statement.inColumns[k] = append(params, args...)
} else {
statement.WhereStr = statement.WhereStr + " AND " + inStr
statement.Params = append(statement.Params, args...)
statement.inColumns[k] = args
}
return statement
}
func (statement *Statement) genInSql() (string, []interface{}) {
if len(statement.inColumns) == 0 {
return "", []interface{}{}
}
inStrs := make([]string, 0, len(statement.inColumns))
args := make([]interface{}, 0)
for column, params := range statement.inColumns {
inStrs = append(inStrs, fmt.Sprintf("(%v IN (%v))", statement.Engine.Quote(column),
strings.Join(makeArray("?", len(params)), ",")))
args = append(args, params...)
}
if len(statement.inColumns) == 1 {
return inStrs[0], args
}
return fmt.Sprintf("(%v)", strings.Join(inStrs, " AND ")), args
}
func (statement *Statement) attachInSql() {
inSql, inArgs := statement.genInSql()
if len(inSql) > 0 {
if statement.ConditionStr != "" {
statement.ConditionStr += " AND "
}
statement.ConditionStr += inSql
statement.Params = append(statement.Params, inArgs...)
}
}
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0)
for _, col := range columns {
@ -608,7 +638,7 @@ func (s *Statement) genDropSQL() string {
}
// !nashtsai! REVIEW, Statement is a huge struct why is this method not passing *Statement?
func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) {
func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) {
table := statement.Engine.autoMap(bean)
statement.RefTable = table
@ -647,7 +677,7 @@ func (s *Statement) genAddUniqueStr(uqeName string, cols []string) (string, []in
return sql, []interface{}{}
}
func (statement Statement) genCountSql(bean interface{}) (string, []interface{}) {
func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) {
table := statement.Engine.autoMap(bean)
statement.RefTable = table
@ -663,7 +693,7 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{})
return statement.genSelectSql(fmt.Sprintf("COUNT(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...)
}
func (statement Statement) genSelectSql(columnStr string) (a string) {
func (statement *Statement) genSelectSql(columnStr string) (a string) {
if statement.GroupByStr != "" {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
statement.GroupByStr = columnStr
@ -680,11 +710,12 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ConditionStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ConditionStr)
a = fmt.Sprintf("%v AND %v", a, statement.ConditionStr)
}
} else if statement.ConditionStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ConditionStr)
}
if statement.GroupByStr != "" {
a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
}