diff --git a/engine.go b/engine.go index bbe8eb47..4081ec14 100644 --- a/engine.go +++ b/engine.go @@ -553,6 +553,13 @@ func (engine *Engine) Omit(columns ...string) *Session { return session.Omit(columns...) } +// Set null when column is zero-value and nullable for update +func (engine *Engine) Nullable(columns ...string) *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.Nullable(columns...) +} + // This method will generate "column IN (?, ?)" func (engine *Engine) In(column string, args ...interface{}) *Session { session := engine.NewSession() diff --git a/helpers.go b/helpers.go index 979a67a1..4f3be860 100644 --- a/helpers.go +++ b/helpers.go @@ -39,6 +39,8 @@ func isZero(k interface{}) bool { return k.(uint64) == 0 case string: return k.(string) == "" + case time.Time: + return k.(time.Time).IsZero() } return false } @@ -356,6 +358,14 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, } } + // !evalphobia! set fieldValue as nil when column is nullable and zero-value + if _, ok := session.Statement.nullableMap[lColName]; ok { + if col.Nullable && isZero(fieldValue.Interface()) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + } + } + if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { val, t := session.Engine.NowTime2(col.SQLType.Name) args = append(args, val) diff --git a/session.go b/session.go index b47d2780..a1984805 100644 --- a/session.go +++ b/session.go @@ -223,6 +223,12 @@ func (session *Session) Omit(columns ...string) *Session { return session } +// Set null when column is zero-value and nullable for update +func (session *Session) Nullable(columns ...string) *Session { + session.Statement.Nullable(columns...) + return session +} + // Method NoAutoTime means do not automatically give created field and updated field // the current time on the current session temporarily func (session *Session) NoAutoTime() *Session { @@ -3408,7 +3414,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.Statement.ColumnStr == "" { colNames, args = buildUpdates(session.Engine, table, bean, false, false, false, false, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.mustColumnMap, session.Statement.columnMap, true) + session.Statement.mustColumnMap, session.Statement.nullableMap, + session.Statement.columnMap, true) } else { colNames, args, err = genCols(table, session, bean, true, true) if err != nil { diff --git a/statement.go b/statement.go index dcd7fefa..0a8de1ee 100644 --- a/statement.go +++ b/statement.go @@ -70,6 +70,7 @@ type Statement struct { checkVersion bool unscoped bool mustColumnMap map[string]bool + nullableMap map[string]bool inColumns map[string]*inParam incrColumns map[string]incrParam decrColumns map[string]decrParam @@ -105,6 +106,7 @@ func (statement *Statement) Init() { statement.allUseBool = false statement.useAllCols = false statement.mustColumnMap = make(map[string]bool) + statement.nullableMap = make(map[string]bool) statement.checkVersion = true statement.unscoped = false statement.inColumns = make(map[string]*inParam) @@ -176,7 +178,8 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { func buildUpdates(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, - mustColumnMap map[string]bool, columnMap map[string]bool, update bool) ([]string, []interface{}) { + mustColumnMap map[string]bool, nullableMap map[string]bool, + columnMap map[string]bool, update bool) ([]string, []interface{}) { colNames := make([]string, 0) var args = make([]interface{}, 0) @@ -215,7 +218,8 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, requiredField := useAllCols includeNil := useAllCols - if b, ok := mustColumnMap[strings.ToLower(col.Name)]; ok { + lColName := strings.ToLower(col.Name) + if b, ok := mustColumnMap[lColName]; ok { if b { requiredField = true } else { @@ -223,6 +227,16 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, } } + // !evalphobia! set fieldValue as nil when column is nullable and zero-value + if b, ok := nullableMap[lColName]; ok { + if b && col.Nullable && isZero(fieldValue.Interface()) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + fieldType = reflect.TypeOf(fieldValue.Interface()) + includeNil = true + } + } + var val interface{} if fieldValue.CanAddr() { @@ -818,6 +832,14 @@ func (statement *Statement) Omit(columns ...string) { statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) } +// Update use only: update columns to null when value is nullable and zero-value +func (statement *Statement) Nullable(columns ...string) { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.nullableMap[strings.ToLower(nc)] = true + } +} + // Generate LIMIT limit statement func (statement *Statement) Top(limit int) *Statement { statement.Limit(limit)