diff --git a/base_test.go b/base_test.go index 92eb6290..53d15bfc 100644 --- a/base_test.go +++ b/base_test.go @@ -370,6 +370,53 @@ func update(engine *Engine, t *testing.T) { panic(err) return } + + type UpdateAllCols struct { + Id int64 + Bool bool + String string + } + + col1 := &UpdateAllCols{} + err = engine.Sync(col1) + if err != nil { + t.Error(err) + panic(err) + } + + _, err = engine.Insert(col1) + if err != nil { + t.Error(err) + panic(err) + } + + col2 := &UpdateAllCols{col1.Id, true, ""} + _, err = engine.Id(col2.Id).AllCols().Update(col2) + if err != nil { + t.Error(err) + panic(err) + } + + col3 := &UpdateAllCols{} + has, err := engine.Id(col2.Id).Get(col3) + if err != nil { + t.Error(err) + panic(err) + } + + if !has { + err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) + t.Error(err) + panic(err) + return + } + + if *col2 != *col3 { + err = errors.New(fmt.Sprintf("col2 should eq col3")) + t.Error(err) + panic(err) + return + } } func updateSameMapper(engine *Engine, t *testing.T) { diff --git a/session.go b/session.go index 825acc46..731027ce 100644 --- a/session.go +++ b/session.go @@ -126,6 +126,11 @@ func (session *Session) Cols(columns ...string) *Session { return session } +func (session *Session) AllCols() *Session { + session.Statement.AllCols() + return session +} + func (session *Session) NoCascade() *Session { session.Statement.UseCascade = false return session @@ -1023,7 +1028,8 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if len(condiBean) > 0 { colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true, - false, true, session.Statement.allUseBool, session.Statement.boolColumnMap) + false, true, session.Statement.allUseBool, session.Statement.useAllCols, + session.Statement.boolColumnMap) session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args } @@ -2838,7 +2844,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.Statement.ColumnStr == "" { colNames, args = buildConditions(session.Engine, table, bean, false, false, - false, false, session.Statement.allUseBool, session.Statement.boolColumnMap) + false, false, session.Statement.allUseBool, session.Statement.useAllCols, + session.Statement.boolColumnMap) } else { colNames, args, err = table.genCols(session, bean, true, true) if err != nil { @@ -2872,7 +2879,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condiBean) > 0 { condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, true, - false, true, session.Statement.allUseBool, session.Statement.boolColumnMap) + false, true, session.Statement.allUseBool, session.Statement.useAllCols, + session.Statement.boolColumnMap) } var condition = "" @@ -3060,7 +3068,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) { table := session.Engine.autoMap(bean) session.Statement.RefTable = table colNames, args := buildConditions(session.Engine, table, bean, true, true, - false, true, session.Statement.allUseBool, session.Statement.boolColumnMap) + false, true, session.Statement.allUseBool, session.Statement.useAllCols, + session.Statement.boolColumnMap) var condition = "" diff --git a/statement.go b/statement.go index 4bde5c7b..80c5dccb 100644 --- a/statement.go +++ b/statement.go @@ -25,6 +25,7 @@ type Statement struct { HavingStr string ColumnStr string columnMap map[string]bool + useAllCols bool OmitStr string ConditionStr string AltTableName string @@ -239,7 +240,8 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { // Auto generating conditions according a struct func buildConditions(engine *Engine, table *Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, + includeVersion bool, includeUpdated bool, includeNil bool, + includeAutoIncr bool, allUseBool bool, useAllCols bool, boolColumnMap map[string]bool) ([]string, []interface{}) { colNames := make([]string, 0) @@ -262,7 +264,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, fieldValue := col.ValueOf(bean) fieldType := reflect.TypeOf(fieldValue.Interface()) - requiredField := false + requiredField := useAllCols if fieldType.Kind() == reflect.Ptr { if fieldValue.IsNil() { if includeNil { @@ -517,6 +519,11 @@ func (statement *Statement) Cols(columns ...string) *Statement { return statement } +func (statement *Statement) AllCols() *Statement { + statement.useAllCols = true + return statement +} + // indicates that use bool fields as update contents and query contiditions func (statement *Statement) UseBool(columns ...string) *Statement { if len(columns) > 0 { @@ -719,7 +726,8 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) statement.RefTable = table colNames, args := buildConditions(statement.Engine, table, bean, true, true, - false, true, statement.allUseBool, statement.boolColumnMap) + false, true, statement.allUseBool, statement.useAllCols, + statement.boolColumnMap) statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args @@ -758,7 +766,7 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} statement.RefTable = table colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, - true, statement.allUseBool, statement.boolColumnMap) + true, statement.allUseBool, statement.useAllCols,statement.boolColumnMap) statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args