diff --git a/integrations/session_upsert_test.go b/integrations/session_upsert_test.go index 20176b59..6ae46cc0 100644 --- a/integrations/session_upsert_test.go +++ b/integrations/session_upsert_test.go @@ -436,4 +436,48 @@ func TestUpsert(t *testing.T) { assert.NoError(t, err) assert.Equal(t, int64(1), n) }) + + t.Run("NoAutoIncrementPK", func(t *testing.T) { + type NoAutoIncrementPrimaryKey struct { + Name string `xorm:"pk"` + Number int `xorm:"pk"` + NotUnique string + } + + assert.NoError(t, testEngine.Sync2(&NoAutoIncrementPrimaryKey{})) + _, _ = testEngine.Exec("DELETE FROM no_primary_unique") + + empty := &NoAutoIncrementPrimaryKey{} + + // Insert default + n, err := testEngine.Upsert(empty) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + + // Insert with 1 + one := &NoAutoIncrementPrimaryKey{Name: "one", Number: 1} + n, err = testEngine.Upsert(one) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + + // Update default + n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{NotUnique: "notunique"}) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + + // Update again + n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{NotUnique: "again"}) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + + // Insert with 2 + n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{Name: "two", Number: 2}) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + + // Fail reinsert with 2 + n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{Name: "one", Number: 1, NotUnique: "updated"}) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + }) } diff --git a/internal/statements/upsert.go b/internal/statements/upsert.go index 70eb95c9..0c75eb1b 100644 --- a/internal/statements/upsert.go +++ b/internal/statements/upsert.go @@ -70,6 +70,23 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [ } case schemas.POSTGRES: if doUpdate { + primaryColumnIncluded := false + for _, primaryKeyColumn := range table.PrimaryKeys { + if _, has := uniqueColValMap[primaryKeyColumn]; !has { + continue + } + primaryColumnIncluded = true + } + if primaryColumnIncluded { + write(" ON CONFLICT (", quote(table.PrimaryKeys[0])) + for _, col := range table.PrimaryKeys[1:] { + write(", ", quote(col)) + } + write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0]) + for _, column := range updateColumns[1:] { + write(", ", column, " = excluded.", column) + } + } for _, index := range table.Indexes { if index.Type != schemas.UniqueType { continue @@ -166,6 +183,23 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args [] write(") AS src ON (") countUniques := 0 + primaryColumnIncluded := false + for _, primaryKeyColumn := range table.PrimaryKeys { + if _, has := uniqueColValMap[primaryKeyColumn]; !has { + continue + } + if !primaryColumnIncluded { + write("(") + } else { + write(" AND ") + } + write("src.", quote(primaryKeyColumn), " = target.", quote(primaryKeyColumn)) + primaryColumnIncluded = true + } + if primaryColumnIncluded { + write(")") + countUniques++ + } for _, index := range table.Indexes { if index.Type != schemas.UniqueType { continue diff --git a/session_upsert.go b/session_upsert.go index 6a70a0f2..c6ccba36 100644 --- a/session_upsert.go +++ b/session_upsert.go @@ -247,13 +247,32 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er return n, err } -func (session *Session) getUniqueColumns(colNames []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) { +func (session *Session) getUniqueColumns(argColumns []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) { uniqueColValMap = make(map[string]interface{}) table := session.statement.RefTable - if len(table.Indexes) == 0 { + if len(table.Indexes) == 0 && (len(table.PrimaryKeys) == 0 || (len(table.PrimaryKeys) == 1 && table.AutoIncrement == table.PrimaryKeys[0])) { return nil, 0, fmt.Errorf("provided bean has no unique constraints") } + // Check the primary key + primaryColumnIncluded := false +primaryCol: + for _, primaryKeyColumn := range table.PrimaryKeys { + for i, column := range argColumns { + if column == primaryKeyColumn { + uniqueColValMap[column] = args[i] + primaryColumnIncluded = true + continue primaryCol + } + } + if primaryKeyColumn != table.AutoIncrement { + primaryColumnIncluded = true + } + } + if primaryColumnIncluded { + numberOfUniqueConstraints++ + } + // Iterate across the indexes in the provided table for _, index := range table.Indexes { if index.Type != schemas.UniqueType { @@ -269,7 +288,7 @@ func (session *Session) getUniqueColumns(colNames []string, args []interface{}) } // Now iterate across colNames and add to the uniqueCols - for i, col := range colNames { + for i, col := range argColumns { if col == indexColumnName { uniqueColValMap[col] = args[i] continue indexCol