diff --git a/base_test.go b/base_test.go index 2cac2d4c..f2102100 100644 --- a/base_test.go +++ b/base_test.go @@ -1359,6 +1359,46 @@ func testDistinct(engine *Engine, t *testing.T) { fmt.Println(users2) } +func testUseBool(engine *Engine, t *testing.T) { + cnt1, err := engine.Count(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } + + users := make([]Userinfo, 0) + err = engine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + var fNumber int64 + for _, u := range users { + if u.IsMan == false { + fNumber += 1 + } + } + + cnt2, err := engine.UseBool().Update(&Userinfo{IsMan: true}) + if err != nil { + t.Error(err) + panic(err) + } + if fNumber != cnt2 { + fmt.Println("cnt1", cnt1, "fNumber", fNumber, "cnt2", cnt2) + /*err = errors.New("Updated number is not corrected.") + t.Error(err) + panic(err)*/ + } + + _, err = engine.Update(&Userinfo{IsMan: true}) + if err == nil { + err = errors.New("error condition") + t.Error(err) + panic(err) + } +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -1447,6 +1487,8 @@ func testAll2(engine *Engine, t *testing.T) { testVersion(engine, t) fmt.Println("-------------- testDistinct --------------") testDistinct(engine, t) + fmt.Println("-------------- testUseBool --------------") + testUseBool(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) } diff --git a/engine.go b/engine.go index 88a26385..117e1ca9 100644 --- a/engine.go +++ b/engine.go @@ -273,6 +273,12 @@ func (engine *Engine) Cols(columns ...string) *Session { return session.Cols(columns...) } +func (engine *Engine) UseBool(columns ...string) *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.UseBool(columns...) +} + func (engine *Engine) Omit(columns ...string) *Session { session := engine.NewSession() session.IsAutoClose = true diff --git a/session.go b/session.go index 6bf73274..45f74abc 100644 --- a/session.go +++ b/session.go @@ -94,6 +94,11 @@ func (session *Session) Cols(columns ...string) *Session { return session } +func (session *Session) UseBool(columns ...string) *Session { + session.Statement.UseBool(columns...) + return session +} + func (session *Session) Distinct(columns ...string) *Session { session.Statement.Distinct(columns...) return session @@ -871,8 +876,9 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } if len(condiBean) > 0 { - colNames, args := buildConditions(session.Engine, table, condiBean[0], true) - session.Statement.ConditionStr = strings.Join(colNames, " and ") + colNames, args := buildConditions(session.Engine, table, condiBean[0], true, + session.Statement.allUseBool, session.Statement.boolColumnMap) + session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args } @@ -1885,7 +1891,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 session.Statement.RefTable = table if session.Statement.ColumnStr == "" { - colNames, args = buildConditions(session.Engine, table, bean, false) + colNames, args = buildConditions(session.Engine, table, bean, false, + session.Statement.allUseBool, session.Statement.boolColumnMap) } else { colNames, args, err = table.genCols(session, bean, true, true) if err != nil { @@ -1918,7 +1925,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var condiArgs []interface{} if len(condiBean) > 0 { - condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true) + condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, + session.Statement.allUseBool, session.Statement.boolColumnMap) } var condition = "" @@ -2029,7 +2037,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) { table := session.Engine.AutoMap(bean) session.Statement.RefTable = table - colNames, args := buildConditions(session.Engine, table, bean, true) + colNames, args := buildConditions(session.Engine, table, bean, true, + session.Statement.allUseBool, session.Statement.boolColumnMap) var condition = "" if session.Statement.WhereStr != "" { diff --git a/statement.go b/statement.go index 6d2f3cc7..94daa433 100644 --- a/statement.go +++ b/statement.go @@ -10,31 +10,33 @@ import ( ) type Statement struct { - RefTable *Table - Engine *Engine - Start int - LimitN int - WhereStr string - Params []interface{} - OrderStr string - JoinStr string - GroupByStr string - HavingStr string - ColumnStr string - columnMap map[string]bool - OmitStr string - ConditionStr string - AltTableName string - RawSQL string - RawParams []interface{} - UseCascade bool - UseAutoJoin bool - StoreEngine string - Charset string - BeanArgs []interface{} - UseCache bool - UseAutoTime bool - IsDistinct bool + RefTable *Table + Engine *Engine + Start int + LimitN int + WhereStr string + Params []interface{} + OrderStr string + JoinStr string + GroupByStr string + HavingStr string + ColumnStr string + columnMap map[string]bool + OmitStr string + ConditionStr string + AltTableName string + RawSQL string + RawParams []interface{} + UseCascade bool + UseAutoJoin bool + StoreEngine string + Charset string + BeanArgs []interface{} + UseCache bool + UseAutoTime bool + IsDistinct bool + allUseBool bool + boolColumnMap map[string]bool } func (statement *Statement) Init() { @@ -59,6 +61,8 @@ func (statement *Statement) Init() { statement.UseCache = statement.Engine.UseCache statement.UseAutoTime = true statement.IsDistinct = false + statement.allUseBool = false + statement.boolColumnMap = make(map[string]bool) } func (statement *Statement) Sql(querystring string, args ...interface{}) { @@ -99,7 +103,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) { } // Auto generating conditions according a struct -func buildConditions(engine *Engine, table *Table, bean interface{}, includeVersion bool) ([]string, []interface{}) { +func buildConditions(engine *Engine, table *Table, bean interface{}, includeVersion bool, allUseBool bool, boolColumnMap map[string]bool) ([]string, []interface{}) { colNames := make([]string, 0) var args = make([]interface{}, 0) for _, col := range table.Columns { @@ -111,10 +115,15 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers var val interface{} switch fieldType.Kind() { case reflect.Bool: - continue - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - val = fieldValue.Interface() + if allUseBool { + val = fieldValue.Interface() + } else if _, ok := boolColumnMap[col.Name]; ok { + val = fieldValue.Interface() + } else { + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + continue + } case reflect.String: if fieldValue.String() == "" { continue @@ -219,7 +228,7 @@ func (statement *Statement) Id(id int64) { statement.WhereStr = "(id)=?" statement.Params = []interface{}{id} } else { - statement.WhereStr = statement.WhereStr + " and (id)=?" + statement.WhereStr = statement.WhereStr + " AND (id)=?" statement.Params = append(statement.Params, id) } } @@ -230,7 +239,7 @@ func (statement *Statement) In(column string, args ...interface{}) { statement.WhereStr = inStr statement.Params = args } else { - statement.WhereStr = statement.WhereStr + " and " + inStr + statement.WhereStr = statement.WhereStr + " AND " + inStr statement.Params = append(statement.Params, args...) } } @@ -261,6 +270,17 @@ func (statement *Statement) Cols(columns ...string) { statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) } +func (statement *Statement) UseBool(columns ...string) { + if len(columns) > 0 { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.boolColumnMap[nc] = true + } + } else { + statement.allUseBool = true + } +} + func (statement *Statement) Omit(columns ...string) { newColumns := col2NewCols(columns...) for _, nc := range newColumns { @@ -396,8 +416,9 @@ func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) { table := statement.Engine.AutoMap(bean) statement.RefTable = table - colNames, args := buildConditions(statement.Engine, table, bean, true) - statement.ConditionStr = strings.Join(colNames, " and ") + colNames, args := buildConditions(statement.Engine, table, bean, true, + statement.allUseBool, statement.boolColumnMap) + statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args var columnStr string = statement.ColumnStr @@ -433,14 +454,14 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{}) table := statement.Engine.AutoMap(bean) statement.RefTable = table - colNames, args := buildConditions(statement.Engine, table, bean, true) - statement.ConditionStr = strings.Join(colNames, " and ") + colNames, args := buildConditions(statement.Engine, table, bean, true, statement.allUseBool, statement.boolColumnMap) + statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args var id string = "*" if table.PrimaryKey != "" { id = statement.Engine.Quote(table.PrimaryKey) } - return statement.genSelectSql(fmt.Sprintf("count(%v) as %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...) + return statement.genSelectSql(fmt.Sprintf("COUNT(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...) } func (statement Statement) genSelectSql(columnStr string) (a string) {