From e88ca1d01712df75b4aae650077a86c378515ed8 Mon Sep 17 00:00:00 2001 From: Andrew Thornton Date: Mon, 13 Mar 2023 18:42:34 +0000 Subject: [PATCH] only permit upsert if there is only one unique constraint Signed-off-by: Andrew Thornton --- integrations/session_upsert_test.go | 48 +++++------------------------ session_upsert.go | 20 +++++++----- 2 files changed, 20 insertions(+), 48 deletions(-) diff --git a/integrations/session_upsert_test.go b/integrations/session_upsert_test.go index d9bb2557..20176b59 100644 --- a/integrations/session_upsert_test.go +++ b/integrations/session_upsert_test.go @@ -378,53 +378,19 @@ func TestUpsert(t *testing.T) { t.Run("MultiMultiUnique", func(t *testing.T) { type MultiMultiUniqueUpsert struct { - ID int64 `xorm:"pk autoincr"` - Data0 string `xorm:"UNIQUE NOT NULL"` - Data1 string `xorm:"UNIQUE(s) NOT NULL"` - Data2 string `xorm:"UNIQUE(s) NOT NULL"` + ID int64 `xorm:"pk autoincr"` + NotUnique string + Data0 string `xorm:"UNIQUE NOT NULL"` + Data1 string `xorm:"UNIQUE(s) NOT NULL"` + Data2 string `xorm:"UNIQUE(s) NOT NULL"` } assert.NoError(t, testEngine.Sync2(&MultiMultiUniqueUpsert{})) _, _ = testEngine.Exec("DELETE FROM multi_multi_unique") - // Insert with default values + // Cannot upsert if there is more than one unique constraint n, err := testEngine.Upsert(&MultiMultiUniqueUpsert{}) - assert.NoError(t, err) - assert.Equal(t, int64(1), n) - - // Insert with value for t1, - n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data1: "test", Data0: "t1"}) - assert.NoError(t, err) - assert.Equal(t, int64(1), n) - - // Fail insert with value for t1, - n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data2: "test2", Data0: "t1"}) - assert.NoError(t, err) - assert.Equal(t, int64(0), n) - - // Insert with value for t2, - n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data2: "test2", Data0: "t2"}) - assert.NoError(t, err) - assert.Equal(t, int64(1), n) - - // Fail insert with value for t2, - n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data2: "test2", Data0: "t2"}) - assert.NoError(t, err) - assert.Equal(t, int64(0), n) - - // Fail insert with value for t2, - n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data1: "test", Data0: "t2"}) - assert.NoError(t, err) - assert.Equal(t, int64(0), n) - - // Insert with value for t3, - n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data1: "test", Data2: "test2", Data0: "t3"}) - assert.NoError(t, err) - assert.Equal(t, int64(1), n) - - // fail insert with value for t2, - n, err = testEngine.Upsert(&MultiMultiUniqueUpsert{Data1: "test", Data2: "test2", Data0: "t2"}) - assert.NoError(t, err) + assert.Error(t, err) assert.Equal(t, int64(0), n) }) diff --git a/session_upsert.go b/session_upsert.go index 04aa7da4..9b3595e2 100644 --- a/session_upsert.go +++ b/session_upsert.go @@ -117,10 +117,13 @@ func (session *Session) upsertMap(doUpdate bool, columns []string, args []interf return 0, ErrTableNotFound } - uniqueColValMap, err := session.getUniqueColumns(columns, args) + uniqueColValMap, numberOfUniqueConstraints, err := session.getUniqueColumns(columns, args) if err != nil { return 0, err } + if numberOfUniqueConstraints > 1 { + return 0, fmt.Errorf("cannot upsert if there is more than one unique constraint") + } sql, args, err := session.statement.GenUpsertMapSQL(doUpdate, columns, args, uniqueColValMap) if err != nil { @@ -169,10 +172,13 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er return 0, err } - uniqueColValMap, err := session.getUniqueColumns(colNames, args) + uniqueColValMap, numberOfUniqueConstraints, err := session.getUniqueColumns(colNames, args) if err != nil { return 0, err } + if numberOfUniqueConstraints > 1 { + return 0, fmt.Errorf("cannot upsert if there is more than one unique constraint") + } sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, colNames, args, uniqueColValMap) if err != nil { @@ -242,11 +248,11 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er return n, err } -func (session *Session) getUniqueColumns(colNames []string, args []interface{}) (uniqueColValMap map[string]interface{}, err error) { +func (session *Session) getUniqueColumns(colNames []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) { uniqueColValMap = make(map[string]interface{}) table := session.statement.RefTable if len(table.Indexes) == 0 { - return nil, fmt.Errorf("provided bean has no unique constraints") + return nil, 0, fmt.Errorf("provided bean has no unique constraints") } // Iterate across the indexes in the provided table @@ -254,7 +260,7 @@ func (session *Session) getUniqueColumns(colNames []string, args []interface{}) if index.Type != schemas.UniqueType { continue } - + numberOfUniqueConstraints++ // index is a Unique constraint indexCol: for _, indexColumnName := range index.Cols { @@ -311,8 +317,8 @@ func (session *Session) getUniqueColumns(colNames []string, args []interface{}) } // FIXME: not sure if there's anything else we can do - return nil, fmt.Errorf("provided bean does not provide a value for unique constraint %s field %s which has no default", index.Name, indexColumnName) + return nil, 0, fmt.Errorf("provided bean does not provide a value for unique constraint %s field %s which has no default", index.Name, indexColumnName) } } - return uniqueColValMap, nil + return uniqueColValMap, numberOfUniqueConstraints, nil }