Add InsertOnConflictDoNothing and Upsert functionality

This PR adds functionality for xorm to perform InsertOnConflictDoNothing
and Upserts.

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
Andrew Thornton 2023-03-12 11:31:08 +00:00
parent d485abba57
commit 3a4fbeaa6f
No known key found for this signature in database
GPG Key ID: 3CDE74631F13A748
9 changed files with 1262 additions and 367 deletions

View File

@ -1224,6 +1224,13 @@ func (engine *Engine) InsertOne(bean interface{}) (int64, error) {
return session.InsertOne(bean) return session.InsertOne(bean)
} }
// InsertOnConflictDoNothing attempt to insert a record but on conflict do nothing
func (engine *Engine) InsertOnConflictDoNothing(beans ...interface{}) (int64, error) {
session := engine.NewSession()
defer session.Close()
return session.InsertOnConflictDoNothing(beans...)
}
// Update records, bean's non-empty fields are updated contents, // Update records, bean's non-empty fields are updated contents,
// condiBean' non-empty filds are conditions // condiBean' non-empty filds are conditions
// CAUTION: // CAUTION:
@ -1237,6 +1244,13 @@ func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64
return session.Update(bean, condiBeans...) return session.Update(bean, condiBeans...)
} }
// Upsert attempt to insert a record but on conflict do update
func (engine *Engine) Upsert(beans ...interface{}) (int64, error) {
session := engine.NewSession()
defer session.Close()
return session.Upsert(beans...)
}
// Delete records, bean's non-empty fields are conditions // Delete records, bean's non-empty fields are conditions
// At least one condition must be set. // At least one condition must be set.
func (engine *Engine) Delete(beans ...interface{}) (int64, error) { func (engine *Engine) Delete(beans ...interface{}) (int64, error) {

View File

@ -16,6 +16,231 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestInsertOnConflictDoNothing(t *testing.T) {
assert.NoError(t, PrepareEngine())
t.Run("NoUnique", func(t *testing.T) {
// InsertOnConflictDoNothing does not work if there is no unique constraint
type NoUniques struct {
ID int64 `xorm:"pk autoincr"`
Data string
}
assert.NoError(t, testEngine.Sync(new(NoUniques)))
toInsert := &NoUniques{Data: "shouldErr"}
n, err := testEngine.InsertOnConflictDoNothing(toInsert)
assert.Error(t, err)
assert.Equal(t, int64(0), n)
assert.Equal(t, int64(0), toInsert.ID)
toInsert = &NoUniques{Data: ""}
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
assert.Error(t, err)
assert.Equal(t, int64(0), n)
assert.Equal(t, int64(0), toInsert.ID)
})
t.Run("OneUnique", func(t *testing.T) {
type OneUnique struct {
ID int64 `xorm:"pk autoincr"`
Data string `xorm:"UNIQUE NOT NULL"`
}
assert.NoError(t, testEngine.Sync2(&OneUnique{}))
_, _ = testEngine.Exec("DELETE FROM one_unique")
// Insert with the default value for the unique field
toInsert := &OneUnique{}
n, err := testEngine.InsertOnConflictDoNothing(toInsert)
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
assert.NotEqual(t, int64(0), toInsert.ID)
// but not twice
toInsert = &OneUnique{}
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
assert.Equal(t, int64(0), toInsert.ID)
// Successfully insert test
toInsert = &OneUnique{Data: "test"}
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
assert.NotEqual(t, int64(0), toInsert.ID)
// Successfully insert test2
toInsert = &OneUnique{Data: "test2"}
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
assert.NotEqual(t, int64(0), toInsert.ID)
// Successfully don't reinsert test
toInsert = &OneUnique{Data: "test"}
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
assert.Equal(t, int64(0), toInsert.ID)
})
t.Run("MultiUnique", func(t *testing.T) {
type MultiUnique struct {
ID int64 `xorm:"pk autoincr"`
NotUnique string
Data1 string `xorm:"UNIQUE(s) NOT NULL"`
Data2 string `xorm:"UNIQUE(s) NOT NULL"`
}
assert.NoError(t, testEngine.Sync2(&MultiUnique{}))
_, _ = testEngine.Exec("DELETE FROM multi_unique")
// Insert with default values
toInsert := &MultiUnique{}
n, err := testEngine.InsertOnConflictDoNothing(toInsert)
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
assert.NotEqual(t, int64(0), toInsert.ID)
// successfully insert test, t1
toInsert = &MultiUnique{Data1: "test", NotUnique: "t1"}
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
assert.NotEqual(t, int64(0), toInsert.ID)
// successfully insert test2, t1
toInsert = &MultiUnique{Data1: "test2", NotUnique: "t1"}
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
assert.NotEqual(t, int64(0), toInsert.ID)
// successfully don't insert test2, t2
toInsert = &MultiUnique{Data1: "test2", NotUnique: "t2"}
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
assert.Equal(t, int64(0), toInsert.ID)
// successfully don't insert test, t2
toInsert = &MultiUnique{Data1: "test", NotUnique: "t2"}
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
assert.Equal(t, int64(0), toInsert.ID)
// successfully insert test/test2, t2
toInsert = &MultiUnique{Data1: "test", Data2: "test2", NotUnique: "t1"}
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
assert.NotEqual(t, int64(0), toInsert.ID)
// successfully don't insert test/test2, t2
toInsert = &MultiUnique{Data1: "test", Data2: "test2", NotUnique: "t2"}
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
assert.Equal(t, int64(0), toInsert.ID)
})
t.Run("MultiMultiUnique", func(t *testing.T) {
type MultiMultiUnique struct {
ID int64 `xorm:"pk autoincr"`
Data0 string `xorm:"UNIQUE NOT NULL"`
Data1 string `xorm:"UNIQUE(s) NOT NULL"`
Data2 string `xorm:"UNIQUE(s) NOT NULL"`
}
assert.NoError(t, testEngine.Sync2(&MultiMultiUnique{}))
_, _ = testEngine.Exec("DELETE FROM multi_multi_unique")
// Insert with default values
n, err := testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Insert with value for t1, <test, "">
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data1: "test", Data0: "t1"})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Fail insert with value for t1, <test2, "">
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data2: "test2", Data0: "t1"})
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
// Insert with value for t2, <test2, "">
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data2: "test2", Data0: "t2"})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Fail insert with value for t2, <test2, "">
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data2: "test2", Data0: "t2"})
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
// Fail insert with value for t2, <test, "">
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data1: "test", Data0: "t2"})
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
// Insert with value for t3, <test, test2>
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data1: "test", Data2: "test2", Data0: "t3"})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// fail insert with value for t2, <test, test2>
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data1: "test", Data2: "test2", Data0: "t2"})
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
})
t.Run("NoPK", func(t *testing.T) {
type NoPrimaryKey struct {
NotID int64
Uniqued string `xorm:"UNIQUE"`
}
assert.NoError(t, testEngine.Sync2(&NoPrimaryKey{}))
_, _ = testEngine.Exec("DELETE FROM no_primary_unique")
empty := &NoPrimaryKey{}
// Insert default
n, err := testEngine.InsertOnConflictDoNothing(empty)
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Insert with 1
n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{Uniqued: "1"})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Fail reinsert default
n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{NotID: 1})
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
// Fail reinsert default
n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{NotID: 2})
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
// Insert with 2
n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{NotID: 2, Uniqued: "2"})
assert.NoError(t, err)
assert.Equal(t, int64(1), n)
// Fail reinsert with 2
n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{NotID: 1, Uniqued: "2"})
assert.NoError(t, err)
assert.Equal(t, int64(0), n)
})
}
func TestInsertOne(t *testing.T) { func TestInsertOne(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
@ -142,8 +367,13 @@ func TestInsert(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
assertSync(t, new(Userinfo)) assertSync(t, new(Userinfo))
user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(), user := Userinfo{
Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true} 0, "xiaolunwen", "dev", "lunny", time.Now(),
Userdetail{Id: 1},
1.78,
[]byte{1, 2, 3},
true,
}
cnt, err := testEngine.Insert(&user) cnt, err := testEngine.Insert(&user)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt, "insert not returned 1") assert.EqualValues(t, 1, cnt, "insert not returned 1")
@ -161,8 +391,10 @@ func TestInsertAutoIncr(t *testing.T) {
assertSync(t, new(Userinfo)) assertSync(t, new(Userinfo))
// auto increment insert // auto increment insert
user := Userinfo{Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(), user := Userinfo{
Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true} Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(),
Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true,
}
cnt, err := testEngine.Insert(&user) cnt, err := testEngine.Insert(&user)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
@ -184,7 +416,7 @@ func TestInsertDefault(t *testing.T) {
err := testEngine.Sync(di) err := testEngine.Sync(di)
assert.NoError(t, err) assert.NoError(t, err)
var di2 = DefaultInsert{Name: "test"} di2 := DefaultInsert{Name: "test"}
_, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("Status")).Insert(&di2) _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("Status")).Insert(&di2)
assert.NoError(t, err) assert.NoError(t, err)
@ -210,7 +442,7 @@ func TestInsertDefault2(t *testing.T) {
err := testEngine.Sync(di) err := testEngine.Sync(di)
assert.NoError(t, err) assert.NoError(t, err)
var di2 = DefaultInsert2{Name: "test"} di2 := DefaultInsert2{Name: "test"}
_, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("CheckTime")).Insert(&di2) _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("CheckTime")).Insert(&di2)
assert.NoError(t, err) assert.NoError(t, err)
@ -438,7 +670,7 @@ func TestCreatedJsonTime(t *testing.T) {
assert.True(t, has) assert.True(t, has)
assert.EqualValues(t, time.Time(ci5.Created).Unix(), time.Time(di5.Created).Unix()) assert.EqualValues(t, time.Time(ci5.Created).Unix(), time.Time(di5.Created).Unix())
var dis = make([]MyJSONTime, 0) dis := make([]MyJSONTime, 0)
err = testEngine.Find(&dis) err = testEngine.Find(&dis)
assert.NoError(t, err) assert.NoError(t, err)
} }
@ -762,7 +994,7 @@ func TestInsertWhere(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
assertSync(t, new(InsertWhere)) assertSync(t, new(InsertWhere))
var i = InsertWhere{ i := InsertWhere{
RepoId: 1, RepoId: 1,
Width: 10, Width: 10,
Height: 20, Height: 20,
@ -872,7 +1104,7 @@ func TestInsertExpr2(t *testing.T) {
assertSync(t, new(InsertExprsRelease)) assertSync(t, new(InsertExprsRelease))
var ie = InsertExprsRelease{ ie := InsertExprsRelease{
RepoId: 1, RepoId: 1,
IsTag: true, IsTag: true,
} }
@ -1047,7 +1279,7 @@ func TestInsertIntSlice(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(InsertIntSlice))) assert.NoError(t, testEngine.Sync(new(InsertIntSlice)))
var v = InsertIntSlice{ v := InsertIntSlice{
NameIDs: []int{1, 2}, NameIDs: []int{1, 2},
} }
cnt, err := testEngine.Insert(&v) cnt, err := testEngine.Insert(&v)
@ -1064,7 +1296,7 @@ func TestInsertIntSlice(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var v3 = InsertIntSlice{ v3 := InsertIntSlice{
NameIDs: nil, NameIDs: nil,
} }
cnt, err = testEngine.Insert(&v3) cnt, err = testEngine.Insert(&v3)

View File

@ -44,7 +44,8 @@ type Interface interface {
In(string, ...interface{}) *Session In(string, ...interface{}) *Session
Incr(column string, arg ...interface{}) *Session Incr(column string, arg ...interface{}) *Session
Insert(...interface{}) (int64, error) Insert(...interface{}) (int64, error)
InsertOne(interface{}) (int64, error) InsertOne(interface{}) (int64, error) // Deprecated: Please use Insert directly
InsertOnConflictDoNothing(beans ...interface{}) (int64, error)
IsTableEmpty(bean interface{}) (bool, error) IsTableEmpty(bean interface{}) (bool, error)
IsTableExist(beanOrTableName interface{}) (bool, error) IsTableExist(beanOrTableName interface{}) (bool, error)
Iterate(interface{}, IterFunc) error Iterate(interface{}, IterFunc) error
@ -71,6 +72,7 @@ type Interface interface {
Table(tableNameOrBean interface{}) *Session Table(tableNameOrBean interface{}) *Session
Unscoped() *Session Unscoped() *Session
Update(bean interface{}, condiBeans ...interface{}) (int64, error) Update(bean interface{}, condiBeans ...interface{}) (int64, error)
Upsert(beans ...interface{}) (int64, error)
UseBool(...string) *Session UseBool(...string) *Session
Where(interface{}, ...interface{}) *Session Where(interface{}, ...interface{}) *Session
} }

View File

@ -59,13 +59,21 @@ func (expr *Expr) WriteArgs(w *builder.BytesWriter) error {
type exprParams []Expr type exprParams []Expr
func (exprs exprParams) ColNames() []string { func (exprs exprParams) ColNames() []string {
var cols = make([]string, 0, len(exprs)) cols := make([]string, 0, len(exprs))
for _, expr := range exprs { for _, expr := range exprs {
cols = append(cols, expr.ColName) cols = append(cols, expr.ColName)
} }
return cols return cols
} }
func (exprs exprParams) ColNamesTrim() []string {
cols := make([]string, 0, len(exprs))
for _, expr := range exprs {
cols = append(cols, schemas.CommonQuoter.Trim(expr.ColName))
}
return cols
}
func (exprs *exprParams) Add(name string, arg interface{}) { func (exprs *exprParams) Add(name string, arg interface{}) {
*exprs = append(*exprs, Expr{name, arg}) *exprs = append(*exprs, Expr{name, arg})
} }

View File

@ -30,7 +30,6 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem
func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (string, []interface{}, error) { func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (string, []interface{}, error) {
var ( var (
buf = builder.NewWriter() buf = builder.NewWriter()
exprs = statement.ExprColumns
table = statement.RefTable table = statement.RefTable
tableName = statement.TableName() tableName = statement.TableName()
) )
@ -43,130 +42,9 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
return "", nil, err return "", nil, err
} }
var hasInsertColumns = len(colNames) > 0 if err := statement.genInsertValues(buf, colNames, args); err != nil {
var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
if needSeq {
for _, col := range colNames {
if strings.EqualFold(col, table.AutoIncrement) {
needSeq = false
break
}
}
}
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE &&
statement.dialect.URI().DBType != schemas.DAMENG {
if statement.dialect.URI().DBType == schemas.MYSQL {
if _, err := buf.WriteString(" VALUES ()"); err != nil {
return "", nil, err return "", nil, err
} }
} else {
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
return "", nil, err
}
if _, err := buf.WriteString(" DEFAULT VALUES"); err != nil {
return "", nil, err
}
}
} else {
if _, err := buf.WriteString(" ("); err != nil {
return "", nil, err
}
if needSeq {
colNames = append(colNames, table.AutoIncrement)
}
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames()...), ","); err != nil {
return "", nil, err
}
if _, err := buf.WriteString(")"); err != nil {
return "", nil, err
}
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
return "", nil, err
}
if statement.Conds().IsValid() {
if _, err := buf.WriteString(" SELECT "); err != nil {
return "", nil, err
}
if err := statement.WriteArgs(buf, args); err != nil {
return "", nil, err
}
if needSeq {
if len(args) > 0 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
}
}
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
return "", nil, err
}
}
if len(exprs) > 0 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
}
if err := exprs.WriteArgs(buf); err != nil {
return "", nil, err
}
}
if _, err := buf.WriteString(" FROM "); err != nil {
return "", nil, err
}
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil {
return "", nil, err
}
if _, err := buf.WriteString(" WHERE "); err != nil {
return "", nil, err
}
if err := statement.Conds().WriteTo(buf); err != nil {
return "", nil, err
}
} else {
if _, err := buf.WriteString(" VALUES ("); err != nil {
return "", nil, err
}
if err := statement.WriteArgs(buf, args); err != nil {
return "", nil, err
}
// Insert tablename (id) Values(seq_tablename.nextval)
if needSeq {
if hasInsertColumns {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
}
}
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
return "", nil, err
}
}
if len(exprs) > 0 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
}
}
if err := exprs.WriteArgs(buf); err != nil {
return "", nil, err
}
if _, err := buf.WriteString(")"); err != nil {
return "", nil, err
}
}
}
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES {
if _, err := buf.WriteString(" RETURNING "); err != nil { if _, err := buf.WriteString(" RETURNING "); err != nil {
@ -180,6 +58,169 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil
} }
func (statement *Statement) includeAutoIncrement(colNames []string) bool {
includesAutoIncrement := len(statement.RefTable.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
if includesAutoIncrement {
for _, col := range colNames {
if strings.EqualFold(col, statement.RefTable.AutoIncrement) {
includesAutoIncrement = false
break
}
}
}
return includesAutoIncrement
}
func (statement *Statement) genInsertValues(buf *builder.BytesWriter, colNames []string, args []interface{}) error {
var (
exprs = statement.ExprColumns
table = statement.RefTable
)
hasInsertColumns := len(colNames) > 0
includeAutoIncrement := statement.includeAutoIncrement(colNames)
// Empty insert - i.e. insert default values only
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE &&
statement.dialect.URI().DBType != schemas.DAMENG {
if statement.dialect.URI().DBType == schemas.MYSQL {
// MySQL doesn't have DEFAULT VALUES and uses VALUES () for this.
if _, err := buf.WriteString(" VALUES ()"); err != nil {
return err
}
return nil
}
// (MSSQL: return the inserted values)
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
return err
}
// All others use DEFAULT VALUES
if _, err := buf.WriteString(" DEFAULT VALUES"); err != nil {
return err
}
return nil
}
// We have some values - Write the column names we need to insert:
if _, err := buf.WriteString(" ("); err != nil {
return err
}
if includeAutoIncrement {
colNames = append(colNames, table.AutoIncrement)
}
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames()...), ","); err != nil {
return err
}
if _, err := buf.WriteString(")"); err != nil {
return err
}
// (MSSQL: return the inserted values)
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
return err
}
return statement.genInsertValuesValues(buf, includeAutoIncrement, colNames, args)
}
func (statement *Statement) genInsertValuesValues(buf *builder.BytesWriter, includeAutoIncrement bool, colNames []string, args []interface{}) error {
var (
exprs = statement.ExprColumns
tableName = statement.TableName()
)
hasInsertColumns := len(colNames) > 0
if statement.Conds().IsValid() {
// We have conditions which we're trying to insert
if _, err := buf.WriteString(" SELECT "); err != nil {
return err
}
if err := statement.WriteArgs(buf, args); err != nil {
return err
}
if includeAutoIncrement {
if len(args) > 0 {
if _, err := buf.WriteString(","); err != nil {
return err
}
}
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
return err
}
}
if len(exprs) > 0 {
if _, err := buf.WriteString(","); err != nil {
return err
}
if err := exprs.WriteArgs(buf); err != nil {
return err
}
}
if _, err := buf.WriteString(" FROM "); err != nil {
return err
}
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil {
return err
}
if _, err := buf.WriteString(" WHERE "); err != nil {
return err
}
if err := statement.Conds().WriteTo(buf); err != nil {
return err
}
return nil
}
// Direct insertion of values
if _, err := buf.WriteString(" VALUES ("); err != nil {
return err
}
if err := statement.WriteArgs(buf, args); err != nil {
return err
}
// Insert tablename (id) Values(seq_tablename.nextval)
if includeAutoIncrement {
if hasInsertColumns {
if _, err := buf.WriteString(","); err != nil {
return err
}
}
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
return err
}
}
if len(exprs) > 0 {
if _, err := buf.WriteString(","); err != nil {
return err
}
}
if err := exprs.WriteArgs(buf); err != nil {
return err
}
if _, err := buf.WriteString(")"); err != nil {
return err
}
return nil
}
// GenInsertMapSQL generates insert map SQL // GenInsertMapSQL generates insert map SQL
func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{}) (string, []interface{}, error) { func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{}) (string, []interface{}, error) {
var ( var (
@ -196,51 +237,12 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{}
return "", nil, err return "", nil, err
} }
// if insert where
if statement.Conds().IsValid() {
if _, err := buf.WriteString(") SELECT "); err != nil {
return "", nil, err
}
if err := statement.WriteArgs(buf, args); err != nil {
return "", nil, err
}
if len(exprs) > 0 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
}
if err := exprs.WriteArgs(buf); err != nil {
return "", nil, err
}
}
if _, err := buf.WriteString(fmt.Sprintf(" FROM %s WHERE ", statement.quote(tableName))); err != nil {
return "", nil, err
}
if err := statement.Conds().WriteTo(buf); err != nil {
return "", nil, err
}
} else {
if _, err := buf.WriteString(") VALUES ("); err != nil {
return "", nil, err
}
if err := statement.WriteArgs(buf, args); err != nil {
return "", nil, err
}
if len(exprs) > 0 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
}
if err := exprs.WriteArgs(buf); err != nil {
return "", nil, err
}
}
if _, err := buf.WriteString(")"); err != nil { if _, err := buf.WriteString(")"); err != nil {
return "", nil, err return "", nil, err
} }
if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil {
return "", nil, err
} }
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil

View File

@ -0,0 +1,271 @@
// Copyright 2020 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"xorm.io/builder"
"xorm.io/xorm/schemas"
)
// GenUpsertSQL generates upsert beans SQL
func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
if statement.dialect.URI().DBType == schemas.MSSQL ||
statement.dialect.URI().DBType == schemas.DAMENG ||
statement.dialect.URI().DBType == schemas.ORACLE {
return statement.genMergeSQL(doUpdate, columns, args, uniqueColValMap)
}
var (
buf = builder.NewWriter()
table = statement.RefTable
tableName = statement.TableName()
)
if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate {
if _, err := buf.WriteString("INSERT IGNORE INTO "); err != nil {
return "", nil, err
}
} else {
if _, err := buf.WriteString("INSERT INTO "); err != nil {
return "", nil, err
}
}
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil {
return "", nil, err
}
if err := statement.genInsertValues(buf, columns, args); err != nil {
return "", nil, err
}
switch statement.dialect.URI().DBType {
case schemas.SQLITE, schemas.POSTGRES:
if _, err := buf.WriteString(" ON CONFLICT DO "); err != nil {
return "", nil, err
}
if doUpdate {
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
} else {
if _, err := buf.WriteString("NOTHING"); err != nil {
return "", nil, err
}
}
case schemas.MYSQL:
if doUpdate {
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
}
default:
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
}
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES {
if _, err := buf.WriteString(" RETURNING "); err != nil {
return "", nil, err
}
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
return "", nil, err
}
}
return buf.String(), buf.Args(), nil
}
func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
var (
buf = builder.NewWriter()
table = statement.RefTable
tableName = statement.TableName()
)
quote := statement.dialect.Quoter().Quote
write := func(args ...string) {
for _, arg := range args {
_, _ = buf.WriteString(arg)
}
}
write("MERGE INTO ", quote(tableName))
if statement.dialect.URI().DBType == schemas.MSSQL {
write("WITH (HOLDLOCK) AS target ")
}
write("USING (SELECT ")
uniqueCols := make([]string, 0, len(uniqueColValMap))
for colName := range uniqueColValMap {
uniqueCols = append(uniqueCols, colName)
}
for i, colName := range uniqueCols {
if err := statement.WriteArg(buf, uniqueColValMap[colName]); err != nil {
return "", nil, err
}
write(" AS ", quote(colName))
if i < len(uniqueCols)-1 {
write(", ")
}
}
write(") AS src ON (")
countUniques := 0
for _, index := range table.Indexes {
if index.Type != schemas.UniqueType {
continue
}
if countUniques > 0 {
write(" OR ")
}
countUniques++
write("(")
write("src.", quote(index.Cols[0]), "= target.", quote(index.Cols[0]))
for _, col := range index.Cols[1:] {
write(" AND src.", quote(col), "= target.", quote(col))
}
write(")")
}
if doUpdate {
return "", nil, fmt.Errorf("unimplemented")
}
write(") WHEN NOT MATCHED THEN INSERT")
if err := statement.genInsertValues(buf, columns, args); err != nil {
return "", nil, err
}
return buf.String(), buf.Args(), nil
}
// GenUpsertMapSQL generates insert map SQL
func (statement *Statement) GenUpsertMapSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
if statement.dialect.URI().DBType == schemas.MSSQL ||
statement.dialect.URI().DBType == schemas.DAMENG ||
statement.dialect.URI().DBType == schemas.ORACLE {
return statement.genMergeMapSQL(doUpdate, columns, args, uniqueColValMap)
}
var (
buf = builder.NewWriter()
exprs = statement.ExprColumns
table = statement.RefTable
tableName = statement.TableName()
)
quote := statement.dialect.Quoter().Quote
write := func(args ...string) {
for _, arg := range args {
_, _ = buf.WriteString(arg)
}
}
if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate {
write("INSERT IGNORE INTO ", quote(tableName), " (")
} else {
write("INSERT INTO ", quote(tableName), " (")
}
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil {
return "", nil, err
}
write(")")
if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil {
return "", nil, err
}
switch statement.dialect.URI().DBType {
case schemas.SQLITE, schemas.POSTGRES:
if _, err := buf.WriteString(" ON CONFLICT DO "); err != nil {
return "", nil, err
}
if doUpdate {
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
} else {
if _, err := buf.WriteString("NOTHING"); err != nil {
return "", nil, err
}
}
case schemas.MYSQL:
if doUpdate {
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
}
default:
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
}
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES {
if _, err := buf.WriteString(" RETURNING "); err != nil {
return "", nil, err
}
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
return "", nil, err
}
}
return buf.String(), buf.Args(), nil
}
func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
var (
buf = builder.NewWriter()
table = statement.RefTable
exprs = statement.ExprColumns
tableName = statement.TableName()
)
quote := statement.dialect.Quoter().Quote
write := func(args ...string) {
for _, arg := range args {
_, _ = buf.WriteString(arg)
}
}
write("MERGE INTO ", quote(tableName))
if statement.dialect.URI().DBType == schemas.MSSQL {
write("WITH (HOLDLOCK) AS target ")
}
write("USING (SELECT ")
uniqueCols := make([]string, 0, len(uniqueColValMap))
for colName := range uniqueColValMap {
uniqueCols = append(uniqueCols, colName)
}
for i, colName := range uniqueCols {
if err := statement.WriteArg(buf, uniqueColValMap[colName]); err != nil {
return "", nil, err
}
write(" AS ", quote(colName))
if i < len(uniqueCols)-1 {
write(", ")
}
}
write(") AS src ON (")
countUniques := 0
for _, index := range table.Indexes {
if index.Type != schemas.UniqueType {
continue
}
if countUniques > 0 {
write(" OR ")
}
countUniques++
write("(")
write("src.", quote(index.Cols[0]), "= target.", quote(index.Cols[0]))
for _, col := range index.Cols[1:] {
write(" AND src.", quote(col), "= target.", quote(col))
}
write(")")
}
if doUpdate {
return "", nil, fmt.Errorf("unimplemented")
}
write(") WHEN NOT MATCHED THEN INSERT")
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil {
return "", nil, err
}
write(")")
if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil {
return "", nil, err
}
return buf.String(), buf.Args(), nil
}

110
internal/utils/map.go Normal file
View File

@ -0,0 +1,110 @@
// Copyright 2020 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package utils
import (
"sort"
"strings"
)
func MapToSlices(m map[string]interface{}, exclude []string, trimmer func(string) string) ([]string, []interface{}) {
columns := make([]string, 0, len(m))
outer:
for colName := range m {
trimmed := trimmer(colName)
for _, excluded := range exclude {
if strings.EqualFold(excluded, trimmed) {
continue outer
}
}
columns = append(columns, colName)
}
sort.Strings(columns)
args := make([]interface{}, 0, len(columns))
for _, colName := range columns {
args = append(args, m[colName])
}
return columns, args
}
func MapStringToSlices(m map[string]string, exclude []string, trimmer func(string) string) ([]string, []interface{}) {
columns := make([]string, 0, len(m))
outer:
for colName := range m {
trimmed := trimmer(colName)
for _, excluded := range exclude {
if strings.EqualFold(excluded, trimmed) {
continue outer
}
}
columns = append(columns, colName)
}
sort.Strings(columns)
args := make([]interface{}, 0, len(columns))
for _, colName := range columns {
args = append(args, m[colName])
}
return columns, args
}
func MultipleMapToSlices(maps []map[string]interface{}, exclude []string, trimmer func(string) string) ([]string, [][]interface{}) {
columns := make([]string, 0, len(maps[0]))
outer:
for colName := range maps[0] {
trimmed := trimmer(colName)
for _, excluded := range exclude {
if strings.EqualFold(excluded, trimmed) {
continue outer
}
}
columns = append(columns, colName)
}
sort.Strings(columns)
argss := make([][]interface{}, 0, len(maps))
for _, m := range maps {
args := make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
argss = append(argss, args)
}
return columns, argss
}
func MultipleMapStringToSlices(maps []map[string]string, exclude []string, trimmer func(string) string) ([]string, [][]interface{}) {
columns := make([]string, 0, len(maps[0]))
outer:
for colName := range maps[0] {
trimmed := trimmer(colName)
for _, excluded := range exclude {
if strings.EqualFold(excluded, trimmed) {
continue outer
}
}
columns = append(columns, colName)
}
sort.Strings(columns)
argss := make([][]interface{}, 0, len(maps))
for _, m := range maps {
args := make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
argss = append(argss, args)
}
return columns, argss
}

View File

@ -5,10 +5,10 @@
package xorm package xorm
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"sort"
"strings" "strings"
"time" "time"
@ -156,14 +156,14 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, e
} }
args = append(args, val) args = append(args, val)
var colName = col.Name colName := col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)
}) })
} else if col.IsVersion && session.statement.CheckVersion { } else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1) args = append(args, 1)
var colName = col.Name colName := col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnInt(bean, col, 1) setColumnInt(bean, col, 1)
@ -276,7 +276,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
processor.BeforeInsert() processor.BeforeInsert()
} }
var tableName = session.statement.TableName() tableName := session.statement.TableName()
table := session.statement.RefTable table := session.statement.RefTable
colNames, args, err := session.genInsertColumns(bean) colNames, args, err := session.genInsertColumns(bean)
@ -290,102 +290,9 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
} }
sqlStr = session.engine.dialect.Quoter().Replace(sqlStr) sqlStr = session.engine.dialect.Quoter().Replace(sqlStr)
handleAfterInsertProcessorFunc := func(bean interface{}) {
if session.isAutoCommit {
for _, closure := range session.afterClosures {
closure(bean)
}
if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
processor.AfterInsert()
}
} else {
lenAfterClosures := len(session.afterClosures)
if lenAfterClosures > 0 {
if value, has := session.afterInsertBeans[bean]; has && value != nil {
*value = append(*value, session.afterClosures...)
} else {
afterClosures := make([]func(interface{}), lenAfterClosures)
copy(afterClosures, session.afterClosures)
session.afterInsertBeans[bean] = &afterClosures
}
} else {
if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
session.afterInsertBeans[bean] = nil
}
}
}
cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
}
// if there is auto increment column and driver don't support return it // if there is auto increment column and driver don't support return it
if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID { if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID {
var sql string return session.execInsertSqlNoAutoReturn(sqlStr, bean, colNames, args)
var newArgs []interface{}
var needCommit bool
var id int64
if session.engine.dialect.URI().DBType == schemas.ORACLE || session.engine.dialect.URI().DBType == schemas.DAMENG {
if session.isAutoCommit { // if it's not in transaction
if err := session.Begin(); err != nil {
return 0, err
}
needCommit = true
}
_, err := session.exec(sqlStr, args...)
if err != nil {
return 0, err
}
i := utils.IndexSlice(colNames, table.AutoIncrement)
if i > -1 {
id, err = convert.AsInt64(args[i])
if err != nil {
return 0, err
}
} else {
sql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName))
}
} else {
sql = sqlStr
newArgs = args
}
if id == 0 {
err := session.queryRow(sql, newArgs...).Scan(&id)
if err != nil {
return 0, err
}
if needCommit {
if err := session.Commit(); err != nil {
return 0, err
}
}
if id == 0 {
return 0, errors.New("insert successfully but not returned id")
}
}
defer handleAfterInsertProcessorFunc(bean)
_ = session.cacheInsert(tableName)
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return 1, nil
}
return 1, convert.AssignValue(*aiValue, id)
} }
res, err := session.exec(sqlStr, args...) res, err := session.exec(sqlStr, args...)
@ -393,7 +300,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
return 0, err return 0, err
} }
defer handleAfterInsertProcessorFunc(bean) defer session.handleAfterInsertProcessorFunc(bean)
_ = session.cacheInsert(tableName) _ = session.cacheInsert(tableName)
@ -432,31 +339,6 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
return res.RowsAffected() return res.RowsAffected()
} }
// InsertOne insert only one struct into database as a record.
// The in parameter bean must a struct or a point to struct. The return
// parameter is inserted and error
// Deprecated: Please use Insert directly
func (session *Session) InsertOne(bean interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
return session.insertStruct(bean)
}
func (session *Session) cacheInsert(table string) error {
if !session.statement.UseCache {
return nil
}
cacher := session.engine.cacherMgr.GetCacher(table)
if cacher == nil {
return nil
}
session.engine.logger.Debugf("[cache] clear SQL: %v", table)
cacher.ClearIds(table)
return nil
}
// genInsertColumns generates insert needed columns // genInsertColumns generates insert needed columns
func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) { func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) {
table := session.statement.RefTable table := session.statement.RefTable
@ -517,7 +399,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
} }
args = append(args, val) args = append(args, val)
var colName = col.Name colName := col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)
@ -537,6 +419,112 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
return colNames, args, nil return colNames, args, nil
} }
func (session *Session) execInsertSqlNoAutoReturn(sqlStr string, bean interface{}, colNames []string, args []interface{}) (int64, error) {
var newSql string
var newArgs []interface{}
var needCommit bool
var id int64
tableName := session.statement.TableName()
table := session.statement.RefTable
if session.engine.dialect.URI().DBType == schemas.ORACLE || session.engine.dialect.URI().DBType == schemas.DAMENG {
if session.isAutoCommit { // if it's not in transaction
if err := session.Begin(); err != nil {
return 0, err
}
needCommit = true
}
res, err := session.exec(sqlStr, args...)
if err != nil {
return 0, err
}
n, err := res.RowsAffected()
if err != nil {
return 0, err
}
if n == 0 {
return 0, sql.ErrNoRows
}
i := utils.IndexSlice(colNames, table.AutoIncrement)
if i > -1 {
id, err = convert.AsInt64(args[i])
if err != nil {
return 0, err
}
} else {
newSql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName))
}
} else {
newSql = sqlStr
newArgs = args
}
if id == 0 {
err := session.queryRow(newSql, newArgs...).Scan(&id)
if err != nil {
return 0, err
}
if needCommit {
if err := session.Commit(); err != nil {
return 0, err
}
}
if id == 0 {
return 0, errors.New("insert successfully but not returned id")
}
}
defer session.handleAfterInsertProcessorFunc(bean)
_ = session.cacheInsert(tableName)
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return 1, nil
}
return 1, convert.AssignValue(*aiValue, id)
}
// InsertOne insert only one struct into database as a record.
// The in parameter bean must a struct or a point to struct. The return
// parameter is inserted and error
// Deprecated: Please use Insert directly
func (session *Session) InsertOne(bean interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
return session.insertStruct(bean)
}
func (session *Session) cacheInsert(table string) error {
if !session.statement.UseCache {
return nil
}
cacher := session.engine.cacherMgr.GetCacher(table)
if cacher == nil {
return nil
}
session.engine.logger.Debugf("[cache] clear SQL: %v", table)
cacher.ClearIds(table)
return nil
}
func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) { func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) {
if len(m) == 0 { if len(m) == 0 {
return 0, ErrParamsType return 0, ErrParamsType
@ -547,19 +535,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
var columns = make([]string, 0, len(m)) columns, args := utils.MapToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
exprs := session.statement.ExprColumns
for k := range m {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
return session.insertMap(columns, args) return session.insertMap(columns, args)
} }
@ -574,23 +550,7 @@ func (session *Session) insertMultipleMapInterface(maps []map[string]interface{}
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
var columns = make([]string, 0, len(maps[0])) columns, argss := utils.MultipleMapToSlices(maps, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
exprs := session.statement.ExprColumns
for k := range maps[0] {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var argss = make([][]interface{}, 0, len(maps))
for _, m := range maps {
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
argss = append(argss, args)
}
return session.insertMultipleMap(columns, argss) return session.insertMultipleMap(columns, argss)
} }
@ -605,20 +565,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
var columns = make([]string, 0, len(m)) columns, args := utils.MapStringToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
exprs := session.statement.ExprColumns
for k := range m {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
return session.insertMap(columns, args) return session.insertMap(columns, args)
} }
@ -633,23 +580,7 @@ func (session *Session) insertMultipleMapString(maps []map[string]string) (int64
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
var columns = make([]string, 0, len(maps[0])) columns, argss := utils.MultipleMapStringToSlices(maps, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
exprs := session.statement.ExprColumns
for k := range maps[0] {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var argss = make([][]interface{}, 0, len(maps))
for _, m := range maps {
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
argss = append(argss, args)
}
return session.insertMultipleMap(columns, argss) return session.insertMultipleMap(columns, argss)
} }
@ -707,3 +638,30 @@ func (session *Session) insertMultipleMap(columns []string, argss [][]interface{
} }
return affected, nil return affected, nil
} }
func (session *Session) handleAfterInsertProcessorFunc(bean interface{}) {
if session.isAutoCommit {
for _, closure := range session.afterClosures {
closure(bean)
}
if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
processor.AfterInsert()
}
} else {
lenAfterClosures := len(session.afterClosures)
if lenAfterClosures > 0 {
if value, has := session.afterInsertBeans[bean]; has && value != nil {
*value = append(*value, session.afterClosures...)
} else {
afterClosures := make([]func(interface{}), lenAfterClosures)
copy(afterClosures, session.afterClosures)
session.afterInsertBeans[bean] = &afterClosures
}
} else {
if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
session.afterInsertBeans[bean] = nil
}
}
}
cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
}

298
session_upsert.go Normal file
View File

@ -0,0 +1,298 @@
package xorm
import (
"database/sql"
"fmt"
"reflect"
"xorm.io/xorm/convert"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
func (session *Session) InsertOnConflictDoNothing(beans ...interface{}) (int64, error) {
return session.upsert(false, beans...)
}
func (session *Session) Upsert(beans ...interface{}) (int64, error) {
return session.upsert(true, beans...)
}
func (session *Session) upsert(doUpdate bool, beans ...interface{}) (int64, error) {
var affected int64
var err error
if session.isAutoClose {
defer session.Close()
}
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
for _, bean := range beans {
var cnt int64
var err error
switch v := bean.(type) {
case map[string]interface{}:
cnt, err = session.upsertMapInterface(doUpdate, v)
case []map[string]interface{}:
for _, m := range v {
cnt, err := session.upsertMapInterface(doUpdate, m)
if err != nil {
return affected, err
}
affected += cnt
}
case map[string]string:
cnt, err = session.upsertMapString(doUpdate, v)
case []map[string]string:
for _, m := range v {
cnt, err := session.upsertMapString(doUpdate, m)
if err != nil {
return affected, err
}
affected += cnt
}
default:
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice {
if sliceValue.Len() <= 0 {
return 0, ErrNoElementsOnSlice
}
for i := 0; i < sliceValue.Len(); i++ {
v := sliceValue.Index(i)
bean := v.Interface()
cnt, err := session.upsertStruct(doUpdate, bean)
if err != nil {
return affected, err
}
affected += cnt
}
} else {
cnt, err = session.upsertStruct(doUpdate, bean)
}
}
if err != nil {
return affected, err
}
affected += cnt
}
return affected, err
}
func (session *Session) upsertMapInterface(doUpdate bool, m map[string]interface{}) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
}
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
columns, args := utils.MapToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
return session.upsertMap(doUpdate, columns, args)
}
func (session *Session) upsertMapString(doUpdate bool, m map[string]string) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
}
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
columns, args := utils.MapStringToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
return session.upsertMap(doUpdate, columns, args)
}
func (session *Session) upsertMap(doUpdate bool, columns []string, args []interface{}) (int64, error) {
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
uniqueColValMap, err := session.getUniqueColumns(columns, args)
if err != nil {
return 0, err
}
sql, args, err := session.statement.GenUpsertMapSQL(doUpdate, columns, args, uniqueColValMap)
if err != nil {
return 0, err
}
sql = session.engine.dialect.Quoter().Replace(sql)
if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return affected, fmt.Errorf("unimplemented")
}
func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, error) {
if err := session.statement.SetRefBean(bean); err != nil {
return 0, err
}
if len(session.statement.TableName()) == 0 {
return 0, ErrTableNotFound
}
// handle BeforeInsertProcessor
for _, closure := range session.beforeClosures {
closure(bean)
}
cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
processor.BeforeInsert()
}
tableName := session.statement.TableName()
table := session.statement.RefTable
colNames, args, err := session.genInsertColumns(bean)
if err != nil {
return 0, err
}
uniqueColValMap, err := session.getUniqueColumns(colNames, args)
if err != nil {
return 0, err
}
sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, colNames, args, uniqueColValMap)
if err != nil {
return 0, err
}
sqlStr = session.engine.dialect.Quoter().Replace(sqlStr)
// if there is auto increment column and driver doesn't support return it
if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID {
n, err := session.execInsertSqlNoAutoReturn(sqlStr, bean, colNames, args)
if err == sql.ErrNoRows {
return n, nil
}
return n, err
}
res, err := session.exec(sqlStr, args...)
if err != nil {
return 0, err
}
defer session.handleAfterInsertProcessorFunc(bean)
_ = session.cacheInsert(tableName)
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
}
if table.AutoIncrement == "" {
return res.RowsAffected()
}
n, err := res.RowsAffected()
if err != nil || n == 0 {
return 0, err
}
var id int64
id, err = res.LastInsertId()
if err != nil || id <= 0 {
return n, err
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return n, err
}
if err := convert.AssignValue(*aiValue, id); err != nil {
return 0, err
}
return n, err
}
func (session *Session) getUniqueColumns(colNames []string, args []interface{}) (uniqueColValMap map[string]interface{}, err error) {
uniqueColValMap = make(map[string]interface{})
table := session.statement.RefTable
if len(table.Indexes) == 0 {
return nil, fmt.Errorf("provided bean has no unique constraints")
}
// Iterate across the indexes in the provided table
for _, index := range table.Indexes {
if index.Type != schemas.UniqueType {
continue
}
// index is a Unique constraint
indexCol:
for _, indexColumnName := range index.Cols {
if _, has := uniqueColValMap[indexColumnName]; has {
// column is already included in uniqueCols so we don't need to add it again
continue indexCol
}
// Now iterate across colNames and add to the uniqueCols
for i, col := range colNames {
if col == indexColumnName {
uniqueColValMap[col] = args[i]
continue indexCol
}
}
indexColumn := table.GetColumn(indexColumnName)
if !indexColumn.DefaultIsEmpty {
uniqueColValMap[indexColumnName] = indexColumn.Default
}
if indexColumn.MapType == schemas.ONLYFROMDB || indexColumn.IsAutoIncrement {
continue
}
if session.statement.OmitColumnMap.Contain(indexColumn.Name) {
continue
}
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) {
continue
}
// FIXME: what do we do here?!
if session.statement.IncrColumns.IsColExist(indexColumn.Name) {
continue
} else if session.statement.DecrColumns.IsColExist(indexColumn.Name) {
continue
} else if session.statement.ExprColumns.IsColExist(indexColumn.Name) {
continue
}
// FIXME: not sure if there's anything else we can do
return nil, fmt.Errorf("provided bean does not provide a value for unique constraint %s field %s which has no default", index.Name, indexColumnName)
}
}
return uniqueColValMap, nil
}