diff --git a/session_insert.go b/session_insert.go index 2c8ad782..1b3f7fb9 100644 --- a/session_insert.go +++ b/session_insert.go @@ -67,7 +67,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, errors.New("could not insert a empty slice") } - if err := session.Statement.setRefValue(sliceValue.Index(0)); err != nil { + if err := session.Statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil { return 0, err } diff --git a/session_insert_test.go b/session_insert_test.go index b232d3f7..b58d9bdb 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -5,6 +5,8 @@ package xorm import ( + "fmt" + "reflect" "testing" "time" @@ -26,3 +28,79 @@ func TestInsertOne(t *testing.T) { _, err := testEngine.InsertOne(data) assert.NoError(t, err) } + +func TestInsertMulti(t *testing.T) { + + assert.NoError(t, prepareEngine()) + type TestMulti struct { + Id int64 `xorm:"int(11) pk"` + Name string `xorm:"varchar(255)"` + } + + assert.NoError(t, testEngine.Sync2(new(TestMulti))) + + num, err := insertMultiDatas(1, + append([]TestMulti{}, TestMulti{1, "test1"}, TestMulti{2, "test2"}, TestMulti{3, "test3"})) + assert.NoError(t, err) + assert.EqualValues(t, 3, num) +} + +func insertMultiDatas(step int, datas interface{}) (num int64, err error) { + sliceValue := reflect.Indirect(reflect.ValueOf(datas)) + var iLen int64 + if sliceValue.Kind() != reflect.Slice { + return 0, fmt.Errorf("not silce") + } + iLen = int64(sliceValue.Len()) + if iLen == 0 { + return + } + + session := testEngine.NewSession() + defer session.Close() + + if err = callbackLooper(datas, step, + func(innerDatas interface{}) error { + n, e := session.InsertMulti(innerDatas) + if e != nil { + return e + } + num += n + return nil + }); err != nil { + return 0, err + } else if num != iLen { + return 0, fmt.Errorf("num error: %d - %d", num, iLen) + } + return +} + +func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) error) (err error) { + + sliceValue := reflect.Indirect(reflect.ValueOf(datas)) + if sliceValue.Kind() != reflect.Slice { + return fmt.Errorf("not slice") + } + if sliceValue.Len() <= 0 { + return + } + + tempLen := 0 + processedLen := sliceValue.Len() + for i := 0; i < sliceValue.Len(); i += step { + if processedLen > step { + tempLen = i + step + } else { + tempLen = sliceValue.Len() + } + var tempInterface []interface{} + for j := i; j < tempLen; j++ { + tempInterface = append(tempInterface, sliceValue.Index(j).Interface()) + } + if err = actionFunc(tempInterface); err != nil { + return + } + processedLen = processedLen - step + } + return +}