diff --git a/engine.go b/engine.go index 94ac29e3..c33e762b 100644 --- a/engine.go +++ b/engine.go @@ -560,6 +560,13 @@ func (engine *Engine) Omit(columns ...string) *Session { return session.Omit(columns...) } +// Set null when column is zero-value and nullable for update +func (engine *Engine) Nullable(columns ...string) *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.Nullable(columns...) +} + // This method will generate "column IN (?, ?)" func (engine *Engine) In(column string, args ...interface{}) *Session { session := engine.NewSession() diff --git a/helpers.go b/helpers.go index 979a67a1..7e8978f0 100644 --- a/helpers.go +++ b/helpers.go @@ -37,8 +37,16 @@ func isZero(k interface{}) bool { return k.(uint32) == 0 case uint64: return k.(uint64) == 0 + case float32: + return k.(float32) == 0 + case float64: + return k.(float64) == 0 + case bool: + return k.(bool) == false case string: return k.(string) == "" + case time.Time: + return k.(time.Time).IsZero() } return false } @@ -356,6 +364,14 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, } } + // !evalphobia! set fieldValue as nil when column is nullable and zero-value + if _, ok := session.Statement.nullableMap[lColName]; ok { + if col.Nullable && isZero(fieldValue.Interface()) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + } + } + if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { val, t := session.Engine.NowTime2(col.SQLType.Name) args = append(args, val) diff --git a/session.go b/session.go index adf0aa8b..d7975ecf 100644 --- a/session.go +++ b/session.go @@ -229,6 +229,12 @@ func (session *Session) Omit(columns ...string) *Session { return session } +// Set null when column is zero-value and nullable for update +func (session *Session) Nullable(columns ...string) *Session { + session.Statement.Nullable(columns...) + return session +} + // Method NoAutoTime means do not automatically give created field and updated field // the current time on the current session temporarily func (session *Session) NoAutoTime() *Session { @@ -3414,7 +3420,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.Statement.ColumnStr == "" { colNames, args = buildUpdates(session.Engine, table, bean, false, false, false, false, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.mustColumnMap, session.Statement.columnMap, true) + session.Statement.mustColumnMap, session.Statement.nullableMap, + session.Statement.columnMap, true) } else { colNames, args, err = genCols(table, session, bean, true, true) if err != nil { @@ -3888,6 +3895,7 @@ func (s *Session) Sync2(beans ...interface{}) error { } var foundIndexNames = make(map[string]bool) + var addedNames = make(map[string]*core.Index) for name, index := range table.Indexes { var oriIndex *core.Index @@ -3911,20 +3919,7 @@ func (s *Session) Sync2(beans ...interface{}) error { } if oriIndex == nil { - if index.Type == core.UniqueType { - session := engine.NewSession() - session.Statement.RefTable = table - defer session.Close() - err = session.addUnique(table.Name, name) - } else if index.Type == core.IndexType { - session := engine.NewSession() - session.Statement.RefTable = table - defer session.Close() - err = session.addIndex(table.Name, name) - } - if err != nil { - return err - } + addedNames[name] = index } } @@ -3937,6 +3932,23 @@ func (s *Session) Sync2(beans ...interface{}) error { } } } + + for name, index := range addedNames { + if index.Type == core.UniqueType { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addUnique(table.Name, name) + } else if index.Type == core.IndexType { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addIndex(table.Name, name) + } + if err != nil { + return err + } + } } } diff --git a/statement.go b/statement.go index 3278662d..956671e0 100644 --- a/statement.go +++ b/statement.go @@ -70,6 +70,7 @@ type Statement struct { checkVersion bool unscoped bool mustColumnMap map[string]bool + nullableMap map[string]bool inColumns map[string]*inParam incrColumns map[string]incrParam decrColumns map[string]decrParam @@ -107,6 +108,7 @@ func (statement *Statement) Init() { statement.allUseBool = false statement.useAllCols = false statement.mustColumnMap = make(map[string]bool) + statement.nullableMap = make(map[string]bool) statement.checkVersion = true statement.unscoped = false statement.inColumns = make(map[string]*inParam) @@ -191,7 +193,8 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { func buildUpdates(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, - mustColumnMap map[string]bool, columnMap map[string]bool, update bool) ([]string, []interface{}) { + mustColumnMap map[string]bool, nullableMap map[string]bool, + columnMap map[string]bool, update bool) ([]string, []interface{}) { colNames := make([]string, 0) var args = make([]interface{}, 0) @@ -230,7 +233,8 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, requiredField := useAllCols includeNil := useAllCols - if b, ok := mustColumnMap[strings.ToLower(col.Name)]; ok { + lColName := strings.ToLower(col.Name) + if b, ok := mustColumnMap[lColName]; ok { if b { requiredField = true } else { @@ -238,6 +242,16 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, } } + // !evalphobia! set fieldValue as nil when column is nullable and zero-value + if b, ok := nullableMap[lColName]; ok { + if b && col.Nullable && isZero(fieldValue.Interface()) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + fieldType = reflect.TypeOf(fieldValue.Interface()) + includeNil = true + } + } + var val interface{} if fieldValue.CanAddr() { @@ -833,6 +847,14 @@ func (statement *Statement) Omit(columns ...string) { statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) } +// Update use only: update columns to null when value is nullable and zero-value +func (statement *Statement) Nullable(columns ...string) { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.nullableMap[strings.ToLower(nc)] = true + } +} + // Generate LIMIT limit statement func (statement *Statement) Top(limit int) *Statement { statement.Limit(limit) diff --git a/xorm.go b/xorm.go index 036d420f..8e630b9a 100644 --- a/xorm.go +++ b/xorm.go @@ -17,7 +17,7 @@ import ( ) const ( - Version string = "0.4.3.0526" + Version string = "0.4.3.0627" ) func regDrvsNDialects() bool { @@ -39,7 +39,7 @@ func regDrvsNDialects() bool { for driverName, v := range providedDrvsNDialects { if driver := core.QueryDriver(driverName); driver == nil { core.RegisterDriver(driverName, v.getDriver()) - core.RegisterDialect(v.dbType, v.getDialect()) + core.RegisterDialect(v.dbType, v.getDialect) } } return true