allow upsert on primary key
Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
parent
4bf7c4d738
commit
4bf706dd0c
|
@ -436,4 +436,48 @@ func TestUpsert(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
})
|
||||
|
||||
t.Run("NoAutoIncrementPK", func(t *testing.T) {
|
||||
type NoAutoIncrementPrimaryKey struct {
|
||||
Name string `xorm:"pk"`
|
||||
Number int `xorm:"pk"`
|
||||
NotUnique string
|
||||
}
|
||||
|
||||
assert.NoError(t, testEngine.Sync2(&NoAutoIncrementPrimaryKey{}))
|
||||
_, _ = testEngine.Exec("DELETE FROM no_primary_unique")
|
||||
|
||||
empty := &NoAutoIncrementPrimaryKey{}
|
||||
|
||||
// Insert default
|
||||
n, err := testEngine.Upsert(empty)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
|
||||
// Insert with 1
|
||||
one := &NoAutoIncrementPrimaryKey{Name: "one", Number: 1}
|
||||
n, err = testEngine.Upsert(one)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
|
||||
// Update default
|
||||
n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{NotUnique: "notunique"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
|
||||
// Update again
|
||||
n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{NotUnique: "again"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
|
||||
// Insert with 2
|
||||
n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{Name: "two", Number: 2})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
|
||||
// Fail reinsert with 2
|
||||
n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{Name: "one", Number: 1, NotUnique: "updated"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -70,6 +70,23 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [
|
|||
}
|
||||
case schemas.POSTGRES:
|
||||
if doUpdate {
|
||||
primaryColumnIncluded := false
|
||||
for _, primaryKeyColumn := range table.PrimaryKeys {
|
||||
if _, has := uniqueColValMap[primaryKeyColumn]; !has {
|
||||
continue
|
||||
}
|
||||
primaryColumnIncluded = true
|
||||
}
|
||||
if primaryColumnIncluded {
|
||||
write(" ON CONFLICT (", quote(table.PrimaryKeys[0]))
|
||||
for _, col := range table.PrimaryKeys[1:] {
|
||||
write(", ", quote(col))
|
||||
}
|
||||
write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
|
||||
for _, column := range updateColumns[1:] {
|
||||
write(", ", column, " = excluded.", column)
|
||||
}
|
||||
}
|
||||
for _, index := range table.Indexes {
|
||||
if index.Type != schemas.UniqueType {
|
||||
continue
|
||||
|
@ -166,6 +183,23 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
|
|||
write(") AS src ON (")
|
||||
|
||||
countUniques := 0
|
||||
primaryColumnIncluded := false
|
||||
for _, primaryKeyColumn := range table.PrimaryKeys {
|
||||
if _, has := uniqueColValMap[primaryKeyColumn]; !has {
|
||||
continue
|
||||
}
|
||||
if !primaryColumnIncluded {
|
||||
write("(")
|
||||
} else {
|
||||
write(" AND ")
|
||||
}
|
||||
write("src.", quote(primaryKeyColumn), " = target.", quote(primaryKeyColumn))
|
||||
primaryColumnIncluded = true
|
||||
}
|
||||
if primaryColumnIncluded {
|
||||
write(")")
|
||||
countUniques++
|
||||
}
|
||||
for _, index := range table.Indexes {
|
||||
if index.Type != schemas.UniqueType {
|
||||
continue
|
||||
|
|
|
@ -247,13 +247,32 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er
|
|||
return n, err
|
||||
}
|
||||
|
||||
func (session *Session) getUniqueColumns(colNames []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) {
|
||||
func (session *Session) getUniqueColumns(argColumns []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) {
|
||||
uniqueColValMap = make(map[string]interface{})
|
||||
table := session.statement.RefTable
|
||||
if len(table.Indexes) == 0 {
|
||||
if len(table.Indexes) == 0 && (len(table.PrimaryKeys) == 0 || (len(table.PrimaryKeys) == 1 && table.AutoIncrement == table.PrimaryKeys[0])) {
|
||||
return nil, 0, fmt.Errorf("provided bean has no unique constraints")
|
||||
}
|
||||
|
||||
// Check the primary key
|
||||
primaryColumnIncluded := false
|
||||
primaryCol:
|
||||
for _, primaryKeyColumn := range table.PrimaryKeys {
|
||||
for i, column := range argColumns {
|
||||
if column == primaryKeyColumn {
|
||||
uniqueColValMap[column] = args[i]
|
||||
primaryColumnIncluded = true
|
||||
continue primaryCol
|
||||
}
|
||||
}
|
||||
if primaryKeyColumn != table.AutoIncrement {
|
||||
primaryColumnIncluded = true
|
||||
}
|
||||
}
|
||||
if primaryColumnIncluded {
|
||||
numberOfUniqueConstraints++
|
||||
}
|
||||
|
||||
// Iterate across the indexes in the provided table
|
||||
for _, index := range table.Indexes {
|
||||
if index.Type != schemas.UniqueType {
|
||||
|
@ -269,7 +288,7 @@ func (session *Session) getUniqueColumns(colNames []string, args []interface{})
|
|||
}
|
||||
|
||||
// Now iterate across colNames and add to the uniqueCols
|
||||
for i, col := range colNames {
|
||||
for i, col := range argColumns {
|
||||
if col == indexColumnName {
|
||||
uniqueColValMap[col] = args[i]
|
||||
continue indexCol
|
||||
|
|
Loading…
Reference in New Issue