From bd994cb726ac48eb8b9dec3aa5fe44d4d12626f1 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 9 Jul 2016 13:34:34 +0800 Subject: [PATCH] resolved #209 --- helpers.go | 15 +++++++++++++++ session.go | 37 ++++++++++++++++++++++++++++--------- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/helpers.go b/helpers.go index 3b4d01f1..9a461c0e 100644 --- a/helpers.go +++ b/helpers.go @@ -457,6 +457,21 @@ func query2(db *core.DB, sqlStr string, params ...interface{}) (resultsSlice []m return rows2Strings(rows) } +func setColumnInt(bean interface{}, col *core.Column, t int64) { + v, err := col.ValueOf(bean) + if err != nil { + return + } + if v.CanSet() { + switch v.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32: + v.SetInt(t) + case reflect.Uint, reflect.Uint64, reflect.Uint32: + v.SetUint(uint64(t)) + } + } +} + func setColumnTime(bean interface{}, col *core.Column, t time.Time) { v, err := col.ValueOf(bean) if err != nil { diff --git a/session.go b/session.go index a354b6d1..6997dcc5 100644 --- a/session.go +++ b/session.go @@ -2292,6 +2292,13 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error col := table.GetColumn(colName) setColumnTime(bean, col, t) }) + } else if col.IsVersion && session.Statement.checkVersion { + args = append(args, 1) + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnInt(bean, col, 1) + }) } else { arg, err := session.value2Interface(col, fieldValue) if err != nil { @@ -2340,6 +2347,13 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error col := table.GetColumn(colName) setColumnTime(bean, col, t) }) + } else if col.IsVersion && session.Statement.checkVersion { + args = append(args, 1) + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnInt(bean, col, 1) + }) } else { arg, err := session.value2Interface(col, fieldValue) if err != nil { @@ -2400,24 +2414,29 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } } } + cleanupProcessorsClosures(&session.afterClosures) return res.RowsAffected() } // InsertMulti insert multiple records func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { + defer session.resetStatement() + if session.IsAutoClose { + defer session.Close() + } + sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - if sliceValue.Kind() == reflect.Slice { - if sliceValue.Len() > 0 { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - return session.innerInsertMulti(rowsSlicePtr) - } + if sliceValue.Kind() != reflect.Slice { + return 0, ErrParamsType + + } + + if sliceValue.Len() <= 0 { return 0, nil } - return 0, ErrParamsType + + return session.innerInsertMulti(rowsSlicePtr) } func (session *Session) str2Time(col *core.Column, data string) (outTime time.Time, outErr error) {