allow upsert on primary key

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
Andrew Thornton 2023-03-14 22:52:53 +00:00
parent 4bf7c4d738
commit 4bf706dd0c
No known key found for this signature in database
GPG Key ID: 3CDE74631F13A748
3 changed files with 100 additions and 3 deletions

View File

@ -436,4 +436,48 @@ func TestUpsert(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, int64(1), n) assert.Equal(t, int64(1), n)
}) })
t.Run("NoAutoIncrementPK", func(t *testing.T) {
type NoAutoIncrementPrimaryKey struct {
Name string `xorm:"pk"`
Number int `xorm:"pk"`
NotUnique string
}
assert.NoError(t, testEngine.Sync2(&NoAutoIncrementPrimaryKey{}))
_, _ = testEngine.Exec("DELETE FROM no_primary_unique")
empty := &NoAutoIncrementPrimaryKey{}
// Insert default
n, err := testEngine.Upsert(empty)
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Insert with 1
one := &NoAutoIncrementPrimaryKey{Name: "one", Number: 1}
n, err = testEngine.Upsert(one)
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Update default
n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{NotUnique: "notunique"})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Update again
n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{NotUnique: "again"})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Insert with 2
n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{Name: "two", Number: 2})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Fail reinsert with 2
n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{Name: "one", Number: 1, NotUnique: "updated"})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
})
} }

View File

@ -70,6 +70,23 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [
} }
case schemas.POSTGRES: case schemas.POSTGRES:
if doUpdate { if doUpdate {
primaryColumnIncluded := false
for _, primaryKeyColumn := range table.PrimaryKeys {
if _, has := uniqueColValMap[primaryKeyColumn]; !has {
continue
}
primaryColumnIncluded = true
}
if primaryColumnIncluded {
write(" ON CONFLICT (", quote(table.PrimaryKeys[0]))
for _, col := range table.PrimaryKeys[1:] {
write(", ", quote(col))
}
write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
for _, column := range updateColumns[1:] {
write(", ", column, " = excluded.", column)
}
}
for _, index := range table.Indexes { for _, index := range table.Indexes {
if index.Type != schemas.UniqueType { if index.Type != schemas.UniqueType {
continue continue
@ -166,6 +183,23 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
write(") AS src ON (") write(") AS src ON (")
countUniques := 0 countUniques := 0
primaryColumnIncluded := false
for _, primaryKeyColumn := range table.PrimaryKeys {
if _, has := uniqueColValMap[primaryKeyColumn]; !has {
continue
}
if !primaryColumnIncluded {
write("(")
} else {
write(" AND ")
}
write("src.", quote(primaryKeyColumn), " = target.", quote(primaryKeyColumn))
primaryColumnIncluded = true
}
if primaryColumnIncluded {
write(")")
countUniques++
}
for _, index := range table.Indexes { for _, index := range table.Indexes {
if index.Type != schemas.UniqueType { if index.Type != schemas.UniqueType {
continue continue

View File

@ -247,13 +247,32 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er
return n, err return n, err
} }
func (session *Session) getUniqueColumns(colNames []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) { func (session *Session) getUniqueColumns(argColumns []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) {
uniqueColValMap = make(map[string]interface{}) uniqueColValMap = make(map[string]interface{})
table := session.statement.RefTable table := session.statement.RefTable
if len(table.Indexes) == 0 { if len(table.Indexes) == 0 && (len(table.PrimaryKeys) == 0 || (len(table.PrimaryKeys) == 1 && table.AutoIncrement == table.PrimaryKeys[0])) {
return nil, 0, fmt.Errorf("provided bean has no unique constraints") return nil, 0, fmt.Errorf("provided bean has no unique constraints")
} }
// Check the primary key
primaryColumnIncluded := false
primaryCol:
for _, primaryKeyColumn := range table.PrimaryKeys {
for i, column := range argColumns {
if column == primaryKeyColumn {
uniqueColValMap[column] = args[i]
primaryColumnIncluded = true
continue primaryCol
}
}
if primaryKeyColumn != table.AutoIncrement {
primaryColumnIncluded = true
}
}
if primaryColumnIncluded {
numberOfUniqueConstraints++
}
// Iterate across the indexes in the provided table // Iterate across the indexes in the provided table
for _, index := range table.Indexes { for _, index := range table.Indexes {
if index.Type != schemas.UniqueType { if index.Type != schemas.UniqueType {
@ -269,7 +288,7 @@ func (session *Session) getUniqueColumns(colNames []string, args []interface{})
} }
// Now iterate across colNames and add to the uniqueCols // Now iterate across colNames and add to the uniqueCols
for i, col := range colNames { for i, col := range argColumns {
if col == indexColumnName { if col == indexColumnName {
uniqueColValMap[col] = args[i] uniqueColValMap[col] = args[i]
continue indexCol continue indexCol