multiple In() supports
This commit is contained in:
parent
82bdc0ec5a
commit
bb6a9c24fa
57
base_test.go
57
base_test.go
|
@ -520,6 +520,63 @@ func in(engine *Engine, t *testing.T) {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
fmt.Println(users)
|
fmt.Println(users)
|
||||||
|
|
||||||
|
err = engine.In("(id)", 1).In("(id)", 2).In("departname", "dev").Find(&users)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Println(users)
|
||||||
|
|
||||||
|
cnt, err := engine.In("(id)", 4).Update(&Userinfo{Departname: "dev-"})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if cnt != 1 {
|
||||||
|
err = errors.New("update records not 1")
|
||||||
|
t.Error(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := new(Userinfo)
|
||||||
|
has, err := engine.Id(4).Get(user)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if !has {
|
||||||
|
err = errors.New("get record not 1")
|
||||||
|
t.Error(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if user.Departname != "dev-" {
|
||||||
|
err = errors.New("update not success")
|
||||||
|
t.Error(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cnt, err = engine.In("(id)", 4).Update(&Userinfo{Departname: "dev"})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if cnt != 1 {
|
||||||
|
err = errors.New("update records not 1")
|
||||||
|
t.Error(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cnt, err = engine.In("(id)", 5).Delete(&Userinfo{})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if cnt != 1 {
|
||||||
|
err = errors.New("deleted records not 1")
|
||||||
|
t.Error(err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func limit(engine *Engine, t *testing.T) {
|
func limit(engine *Engine, t *testing.T) {
|
||||||
|
|
40
session.go
40
session.go
|
@ -1015,6 +1015,9 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
|
||||||
if columnStr == "" {
|
if columnStr == "" {
|
||||||
columnStr = session.Statement.genColumnStr()
|
columnStr = session.Statement.genColumnStr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
session.Statement.attachInSql()
|
||||||
|
|
||||||
sql = session.Statement.genSelectSql(columnStr)
|
sql = session.Statement.genSelectSql(columnStr)
|
||||||
args = append(session.Statement.Params, session.Statement.BeanArgs...)
|
args = append(session.Statement.Params, session.Statement.BeanArgs...)
|
||||||
} else {
|
} else {
|
||||||
|
@ -2508,7 +2511,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var sql string
|
var sql, inSql string
|
||||||
|
var inArgs []interface{}
|
||||||
if table.Version != "" && session.Statement.checkVersion {
|
if table.Version != "" && session.Statement.checkVersion {
|
||||||
if condition != "" {
|
if condition != "" {
|
||||||
condition = fmt.Sprintf("WHERE (%v) AND %v = ?", condition,
|
condition = fmt.Sprintf("WHERE (%v) AND %v = ?", condition,
|
||||||
|
@ -2516,6 +2520,15 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
} else {
|
} else {
|
||||||
condition = fmt.Sprintf("WHERE %v = ?", session.Engine.Quote(table.Version))
|
condition = fmt.Sprintf("WHERE %v = ?", session.Engine.Quote(table.Version))
|
||||||
}
|
}
|
||||||
|
inSql, inArgs = session.Statement.genInSql()
|
||||||
|
if len(inSql) > 0 {
|
||||||
|
if condition != "" {
|
||||||
|
condition += " AND " + inSql
|
||||||
|
} else {
|
||||||
|
condition = "WHERE " + inSql
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sql = fmt.Sprintf("UPDATE %v SET %v, %v %v",
|
sql = fmt.Sprintf("UPDATE %v SET %v, %v %v",
|
||||||
session.Engine.Quote(session.Statement.TableName()),
|
session.Engine.Quote(session.Statement.TableName()),
|
||||||
strings.Join(colNames, ", "),
|
strings.Join(colNames, ", "),
|
||||||
|
@ -2527,13 +2540,24 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
if condition != "" {
|
if condition != "" {
|
||||||
condition = "WHERE " + condition
|
condition = "WHERE " + condition
|
||||||
}
|
}
|
||||||
|
inSql, inArgs = session.Statement.genInSql()
|
||||||
|
if len(inSql) > 0 {
|
||||||
|
if condition != "" {
|
||||||
|
condition += " AND " + inSql
|
||||||
|
} else {
|
||||||
|
condition = "WHERE " + inSql
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sql = fmt.Sprintf("UPDATE %v SET %v %v",
|
sql = fmt.Sprintf("UPDATE %v SET %v %v",
|
||||||
session.Engine.Quote(session.Statement.TableName()),
|
session.Engine.Quote(session.Statement.TableName()),
|
||||||
strings.Join(colNames, ", "),
|
strings.Join(colNames, ", "),
|
||||||
condition)
|
condition)
|
||||||
}
|
}
|
||||||
|
|
||||||
args = append(append(args, st.Params...), condiArgs...)
|
args = append(args, st.Params...)
|
||||||
|
args = append(args, inArgs...)
|
||||||
|
args = append(args, condiArgs...)
|
||||||
|
|
||||||
res, err := session.exec(sql, args...)
|
res, err := session.exec(sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2660,10 +2684,18 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
|
||||||
if session.Statement.WhereStr != "" {
|
if session.Statement.WhereStr != "" {
|
||||||
condition = session.Statement.WhereStr
|
condition = session.Statement.WhereStr
|
||||||
if len(colNames) > 0 {
|
if len(colNames) > 0 {
|
||||||
condition += " and " + strings.Join(colNames, " and ")
|
condition += " AND " + strings.Join(colNames, " AND ")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
condition = strings.Join(colNames, " and ")
|
condition = strings.Join(colNames, " AND ")
|
||||||
|
}
|
||||||
|
inSql, inArgs := session.Statement.genInSql()
|
||||||
|
if len(inSql) > 0 {
|
||||||
|
if len(condition) > 0 {
|
||||||
|
condition += " AND "
|
||||||
|
}
|
||||||
|
condition += inSql
|
||||||
|
args = append(args, inArgs...)
|
||||||
}
|
}
|
||||||
if len(condition) == 0 {
|
if len(condition) == 0 {
|
||||||
return 0, ErrNeedDeletedCond
|
return 0, ErrNeedDeletedCond
|
||||||
|
|
51
statement.go
51
statement.go
|
@ -39,6 +39,7 @@ type Statement struct {
|
||||||
allUseBool bool
|
allUseBool bool
|
||||||
checkVersion bool
|
checkVersion bool
|
||||||
boolColumnMap map[string]bool
|
boolColumnMap map[string]bool
|
||||||
|
inColumns map[string][]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// init
|
// init
|
||||||
|
@ -67,6 +68,7 @@ func (statement *Statement) Init() {
|
||||||
statement.allUseBool = false
|
statement.allUseBool = false
|
||||||
statement.boolColumnMap = make(map[string]bool)
|
statement.boolColumnMap = make(map[string]bool)
|
||||||
statement.checkVersion = true
|
statement.checkVersion = true
|
||||||
|
statement.inColumns = make(map[string][]interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// add the raw sql statement
|
// add the raw sql statement
|
||||||
|
@ -406,17 +408,45 @@ func (statement *Statement) Id(id int64) *Statement {
|
||||||
|
|
||||||
// Generate "Where column IN (?) " statment
|
// Generate "Where column IN (?) " statment
|
||||||
func (statement *Statement) In(column string, args ...interface{}) *Statement {
|
func (statement *Statement) In(column string, args ...interface{}) *Statement {
|
||||||
inStr := fmt.Sprintf("%v IN (%v)", column, strings.Join(makeArray("?", len(args)), ","))
|
k := strings.ToLower(column)
|
||||||
if statement.WhereStr == "" {
|
if params, ok := statement.inColumns[k]; ok {
|
||||||
statement.WhereStr = inStr
|
statement.inColumns[k] = append(params, args...)
|
||||||
statement.Params = args
|
|
||||||
} else {
|
} else {
|
||||||
statement.WhereStr = statement.WhereStr + " AND " + inStr
|
statement.inColumns[k] = args
|
||||||
statement.Params = append(statement.Params, args...)
|
|
||||||
}
|
}
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) genInSql() (string, []interface{}) {
|
||||||
|
if len(statement.inColumns) == 0 {
|
||||||
|
return "", []interface{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
inStrs := make([]string, 0, len(statement.inColumns))
|
||||||
|
args := make([]interface{}, 0)
|
||||||
|
for column, params := range statement.inColumns {
|
||||||
|
inStrs = append(inStrs, fmt.Sprintf("(%v IN (%v))", statement.Engine.Quote(column),
|
||||||
|
strings.Join(makeArray("?", len(params)), ",")))
|
||||||
|
args = append(args, params...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(statement.inColumns) == 1 {
|
||||||
|
return inStrs[0], args
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("(%v)", strings.Join(inStrs, " AND ")), args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) attachInSql() {
|
||||||
|
inSql, inArgs := statement.genInSql()
|
||||||
|
if len(inSql) > 0 {
|
||||||
|
if statement.ConditionStr != "" {
|
||||||
|
statement.ConditionStr += " AND "
|
||||||
|
}
|
||||||
|
statement.ConditionStr += inSql
|
||||||
|
statement.Params = append(statement.Params, inArgs...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func col2NewCols(columns ...string) []string {
|
func col2NewCols(columns ...string) []string {
|
||||||
newColumns := make([]string, 0)
|
newColumns := make([]string, 0)
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
|
@ -608,7 +638,7 @@ func (s *Statement) genDropSQL() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// !nashtsai! REVIEW, Statement is a huge struct why is this method not passing *Statement?
|
// !nashtsai! REVIEW, Statement is a huge struct why is this method not passing *Statement?
|
||||||
func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) {
|
func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) {
|
||||||
table := statement.Engine.autoMap(bean)
|
table := statement.Engine.autoMap(bean)
|
||||||
statement.RefTable = table
|
statement.RefTable = table
|
||||||
|
|
||||||
|
@ -647,7 +677,7 @@ func (s *Statement) genAddUniqueStr(uqeName string, cols []string) (string, []in
|
||||||
return sql, []interface{}{}
|
return sql, []interface{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement Statement) genCountSql(bean interface{}) (string, []interface{}) {
|
func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) {
|
||||||
table := statement.Engine.autoMap(bean)
|
table := statement.Engine.autoMap(bean)
|
||||||
statement.RefTable = table
|
statement.RefTable = table
|
||||||
|
|
||||||
|
@ -663,7 +693,7 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{})
|
||||||
return statement.genSelectSql(fmt.Sprintf("COUNT(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...)
|
return statement.genSelectSql(fmt.Sprintf("COUNT(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement Statement) genSelectSql(columnStr string) (a string) {
|
func (statement *Statement) genSelectSql(columnStr string) (a string) {
|
||||||
if statement.GroupByStr != "" {
|
if statement.GroupByStr != "" {
|
||||||
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
|
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
|
||||||
statement.GroupByStr = columnStr
|
statement.GroupByStr = columnStr
|
||||||
|
@ -680,11 +710,12 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
|
||||||
if statement.WhereStr != "" {
|
if statement.WhereStr != "" {
|
||||||
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
|
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
|
||||||
if statement.ConditionStr != "" {
|
if statement.ConditionStr != "" {
|
||||||
a = fmt.Sprintf("%v and %v", a, statement.ConditionStr)
|
a = fmt.Sprintf("%v AND %v", a, statement.ConditionStr)
|
||||||
}
|
}
|
||||||
} else if statement.ConditionStr != "" {
|
} else if statement.ConditionStr != "" {
|
||||||
a = fmt.Sprintf("%v WHERE %v", a, statement.ConditionStr)
|
a = fmt.Sprintf("%v WHERE %v", a, statement.ConditionStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if statement.GroupByStr != "" {
|
if statement.GroupByStr != "" {
|
||||||
a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
|
a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue