Add UseBool method for MUST use bool on a struct as condition or update column

This commit is contained in:
Lunny Xiao 2013-11-15 10:16:08 +08:00
parent 1a64d60e06
commit 74ec8ba9d2
4 changed files with 120 additions and 42 deletions

View File

@ -1359,6 +1359,46 @@ func testDistinct(engine *Engine, t *testing.T) {
fmt.Println(users2) fmt.Println(users2)
} }
func testUseBool(engine *Engine, t *testing.T) {
cnt1, err := engine.Count(&Userinfo{})
if err != nil {
t.Error(err)
panic(err)
}
users := make([]Userinfo, 0)
err = engine.Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
var fNumber int64
for _, u := range users {
if u.IsMan == false {
fNumber += 1
}
}
cnt2, err := engine.UseBool().Update(&Userinfo{IsMan: true})
if err != nil {
t.Error(err)
panic(err)
}
if fNumber != cnt2 {
fmt.Println("cnt1", cnt1, "fNumber", fNumber, "cnt2", cnt2)
/*err = errors.New("Updated number is not corrected.")
t.Error(err)
panic(err)*/
}
_, err = engine.Update(&Userinfo{IsMan: true})
if err == nil {
err = errors.New("error condition")
t.Error(err)
panic(err)
}
}
func testAll(engine *Engine, t *testing.T) { func testAll(engine *Engine, t *testing.T) {
fmt.Println("-------------- directCreateTable --------------") fmt.Println("-------------- directCreateTable --------------")
directCreateTable(engine, t) directCreateTable(engine, t)
@ -1447,6 +1487,8 @@ func testAll2(engine *Engine, t *testing.T) {
testVersion(engine, t) testVersion(engine, t)
fmt.Println("-------------- testDistinct --------------") fmt.Println("-------------- testDistinct --------------")
testDistinct(engine, t) testDistinct(engine, t)
fmt.Println("-------------- testUseBool --------------")
testUseBool(engine, t)
fmt.Println("-------------- transaction --------------") fmt.Println("-------------- transaction --------------")
transaction(engine, t) transaction(engine, t)
} }

View File

@ -273,6 +273,12 @@ func (engine *Engine) Cols(columns ...string) *Session {
return session.Cols(columns...) return session.Cols(columns...)
} }
func (engine *Engine) UseBool(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
return session.UseBool(columns...)
}
func (engine *Engine) Omit(columns ...string) *Session { func (engine *Engine) Omit(columns ...string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.IsAutoClose = true

View File

@ -94,6 +94,11 @@ func (session *Session) Cols(columns ...string) *Session {
return session return session
} }
func (session *Session) UseBool(columns ...string) *Session {
session.Statement.UseBool(columns...)
return session
}
func (session *Session) Distinct(columns ...string) *Session { func (session *Session) Distinct(columns ...string) *Session {
session.Statement.Distinct(columns...) session.Statement.Distinct(columns...)
return session return session
@ -871,8 +876,9 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
if len(condiBean) > 0 { if len(condiBean) > 0 {
colNames, args := buildConditions(session.Engine, table, condiBean[0], true) colNames, args := buildConditions(session.Engine, table, condiBean[0], true,
session.Statement.ConditionStr = strings.Join(colNames, " and ") session.Statement.allUseBool, session.Statement.boolColumnMap)
session.Statement.ConditionStr = strings.Join(colNames, " AND ")
session.Statement.BeanArgs = args session.Statement.BeanArgs = args
} }
@ -1885,7 +1891,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
session.Statement.RefTable = table session.Statement.RefTable = table
if session.Statement.ColumnStr == "" { if session.Statement.ColumnStr == "" {
colNames, args = buildConditions(session.Engine, table, bean, false) colNames, args = buildConditions(session.Engine, table, bean, false,
session.Statement.allUseBool, session.Statement.boolColumnMap)
} else { } else {
colNames, args, err = table.genCols(session, bean, true, true) colNames, args, err = table.genCols(session, bean, true, true)
if err != nil { if err != nil {
@ -1918,7 +1925,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var condiArgs []interface{} var condiArgs []interface{}
if len(condiBean) > 0 { if len(condiBean) > 0 {
condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true) condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true,
session.Statement.allUseBool, session.Statement.boolColumnMap)
} }
var condition = "" var condition = ""
@ -2029,7 +2037,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
table := session.Engine.AutoMap(bean) table := session.Engine.AutoMap(bean)
session.Statement.RefTable = table session.Statement.RefTable = table
colNames, args := buildConditions(session.Engine, table, bean, true) colNames, args := buildConditions(session.Engine, table, bean, true,
session.Statement.allUseBool, session.Statement.boolColumnMap)
var condition = "" var condition = ""
if session.Statement.WhereStr != "" { if session.Statement.WhereStr != "" {

View File

@ -10,31 +10,33 @@ import (
) )
type Statement struct { type Statement struct {
RefTable *Table RefTable *Table
Engine *Engine Engine *Engine
Start int Start int
LimitN int LimitN int
WhereStr string WhereStr string
Params []interface{} Params []interface{}
OrderStr string OrderStr string
JoinStr string JoinStr string
GroupByStr string GroupByStr string
HavingStr string HavingStr string
ColumnStr string ColumnStr string
columnMap map[string]bool columnMap map[string]bool
OmitStr string OmitStr string
ConditionStr string ConditionStr string
AltTableName string AltTableName string
RawSQL string RawSQL string
RawParams []interface{} RawParams []interface{}
UseCascade bool UseCascade bool
UseAutoJoin bool UseAutoJoin bool
StoreEngine string StoreEngine string
Charset string Charset string
BeanArgs []interface{} BeanArgs []interface{}
UseCache bool UseCache bool
UseAutoTime bool UseAutoTime bool
IsDistinct bool IsDistinct bool
allUseBool bool
boolColumnMap map[string]bool
} }
func (statement *Statement) Init() { func (statement *Statement) Init() {
@ -59,6 +61,8 @@ func (statement *Statement) Init() {
statement.UseCache = statement.Engine.UseCache statement.UseCache = statement.Engine.UseCache
statement.UseAutoTime = true statement.UseAutoTime = true
statement.IsDistinct = false statement.IsDistinct = false
statement.allUseBool = false
statement.boolColumnMap = make(map[string]bool)
} }
func (statement *Statement) Sql(querystring string, args ...interface{}) { func (statement *Statement) Sql(querystring string, args ...interface{}) {
@ -99,7 +103,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) {
} }
// Auto generating conditions according a struct // Auto generating conditions according a struct
func buildConditions(engine *Engine, table *Table, bean interface{}, includeVersion bool) ([]string, []interface{}) { func buildConditions(engine *Engine, table *Table, bean interface{}, includeVersion bool, allUseBool bool, boolColumnMap map[string]bool) ([]string, []interface{}) {
colNames := make([]string, 0) colNames := make([]string, 0)
var args = make([]interface{}, 0) var args = make([]interface{}, 0)
for _, col := range table.Columns { for _, col := range table.Columns {
@ -111,10 +115,15 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, includeVers
var val interface{} var val interface{}
switch fieldType.Kind() { switch fieldType.Kind() {
case reflect.Bool: case reflect.Bool:
continue if allUseBool {
// if a bool in a struct, it will not be as a condition because it default is false, val = fieldValue.Interface()
// please use Where() instead } else if _, ok := boolColumnMap[col.Name]; ok {
val = fieldValue.Interface() val = fieldValue.Interface()
} else {
// if a bool in a struct, it will not be as a condition because it default is false,
// please use Where() instead
continue
}
case reflect.String: case reflect.String:
if fieldValue.String() == "" { if fieldValue.String() == "" {
continue continue
@ -219,7 +228,7 @@ func (statement *Statement) Id(id int64) {
statement.WhereStr = "(id)=?" statement.WhereStr = "(id)=?"
statement.Params = []interface{}{id} statement.Params = []interface{}{id}
} else { } else {
statement.WhereStr = statement.WhereStr + " and (id)=?" statement.WhereStr = statement.WhereStr + " AND (id)=?"
statement.Params = append(statement.Params, id) statement.Params = append(statement.Params, id)
} }
} }
@ -230,7 +239,7 @@ func (statement *Statement) In(column string, args ...interface{}) {
statement.WhereStr = inStr statement.WhereStr = inStr
statement.Params = args statement.Params = args
} else { } else {
statement.WhereStr = statement.WhereStr + " and " + inStr statement.WhereStr = statement.WhereStr + " AND " + inStr
statement.Params = append(statement.Params, args...) statement.Params = append(statement.Params, args...)
} }
} }
@ -261,6 +270,17 @@ func (statement *Statement) Cols(columns ...string) {
statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
} }
func (statement *Statement) UseBool(columns ...string) {
if len(columns) > 0 {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.boolColumnMap[nc] = true
}
} else {
statement.allUseBool = true
}
}
func (statement *Statement) Omit(columns ...string) { func (statement *Statement) Omit(columns ...string) {
newColumns := col2NewCols(columns...) newColumns := col2NewCols(columns...)
for _, nc := range newColumns { for _, nc := range newColumns {
@ -396,8 +416,9 @@ func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) {
table := statement.Engine.AutoMap(bean) table := statement.Engine.AutoMap(bean)
statement.RefTable = table statement.RefTable = table
colNames, args := buildConditions(statement.Engine, table, bean, true) colNames, args := buildConditions(statement.Engine, table, bean, true,
statement.ConditionStr = strings.Join(colNames, " and ") statement.allUseBool, statement.boolColumnMap)
statement.ConditionStr = strings.Join(colNames, " AND ")
statement.BeanArgs = args statement.BeanArgs = args
var columnStr string = statement.ColumnStr var columnStr string = statement.ColumnStr
@ -433,14 +454,14 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{})
table := statement.Engine.AutoMap(bean) table := statement.Engine.AutoMap(bean)
statement.RefTable = table statement.RefTable = table
colNames, args := buildConditions(statement.Engine, table, bean, true) colNames, args := buildConditions(statement.Engine, table, bean, true, statement.allUseBool, statement.boolColumnMap)
statement.ConditionStr = strings.Join(colNames, " and ") statement.ConditionStr = strings.Join(colNames, " AND ")
statement.BeanArgs = args statement.BeanArgs = args
var id string = "*" var id string = "*"
if table.PrimaryKey != "" { if table.PrimaryKey != "" {
id = statement.Engine.Quote(table.PrimaryKey) id = statement.Engine.Quote(table.PrimaryKey)
} }
return statement.genSelectSql(fmt.Sprintf("count(%v) as %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...) return statement.genSelectSql(fmt.Sprintf("COUNT(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...)
} }
func (statement Statement) genSelectSql(columnStr string) (a string) { func (statement Statement) genSelectSql(columnStr string) (a string) {