allow upsert on primary key
Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
parent
4bf7c4d738
commit
4bf706dd0c
|
@ -436,4 +436,48 @@ func TestUpsert(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, int64(1), n)
|
assert.Equal(t, int64(1), n)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("NoAutoIncrementPK", func(t *testing.T) {
|
||||||
|
type NoAutoIncrementPrimaryKey struct {
|
||||||
|
Name string `xorm:"pk"`
|
||||||
|
Number int `xorm:"pk"`
|
||||||
|
NotUnique string
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, testEngine.Sync2(&NoAutoIncrementPrimaryKey{}))
|
||||||
|
_, _ = testEngine.Exec("DELETE FROM no_primary_unique")
|
||||||
|
|
||||||
|
empty := &NoAutoIncrementPrimaryKey{}
|
||||||
|
|
||||||
|
// Insert default
|
||||||
|
n, err := testEngine.Upsert(empty)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), n)
|
||||||
|
|
||||||
|
// Insert with 1
|
||||||
|
one := &NoAutoIncrementPrimaryKey{Name: "one", Number: 1}
|
||||||
|
n, err = testEngine.Upsert(one)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), n)
|
||||||
|
|
||||||
|
// Update default
|
||||||
|
n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{NotUnique: "notunique"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), n)
|
||||||
|
|
||||||
|
// Update again
|
||||||
|
n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{NotUnique: "again"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), n)
|
||||||
|
|
||||||
|
// Insert with 2
|
||||||
|
n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{Name: "two", Number: 2})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), n)
|
||||||
|
|
||||||
|
// Fail reinsert with 2
|
||||||
|
n, err = testEngine.Upsert(&NoAutoIncrementPrimaryKey{Name: "one", Number: 1, NotUnique: "updated"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), n)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,6 +70,23 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [
|
||||||
}
|
}
|
||||||
case schemas.POSTGRES:
|
case schemas.POSTGRES:
|
||||||
if doUpdate {
|
if doUpdate {
|
||||||
|
primaryColumnIncluded := false
|
||||||
|
for _, primaryKeyColumn := range table.PrimaryKeys {
|
||||||
|
if _, has := uniqueColValMap[primaryKeyColumn]; !has {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
primaryColumnIncluded = true
|
||||||
|
}
|
||||||
|
if primaryColumnIncluded {
|
||||||
|
write(" ON CONFLICT (", quote(table.PrimaryKeys[0]))
|
||||||
|
for _, col := range table.PrimaryKeys[1:] {
|
||||||
|
write(", ", quote(col))
|
||||||
|
}
|
||||||
|
write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
|
||||||
|
for _, column := range updateColumns[1:] {
|
||||||
|
write(", ", column, " = excluded.", column)
|
||||||
|
}
|
||||||
|
}
|
||||||
for _, index := range table.Indexes {
|
for _, index := range table.Indexes {
|
||||||
if index.Type != schemas.UniqueType {
|
if index.Type != schemas.UniqueType {
|
||||||
continue
|
continue
|
||||||
|
@ -166,6 +183,23 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
|
||||||
write(") AS src ON (")
|
write(") AS src ON (")
|
||||||
|
|
||||||
countUniques := 0
|
countUniques := 0
|
||||||
|
primaryColumnIncluded := false
|
||||||
|
for _, primaryKeyColumn := range table.PrimaryKeys {
|
||||||
|
if _, has := uniqueColValMap[primaryKeyColumn]; !has {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !primaryColumnIncluded {
|
||||||
|
write("(")
|
||||||
|
} else {
|
||||||
|
write(" AND ")
|
||||||
|
}
|
||||||
|
write("src.", quote(primaryKeyColumn), " = target.", quote(primaryKeyColumn))
|
||||||
|
primaryColumnIncluded = true
|
||||||
|
}
|
||||||
|
if primaryColumnIncluded {
|
||||||
|
write(")")
|
||||||
|
countUniques++
|
||||||
|
}
|
||||||
for _, index := range table.Indexes {
|
for _, index := range table.Indexes {
|
||||||
if index.Type != schemas.UniqueType {
|
if index.Type != schemas.UniqueType {
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -247,13 +247,32 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *Session) getUniqueColumns(colNames []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) {
|
func (session *Session) getUniqueColumns(argColumns []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) {
|
||||||
uniqueColValMap = make(map[string]interface{})
|
uniqueColValMap = make(map[string]interface{})
|
||||||
table := session.statement.RefTable
|
table := session.statement.RefTable
|
||||||
if len(table.Indexes) == 0 {
|
if len(table.Indexes) == 0 && (len(table.PrimaryKeys) == 0 || (len(table.PrimaryKeys) == 1 && table.AutoIncrement == table.PrimaryKeys[0])) {
|
||||||
return nil, 0, fmt.Errorf("provided bean has no unique constraints")
|
return nil, 0, fmt.Errorf("provided bean has no unique constraints")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check the primary key
|
||||||
|
primaryColumnIncluded := false
|
||||||
|
primaryCol:
|
||||||
|
for _, primaryKeyColumn := range table.PrimaryKeys {
|
||||||
|
for i, column := range argColumns {
|
||||||
|
if column == primaryKeyColumn {
|
||||||
|
uniqueColValMap[column] = args[i]
|
||||||
|
primaryColumnIncluded = true
|
||||||
|
continue primaryCol
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if primaryKeyColumn != table.AutoIncrement {
|
||||||
|
primaryColumnIncluded = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if primaryColumnIncluded {
|
||||||
|
numberOfUniqueConstraints++
|
||||||
|
}
|
||||||
|
|
||||||
// Iterate across the indexes in the provided table
|
// Iterate across the indexes in the provided table
|
||||||
for _, index := range table.Indexes {
|
for _, index := range table.Indexes {
|
||||||
if index.Type != schemas.UniqueType {
|
if index.Type != schemas.UniqueType {
|
||||||
|
@ -269,7 +288,7 @@ func (session *Session) getUniqueColumns(colNames []string, args []interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now iterate across colNames and add to the uniqueCols
|
// Now iterate across colNames and add to the uniqueCols
|
||||||
for i, col := range colNames {
|
for i, col := range argColumns {
|
||||||
if col == indexColumnName {
|
if col == indexColumnName {
|
||||||
uniqueColValMap[col] = args[i]
|
uniqueColValMap[col] = args[i]
|
||||||
continue indexCol
|
continue indexCol
|
||||||
|
|
Loading…
Reference in New Issue