From 3a4fbeaa6fa64ef6ecaae4eec59971dadfea6294 Mon Sep 17 00:00:00 2001 From: Andrew Thornton Date: Sun, 12 Mar 2023 11:31:08 +0000 Subject: [PATCH] Add InsertOnConflictDoNothing and Upsert functionality This PR adds functionality for xorm to perform InsertOnConflictDoNothing and Upserts. Signed-off-by: Andrew Thornton --- engine.go | 14 ++ integrations/session_insert_test.go | 254 ++++++++++++++++++++- interface.go | 4 +- internal/statements/expr.go | 10 +- internal/statements/insert.go | 338 ++++++++++++++-------------- internal/statements/upsert.go | 271 ++++++++++++++++++++++ internal/utils/map.go | 110 +++++++++ session_insert.go | 330 ++++++++++++--------------- session_upsert.go | 298 ++++++++++++++++++++++++ 9 files changed, 1262 insertions(+), 367 deletions(-) create mode 100644 internal/statements/upsert.go create mode 100644 internal/utils/map.go create mode 100644 session_upsert.go diff --git a/engine.go b/engine.go index 389819e7..a8edbfc4 100644 --- a/engine.go +++ b/engine.go @@ -1224,6 +1224,13 @@ func (engine *Engine) InsertOne(bean interface{}) (int64, error) { return session.InsertOne(bean) } +// InsertOnConflictDoNothing attempt to insert a record but on conflict do nothing +func (engine *Engine) InsertOnConflictDoNothing(beans ...interface{}) (int64, error) { + session := engine.NewSession() + defer session.Close() + return session.InsertOnConflictDoNothing(beans...) +} + // Update records, bean's non-empty fields are updated contents, // condiBean' non-empty filds are conditions // CAUTION: @@ -1237,6 +1244,13 @@ func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64 return session.Update(bean, condiBeans...) } +// Upsert attempt to insert a record but on conflict do update +func (engine *Engine) Upsert(beans ...interface{}) (int64, error) { + session := engine.NewSession() + defer session.Close() + return session.Upsert(beans...) +} + // Delete records, bean's non-empty fields are conditions // At least one condition must be set. func (engine *Engine) Delete(beans ...interface{}) (int64, error) { diff --git a/integrations/session_insert_test.go b/integrations/session_insert_test.go index 084deb38..40b52806 100644 --- a/integrations/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -16,6 +16,231 @@ import ( "github.com/stretchr/testify/assert" ) +func TestInsertOnConflictDoNothing(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + t.Run("NoUnique", func(t *testing.T) { + // InsertOnConflictDoNothing does not work if there is no unique constraint + type NoUniques struct { + ID int64 `xorm:"pk autoincr"` + Data string + } + assert.NoError(t, testEngine.Sync(new(NoUniques))) + + toInsert := &NoUniques{Data: "shouldErr"} + n, err := testEngine.InsertOnConflictDoNothing(toInsert) + assert.Error(t, err) + assert.Equal(t, int64(0), n) + assert.Equal(t, int64(0), toInsert.ID) + + toInsert = &NoUniques{Data: ""} + n, err = testEngine.InsertOnConflictDoNothing(toInsert) + assert.Error(t, err) + assert.Equal(t, int64(0), n) + assert.Equal(t, int64(0), toInsert.ID) + }) + + t.Run("OneUnique", func(t *testing.T) { + type OneUnique struct { + ID int64 `xorm:"pk autoincr"` + Data string `xorm:"UNIQUE NOT NULL"` + } + + assert.NoError(t, testEngine.Sync2(&OneUnique{})) + _, _ = testEngine.Exec("DELETE FROM one_unique") + + // Insert with the default value for the unique field + toInsert := &OneUnique{} + n, err := testEngine.InsertOnConflictDoNothing(toInsert) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + assert.NotEqual(t, int64(0), toInsert.ID) + + // but not twice + toInsert = &OneUnique{} + n, err = testEngine.InsertOnConflictDoNothing(toInsert) + assert.NoError(t, err) + assert.Equal(t, int64(0), n) + assert.Equal(t, int64(0), toInsert.ID) + + // Successfully insert test + toInsert = &OneUnique{Data: "test"} + n, err = testEngine.InsertOnConflictDoNothing(toInsert) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + assert.NotEqual(t, int64(0), toInsert.ID) + + // Successfully insert test2 + toInsert = &OneUnique{Data: "test2"} + n, err = testEngine.InsertOnConflictDoNothing(toInsert) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + assert.NotEqual(t, int64(0), toInsert.ID) + + // Successfully don't reinsert test + toInsert = &OneUnique{Data: "test"} + n, err = testEngine.InsertOnConflictDoNothing(toInsert) + assert.NoError(t, err) + assert.Equal(t, int64(0), n) + assert.Equal(t, int64(0), toInsert.ID) + }) + + t.Run("MultiUnique", func(t *testing.T) { + type MultiUnique struct { + ID int64 `xorm:"pk autoincr"` + NotUnique string + Data1 string `xorm:"UNIQUE(s) NOT NULL"` + Data2 string `xorm:"UNIQUE(s) NOT NULL"` + } + + assert.NoError(t, testEngine.Sync2(&MultiUnique{})) + _, _ = testEngine.Exec("DELETE FROM multi_unique") + + // Insert with default values + toInsert := &MultiUnique{} + n, err := testEngine.InsertOnConflictDoNothing(toInsert) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + assert.NotEqual(t, int64(0), toInsert.ID) + + // successfully insert test, t1 + toInsert = &MultiUnique{Data1: "test", NotUnique: "t1"} + n, err = testEngine.InsertOnConflictDoNothing(toInsert) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + assert.NotEqual(t, int64(0), toInsert.ID) + + // successfully insert test2, t1 + toInsert = &MultiUnique{Data1: "test2", NotUnique: "t1"} + n, err = testEngine.InsertOnConflictDoNothing(toInsert) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + assert.NotEqual(t, int64(0), toInsert.ID) + + // successfully don't insert test2, t2 + toInsert = &MultiUnique{Data1: "test2", NotUnique: "t2"} + n, err = testEngine.InsertOnConflictDoNothing(toInsert) + assert.NoError(t, err) + assert.Equal(t, int64(0), n) + assert.Equal(t, int64(0), toInsert.ID) + + // successfully don't insert test, t2 + toInsert = &MultiUnique{Data1: "test", NotUnique: "t2"} + n, err = testEngine.InsertOnConflictDoNothing(toInsert) + assert.NoError(t, err) + assert.Equal(t, int64(0), n) + assert.Equal(t, int64(0), toInsert.ID) + + // successfully insert test/test2, t2 + toInsert = &MultiUnique{Data1: "test", Data2: "test2", NotUnique: "t1"} + n, err = testEngine.InsertOnConflictDoNothing(toInsert) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + assert.NotEqual(t, int64(0), toInsert.ID) + + // successfully don't insert test/test2, t2 + toInsert = &MultiUnique{Data1: "test", Data2: "test2", NotUnique: "t2"} + n, err = testEngine.InsertOnConflictDoNothing(toInsert) + assert.NoError(t, err) + assert.Equal(t, int64(0), n) + assert.Equal(t, int64(0), toInsert.ID) + }) + + t.Run("MultiMultiUnique", func(t *testing.T) { + type MultiMultiUnique struct { + ID int64 `xorm:"pk autoincr"` + Data0 string `xorm:"UNIQUE NOT NULL"` + Data1 string `xorm:"UNIQUE(s) NOT NULL"` + Data2 string `xorm:"UNIQUE(s) NOT NULL"` + } + + assert.NoError(t, testEngine.Sync2(&MultiMultiUnique{})) + _, _ = testEngine.Exec("DELETE FROM multi_multi_unique") + + // Insert with default values + n, err := testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{}) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + + // Insert with value for t1, + n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data1: "test", Data0: "t1"}) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + + // Fail insert with value for t1, + n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data2: "test2", Data0: "t1"}) + assert.NoError(t, err) + assert.Equal(t, int64(0), n) + + // Insert with value for t2, + n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data2: "test2", Data0: "t2"}) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + + // Fail insert with value for t2, + n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data2: "test2", Data0: "t2"}) + assert.NoError(t, err) + assert.Equal(t, int64(0), n) + + // Fail insert with value for t2, + n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data1: "test", Data0: "t2"}) + assert.NoError(t, err) + assert.Equal(t, int64(0), n) + + // Insert with value for t3, + n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data1: "test", Data2: "test2", Data0: "t3"}) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + + // fail insert with value for t2, + n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data1: "test", Data2: "test2", Data0: "t2"}) + assert.NoError(t, err) + assert.Equal(t, int64(0), n) + }) + + t.Run("NoPK", func(t *testing.T) { + type NoPrimaryKey struct { + NotID int64 + Uniqued string `xorm:"UNIQUE"` + } + + assert.NoError(t, testEngine.Sync2(&NoPrimaryKey{})) + _, _ = testEngine.Exec("DELETE FROM no_primary_unique") + + empty := &NoPrimaryKey{} + + // Insert default + n, err := testEngine.InsertOnConflictDoNothing(empty) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + + // Insert with 1 + n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{Uniqued: "1"}) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + + // Fail reinsert default + n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{NotID: 1}) + assert.NoError(t, err) + assert.Equal(t, int64(0), n) + + // Fail reinsert default + n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{NotID: 2}) + assert.NoError(t, err) + assert.Equal(t, int64(0), n) + + // Insert with 2 + n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{NotID: 2, Uniqued: "2"}) + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + + // Fail reinsert with 2 + n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{NotID: 1, Uniqued: "2"}) + assert.NoError(t, err) + assert.Equal(t, int64(0), n) + }) +} + func TestInsertOne(t *testing.T) { assert.NoError(t, PrepareEngine()) @@ -142,8 +367,13 @@ func TestInsert(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) - user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(), - Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true} + user := Userinfo{ + 0, "xiaolunwen", "dev", "lunny", time.Now(), + Userdetail{Id: 1}, + 1.78, + []byte{1, 2, 3}, + true, + } cnt, err := testEngine.Insert(&user) assert.NoError(t, err) assert.EqualValues(t, 1, cnt, "insert not returned 1") @@ -161,8 +391,10 @@ func TestInsertAutoIncr(t *testing.T) { assertSync(t, new(Userinfo)) // auto increment insert - user := Userinfo{Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(), - Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true} + user := Userinfo{ + Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(), + Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true, + } cnt, err := testEngine.Insert(&user) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) @@ -184,7 +416,7 @@ func TestInsertDefault(t *testing.T) { err := testEngine.Sync(di) assert.NoError(t, err) - var di2 = DefaultInsert{Name: "test"} + di2 := DefaultInsert{Name: "test"} _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("Status")).Insert(&di2) assert.NoError(t, err) @@ -210,7 +442,7 @@ func TestInsertDefault2(t *testing.T) { err := testEngine.Sync(di) assert.NoError(t, err) - var di2 = DefaultInsert2{Name: "test"} + di2 := DefaultInsert2{Name: "test"} _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("CheckTime")).Insert(&di2) assert.NoError(t, err) @@ -438,7 +670,7 @@ func TestCreatedJsonTime(t *testing.T) { assert.True(t, has) assert.EqualValues(t, time.Time(ci5.Created).Unix(), time.Time(di5.Created).Unix()) - var dis = make([]MyJSONTime, 0) + dis := make([]MyJSONTime, 0) err = testEngine.Find(&dis) assert.NoError(t, err) } @@ -762,7 +994,7 @@ func TestInsertWhere(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(InsertWhere)) - var i = InsertWhere{ + i := InsertWhere{ RepoId: 1, Width: 10, Height: 20, @@ -872,7 +1104,7 @@ func TestInsertExpr2(t *testing.T) { assertSync(t, new(InsertExprsRelease)) - var ie = InsertExprsRelease{ + ie := InsertExprsRelease{ RepoId: 1, IsTag: true, } @@ -1047,7 +1279,7 @@ func TestInsertIntSlice(t *testing.T) { assert.NoError(t, testEngine.Sync(new(InsertIntSlice))) - var v = InsertIntSlice{ + v := InsertIntSlice{ NameIDs: []int{1, 2}, } cnt, err := testEngine.Insert(&v) @@ -1064,7 +1296,7 @@ func TestInsertIntSlice(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var v3 = InsertIntSlice{ + v3 := InsertIntSlice{ NameIDs: nil, } cnt, err = testEngine.Insert(&v3) diff --git a/interface.go b/interface.go index d10abe9e..4485485d 100644 --- a/interface.go +++ b/interface.go @@ -44,7 +44,8 @@ type Interface interface { In(string, ...interface{}) *Session Incr(column string, arg ...interface{}) *Session Insert(...interface{}) (int64, error) - InsertOne(interface{}) (int64, error) + InsertOne(interface{}) (int64, error) // Deprecated: Please use Insert directly + InsertOnConflictDoNothing(beans ...interface{}) (int64, error) IsTableEmpty(bean interface{}) (bool, error) IsTableExist(beanOrTableName interface{}) (bool, error) Iterate(interface{}, IterFunc) error @@ -71,6 +72,7 @@ type Interface interface { Table(tableNameOrBean interface{}) *Session Unscoped() *Session Update(bean interface{}, condiBeans ...interface{}) (int64, error) + Upsert(beans ...interface{}) (int64, error) UseBool(...string) *Session Where(interface{}, ...interface{}) *Session } diff --git a/internal/statements/expr.go b/internal/statements/expr.go index c2a2e1cc..2a5e11ea 100644 --- a/internal/statements/expr.go +++ b/internal/statements/expr.go @@ -59,13 +59,21 @@ func (expr *Expr) WriteArgs(w *builder.BytesWriter) error { type exprParams []Expr func (exprs exprParams) ColNames() []string { - var cols = make([]string, 0, len(exprs)) + cols := make([]string, 0, len(exprs)) for _, expr := range exprs { cols = append(cols, expr.ColName) } return cols } +func (exprs exprParams) ColNamesTrim() []string { + cols := make([]string, 0, len(exprs)) + for _, expr := range exprs { + cols = append(cols, schemas.CommonQuoter.Trim(expr.ColName)) + } + return cols +} + func (exprs *exprParams) Add(name string, arg interface{}) { *exprs = append(*exprs, Expr{name, arg}) } diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 91a33319..8475fc57 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -30,7 +30,6 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (string, []interface{}, error) { var ( buf = builder.NewWriter() - exprs = statement.ExprColumns table = statement.RefTable tableName = statement.TableName() ) @@ -43,129 +42,8 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - var hasInsertColumns = len(colNames) > 0 - var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) - if needSeq { - for _, col := range colNames { - if strings.EqualFold(col, table.AutoIncrement) { - needSeq = false - break - } - } - } - - if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE && - statement.dialect.URI().DBType != schemas.DAMENG { - if statement.dialect.URI().DBType == schemas.MYSQL { - if _, err := buf.WriteString(" VALUES ()"); err != nil { - return "", nil, err - } - } else { - if err := statement.writeInsertOutput(buf.Builder, table); err != nil { - return "", nil, err - } - if _, err := buf.WriteString(" DEFAULT VALUES"); err != nil { - return "", nil, err - } - } - } else { - if _, err := buf.WriteString(" ("); err != nil { - return "", nil, err - } - - if needSeq { - colNames = append(colNames, table.AutoIncrement) - } - - if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames()...), ","); err != nil { - return "", nil, err - } - - if _, err := buf.WriteString(")"); err != nil { - return "", nil, err - } - if err := statement.writeInsertOutput(buf.Builder, table); err != nil { - return "", nil, err - } - - if statement.Conds().IsValid() { - if _, err := buf.WriteString(" SELECT "); err != nil { - return "", nil, err - } - - if err := statement.WriteArgs(buf, args); err != nil { - return "", nil, err - } - - if needSeq { - if len(args) > 0 { - if _, err := buf.WriteString(","); err != nil { - return "", nil, err - } - } - if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil { - return "", nil, err - } - } - if len(exprs) > 0 { - if _, err := buf.WriteString(","); err != nil { - return "", nil, err - } - if err := exprs.WriteArgs(buf); err != nil { - return "", nil, err - } - } - - if _, err := buf.WriteString(" FROM "); err != nil { - return "", nil, err - } - - if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { - return "", nil, err - } - - if _, err := buf.WriteString(" WHERE "); err != nil { - return "", nil, err - } - - if err := statement.Conds().WriteTo(buf); err != nil { - return "", nil, err - } - } else { - if _, err := buf.WriteString(" VALUES ("); err != nil { - return "", nil, err - } - - if err := statement.WriteArgs(buf, args); err != nil { - return "", nil, err - } - - // Insert tablename (id) Values(seq_tablename.nextval) - if needSeq { - if hasInsertColumns { - if _, err := buf.WriteString(","); err != nil { - return "", nil, err - } - } - if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil { - return "", nil, err - } - } - - if len(exprs) > 0 { - if _, err := buf.WriteString(","); err != nil { - return "", nil, err - } - } - - if err := exprs.WriteArgs(buf); err != nil { - return "", nil, err - } - - if _, err := buf.WriteString(")"); err != nil { - return "", nil, err - } - } + if err := statement.genInsertValues(buf, colNames, args); err != nil { + return "", nil, err } if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { @@ -180,6 +58,169 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return buf.String(), buf.Args(), nil } +func (statement *Statement) includeAutoIncrement(colNames []string) bool { + includesAutoIncrement := len(statement.RefTable.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) + if includesAutoIncrement { + for _, col := range colNames { + if strings.EqualFold(col, statement.RefTable.AutoIncrement) { + includesAutoIncrement = false + break + } + } + } + return includesAutoIncrement +} + +func (statement *Statement) genInsertValues(buf *builder.BytesWriter, colNames []string, args []interface{}) error { + var ( + exprs = statement.ExprColumns + table = statement.RefTable + ) + + hasInsertColumns := len(colNames) > 0 + includeAutoIncrement := statement.includeAutoIncrement(colNames) + + // Empty insert - i.e. insert default values only + if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE && + statement.dialect.URI().DBType != schemas.DAMENG { + + if statement.dialect.URI().DBType == schemas.MYSQL { + // MySQL doesn't have DEFAULT VALUES and uses VALUES () for this. + if _, err := buf.WriteString(" VALUES ()"); err != nil { + return err + } + return nil + } + + // (MSSQL: return the inserted values) + if err := statement.writeInsertOutput(buf.Builder, table); err != nil { + return err + } + + // All others use DEFAULT VALUES + if _, err := buf.WriteString(" DEFAULT VALUES"); err != nil { + return err + } + return nil + } + + // We have some values - Write the column names we need to insert: + if _, err := buf.WriteString(" ("); err != nil { + return err + } + + if includeAutoIncrement { + colNames = append(colNames, table.AutoIncrement) + } + + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames()...), ","); err != nil { + return err + } + + if _, err := buf.WriteString(")"); err != nil { + return err + } + + // (MSSQL: return the inserted values) + if err := statement.writeInsertOutput(buf.Builder, table); err != nil { + return err + } + + return statement.genInsertValuesValues(buf, includeAutoIncrement, colNames, args) +} + +func (statement *Statement) genInsertValuesValues(buf *builder.BytesWriter, includeAutoIncrement bool, colNames []string, args []interface{}) error { + var ( + exprs = statement.ExprColumns + tableName = statement.TableName() + ) + hasInsertColumns := len(colNames) > 0 + + if statement.Conds().IsValid() { + // We have conditions which we're trying to insert + if _, err := buf.WriteString(" SELECT "); err != nil { + return err + } + + if err := statement.WriteArgs(buf, args); err != nil { + return err + } + + if includeAutoIncrement { + if len(args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return err + } + } + if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil { + return err + } + } + + if len(exprs) > 0 { + if _, err := buf.WriteString(","); err != nil { + return err + } + if err := exprs.WriteArgs(buf); err != nil { + return err + } + } + + if _, err := buf.WriteString(" FROM "); err != nil { + return err + } + + if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { + return err + } + + if _, err := buf.WriteString(" WHERE "); err != nil { + return err + } + + if err := statement.Conds().WriteTo(buf); err != nil { + return err + } + return nil + } + + // Direct insertion of values + if _, err := buf.WriteString(" VALUES ("); err != nil { + return err + } + + if err := statement.WriteArgs(buf, args); err != nil { + return err + } + + // Insert tablename (id) Values(seq_tablename.nextval) + if includeAutoIncrement { + if hasInsertColumns { + if _, err := buf.WriteString(","); err != nil { + return err + } + } + if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil { + return err + } + } + + if len(exprs) > 0 { + if _, err := buf.WriteString(","); err != nil { + return err + } + } + + if err := exprs.WriteArgs(buf); err != nil { + return err + } + + if _, err := buf.WriteString(")"); err != nil { + return err + } + return nil +} + // GenInsertMapSQL generates insert map SQL func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{}) (string, []interface{}, error) { var ( @@ -196,51 +237,12 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{} return "", nil, err } - // if insert where - if statement.Conds().IsValid() { - if _, err := buf.WriteString(") SELECT "); err != nil { - return "", nil, err - } + if _, err := buf.WriteString(")"); err != nil { + return "", nil, err + } - if err := statement.WriteArgs(buf, args); err != nil { - return "", nil, err - } - - if len(exprs) > 0 { - if _, err := buf.WriteString(","); err != nil { - return "", nil, err - } - if err := exprs.WriteArgs(buf); err != nil { - return "", nil, err - } - } - - if _, err := buf.WriteString(fmt.Sprintf(" FROM %s WHERE ", statement.quote(tableName))); err != nil { - return "", nil, err - } - - if err := statement.Conds().WriteTo(buf); err != nil { - return "", nil, err - } - } else { - if _, err := buf.WriteString(") VALUES ("); err != nil { - return "", nil, err - } - if err := statement.WriteArgs(buf, args); err != nil { - return "", nil, err - } - - if len(exprs) > 0 { - if _, err := buf.WriteString(","); err != nil { - return "", nil, err - } - if err := exprs.WriteArgs(buf); err != nil { - return "", nil, err - } - } - if _, err := buf.WriteString(")"); err != nil { - return "", nil, err - } + if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil { + return "", nil, err } return buf.String(), buf.Args(), nil diff --git a/internal/statements/upsert.go b/internal/statements/upsert.go new file mode 100644 index 00000000..abcdaa14 --- /dev/null +++ b/internal/statements/upsert.go @@ -0,0 +1,271 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + + "xorm.io/builder" + "xorm.io/xorm/schemas" +) + +// GenUpsertSQL generates upsert beans SQL +func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) { + if statement.dialect.URI().DBType == schemas.MSSQL || + statement.dialect.URI().DBType == schemas.DAMENG || + statement.dialect.URI().DBType == schemas.ORACLE { + return statement.genMergeSQL(doUpdate, columns, args, uniqueColValMap) + } + + var ( + buf = builder.NewWriter() + table = statement.RefTable + tableName = statement.TableName() + ) + + if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate { + if _, err := buf.WriteString("INSERT IGNORE INTO "); err != nil { + return "", nil, err + } + } else { + if _, err := buf.WriteString("INSERT INTO "); err != nil { + return "", nil, err + } + } + + if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { + return "", nil, err + } + + if err := statement.genInsertValues(buf, columns, args); err != nil { + return "", nil, err + } + + switch statement.dialect.URI().DBType { + case schemas.SQLITE, schemas.POSTGRES: + if _, err := buf.WriteString(" ON CONFLICT DO "); err != nil { + return "", nil, err + } + if doUpdate { + return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT + } else { + if _, err := buf.WriteString("NOTHING"); err != nil { + return "", nil, err + } + } + case schemas.MYSQL: + if doUpdate { + return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT + } + default: + return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT + } + + if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { + if _, err := buf.WriteString(" RETURNING "); err != nil { + return "", nil, err + } + if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { + return "", nil, err + } + } + + return buf.String(), buf.Args(), nil +} + +func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) { + var ( + buf = builder.NewWriter() + table = statement.RefTable + tableName = statement.TableName() + ) + + quote := statement.dialect.Quoter().Quote + write := func(args ...string) { + for _, arg := range args { + _, _ = buf.WriteString(arg) + } + } + + write("MERGE INTO ", quote(tableName)) + if statement.dialect.URI().DBType == schemas.MSSQL { + write("WITH (HOLDLOCK) AS target ") + } + write("USING (SELECT ") + + uniqueCols := make([]string, 0, len(uniqueColValMap)) + for colName := range uniqueColValMap { + uniqueCols = append(uniqueCols, colName) + } + for i, colName := range uniqueCols { + if err := statement.WriteArg(buf, uniqueColValMap[colName]); err != nil { + return "", nil, err + } + write(" AS ", quote(colName)) + if i < len(uniqueCols)-1 { + write(", ") + } + } + write(") AS src ON (") + + countUniques := 0 + for _, index := range table.Indexes { + if index.Type != schemas.UniqueType { + continue + } + if countUniques > 0 { + write(" OR ") + } + countUniques++ + write("(") + write("src.", quote(index.Cols[0]), "= target.", quote(index.Cols[0])) + for _, col := range index.Cols[1:] { + write(" AND src.", quote(col), "= target.", quote(col)) + } + write(")") + } + if doUpdate { + return "", nil, fmt.Errorf("unimplemented") + } + write(") WHEN NOT MATCHED THEN INSERT") + if err := statement.genInsertValues(buf, columns, args); err != nil { + return "", nil, err + } + return buf.String(), buf.Args(), nil +} + +// GenUpsertMapSQL generates insert map SQL +func (statement *Statement) GenUpsertMapSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) { + if statement.dialect.URI().DBType == schemas.MSSQL || + statement.dialect.URI().DBType == schemas.DAMENG || + statement.dialect.URI().DBType == schemas.ORACLE { + return statement.genMergeMapSQL(doUpdate, columns, args, uniqueColValMap) + } + var ( + buf = builder.NewWriter() + exprs = statement.ExprColumns + table = statement.RefTable + tableName = statement.TableName() + ) + quote := statement.dialect.Quoter().Quote + write := func(args ...string) { + for _, arg := range args { + _, _ = buf.WriteString(arg) + } + } + + if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate { + write("INSERT IGNORE INTO ", quote(tableName), " (") + } else { + write("INSERT INTO ", quote(tableName), " (") + } + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil { + return "", nil, err + } + write(")") + + if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil { + return "", nil, err + } + + switch statement.dialect.URI().DBType { + case schemas.SQLITE, schemas.POSTGRES: + if _, err := buf.WriteString(" ON CONFLICT DO "); err != nil { + return "", nil, err + } + if doUpdate { + return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT + } else { + if _, err := buf.WriteString("NOTHING"); err != nil { + return "", nil, err + } + } + case schemas.MYSQL: + if doUpdate { + return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT + } + default: + return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT + } + + if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { + if _, err := buf.WriteString(" RETURNING "); err != nil { + return "", nil, err + } + if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { + return "", nil, err + } + } + + return buf.String(), buf.Args(), nil +} + +func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) { + var ( + buf = builder.NewWriter() + table = statement.RefTable + exprs = statement.ExprColumns + tableName = statement.TableName() + ) + + quote := statement.dialect.Quoter().Quote + write := func(args ...string) { + for _, arg := range args { + _, _ = buf.WriteString(arg) + } + } + + write("MERGE INTO ", quote(tableName)) + if statement.dialect.URI().DBType == schemas.MSSQL { + write("WITH (HOLDLOCK) AS target ") + } + write("USING (SELECT ") + + uniqueCols := make([]string, 0, len(uniqueColValMap)) + for colName := range uniqueColValMap { + uniqueCols = append(uniqueCols, colName) + } + for i, colName := range uniqueCols { + if err := statement.WriteArg(buf, uniqueColValMap[colName]); err != nil { + return "", nil, err + } + write(" AS ", quote(colName)) + if i < len(uniqueCols)-1 { + write(", ") + } + } + write(") AS src ON (") + + countUniques := 0 + for _, index := range table.Indexes { + if index.Type != schemas.UniqueType { + continue + } + if countUniques > 0 { + write(" OR ") + } + countUniques++ + write("(") + write("src.", quote(index.Cols[0]), "= target.", quote(index.Cols[0])) + for _, col := range index.Cols[1:] { + write(" AND src.", quote(col), "= target.", quote(col)) + } + write(")") + } + if doUpdate { + return "", nil, fmt.Errorf("unimplemented") + } + write(") WHEN NOT MATCHED THEN INSERT") + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil { + return "", nil, err + } + write(")") + + if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil { + return "", nil, err + } + + return buf.String(), buf.Args(), nil +} diff --git a/internal/utils/map.go b/internal/utils/map.go new file mode 100644 index 00000000..8c99cd9b --- /dev/null +++ b/internal/utils/map.go @@ -0,0 +1,110 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "sort" + "strings" +) + +func MapToSlices(m map[string]interface{}, exclude []string, trimmer func(string) string) ([]string, []interface{}) { + columns := make([]string, 0, len(m)) + +outer: + for colName := range m { + trimmed := trimmer(colName) + for _, excluded := range exclude { + if strings.EqualFold(excluded, trimmed) { + continue outer + } + } + columns = append(columns, colName) + } + sort.Strings(columns) + + args := make([]interface{}, 0, len(columns)) + for _, colName := range columns { + args = append(args, m[colName]) + } + + return columns, args +} + +func MapStringToSlices(m map[string]string, exclude []string, trimmer func(string) string) ([]string, []interface{}) { + columns := make([]string, 0, len(m)) + +outer: + for colName := range m { + trimmed := trimmer(colName) + for _, excluded := range exclude { + if strings.EqualFold(excluded, trimmed) { + continue outer + } + } + columns = append(columns, colName) + } + sort.Strings(columns) + + args := make([]interface{}, 0, len(columns)) + for _, colName := range columns { + args = append(args, m[colName]) + } + + return columns, args +} + +func MultipleMapToSlices(maps []map[string]interface{}, exclude []string, trimmer func(string) string) ([]string, [][]interface{}) { + columns := make([]string, 0, len(maps[0])) + +outer: + for colName := range maps[0] { + trimmed := trimmer(colName) + for _, excluded := range exclude { + if strings.EqualFold(excluded, trimmed) { + continue outer + } + } + columns = append(columns, colName) + } + sort.Strings(columns) + + argss := make([][]interface{}, 0, len(maps)) + for _, m := range maps { + args := make([]interface{}, 0, len(m)) + for _, colName := range columns { + args = append(args, m[colName]) + } + argss = append(argss, args) + } + + return columns, argss +} + +func MultipleMapStringToSlices(maps []map[string]string, exclude []string, trimmer func(string) string) ([]string, [][]interface{}) { + columns := make([]string, 0, len(maps[0])) + +outer: + for colName := range maps[0] { + trimmed := trimmer(colName) + for _, excluded := range exclude { + if strings.EqualFold(excluded, trimmed) { + continue outer + } + } + columns = append(columns, colName) + } + sort.Strings(columns) + + argss := make([][]interface{}, 0, len(maps)) + for _, m := range maps { + args := make([]interface{}, 0, len(m)) + for _, colName := range columns { + args = append(args, m[colName]) + } + argss = append(argss, args) + } + + return columns, argss +} diff --git a/session_insert.go b/session_insert.go index fc025613..97a1789b 100644 --- a/session_insert.go +++ b/session_insert.go @@ -5,10 +5,10 @@ package xorm import ( + "database/sql" "errors" "fmt" "reflect" - "sort" "strings" "time" @@ -156,14 +156,14 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, e } args = append(args, val) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t) }) } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnInt(bean, col, 1) @@ -276,7 +276,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { processor.BeforeInsert() } - var tableName = session.statement.TableName() + tableName := session.statement.TableName() table := session.statement.RefTable colNames, args, err := session.genInsertColumns(bean) @@ -290,102 +290,9 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { } sqlStr = session.engine.dialect.Quoter().Replace(sqlStr) - handleAfterInsertProcessorFunc := func(bean interface{}) { - if session.isAutoCommit { - for _, closure := range session.afterClosures { - closure(bean) - } - if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { - processor.AfterInsert() - } - } else { - lenAfterClosures := len(session.afterClosures) - if lenAfterClosures > 0 { - if value, has := session.afterInsertBeans[bean]; has && value != nil { - *value = append(*value, session.afterClosures...) - } else { - afterClosures := make([]func(interface{}), lenAfterClosures) - copy(afterClosures, session.afterClosures) - session.afterInsertBeans[bean] = &afterClosures - } - } else { - if _, ok := interface{}(bean).(AfterInsertProcessor); ok { - session.afterInsertBeans[bean] = nil - } - } - } - cleanupProcessorsClosures(&session.afterClosures) // cleanup after used - } - // if there is auto increment column and driver don't support return it if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID { - var sql string - var newArgs []interface{} - var needCommit bool - var id int64 - if session.engine.dialect.URI().DBType == schemas.ORACLE || session.engine.dialect.URI().DBType == schemas.DAMENG { - if session.isAutoCommit { // if it's not in transaction - if err := session.Begin(); err != nil { - return 0, err - } - needCommit = true - } - _, err := session.exec(sqlStr, args...) - if err != nil { - return 0, err - } - i := utils.IndexSlice(colNames, table.AutoIncrement) - if i > -1 { - id, err = convert.AsInt64(args[i]) - if err != nil { - return 0, err - } - } else { - sql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName)) - } - } else { - sql = sqlStr - newArgs = args - } - - if id == 0 { - err := session.queryRow(sql, newArgs...).Scan(&id) - if err != nil { - return 0, err - } - if needCommit { - if err := session.Commit(); err != nil { - return 0, err - } - } - if id == 0 { - return 0, errors.New("insert successfully but not returned id") - } - } - - defer handleAfterInsertProcessorFunc(bean) - - _ = session.cacheInsert(tableName) - - if table.Version != "" && session.statement.CheckVersion { - verValue, err := table.VersionColumn().ValueOf(bean) - if err != nil { - session.engine.logger.Errorf("%v", err) - } else if verValue.IsValid() && verValue.CanSet() { - session.incrVersionFieldValue(verValue) - } - } - - aiValue, err := table.AutoIncrColumn().ValueOf(bean) - if err != nil { - session.engine.logger.Errorf("%v", err) - } - - if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { - return 1, nil - } - - return 1, convert.AssignValue(*aiValue, id) + return session.execInsertSqlNoAutoReturn(sqlStr, bean, colNames, args) } res, err := session.exec(sqlStr, args...) @@ -393,7 +300,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { return 0, err } - defer handleAfterInsertProcessorFunc(bean) + defer session.handleAfterInsertProcessorFunc(bean) _ = session.cacheInsert(tableName) @@ -432,31 +339,6 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { return res.RowsAffected() } -// InsertOne insert only one struct into database as a record. -// The in parameter bean must a struct or a point to struct. The return -// parameter is inserted and error -// Deprecated: Please use Insert directly -func (session *Session) InsertOne(bean interface{}) (int64, error) { - if session.isAutoClose { - defer session.Close() - } - - return session.insertStruct(bean) -} - -func (session *Session) cacheInsert(table string) error { - if !session.statement.UseCache { - return nil - } - cacher := session.engine.cacherMgr.GetCacher(table) - if cacher == nil { - return nil - } - session.engine.logger.Debugf("[cache] clear SQL: %v", table) - cacher.ClearIds(table) - return nil -} - // genInsertColumns generates insert needed columns func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) { table := session.statement.RefTable @@ -517,7 +399,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } args = append(args, val) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t) @@ -537,6 +419,112 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac return colNames, args, nil } +func (session *Session) execInsertSqlNoAutoReturn(sqlStr string, bean interface{}, colNames []string, args []interface{}) (int64, error) { + var newSql string + var newArgs []interface{} + var needCommit bool + var id int64 + + tableName := session.statement.TableName() + table := session.statement.RefTable + + if session.engine.dialect.URI().DBType == schemas.ORACLE || session.engine.dialect.URI().DBType == schemas.DAMENG { + if session.isAutoCommit { // if it's not in transaction + if err := session.Begin(); err != nil { + return 0, err + } + needCommit = true + } + res, err := session.exec(sqlStr, args...) + if err != nil { + return 0, err + } + n, err := res.RowsAffected() + if err != nil { + return 0, err + } + if n == 0 { + return 0, sql.ErrNoRows + } + i := utils.IndexSlice(colNames, table.AutoIncrement) + if i > -1 { + id, err = convert.AsInt64(args[i]) + if err != nil { + return 0, err + } + } else { + newSql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName)) + } + } else { + newSql = sqlStr + newArgs = args + } + + if id == 0 { + err := session.queryRow(newSql, newArgs...).Scan(&id) + if err != nil { + return 0, err + } + if needCommit { + if err := session.Commit(); err != nil { + return 0, err + } + } + if id == 0 { + return 0, errors.New("insert successfully but not returned id") + } + } + + defer session.handleAfterInsertProcessorFunc(bean) + + _ = session.cacheInsert(tableName) + + if table.Version != "" && session.statement.CheckVersion { + verValue, err := table.VersionColumn().ValueOf(bean) + if err != nil { + session.engine.logger.Errorf("%v", err) + } else if verValue.IsValid() && verValue.CanSet() { + session.incrVersionFieldValue(verValue) + } + } + + aiValue, err := table.AutoIncrColumn().ValueOf(bean) + if err != nil { + session.engine.logger.Errorf("%v", err) + } + + if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { + return 1, nil + } + + return 1, convert.AssignValue(*aiValue, id) +} + +// InsertOne insert only one struct into database as a record. +// The in parameter bean must a struct or a point to struct. The return +// parameter is inserted and error +// Deprecated: Please use Insert directly +func (session *Session) InsertOne(bean interface{}) (int64, error) { + if session.isAutoClose { + defer session.Close() + } + + return session.insertStruct(bean) +} + +func (session *Session) cacheInsert(table string) error { + if !session.statement.UseCache { + return nil + } + cacher := session.engine.cacherMgr.GetCacher(table) + if cacher == nil { + return nil + } + session.engine.logger.Debugf("[cache] clear SQL: %v", table) + cacher.ClearIds(table) + return nil +} + func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) { if len(m) == 0 { return 0, ErrParamsType @@ -547,19 +535,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err return 0, ErrTableNotFound } - var columns = make([]string, 0, len(m)) - exprs := session.statement.ExprColumns - for k := range m { - if !exprs.IsColExist(k) { - columns = append(columns, k) - } - } - sort.Strings(columns) - - var args = make([]interface{}, 0, len(m)) - for _, colName := range columns { - args = append(args, m[colName]) - } + columns, args := utils.MapToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim) return session.insertMap(columns, args) } @@ -574,23 +550,7 @@ func (session *Session) insertMultipleMapInterface(maps []map[string]interface{} return 0, ErrTableNotFound } - var columns = make([]string, 0, len(maps[0])) - exprs := session.statement.ExprColumns - for k := range maps[0] { - if !exprs.IsColExist(k) { - columns = append(columns, k) - } - } - sort.Strings(columns) - - var argss = make([][]interface{}, 0, len(maps)) - for _, m := range maps { - var args = make([]interface{}, 0, len(m)) - for _, colName := range columns { - args = append(args, m[colName]) - } - argss = append(argss, args) - } + columns, argss := utils.MultipleMapToSlices(maps, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim) return session.insertMultipleMap(columns, argss) } @@ -605,20 +565,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { return 0, ErrTableNotFound } - var columns = make([]string, 0, len(m)) - exprs := session.statement.ExprColumns - for k := range m { - if !exprs.IsColExist(k) { - columns = append(columns, k) - } - } - - sort.Strings(columns) - - var args = make([]interface{}, 0, len(m)) - for _, colName := range columns { - args = append(args, m[colName]) - } + columns, args := utils.MapStringToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim) return session.insertMap(columns, args) } @@ -633,23 +580,7 @@ func (session *Session) insertMultipleMapString(maps []map[string]string) (int64 return 0, ErrTableNotFound } - var columns = make([]string, 0, len(maps[0])) - exprs := session.statement.ExprColumns - for k := range maps[0] { - if !exprs.IsColExist(k) { - columns = append(columns, k) - } - } - sort.Strings(columns) - - var argss = make([][]interface{}, 0, len(maps)) - for _, m := range maps { - var args = make([]interface{}, 0, len(m)) - for _, colName := range columns { - args = append(args, m[colName]) - } - argss = append(argss, args) - } + columns, argss := utils.MultipleMapStringToSlices(maps, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim) return session.insertMultipleMap(columns, argss) } @@ -707,3 +638,30 @@ func (session *Session) insertMultipleMap(columns []string, argss [][]interface{ } return affected, nil } + +func (session *Session) handleAfterInsertProcessorFunc(bean interface{}) { + if session.isAutoCommit { + for _, closure := range session.afterClosures { + closure(bean) + } + if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { + processor.AfterInsert() + } + } else { + lenAfterClosures := len(session.afterClosures) + if lenAfterClosures > 0 { + if value, has := session.afterInsertBeans[bean]; has && value != nil { + *value = append(*value, session.afterClosures...) + } else { + afterClosures := make([]func(interface{}), lenAfterClosures) + copy(afterClosures, session.afterClosures) + session.afterInsertBeans[bean] = &afterClosures + } + } else { + if _, ok := interface{}(bean).(AfterInsertProcessor); ok { + session.afterInsertBeans[bean] = nil + } + } + } + cleanupProcessorsClosures(&session.afterClosures) // cleanup after used +} diff --git a/session_upsert.go b/session_upsert.go new file mode 100644 index 00000000..695b7b69 --- /dev/null +++ b/session_upsert.go @@ -0,0 +1,298 @@ +package xorm + +import ( + "database/sql" + "fmt" + "reflect" + + "xorm.io/xorm/convert" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +func (session *Session) InsertOnConflictDoNothing(beans ...interface{}) (int64, error) { + return session.upsert(false, beans...) +} + +func (session *Session) Upsert(beans ...interface{}) (int64, error) { + return session.upsert(true, beans...) +} + +func (session *Session) upsert(doUpdate bool, beans ...interface{}) (int64, error) { + var affected int64 + var err error + + if session.isAutoClose { + defer session.Close() + } + + session.autoResetStatement = false + defer func() { + session.autoResetStatement = true + session.resetStatement() + }() + for _, bean := range beans { + var cnt int64 + var err error + switch v := bean.(type) { + case map[string]interface{}: + cnt, err = session.upsertMapInterface(doUpdate, v) + case []map[string]interface{}: + for _, m := range v { + cnt, err := session.upsertMapInterface(doUpdate, m) + if err != nil { + return affected, err + } + affected += cnt + } + case map[string]string: + cnt, err = session.upsertMapString(doUpdate, v) + case []map[string]string: + for _, m := range v { + cnt, err := session.upsertMapString(doUpdate, m) + if err != nil { + return affected, err + } + affected += cnt + } + default: + sliceValue := reflect.Indirect(reflect.ValueOf(bean)) + if sliceValue.Kind() == reflect.Slice { + if sliceValue.Len() <= 0 { + return 0, ErrNoElementsOnSlice + } + for i := 0; i < sliceValue.Len(); i++ { + v := sliceValue.Index(i) + bean := v.Interface() + cnt, err := session.upsertStruct(doUpdate, bean) + if err != nil { + return affected, err + } + affected += cnt + } + } else { + cnt, err = session.upsertStruct(doUpdate, bean) + } + } + if err != nil { + return affected, err + } + affected += cnt + } + + return affected, err +} + +func (session *Session) upsertMapInterface(doUpdate bool, m map[string]interface{}) (int64, error) { + if len(m) == 0 { + return 0, ErrParamsType + } + + tableName := session.statement.TableName() + if len(tableName) == 0 { + return 0, ErrTableNotFound + } + + columns, args := utils.MapToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim) + return session.upsertMap(doUpdate, columns, args) +} + +func (session *Session) upsertMapString(doUpdate bool, m map[string]string) (int64, error) { + if len(m) == 0 { + return 0, ErrParamsType + } + + tableName := session.statement.TableName() + if len(tableName) == 0 { + return 0, ErrTableNotFound + } + + columns, args := utils.MapStringToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim) + return session.upsertMap(doUpdate, columns, args) +} + +func (session *Session) upsertMap(doUpdate bool, columns []string, args []interface{}) (int64, error) { + tableName := session.statement.TableName() + if len(tableName) == 0 { + return 0, ErrTableNotFound + } + + uniqueColValMap, err := session.getUniqueColumns(columns, args) + if err != nil { + return 0, err + } + + sql, args, err := session.statement.GenUpsertMapSQL(doUpdate, columns, args, uniqueColValMap) + if err != nil { + return 0, err + } + sql = session.engine.dialect.Quoter().Replace(sql) + + if err := session.cacheInsert(tableName); err != nil { + return 0, err + } + + res, err := session.exec(sql, args...) + if err != nil { + return 0, err + } + affected, err := res.RowsAffected() + if err != nil { + return 0, err + } + return affected, fmt.Errorf("unimplemented") +} + +func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, error) { + if err := session.statement.SetRefBean(bean); err != nil { + return 0, err + } + if len(session.statement.TableName()) == 0 { + return 0, ErrTableNotFound + } + + // handle BeforeInsertProcessor + for _, closure := range session.beforeClosures { + closure(bean) + } + cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used + + if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok { + processor.BeforeInsert() + } + + tableName := session.statement.TableName() + table := session.statement.RefTable + + colNames, args, err := session.genInsertColumns(bean) + if err != nil { + return 0, err + } + + uniqueColValMap, err := session.getUniqueColumns(colNames, args) + if err != nil { + return 0, err + } + + sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, colNames, args, uniqueColValMap) + if err != nil { + return 0, err + } + sqlStr = session.engine.dialect.Quoter().Replace(sqlStr) + + // if there is auto increment column and driver doesn't support return it + if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID { + n, err := session.execInsertSqlNoAutoReturn(sqlStr, bean, colNames, args) + if err == sql.ErrNoRows { + return n, nil + } + return n, err + } + + res, err := session.exec(sqlStr, args...) + if err != nil { + return 0, err + } + + defer session.handleAfterInsertProcessorFunc(bean) + + _ = session.cacheInsert(tableName) + + if table.Version != "" && session.statement.CheckVersion { + verValue, err := table.VersionColumn().ValueOf(bean) + if err != nil { + session.engine.logger.Errorf("%v", err) + } else if verValue.IsValid() && verValue.CanSet() { + session.incrVersionFieldValue(verValue) + } + } + + if table.AutoIncrement == "" { + return res.RowsAffected() + } + + n, err := res.RowsAffected() + if err != nil || n == 0 { + return 0, err + } + + var id int64 + id, err = res.LastInsertId() + if err != nil || id <= 0 { + return n, err + } + + aiValue, err := table.AutoIncrColumn().ValueOf(bean) + if err != nil { + session.engine.logger.Errorf("%v", err) + } + + if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { + return n, err + } + + if err := convert.AssignValue(*aiValue, id); err != nil { + return 0, err + } + + return n, err +} + +func (session *Session) getUniqueColumns(colNames []string, args []interface{}) (uniqueColValMap map[string]interface{}, err error) { + uniqueColValMap = make(map[string]interface{}) + table := session.statement.RefTable + if len(table.Indexes) == 0 { + return nil, fmt.Errorf("provided bean has no unique constraints") + } + + // Iterate across the indexes in the provided table + for _, index := range table.Indexes { + if index.Type != schemas.UniqueType { + continue + } + + // index is a Unique constraint + indexCol: + for _, indexColumnName := range index.Cols { + if _, has := uniqueColValMap[indexColumnName]; has { + // column is already included in uniqueCols so we don't need to add it again + continue indexCol + } + + // Now iterate across colNames and add to the uniqueCols + for i, col := range colNames { + if col == indexColumnName { + uniqueColValMap[col] = args[i] + continue indexCol + } + } + + indexColumn := table.GetColumn(indexColumnName) + if !indexColumn.DefaultIsEmpty { + uniqueColValMap[indexColumnName] = indexColumn.Default + } + + if indexColumn.MapType == schemas.ONLYFROMDB || indexColumn.IsAutoIncrement { + continue + } + if session.statement.OmitColumnMap.Contain(indexColumn.Name) { + continue + } + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) { + continue + } + // FIXME: what do we do here?! + if session.statement.IncrColumns.IsColExist(indexColumn.Name) { + continue + } else if session.statement.DecrColumns.IsColExist(indexColumn.Name) { + continue + } else if session.statement.ExprColumns.IsColExist(indexColumn.Name) { + continue + } + + // FIXME: not sure if there's anything else we can do + return nil, fmt.Errorf("provided bean does not provide a value for unique constraint %s field %s which has no default", index.Name, indexColumnName) + } + } + return uniqueColValMap, nil +}