diff --git a/helpers.go b/helpers.go index 88ba60a0..a0e894e0 100644 --- a/helpers.go +++ b/helpers.go @@ -281,6 +281,20 @@ func query2(db *core.DB, sqlStr string, params ...interface{}) (resultsSlice []m return rows2Strings(rows) } +func setColumnTime(bean interface{}, col *core.Column, t time.Time) { + v, _ := col.ValueOf(bean) + if v.CanSet() { + switch v.Type().Kind() { + case reflect.Struct: + v.Set(reflect.ValueOf(t).Convert(v.Type())) + case reflect.Int, reflect.Int64, reflect.Int32: + v.SetInt(t.Unix()) + case reflect.Uint, reflect.Uint64, reflect.Uint32: + v.SetUint(uint64(t.Unix())) + } + } +} + func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) { colNames := make([]string, 0) args := make([]interface{}, 0) @@ -338,16 +352,11 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { val, t := session.Engine.NowTime2(col.SQLType.Name) args = append(args, val) + + var colName = col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { - v, _ := col.ValueOf(bean) - switch v.Type().Kind() { - case reflect.Struct: - v.Set(reflect.ValueOf(t)) - case reflect.Int, reflect.Int64, reflect.Int32: - v.SetInt(t.Unix()) - case reflect.Uint, reflect.Uint64, reflect.Uint32: - v.SetUint(uint64(t.Unix())) - } + col := table.GetColumn(colName) + setColumnTime(bean, col, t) }) } else if col.IsVersion && session.Statement.checkVersion { args = append(args, 1) diff --git a/session.go b/session.go index 4dde9bb3..cf566902 100644 --- a/session.go +++ b/session.go @@ -1671,11 +1671,11 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount fieldValue.SetUint(uint64(vv.Int())) } case reflect.Struct: - if fieldType == core.TimeType { + if fieldType.ConvertibleTo(core.TimeType) { if rawValueType == core.TimeType { hasAssigned = true - t := vv.Interface().(time.Time) + t := vv.Convert(core.TimeType).Interface().(time.Time) z, _ := t.Zone() if len(z) == 0 || t.Year() == 0 { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location session.Engine.LogDebug("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) @@ -1684,7 +1684,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount vv = reflect.ValueOf(tt) } // !nashtsai! convert to engine location - t = vv.Interface().(time.Time).In(session.Engine.TZLocation) + t = vv.Convert(core.TimeType).Interface().(time.Time).In(session.Engine.TZLocation) vv = reflect.ValueOf(t) fieldValue.Set(vv) @@ -2059,7 +2059,14 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } } if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { - args = append(args, session.Engine.NowTime(col.SQLType.Name)) + val, t := session.Engine.NowTime2(col.SQLType.Name) + args = append(args, val) + + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) } else { arg, err := session.value2Interface(col, fieldValue) if err != nil { @@ -2095,7 +2102,14 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } } if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { - args = append(args, session.Engine.NowTime(col.SQLType.Name)) + val, t := session.Engine.NowTime2(col.SQLType.Name) + args = append(args, val) + + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) } else { arg, err := session.value2Interface(col, fieldValue) if err != nil { @@ -2346,13 +2360,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, fieldValue.SetUint(x) //Currently only support Time type case reflect.Struct: - if fieldType == core.TimeType { + if fieldType.ConvertibleTo(core.TimeType) { x, err := session.byte2Time(col, data) if err != nil { return err } v = x - fieldValue.Set(reflect.ValueOf(v)) + fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) } else if session.Statement.UseCascade { table := session.Engine.autoMapType(*fieldValue) if table != nil { @@ -3379,7 +3393,15 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.Statement.UseAutoTime && table.Updated != "" { colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?") - args = append(args, session.Engine.NowTime(table.UpdatedColumn().SQLType.Name)) + col := table.UpdatedColumn() + val, t := session.Engine.NowTime2(col.SQLType.Name) + args = append(args, val) + + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) } //for update action to like "column = column + ?" @@ -3659,7 +3681,15 @@ func (session *Session) Delete(bean interface{}) (int64, error) { session.Statement.Params = append(session.Statement.Params, "") paramsLen := len(session.Statement.Params) copy(session.Statement.Params[1:paramsLen], session.Statement.Params[0:paramsLen-1]) - session.Statement.Params[0] = session.Engine.NowTime(deletedColumn.SQLType.Name) + + val, t := session.Engine.NowTime2(deletedColumn.SQLType.Name) + session.Statement.Params[0] = val + + var colName = deletedColumn.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) } args = append(session.Statement.Params, args...) diff --git a/statement.go b/statement.go index 75501f40..f645a1d3 100644 --- a/statement.go +++ b/statement.go @@ -583,8 +583,8 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, t := int64(fieldValue.Uint()) val = reflect.ValueOf(&t).Interface() case reflect.Struct: - if fieldType == reflect.TypeOf(time.Now()) { - t := fieldValue.Interface().(time.Time) + if fieldType.ConvertibleTo(core.TimeType) { + t := fieldValue.Convert(core.TimeType).Interface().(time.Time) if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { continue }