only permit upsert if there is only one unique constraint

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
Andrew Thornton 2023-03-13 18:42:34 +00:00
parent 5b92ebc141
commit e88ca1d017
No known key found for this signature in database
GPG Key ID: 3CDE74631F13A748
2 changed files with 20 additions and 48 deletions

View File

@ -378,53 +378,19 @@ func TestUpsert(t *testing.T) {
t.Run("MultiMultiUnique", func(t *testing.T) { t.Run("MultiMultiUnique", func(t *testing.T) {
type MultiMultiUniqueUpsert struct { type MultiMultiUniqueUpsert struct {
ID int64 `xorm:"pk autoincr"` ID int64 `xorm:"pk autoincr"`
Data0 string `xorm:"UNIQUE NOT NULL"` NotUnique string
Data1 string `xorm:"UNIQUE(s) NOT NULL"` Data0 string `xorm:"UNIQUE NOT NULL"`
Data2 string `xorm:"UNIQUE(s) NOT NULL"` Data1 string `xorm:"UNIQUE(s) NOT NULL"`
Data2 string `xorm:"UNIQUE(s) NOT NULL"`
} }
assert.NoError(t, testEngine.Sync2(&MultiMultiUniqueUpsert{})) assert.NoError(t, testEngine.Sync2(&MultiMultiUniqueUpsert{}))
_, _ = testEngine.Exec("DELETE FROM multi_multi_unique") _, _ = testEngine.Exec("DELETE FROM multi_multi_unique")
// Insert with default values // Cannot upsert if there is more than one unique constraint
n, err := testEngine.Upsert(&MultiMultiUniqueUpsert{}) n, err := testEngine.Upsert(&MultiMultiUniqueUpsert{})
assert.NoError(t, err) assert.Error(t, err)
assert.Equal(t, int64(1), n)
// Insert with value for t1, <test, "">
n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data1: "test", Data0: "t1"})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Fail insert with value for t1, <test2, "">
n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data2: "test2", Data0: "t1"})
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
// Insert with value for t2, <test2, "">
n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data2: "test2", Data0: "t2"})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Fail insert with value for t2, <test2, "">
n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data2: "test2", Data0: "t2"})
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
// Fail insert with value for t2, <test, "">
n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data1: "test", Data0: "t2"})
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
// Insert with value for t3, <test, test2>
n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data1: "test", Data2: "test2", Data0: "t3"})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// fail insert with value for t2, <test, test2>
n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data1: "test", Data2: "test2", Data0: "t2"})
assert.NoError(t, err)
assert.Equal(t, int64(0), n) assert.Equal(t, int64(0), n)
}) })

View File

@ -117,10 +117,13 @@ func (session *Session) upsertMap(doUpdate bool, columns []string, args []interf
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
uniqueColValMap, err := session.getUniqueColumns(columns, args) uniqueColValMap, numberOfUniqueConstraints, err := session.getUniqueColumns(columns, args)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if numberOfUniqueConstraints > 1 {
return 0, fmt.Errorf("cannot upsert if there is more than one unique constraint")
}
sql, args, err := session.statement.GenUpsertMapSQL(doUpdate, columns, args, uniqueColValMap) sql, args, err := session.statement.GenUpsertMapSQL(doUpdate, columns, args, uniqueColValMap)
if err != nil { if err != nil {
@ -169,10 +172,13 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er
return 0, err return 0, err
} }
uniqueColValMap, err := session.getUniqueColumns(colNames, args) uniqueColValMap, numberOfUniqueConstraints, err := session.getUniqueColumns(colNames, args)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if numberOfUniqueConstraints > 1 {
return 0, fmt.Errorf("cannot upsert if there is more than one unique constraint")
}
sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, colNames, args, uniqueColValMap) sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, colNames, args, uniqueColValMap)
if err != nil { if err != nil {
@ -242,11 +248,11 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er
return n, err return n, err
} }
func (session *Session) getUniqueColumns(colNames []string, args []interface{}) (uniqueColValMap map[string]interface{}, err error) { func (session *Session) getUniqueColumns(colNames []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) {
uniqueColValMap = make(map[string]interface{}) uniqueColValMap = make(map[string]interface{})
table := session.statement.RefTable table := session.statement.RefTable
if len(table.Indexes) == 0 { if len(table.Indexes) == 0 {
return nil, fmt.Errorf("provided bean has no unique constraints") return nil, 0, fmt.Errorf("provided bean has no unique constraints")
} }
// Iterate across the indexes in the provided table // Iterate across the indexes in the provided table
@ -254,7 +260,7 @@ func (session *Session) getUniqueColumns(colNames []string, args []interface{})
if index.Type != schemas.UniqueType { if index.Type != schemas.UniqueType {
continue continue
} }
numberOfUniqueConstraints++
// index is a Unique constraint // index is a Unique constraint
indexCol: indexCol:
for _, indexColumnName := range index.Cols { for _, indexColumnName := range index.Cols {
@ -311,8 +317,8 @@ func (session *Session) getUniqueColumns(colNames []string, args []interface{})
} }
// FIXME: not sure if there's anything else we can do // FIXME: not sure if there's anything else we can do
return nil, fmt.Errorf("provided bean does not provide a value for unique constraint %s field %s which has no default", index.Name, indexColumnName) return nil, 0, fmt.Errorf("provided bean does not provide a value for unique constraint %s field %s which has no default", index.Name, indexColumnName)
} }
} }
return uniqueColValMap, nil return uniqueColValMap, numberOfUniqueConstraints, nil
} }