diff --git a/integrations/session_insert_test.go b/integrations/session_insert_test.go index 47789b8a..a2e2da7f 100644 --- a/integrations/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -7,8 +7,10 @@ package integrations import ( "fmt" "reflect" + "strings" "testing" "time" + "xorm.io/xorm/schemas" "xorm.io/xorm" @@ -45,6 +47,26 @@ func TestInsertMulti(t *testing.T) { append([]TestMulti{}, TestMulti{1, "test1"}, TestMulti{2, "test2"}, TestMulti{3, "test3"})) assert.NoError(t, err) assert.EqualValues(t, 3, num) + + if schemas.DBType(strings.ToLower(dbType)) == schemas.POSTGRES { + type TestMultiPG struct { + Id int64 `xorm:"int(11) pk"` + Name string `xorm:"varchar(255)"` + } + assert.NoError(t, testEngine.Sync2(new(TestMultiPG))) + + var data []TestMultiPG + for i := 1; i < 655360; i++ { + data = append(data, TestMultiPG{ + Id: int64(i), + Name: fmt.Sprintf("test %d", i), + }) + } + + num, err := insertMultiDatas(655359, data) + assert.NoError(t, err) + assert.EqualValues(t, 655359, num) + } } func insertMultiDatas(step int, datas interface{}) (num int64, err error) { diff --git a/session_insert.go b/session_insert.go index 5f968151..6373f54a 100644 --- a/session_insert.go +++ b/session_insert.go @@ -19,6 +19,9 @@ import ( // ErrNoElementsOnSlice represents an error there is no element when insert var ErrNoElementsOnSlice = errors.New("No element on slice when insert") +// maxPgParams pg only support max 65535 placeholder params +const maxPgParams = 65535 + // Insert insert one or more beans func (session *Session) Insert(beans ...interface{}) (int64, error) { var affected int64 @@ -112,113 +115,135 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } var ( - table = session.statement.RefTable - size = sliceValue.Len() - colNames []string - colMultiPlaces []string - args []interface{} - cols []*schemas.Column + table = session.statement.RefTable + size = sliceValue.Len() + cols []*schemas.Column + colNames []string + step = size + affectRows int64 ) - for i := 0; i < size; i++ { - v := sliceValue.Index(i) - var vv reflect.Value - switch v.Kind() { - case reflect.Interface: - vv = reflect.Indirect(v.Elem()) - default: - vv = reflect.Indirect(v) - } - elemValue := v.Interface() - var colPlaces []string + if session.engine.dialect.URI().DBType == schemas.POSTGRES { + step = maxPgParams / len(session.statement.RefTable.Columns()) + } - // handle BeforeInsertProcessor - // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi?? - for _, closure := range session.beforeClosures { - closure(elemValue) + for i := 0; i < size; i += step { + var colMultiPlaces []string + var args []interface{} + + stepSize := i + step + if stepSize > size { + stepSize = size } - if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok { - processor.BeforeInsert() - } - // -- + for j := i; j < stepSize; j++ { + v := sliceValue.Index(j) + var vv reflect.Value + switch v.Kind() { + case reflect.Interface: + vv = reflect.Indirect(v.Elem()) + default: + vv = reflect.Indirect(v) + } + elemValue := v.Interface() + var colPlaces []string - for _, col := range table.Columns() { - ptrFieldValue, err := col.ValueOfV(&vv) - if err != nil { - return 0, err + // handle BeforeInsertProcessor + // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi?? + for _, closure := range session.beforeClosures { + closure(elemValue) } - fieldValue := *ptrFieldValue - if col.IsAutoIncrement && utils.IsZero(fieldValue.Interface()) { - continue - } - if col.MapType == schemas.ONLYFROMDB { - continue - } - if col.IsDeleted { - continue - } - if session.statement.OmitColumnMap.Contain(col.Name) { - continue - } - if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { - continue - } - if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { - val, t := session.engine.nowTime(col) - args = append(args, val) - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) - } else if col.IsVersion && session.statement.CheckVersion { - args = append(args, 1) - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnInt(bean, col, 1) - }) - } else { - arg, err := session.statement.Value2Interface(col, fieldValue) + if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok { + processor.BeforeInsert() + } + // -- + + for _, col := range table.Columns() { + ptrFieldValue, err := col.ValueOfV(&vv) if err != nil { - return 0, err + return affectRows, err } - args = append(args, arg) + fieldValue := *ptrFieldValue + if col.IsAutoIncrement && utils.IsZero(fieldValue.Interface()) { + continue + } + if col.MapType == schemas.ONLYFROMDB { + continue + } + if col.IsDeleted { + continue + } + if session.statement.OmitColumnMap.Contain(col.Name) { + continue + } + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { + continue + } + if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { + val, t := session.engine.nowTime(col) + args = append(args, val) + + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) + } else if col.IsVersion && session.statement.CheckVersion { + args = append(args, 1) + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnInt(bean, col, 1) + }) + } else { + arg, err := session.statement.Value2Interface(col, fieldValue) + if err != nil { + return affectRows, err + } + args = append(args, arg) + } + + if j == 0 { + colNames = append(colNames, col.Name) + cols = append(cols, col) + } + colPlaces = append(colPlaces, "?") } - if i == 0 { - colNames = append(colNames, col.Name) - cols = append(cols, col) - } - colPlaces = append(colPlaces, "?") + colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", ")) } + cleanupProcessorsClosures(&session.beforeClosures) - colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", ")) - } - cleanupProcessorsClosures(&session.beforeClosures) - - quoter := session.engine.dialect.Quoter() - var sql string - colStr := quoter.Join(colNames, ",") - if session.engine.dialect.URI().DBType == schemas.ORACLE { - temp := fmt.Sprintf(") INTO %s (%v) VALUES (", - quoter.Quote(tableName), - colStr) - sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL", - quoter.Quote(tableName), - colStr, - strings.Join(colMultiPlaces, temp)) - } else { - sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)", - quoter.Quote(tableName), - colStr, - strings.Join(colMultiPlaces, "),(")) - } - res, err := session.exec(sql, args...) - if err != nil { - return 0, err + quoter := session.engine.dialect.Quoter() + var sql string + colStr := quoter.Join(colNames, ",") + if session.engine.dialect.URI().DBType == schemas.ORACLE { + temp := fmt.Sprintf(") INTO %s (%v) VALUES (", + quoter.Quote(tableName), + colStr) + sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL", + quoter.Quote(tableName), + colStr, + strings.Join(colMultiPlaces, temp)) + } else { + sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)", + quoter.Quote(tableName), + colStr, + strings.Join(colMultiPlaces, "),(")) + } + res, err := session.exec(sql, args...) + if err != nil { + return affectRows, err + } + af, err := res.RowsAffected() + if err != nil { + return affectRows, err + } + affectRows += af + if stepSize == size { + break + } } session.cacheInsert(tableName) @@ -254,7 +279,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } cleanupProcessorsClosures(&session.afterClosures) - return res.RowsAffected() + return affectRows, nil } // InsertMulti insert multiple records