Add InsertOnConflictDoNothing and Upsert functionality
This PR adds functionality for xorm to perform InsertOnConflictDoNothing and Upserts. Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
parent
d485abba57
commit
3a4fbeaa6f
14
engine.go
14
engine.go
|
@ -1224,6 +1224,13 @@ func (engine *Engine) InsertOne(bean interface{}) (int64, error) {
|
|||
return session.InsertOne(bean)
|
||||
}
|
||||
|
||||
// InsertOnConflictDoNothing attempt to insert a record but on conflict do nothing
|
||||
func (engine *Engine) InsertOnConflictDoNothing(beans ...interface{}) (int64, error) {
|
||||
session := engine.NewSession()
|
||||
defer session.Close()
|
||||
return session.InsertOnConflictDoNothing(beans...)
|
||||
}
|
||||
|
||||
// Update records, bean's non-empty fields are updated contents,
|
||||
// condiBean' non-empty filds are conditions
|
||||
// CAUTION:
|
||||
|
@ -1237,6 +1244,13 @@ func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64
|
|||
return session.Update(bean, condiBeans...)
|
||||
}
|
||||
|
||||
// Upsert attempt to insert a record but on conflict do update
|
||||
func (engine *Engine) Upsert(beans ...interface{}) (int64, error) {
|
||||
session := engine.NewSession()
|
||||
defer session.Close()
|
||||
return session.Upsert(beans...)
|
||||
}
|
||||
|
||||
// Delete records, bean's non-empty fields are conditions
|
||||
// At least one condition must be set.
|
||||
func (engine *Engine) Delete(beans ...interface{}) (int64, error) {
|
||||
|
|
|
@ -16,6 +16,231 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInsertOnConflictDoNothing(t *testing.T) {
|
||||
assert.NoError(t, PrepareEngine())
|
||||
|
||||
t.Run("NoUnique", func(t *testing.T) {
|
||||
// InsertOnConflictDoNothing does not work if there is no unique constraint
|
||||
type NoUniques struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Data string
|
||||
}
|
||||
assert.NoError(t, testEngine.Sync(new(NoUniques)))
|
||||
|
||||
toInsert := &NoUniques{Data: "shouldErr"}
|
||||
n, err := testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
assert.Equal(t, int64(0), toInsert.ID)
|
||||
|
||||
toInsert = &NoUniques{Data: ""}
|
||||
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
assert.Equal(t, int64(0), toInsert.ID)
|
||||
})
|
||||
|
||||
t.Run("OneUnique", func(t *testing.T) {
|
||||
type OneUnique struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Data string `xorm:"UNIQUE NOT NULL"`
|
||||
}
|
||||
|
||||
assert.NoError(t, testEngine.Sync2(&OneUnique{}))
|
||||
_, _ = testEngine.Exec("DELETE FROM one_unique")
|
||||
|
||||
// Insert with the default value for the unique field
|
||||
toInsert := &OneUnique{}
|
||||
n, err := testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
assert.NotEqual(t, int64(0), toInsert.ID)
|
||||
|
||||
// but not twice
|
||||
toInsert = &OneUnique{}
|
||||
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
assert.Equal(t, int64(0), toInsert.ID)
|
||||
|
||||
// Successfully insert test
|
||||
toInsert = &OneUnique{Data: "test"}
|
||||
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
assert.NotEqual(t, int64(0), toInsert.ID)
|
||||
|
||||
// Successfully insert test2
|
||||
toInsert = &OneUnique{Data: "test2"}
|
||||
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
assert.NotEqual(t, int64(0), toInsert.ID)
|
||||
|
||||
// Successfully don't reinsert test
|
||||
toInsert = &OneUnique{Data: "test"}
|
||||
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
assert.Equal(t, int64(0), toInsert.ID)
|
||||
})
|
||||
|
||||
t.Run("MultiUnique", func(t *testing.T) {
|
||||
type MultiUnique struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
NotUnique string
|
||||
Data1 string `xorm:"UNIQUE(s) NOT NULL"`
|
||||
Data2 string `xorm:"UNIQUE(s) NOT NULL"`
|
||||
}
|
||||
|
||||
assert.NoError(t, testEngine.Sync2(&MultiUnique{}))
|
||||
_, _ = testEngine.Exec("DELETE FROM multi_unique")
|
||||
|
||||
// Insert with default values
|
||||
toInsert := &MultiUnique{}
|
||||
n, err := testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
assert.NotEqual(t, int64(0), toInsert.ID)
|
||||
|
||||
// successfully insert test, t1
|
||||
toInsert = &MultiUnique{Data1: "test", NotUnique: "t1"}
|
||||
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
assert.NotEqual(t, int64(0), toInsert.ID)
|
||||
|
||||
// successfully insert test2, t1
|
||||
toInsert = &MultiUnique{Data1: "test2", NotUnique: "t1"}
|
||||
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
assert.NotEqual(t, int64(0), toInsert.ID)
|
||||
|
||||
// successfully don't insert test2, t2
|
||||
toInsert = &MultiUnique{Data1: "test2", NotUnique: "t2"}
|
||||
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
assert.Equal(t, int64(0), toInsert.ID)
|
||||
|
||||
// successfully don't insert test, t2
|
||||
toInsert = &MultiUnique{Data1: "test", NotUnique: "t2"}
|
||||
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
assert.Equal(t, int64(0), toInsert.ID)
|
||||
|
||||
// successfully insert test/test2, t2
|
||||
toInsert = &MultiUnique{Data1: "test", Data2: "test2", NotUnique: "t1"}
|
||||
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
assert.NotEqual(t, int64(0), toInsert.ID)
|
||||
|
||||
// successfully don't insert test/test2, t2
|
||||
toInsert = &MultiUnique{Data1: "test", Data2: "test2", NotUnique: "t2"}
|
||||
n, err = testEngine.InsertOnConflictDoNothing(toInsert)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
assert.Equal(t, int64(0), toInsert.ID)
|
||||
})
|
||||
|
||||
t.Run("MultiMultiUnique", func(t *testing.T) {
|
||||
type MultiMultiUnique struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Data0 string `xorm:"UNIQUE NOT NULL"`
|
||||
Data1 string `xorm:"UNIQUE(s) NOT NULL"`
|
||||
Data2 string `xorm:"UNIQUE(s) NOT NULL"`
|
||||
}
|
||||
|
||||
assert.NoError(t, testEngine.Sync2(&MultiMultiUnique{}))
|
||||
_, _ = testEngine.Exec("DELETE FROM multi_multi_unique")
|
||||
|
||||
// Insert with default values
|
||||
n, err := testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
|
||||
// Insert with value for t1, <test, "">
|
||||
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data1: "test", Data0: "t1"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
|
||||
// Fail insert with value for t1, <test2, "">
|
||||
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data2: "test2", Data0: "t1"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
|
||||
// Insert with value for t2, <test2, "">
|
||||
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data2: "test2", Data0: "t2"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
|
||||
// Fail insert with value for t2, <test2, "">
|
||||
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data2: "test2", Data0: "t2"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
|
||||
// Fail insert with value for t2, <test, "">
|
||||
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data1: "test", Data0: "t2"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
|
||||
// Insert with value for t3, <test, test2>
|
||||
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data1: "test", Data2: "test2", Data0: "t3"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
|
||||
// fail insert with value for t2, <test, test2>
|
||||
n, err = testEngine.InsertOnConflictDoNothing(&MultiMultiUnique{Data1: "test", Data2: "test2", Data0: "t2"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
})
|
||||
|
||||
t.Run("NoPK", func(t *testing.T) {
|
||||
type NoPrimaryKey struct {
|
||||
NotID int64
|
||||
Uniqued string `xorm:"UNIQUE"`
|
||||
}
|
||||
|
||||
assert.NoError(t, testEngine.Sync2(&NoPrimaryKey{}))
|
||||
_, _ = testEngine.Exec("DELETE FROM no_primary_unique")
|
||||
|
||||
empty := &NoPrimaryKey{}
|
||||
|
||||
// Insert default
|
||||
n, err := testEngine.InsertOnConflictDoNothing(empty)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
|
||||
// Insert with 1
|
||||
n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{Uniqued: "1"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
|
||||
// Fail reinsert default
|
||||
n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{NotID: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
|
||||
// Fail reinsert default
|
||||
n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{NotID: 2})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
|
||||
// Insert with 2
|
||||
n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{NotID: 2, Uniqued: "2"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), n)
|
||||
|
||||
// Fail reinsert with 2
|
||||
n, err = testEngine.InsertOnConflictDoNothing(&NoPrimaryKey{NotID: 1, Uniqued: "2"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), n)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInsertOne(t *testing.T) {
|
||||
assert.NoError(t, PrepareEngine())
|
||||
|
||||
|
@ -142,8 +367,13 @@ func TestInsert(t *testing.T) {
|
|||
assert.NoError(t, PrepareEngine())
|
||||
assertSync(t, new(Userinfo))
|
||||
|
||||
user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(),
|
||||
Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true}
|
||||
user := Userinfo{
|
||||
0, "xiaolunwen", "dev", "lunny", time.Now(),
|
||||
Userdetail{Id: 1},
|
||||
1.78,
|
||||
[]byte{1, 2, 3},
|
||||
true,
|
||||
}
|
||||
cnt, err := testEngine.Insert(&user)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, cnt, "insert not returned 1")
|
||||
|
@ -161,8 +391,10 @@ func TestInsertAutoIncr(t *testing.T) {
|
|||
assertSync(t, new(Userinfo))
|
||||
|
||||
// auto increment insert
|
||||
user := Userinfo{Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(),
|
||||
Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true}
|
||||
user := Userinfo{
|
||||
Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(),
|
||||
Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true,
|
||||
}
|
||||
cnt, err := testEngine.Insert(&user)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, cnt)
|
||||
|
@ -184,7 +416,7 @@ func TestInsertDefault(t *testing.T) {
|
|||
err := testEngine.Sync(di)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var di2 = DefaultInsert{Name: "test"}
|
||||
di2 := DefaultInsert{Name: "test"}
|
||||
_, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("Status")).Insert(&di2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -210,7 +442,7 @@ func TestInsertDefault2(t *testing.T) {
|
|||
err := testEngine.Sync(di)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var di2 = DefaultInsert2{Name: "test"}
|
||||
di2 := DefaultInsert2{Name: "test"}
|
||||
_, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("CheckTime")).Insert(&di2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -438,7 +670,7 @@ func TestCreatedJsonTime(t *testing.T) {
|
|||
assert.True(t, has)
|
||||
assert.EqualValues(t, time.Time(ci5.Created).Unix(), time.Time(di5.Created).Unix())
|
||||
|
||||
var dis = make([]MyJSONTime, 0)
|
||||
dis := make([]MyJSONTime, 0)
|
||||
err = testEngine.Find(&dis)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
@ -762,7 +994,7 @@ func TestInsertWhere(t *testing.T) {
|
|||
assert.NoError(t, PrepareEngine())
|
||||
assertSync(t, new(InsertWhere))
|
||||
|
||||
var i = InsertWhere{
|
||||
i := InsertWhere{
|
||||
RepoId: 1,
|
||||
Width: 10,
|
||||
Height: 20,
|
||||
|
@ -872,7 +1104,7 @@ func TestInsertExpr2(t *testing.T) {
|
|||
|
||||
assertSync(t, new(InsertExprsRelease))
|
||||
|
||||
var ie = InsertExprsRelease{
|
||||
ie := InsertExprsRelease{
|
||||
RepoId: 1,
|
||||
IsTag: true,
|
||||
}
|
||||
|
@ -1047,7 +1279,7 @@ func TestInsertIntSlice(t *testing.T) {
|
|||
|
||||
assert.NoError(t, testEngine.Sync(new(InsertIntSlice)))
|
||||
|
||||
var v = InsertIntSlice{
|
||||
v := InsertIntSlice{
|
||||
NameIDs: []int{1, 2},
|
||||
}
|
||||
cnt, err := testEngine.Insert(&v)
|
||||
|
@ -1064,7 +1296,7 @@ func TestInsertIntSlice(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, cnt)
|
||||
|
||||
var v3 = InsertIntSlice{
|
||||
v3 := InsertIntSlice{
|
||||
NameIDs: nil,
|
||||
}
|
||||
cnt, err = testEngine.Insert(&v3)
|
||||
|
|
|
@ -44,7 +44,8 @@ type Interface interface {
|
|||
In(string, ...interface{}) *Session
|
||||
Incr(column string, arg ...interface{}) *Session
|
||||
Insert(...interface{}) (int64, error)
|
||||
InsertOne(interface{}) (int64, error)
|
||||
InsertOne(interface{}) (int64, error) // Deprecated: Please use Insert directly
|
||||
InsertOnConflictDoNothing(beans ...interface{}) (int64, error)
|
||||
IsTableEmpty(bean interface{}) (bool, error)
|
||||
IsTableExist(beanOrTableName interface{}) (bool, error)
|
||||
Iterate(interface{}, IterFunc) error
|
||||
|
@ -71,6 +72,7 @@ type Interface interface {
|
|||
Table(tableNameOrBean interface{}) *Session
|
||||
Unscoped() *Session
|
||||
Update(bean interface{}, condiBeans ...interface{}) (int64, error)
|
||||
Upsert(beans ...interface{}) (int64, error)
|
||||
UseBool(...string) *Session
|
||||
Where(interface{}, ...interface{}) *Session
|
||||
}
|
||||
|
|
|
@ -59,13 +59,21 @@ func (expr *Expr) WriteArgs(w *builder.BytesWriter) error {
|
|||
type exprParams []Expr
|
||||
|
||||
func (exprs exprParams) ColNames() []string {
|
||||
var cols = make([]string, 0, len(exprs))
|
||||
cols := make([]string, 0, len(exprs))
|
||||
for _, expr := range exprs {
|
||||
cols = append(cols, expr.ColName)
|
||||
}
|
||||
return cols
|
||||
}
|
||||
|
||||
func (exprs exprParams) ColNamesTrim() []string {
|
||||
cols := make([]string, 0, len(exprs))
|
||||
for _, expr := range exprs {
|
||||
cols = append(cols, schemas.CommonQuoter.Trim(expr.ColName))
|
||||
}
|
||||
return cols
|
||||
}
|
||||
|
||||
func (exprs *exprParams) Add(name string, arg interface{}) {
|
||||
*exprs = append(*exprs, Expr{name, arg})
|
||||
}
|
||||
|
|
|
@ -30,7 +30,6 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem
|
|||
func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (string, []interface{}, error) {
|
||||
var (
|
||||
buf = builder.NewWriter()
|
||||
exprs = statement.ExprColumns
|
||||
table = statement.RefTable
|
||||
tableName = statement.TableName()
|
||||
)
|
||||
|
@ -43,129 +42,8 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
|||
return "", nil, err
|
||||
}
|
||||
|
||||
var hasInsertColumns = len(colNames) > 0
|
||||
var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
|
||||
if needSeq {
|
||||
for _, col := range colNames {
|
||||
if strings.EqualFold(col, table.AutoIncrement) {
|
||||
needSeq = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE &&
|
||||
statement.dialect.URI().DBType != schemas.DAMENG {
|
||||
if statement.dialect.URI().DBType == schemas.MYSQL {
|
||||
if _, err := buf.WriteString(" VALUES ()"); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
} else {
|
||||
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if _, err := buf.WriteString(" DEFAULT VALUES"); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, err := buf.WriteString(" ("); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if needSeq {
|
||||
colNames = append(colNames, table.AutoIncrement)
|
||||
}
|
||||
|
||||
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames()...), ","); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if _, err := buf.WriteString(")"); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if statement.Conds().IsValid() {
|
||||
if _, err := buf.WriteString(" SELECT "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if err := statement.WriteArgs(buf, args); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if needSeq {
|
||||
if len(args) > 0 {
|
||||
if _, err := buf.WriteString(","); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
if len(exprs) > 0 {
|
||||
if _, err := buf.WriteString(","); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if err := exprs.WriteArgs(buf); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := buf.WriteString(" FROM "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if _, err := buf.WriteString(" WHERE "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if err := statement.Conds().WriteTo(buf); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
} else {
|
||||
if _, err := buf.WriteString(" VALUES ("); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if err := statement.WriteArgs(buf, args); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Insert tablename (id) Values(seq_tablename.nextval)
|
||||
if needSeq {
|
||||
if hasInsertColumns {
|
||||
if _, err := buf.WriteString(","); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(exprs) > 0 {
|
||||
if _, err := buf.WriteString(","); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := exprs.WriteArgs(buf); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if _, err := buf.WriteString(")"); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
if err := statement.genInsertValues(buf, colNames, args); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES {
|
||||
|
@ -180,6 +58,169 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
|||
return buf.String(), buf.Args(), nil
|
||||
}
|
||||
|
||||
func (statement *Statement) includeAutoIncrement(colNames []string) bool {
|
||||
includesAutoIncrement := len(statement.RefTable.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
|
||||
if includesAutoIncrement {
|
||||
for _, col := range colNames {
|
||||
if strings.EqualFold(col, statement.RefTable.AutoIncrement) {
|
||||
includesAutoIncrement = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return includesAutoIncrement
|
||||
}
|
||||
|
||||
func (statement *Statement) genInsertValues(buf *builder.BytesWriter, colNames []string, args []interface{}) error {
|
||||
var (
|
||||
exprs = statement.ExprColumns
|
||||
table = statement.RefTable
|
||||
)
|
||||
|
||||
hasInsertColumns := len(colNames) > 0
|
||||
includeAutoIncrement := statement.includeAutoIncrement(colNames)
|
||||
|
||||
// Empty insert - i.e. insert default values only
|
||||
if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE &&
|
||||
statement.dialect.URI().DBType != schemas.DAMENG {
|
||||
|
||||
if statement.dialect.URI().DBType == schemas.MYSQL {
|
||||
// MySQL doesn't have DEFAULT VALUES and uses VALUES () for this.
|
||||
if _, err := buf.WriteString(" VALUES ()"); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// (MSSQL: return the inserted values)
|
||||
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// All others use DEFAULT VALUES
|
||||
if _, err := buf.WriteString(" DEFAULT VALUES"); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// We have some values - Write the column names we need to insert:
|
||||
if _, err := buf.WriteString(" ("); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if includeAutoIncrement {
|
||||
colNames = append(colNames, table.AutoIncrement)
|
||||
}
|
||||
|
||||
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames()...), ","); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := buf.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// (MSSQL: return the inserted values)
|
||||
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return statement.genInsertValuesValues(buf, includeAutoIncrement, colNames, args)
|
||||
}
|
||||
|
||||
func (statement *Statement) genInsertValuesValues(buf *builder.BytesWriter, includeAutoIncrement bool, colNames []string, args []interface{}) error {
|
||||
var (
|
||||
exprs = statement.ExprColumns
|
||||
tableName = statement.TableName()
|
||||
)
|
||||
hasInsertColumns := len(colNames) > 0
|
||||
|
||||
if statement.Conds().IsValid() {
|
||||
// We have conditions which we're trying to insert
|
||||
if _, err := buf.WriteString(" SELECT "); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.WriteArgs(buf, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if includeAutoIncrement {
|
||||
if len(args) > 0 {
|
||||
if _, err := buf.WriteString(","); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(exprs) > 0 {
|
||||
if _, err := buf.WriteString(","); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := exprs.WriteArgs(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := buf.WriteString(" FROM "); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := buf.WriteString(" WHERE "); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.Conds().WriteTo(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Direct insertion of values
|
||||
if _, err := buf.WriteString(" VALUES ("); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.WriteArgs(buf, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert tablename (id) Values(seq_tablename.nextval)
|
||||
if includeAutoIncrement {
|
||||
if hasInsertColumns {
|
||||
if _, err := buf.WriteString(","); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(exprs) > 0 {
|
||||
if _, err := buf.WriteString(","); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := exprs.WriteArgs(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := buf.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenInsertMapSQL generates insert map SQL
|
||||
func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{}) (string, []interface{}, error) {
|
||||
var (
|
||||
|
@ -196,51 +237,12 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{}
|
|||
return "", nil, err
|
||||
}
|
||||
|
||||
// if insert where
|
||||
if statement.Conds().IsValid() {
|
||||
if _, err := buf.WriteString(") SELECT "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if _, err := buf.WriteString(")"); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if err := statement.WriteArgs(buf, args); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if len(exprs) > 0 {
|
||||
if _, err := buf.WriteString(","); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if err := exprs.WriteArgs(buf); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := buf.WriteString(fmt.Sprintf(" FROM %s WHERE ", statement.quote(tableName))); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if err := statement.Conds().WriteTo(buf); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
} else {
|
||||
if _, err := buf.WriteString(") VALUES ("); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if err := statement.WriteArgs(buf, args); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if len(exprs) > 0 {
|
||||
if _, err := buf.WriteString(","); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if err := exprs.WriteArgs(buf); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
if _, err := buf.WriteString(")"); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return buf.String(), buf.Args(), nil
|
||||
|
|
|
@ -0,0 +1,271 @@
|
|||
// Copyright 2020 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package statements
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
// GenUpsertSQL generates upsert beans SQL
|
||||
func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
|
||||
if statement.dialect.URI().DBType == schemas.MSSQL ||
|
||||
statement.dialect.URI().DBType == schemas.DAMENG ||
|
||||
statement.dialect.URI().DBType == schemas.ORACLE {
|
||||
return statement.genMergeSQL(doUpdate, columns, args, uniqueColValMap)
|
||||
}
|
||||
|
||||
var (
|
||||
buf = builder.NewWriter()
|
||||
table = statement.RefTable
|
||||
tableName = statement.TableName()
|
||||
)
|
||||
|
||||
if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate {
|
||||
if _, err := buf.WriteString("INSERT IGNORE INTO "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
} else {
|
||||
if _, err := buf.WriteString("INSERT INTO "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if err := statement.genInsertValues(buf, columns, args); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch statement.dialect.URI().DBType {
|
||||
case schemas.SQLITE, schemas.POSTGRES:
|
||||
if _, err := buf.WriteString(" ON CONFLICT DO "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if doUpdate {
|
||||
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
|
||||
} else {
|
||||
if _, err := buf.WriteString("NOTHING"); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
case schemas.MYSQL:
|
||||
if doUpdate {
|
||||
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
|
||||
}
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
|
||||
}
|
||||
|
||||
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES {
|
||||
if _, err := buf.WriteString(" RETURNING "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String(), buf.Args(), nil
|
||||
}
|
||||
|
||||
func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
|
||||
var (
|
||||
buf = builder.NewWriter()
|
||||
table = statement.RefTable
|
||||
tableName = statement.TableName()
|
||||
)
|
||||
|
||||
quote := statement.dialect.Quoter().Quote
|
||||
write := func(args ...string) {
|
||||
for _, arg := range args {
|
||||
_, _ = buf.WriteString(arg)
|
||||
}
|
||||
}
|
||||
|
||||
write("MERGE INTO ", quote(tableName))
|
||||
if statement.dialect.URI().DBType == schemas.MSSQL {
|
||||
write("WITH (HOLDLOCK) AS target ")
|
||||
}
|
||||
write("USING (SELECT ")
|
||||
|
||||
uniqueCols := make([]string, 0, len(uniqueColValMap))
|
||||
for colName := range uniqueColValMap {
|
||||
uniqueCols = append(uniqueCols, colName)
|
||||
}
|
||||
for i, colName := range uniqueCols {
|
||||
if err := statement.WriteArg(buf, uniqueColValMap[colName]); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
write(" AS ", quote(colName))
|
||||
if i < len(uniqueCols)-1 {
|
||||
write(", ")
|
||||
}
|
||||
}
|
||||
write(") AS src ON (")
|
||||
|
||||
countUniques := 0
|
||||
for _, index := range table.Indexes {
|
||||
if index.Type != schemas.UniqueType {
|
||||
continue
|
||||
}
|
||||
if countUniques > 0 {
|
||||
write(" OR ")
|
||||
}
|
||||
countUniques++
|
||||
write("(")
|
||||
write("src.", quote(index.Cols[0]), "= target.", quote(index.Cols[0]))
|
||||
for _, col := range index.Cols[1:] {
|
||||
write(" AND src.", quote(col), "= target.", quote(col))
|
||||
}
|
||||
write(")")
|
||||
}
|
||||
if doUpdate {
|
||||
return "", nil, fmt.Errorf("unimplemented")
|
||||
}
|
||||
write(") WHEN NOT MATCHED THEN INSERT")
|
||||
if err := statement.genInsertValues(buf, columns, args); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return buf.String(), buf.Args(), nil
|
||||
}
|
||||
|
||||
// GenUpsertMapSQL generates insert map SQL
|
||||
func (statement *Statement) GenUpsertMapSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
|
||||
if statement.dialect.URI().DBType == schemas.MSSQL ||
|
||||
statement.dialect.URI().DBType == schemas.DAMENG ||
|
||||
statement.dialect.URI().DBType == schemas.ORACLE {
|
||||
return statement.genMergeMapSQL(doUpdate, columns, args, uniqueColValMap)
|
||||
}
|
||||
var (
|
||||
buf = builder.NewWriter()
|
||||
exprs = statement.ExprColumns
|
||||
table = statement.RefTable
|
||||
tableName = statement.TableName()
|
||||
)
|
||||
quote := statement.dialect.Quoter().Quote
|
||||
write := func(args ...string) {
|
||||
for _, arg := range args {
|
||||
_, _ = buf.WriteString(arg)
|
||||
}
|
||||
}
|
||||
|
||||
if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate {
|
||||
write("INSERT IGNORE INTO ", quote(tableName), " (")
|
||||
} else {
|
||||
write("INSERT INTO ", quote(tableName), " (")
|
||||
}
|
||||
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
write(")")
|
||||
|
||||
if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch statement.dialect.URI().DBType {
|
||||
case schemas.SQLITE, schemas.POSTGRES:
|
||||
if _, err := buf.WriteString(" ON CONFLICT DO "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if doUpdate {
|
||||
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
|
||||
} else {
|
||||
if _, err := buf.WriteString("NOTHING"); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
case schemas.MYSQL:
|
||||
if doUpdate {
|
||||
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
|
||||
}
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
|
||||
}
|
||||
|
||||
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES {
|
||||
if _, err := buf.WriteString(" RETURNING "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String(), buf.Args(), nil
|
||||
}
|
||||
|
||||
func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
|
||||
var (
|
||||
buf = builder.NewWriter()
|
||||
table = statement.RefTable
|
||||
exprs = statement.ExprColumns
|
||||
tableName = statement.TableName()
|
||||
)
|
||||
|
||||
quote := statement.dialect.Quoter().Quote
|
||||
write := func(args ...string) {
|
||||
for _, arg := range args {
|
||||
_, _ = buf.WriteString(arg)
|
||||
}
|
||||
}
|
||||
|
||||
write("MERGE INTO ", quote(tableName))
|
||||
if statement.dialect.URI().DBType == schemas.MSSQL {
|
||||
write("WITH (HOLDLOCK) AS target ")
|
||||
}
|
||||
write("USING (SELECT ")
|
||||
|
||||
uniqueCols := make([]string, 0, len(uniqueColValMap))
|
||||
for colName := range uniqueColValMap {
|
||||
uniqueCols = append(uniqueCols, colName)
|
||||
}
|
||||
for i, colName := range uniqueCols {
|
||||
if err := statement.WriteArg(buf, uniqueColValMap[colName]); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
write(" AS ", quote(colName))
|
||||
if i < len(uniqueCols)-1 {
|
||||
write(", ")
|
||||
}
|
||||
}
|
||||
write(") AS src ON (")
|
||||
|
||||
countUniques := 0
|
||||
for _, index := range table.Indexes {
|
||||
if index.Type != schemas.UniqueType {
|
||||
continue
|
||||
}
|
||||
if countUniques > 0 {
|
||||
write(" OR ")
|
||||
}
|
||||
countUniques++
|
||||
write("(")
|
||||
write("src.", quote(index.Cols[0]), "= target.", quote(index.Cols[0]))
|
||||
for _, col := range index.Cols[1:] {
|
||||
write(" AND src.", quote(col), "= target.", quote(col))
|
||||
}
|
||||
write(")")
|
||||
}
|
||||
if doUpdate {
|
||||
return "", nil, fmt.Errorf("unimplemented")
|
||||
}
|
||||
write(") WHEN NOT MATCHED THEN INSERT")
|
||||
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
write(")")
|
||||
|
||||
if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return buf.String(), buf.Args(), nil
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
// Copyright 2020 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func MapToSlices(m map[string]interface{}, exclude []string, trimmer func(string) string) ([]string, []interface{}) {
|
||||
columns := make([]string, 0, len(m))
|
||||
|
||||
outer:
|
||||
for colName := range m {
|
||||
trimmed := trimmer(colName)
|
||||
for _, excluded := range exclude {
|
||||
if strings.EqualFold(excluded, trimmed) {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
columns = append(columns, colName)
|
||||
}
|
||||
sort.Strings(columns)
|
||||
|
||||
args := make([]interface{}, 0, len(columns))
|
||||
for _, colName := range columns {
|
||||
args = append(args, m[colName])
|
||||
}
|
||||
|
||||
return columns, args
|
||||
}
|
||||
|
||||
func MapStringToSlices(m map[string]string, exclude []string, trimmer func(string) string) ([]string, []interface{}) {
|
||||
columns := make([]string, 0, len(m))
|
||||
|
||||
outer:
|
||||
for colName := range m {
|
||||
trimmed := trimmer(colName)
|
||||
for _, excluded := range exclude {
|
||||
if strings.EqualFold(excluded, trimmed) {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
columns = append(columns, colName)
|
||||
}
|
||||
sort.Strings(columns)
|
||||
|
||||
args := make([]interface{}, 0, len(columns))
|
||||
for _, colName := range columns {
|
||||
args = append(args, m[colName])
|
||||
}
|
||||
|
||||
return columns, args
|
||||
}
|
||||
|
||||
func MultipleMapToSlices(maps []map[string]interface{}, exclude []string, trimmer func(string) string) ([]string, [][]interface{}) {
|
||||
columns := make([]string, 0, len(maps[0]))
|
||||
|
||||
outer:
|
||||
for colName := range maps[0] {
|
||||
trimmed := trimmer(colName)
|
||||
for _, excluded := range exclude {
|
||||
if strings.EqualFold(excluded, trimmed) {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
columns = append(columns, colName)
|
||||
}
|
||||
sort.Strings(columns)
|
||||
|
||||
argss := make([][]interface{}, 0, len(maps))
|
||||
for _, m := range maps {
|
||||
args := make([]interface{}, 0, len(m))
|
||||
for _, colName := range columns {
|
||||
args = append(args, m[colName])
|
||||
}
|
||||
argss = append(argss, args)
|
||||
}
|
||||
|
||||
return columns, argss
|
||||
}
|
||||
|
||||
func MultipleMapStringToSlices(maps []map[string]string, exclude []string, trimmer func(string) string) ([]string, [][]interface{}) {
|
||||
columns := make([]string, 0, len(maps[0]))
|
||||
|
||||
outer:
|
||||
for colName := range maps[0] {
|
||||
trimmed := trimmer(colName)
|
||||
for _, excluded := range exclude {
|
||||
if strings.EqualFold(excluded, trimmed) {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
columns = append(columns, colName)
|
||||
}
|
||||
sort.Strings(columns)
|
||||
|
||||
argss := make([][]interface{}, 0, len(maps))
|
||||
for _, m := range maps {
|
||||
args := make([]interface{}, 0, len(m))
|
||||
for _, colName := range columns {
|
||||
args = append(args, m[colName])
|
||||
}
|
||||
argss = append(argss, args)
|
||||
}
|
||||
|
||||
return columns, argss
|
||||
}
|
|
@ -5,10 +5,10 @@
|
|||
package xorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -156,14 +156,14 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, e
|
|||
}
|
||||
args = append(args, val)
|
||||
|
||||
var colName = col.Name
|
||||
colName := col.Name
|
||||
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
|
||||
col := table.GetColumn(colName)
|
||||
setColumnTime(bean, col, t)
|
||||
})
|
||||
} else if col.IsVersion && session.statement.CheckVersion {
|
||||
args = append(args, 1)
|
||||
var colName = col.Name
|
||||
colName := col.Name
|
||||
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
|
||||
col := table.GetColumn(colName)
|
||||
setColumnInt(bean, col, 1)
|
||||
|
@ -276,7 +276,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
|
|||
processor.BeforeInsert()
|
||||
}
|
||||
|
||||
var tableName = session.statement.TableName()
|
||||
tableName := session.statement.TableName()
|
||||
table := session.statement.RefTable
|
||||
|
||||
colNames, args, err := session.genInsertColumns(bean)
|
||||
|
@ -290,102 +290,9 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
|
|||
}
|
||||
sqlStr = session.engine.dialect.Quoter().Replace(sqlStr)
|
||||
|
||||
handleAfterInsertProcessorFunc := func(bean interface{}) {
|
||||
if session.isAutoCommit {
|
||||
for _, closure := range session.afterClosures {
|
||||
closure(bean)
|
||||
}
|
||||
if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
|
||||
processor.AfterInsert()
|
||||
}
|
||||
} else {
|
||||
lenAfterClosures := len(session.afterClosures)
|
||||
if lenAfterClosures > 0 {
|
||||
if value, has := session.afterInsertBeans[bean]; has && value != nil {
|
||||
*value = append(*value, session.afterClosures...)
|
||||
} else {
|
||||
afterClosures := make([]func(interface{}), lenAfterClosures)
|
||||
copy(afterClosures, session.afterClosures)
|
||||
session.afterInsertBeans[bean] = &afterClosures
|
||||
}
|
||||
} else {
|
||||
if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
|
||||
session.afterInsertBeans[bean] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
|
||||
}
|
||||
|
||||
// if there is auto increment column and driver don't support return it
|
||||
if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID {
|
||||
var sql string
|
||||
var newArgs []interface{}
|
||||
var needCommit bool
|
||||
var id int64
|
||||
if session.engine.dialect.URI().DBType == schemas.ORACLE || session.engine.dialect.URI().DBType == schemas.DAMENG {
|
||||
if session.isAutoCommit { // if it's not in transaction
|
||||
if err := session.Begin(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
needCommit = true
|
||||
}
|
||||
_, err := session.exec(sqlStr, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
i := utils.IndexSlice(colNames, table.AutoIncrement)
|
||||
if i > -1 {
|
||||
id, err = convert.AsInt64(args[i])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
sql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName))
|
||||
}
|
||||
} else {
|
||||
sql = sqlStr
|
||||
newArgs = args
|
||||
}
|
||||
|
||||
if id == 0 {
|
||||
err := session.queryRow(sql, newArgs...).Scan(&id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if needCommit {
|
||||
if err := session.Commit(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if id == 0 {
|
||||
return 0, errors.New("insert successfully but not returned id")
|
||||
}
|
||||
}
|
||||
|
||||
defer handleAfterInsertProcessorFunc(bean)
|
||||
|
||||
_ = session.cacheInsert(tableName)
|
||||
|
||||
if table.Version != "" && session.statement.CheckVersion {
|
||||
verValue, err := table.VersionColumn().ValueOf(bean)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("%v", err)
|
||||
} else if verValue.IsValid() && verValue.CanSet() {
|
||||
session.incrVersionFieldValue(verValue)
|
||||
}
|
||||
}
|
||||
|
||||
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
return 1, convert.AssignValue(*aiValue, id)
|
||||
return session.execInsertSqlNoAutoReturn(sqlStr, bean, colNames, args)
|
||||
}
|
||||
|
||||
res, err := session.exec(sqlStr, args...)
|
||||
|
@ -393,7 +300,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
|
|||
return 0, err
|
||||
}
|
||||
|
||||
defer handleAfterInsertProcessorFunc(bean)
|
||||
defer session.handleAfterInsertProcessorFunc(bean)
|
||||
|
||||
_ = session.cacheInsert(tableName)
|
||||
|
||||
|
@ -432,31 +339,6 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
|
|||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
// InsertOne insert only one struct into database as a record.
|
||||
// The in parameter bean must a struct or a point to struct. The return
|
||||
// parameter is inserted and error
|
||||
// Deprecated: Please use Insert directly
|
||||
func (session *Session) InsertOne(bean interface{}) (int64, error) {
|
||||
if session.isAutoClose {
|
||||
defer session.Close()
|
||||
}
|
||||
|
||||
return session.insertStruct(bean)
|
||||
}
|
||||
|
||||
func (session *Session) cacheInsert(table string) error {
|
||||
if !session.statement.UseCache {
|
||||
return nil
|
||||
}
|
||||
cacher := session.engine.cacherMgr.GetCacher(table)
|
||||
if cacher == nil {
|
||||
return nil
|
||||
}
|
||||
session.engine.logger.Debugf("[cache] clear SQL: %v", table)
|
||||
cacher.ClearIds(table)
|
||||
return nil
|
||||
}
|
||||
|
||||
// genInsertColumns generates insert needed columns
|
||||
func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) {
|
||||
table := session.statement.RefTable
|
||||
|
@ -517,7 +399,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
|
|||
}
|
||||
args = append(args, val)
|
||||
|
||||
var colName = col.Name
|
||||
colName := col.Name
|
||||
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
|
||||
col := table.GetColumn(colName)
|
||||
setColumnTime(bean, col, t)
|
||||
|
@ -537,6 +419,112 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
|
|||
return colNames, args, nil
|
||||
}
|
||||
|
||||
func (session *Session) execInsertSqlNoAutoReturn(sqlStr string, bean interface{}, colNames []string, args []interface{}) (int64, error) {
|
||||
var newSql string
|
||||
var newArgs []interface{}
|
||||
var needCommit bool
|
||||
var id int64
|
||||
|
||||
tableName := session.statement.TableName()
|
||||
table := session.statement.RefTable
|
||||
|
||||
if session.engine.dialect.URI().DBType == schemas.ORACLE || session.engine.dialect.URI().DBType == schemas.DAMENG {
|
||||
if session.isAutoCommit { // if it's not in transaction
|
||||
if err := session.Begin(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
needCommit = true
|
||||
}
|
||||
res, err := session.exec(sqlStr, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, sql.ErrNoRows
|
||||
}
|
||||
i := utils.IndexSlice(colNames, table.AutoIncrement)
|
||||
if i > -1 {
|
||||
id, err = convert.AsInt64(args[i])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
newSql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName))
|
||||
}
|
||||
} else {
|
||||
newSql = sqlStr
|
||||
newArgs = args
|
||||
}
|
||||
|
||||
if id == 0 {
|
||||
err := session.queryRow(newSql, newArgs...).Scan(&id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if needCommit {
|
||||
if err := session.Commit(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if id == 0 {
|
||||
return 0, errors.New("insert successfully but not returned id")
|
||||
}
|
||||
}
|
||||
|
||||
defer session.handleAfterInsertProcessorFunc(bean)
|
||||
|
||||
_ = session.cacheInsert(tableName)
|
||||
|
||||
if table.Version != "" && session.statement.CheckVersion {
|
||||
verValue, err := table.VersionColumn().ValueOf(bean)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("%v", err)
|
||||
} else if verValue.IsValid() && verValue.CanSet() {
|
||||
session.incrVersionFieldValue(verValue)
|
||||
}
|
||||
}
|
||||
|
||||
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
return 1, convert.AssignValue(*aiValue, id)
|
||||
}
|
||||
|
||||
// InsertOne insert only one struct into database as a record.
|
||||
// The in parameter bean must a struct or a point to struct. The return
|
||||
// parameter is inserted and error
|
||||
// Deprecated: Please use Insert directly
|
||||
func (session *Session) InsertOne(bean interface{}) (int64, error) {
|
||||
if session.isAutoClose {
|
||||
defer session.Close()
|
||||
}
|
||||
|
||||
return session.insertStruct(bean)
|
||||
}
|
||||
|
||||
func (session *Session) cacheInsert(table string) error {
|
||||
if !session.statement.UseCache {
|
||||
return nil
|
||||
}
|
||||
cacher := session.engine.cacherMgr.GetCacher(table)
|
||||
if cacher == nil {
|
||||
return nil
|
||||
}
|
||||
session.engine.logger.Debugf("[cache] clear SQL: %v", table)
|
||||
cacher.ClearIds(table)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) {
|
||||
if len(m) == 0 {
|
||||
return 0, ErrParamsType
|
||||
|
@ -547,19 +535,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
|
|||
return 0, ErrTableNotFound
|
||||
}
|
||||
|
||||
var columns = make([]string, 0, len(m))
|
||||
exprs := session.statement.ExprColumns
|
||||
for k := range m {
|
||||
if !exprs.IsColExist(k) {
|
||||
columns = append(columns, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(columns)
|
||||
|
||||
var args = make([]interface{}, 0, len(m))
|
||||
for _, colName := range columns {
|
||||
args = append(args, m[colName])
|
||||
}
|
||||
columns, args := utils.MapToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
|
||||
|
||||
return session.insertMap(columns, args)
|
||||
}
|
||||
|
@ -574,23 +550,7 @@ func (session *Session) insertMultipleMapInterface(maps []map[string]interface{}
|
|||
return 0, ErrTableNotFound
|
||||
}
|
||||
|
||||
var columns = make([]string, 0, len(maps[0]))
|
||||
exprs := session.statement.ExprColumns
|
||||
for k := range maps[0] {
|
||||
if !exprs.IsColExist(k) {
|
||||
columns = append(columns, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(columns)
|
||||
|
||||
var argss = make([][]interface{}, 0, len(maps))
|
||||
for _, m := range maps {
|
||||
var args = make([]interface{}, 0, len(m))
|
||||
for _, colName := range columns {
|
||||
args = append(args, m[colName])
|
||||
}
|
||||
argss = append(argss, args)
|
||||
}
|
||||
columns, argss := utils.MultipleMapToSlices(maps, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
|
||||
|
||||
return session.insertMultipleMap(columns, argss)
|
||||
}
|
||||
|
@ -605,20 +565,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
|
|||
return 0, ErrTableNotFound
|
||||
}
|
||||
|
||||
var columns = make([]string, 0, len(m))
|
||||
exprs := session.statement.ExprColumns
|
||||
for k := range m {
|
||||
if !exprs.IsColExist(k) {
|
||||
columns = append(columns, k)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(columns)
|
||||
|
||||
var args = make([]interface{}, 0, len(m))
|
||||
for _, colName := range columns {
|
||||
args = append(args, m[colName])
|
||||
}
|
||||
columns, args := utils.MapStringToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
|
||||
|
||||
return session.insertMap(columns, args)
|
||||
}
|
||||
|
@ -633,23 +580,7 @@ func (session *Session) insertMultipleMapString(maps []map[string]string) (int64
|
|||
return 0, ErrTableNotFound
|
||||
}
|
||||
|
||||
var columns = make([]string, 0, len(maps[0]))
|
||||
exprs := session.statement.ExprColumns
|
||||
for k := range maps[0] {
|
||||
if !exprs.IsColExist(k) {
|
||||
columns = append(columns, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(columns)
|
||||
|
||||
var argss = make([][]interface{}, 0, len(maps))
|
||||
for _, m := range maps {
|
||||
var args = make([]interface{}, 0, len(m))
|
||||
for _, colName := range columns {
|
||||
args = append(args, m[colName])
|
||||
}
|
||||
argss = append(argss, args)
|
||||
}
|
||||
columns, argss := utils.MultipleMapStringToSlices(maps, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
|
||||
|
||||
return session.insertMultipleMap(columns, argss)
|
||||
}
|
||||
|
@ -707,3 +638,30 @@ func (session *Session) insertMultipleMap(columns []string, argss [][]interface{
|
|||
}
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
func (session *Session) handleAfterInsertProcessorFunc(bean interface{}) {
|
||||
if session.isAutoCommit {
|
||||
for _, closure := range session.afterClosures {
|
||||
closure(bean)
|
||||
}
|
||||
if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
|
||||
processor.AfterInsert()
|
||||
}
|
||||
} else {
|
||||
lenAfterClosures := len(session.afterClosures)
|
||||
if lenAfterClosures > 0 {
|
||||
if value, has := session.afterInsertBeans[bean]; has && value != nil {
|
||||
*value = append(*value, session.afterClosures...)
|
||||
} else {
|
||||
afterClosures := make([]func(interface{}), lenAfterClosures)
|
||||
copy(afterClosures, session.afterClosures)
|
||||
session.afterInsertBeans[bean] = &afterClosures
|
||||
}
|
||||
} else {
|
||||
if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
|
||||
session.afterInsertBeans[bean] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
|
||||
}
|
||||
|
|
|
@ -0,0 +1,298 @@
|
|||
package xorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"xorm.io/xorm/convert"
|
||||
"xorm.io/xorm/internal/utils"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
func (session *Session) InsertOnConflictDoNothing(beans ...interface{}) (int64, error) {
|
||||
return session.upsert(false, beans...)
|
||||
}
|
||||
|
||||
func (session *Session) Upsert(beans ...interface{}) (int64, error) {
|
||||
return session.upsert(true, beans...)
|
||||
}
|
||||
|
||||
func (session *Session) upsert(doUpdate bool, beans ...interface{}) (int64, error) {
|
||||
var affected int64
|
||||
var err error
|
||||
|
||||
if session.isAutoClose {
|
||||
defer session.Close()
|
||||
}
|
||||
|
||||
session.autoResetStatement = false
|
||||
defer func() {
|
||||
session.autoResetStatement = true
|
||||
session.resetStatement()
|
||||
}()
|
||||
for _, bean := range beans {
|
||||
var cnt int64
|
||||
var err error
|
||||
switch v := bean.(type) {
|
||||
case map[string]interface{}:
|
||||
cnt, err = session.upsertMapInterface(doUpdate, v)
|
||||
case []map[string]interface{}:
|
||||
for _, m := range v {
|
||||
cnt, err := session.upsertMapInterface(doUpdate, m)
|
||||
if err != nil {
|
||||
return affected, err
|
||||
}
|
||||
affected += cnt
|
||||
}
|
||||
case map[string]string:
|
||||
cnt, err = session.upsertMapString(doUpdate, v)
|
||||
case []map[string]string:
|
||||
for _, m := range v {
|
||||
cnt, err := session.upsertMapString(doUpdate, m)
|
||||
if err != nil {
|
||||
return affected, err
|
||||
}
|
||||
affected += cnt
|
||||
}
|
||||
default:
|
||||
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
|
||||
if sliceValue.Kind() == reflect.Slice {
|
||||
if sliceValue.Len() <= 0 {
|
||||
return 0, ErrNoElementsOnSlice
|
||||
}
|
||||
for i := 0; i < sliceValue.Len(); i++ {
|
||||
v := sliceValue.Index(i)
|
||||
bean := v.Interface()
|
||||
cnt, err := session.upsertStruct(doUpdate, bean)
|
||||
if err != nil {
|
||||
return affected, err
|
||||
}
|
||||
affected += cnt
|
||||
}
|
||||
} else {
|
||||
cnt, err = session.upsertStruct(doUpdate, bean)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return affected, err
|
||||
}
|
||||
affected += cnt
|
||||
}
|
||||
|
||||
return affected, err
|
||||
}
|
||||
|
||||
func (session *Session) upsertMapInterface(doUpdate bool, m map[string]interface{}) (int64, error) {
|
||||
if len(m) == 0 {
|
||||
return 0, ErrParamsType
|
||||
}
|
||||
|
||||
tableName := session.statement.TableName()
|
||||
if len(tableName) == 0 {
|
||||
return 0, ErrTableNotFound
|
||||
}
|
||||
|
||||
columns, args := utils.MapToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
|
||||
return session.upsertMap(doUpdate, columns, args)
|
||||
}
|
||||
|
||||
func (session *Session) upsertMapString(doUpdate bool, m map[string]string) (int64, error) {
|
||||
if len(m) == 0 {
|
||||
return 0, ErrParamsType
|
||||
}
|
||||
|
||||
tableName := session.statement.TableName()
|
||||
if len(tableName) == 0 {
|
||||
return 0, ErrTableNotFound
|
||||
}
|
||||
|
||||
columns, args := utils.MapStringToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
|
||||
return session.upsertMap(doUpdate, columns, args)
|
||||
}
|
||||
|
||||
func (session *Session) upsertMap(doUpdate bool, columns []string, args []interface{}) (int64, error) {
|
||||
tableName := session.statement.TableName()
|
||||
if len(tableName) == 0 {
|
||||
return 0, ErrTableNotFound
|
||||
}
|
||||
|
||||
uniqueColValMap, err := session.getUniqueColumns(columns, args)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
sql, args, err := session.statement.GenUpsertMapSQL(doUpdate, columns, args, uniqueColValMap)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sql = session.engine.dialect.Quoter().Replace(sql)
|
||||
|
||||
if err := session.cacheInsert(tableName); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
res, err := session.exec(sql, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return affected, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, error) {
|
||||
if err := session.statement.SetRefBean(bean); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(session.statement.TableName()) == 0 {
|
||||
return 0, ErrTableNotFound
|
||||
}
|
||||
|
||||
// handle BeforeInsertProcessor
|
||||
for _, closure := range session.beforeClosures {
|
||||
closure(bean)
|
||||
}
|
||||
cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
|
||||
|
||||
if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
|
||||
processor.BeforeInsert()
|
||||
}
|
||||
|
||||
tableName := session.statement.TableName()
|
||||
table := session.statement.RefTable
|
||||
|
||||
colNames, args, err := session.genInsertColumns(bean)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
uniqueColValMap, err := session.getUniqueColumns(colNames, args)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, colNames, args, uniqueColValMap)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sqlStr = session.engine.dialect.Quoter().Replace(sqlStr)
|
||||
|
||||
// if there is auto increment column and driver doesn't support return it
|
||||
if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID {
|
||||
n, err := session.execInsertSqlNoAutoReturn(sqlStr, bean, colNames, args)
|
||||
if err == sql.ErrNoRows {
|
||||
return n, nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
res, err := session.exec(sqlStr, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer session.handleAfterInsertProcessorFunc(bean)
|
||||
|
||||
_ = session.cacheInsert(tableName)
|
||||
|
||||
if table.Version != "" && session.statement.CheckVersion {
|
||||
verValue, err := table.VersionColumn().ValueOf(bean)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("%v", err)
|
||||
} else if verValue.IsValid() && verValue.CanSet() {
|
||||
session.incrVersionFieldValue(verValue)
|
||||
}
|
||||
}
|
||||
|
||||
if table.AutoIncrement == "" {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil || n == 0 {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var id int64
|
||||
id, err = res.LastInsertId()
|
||||
if err != nil || id <= 0 {
|
||||
return n, err
|
||||
}
|
||||
|
||||
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
|
||||
if err != nil {
|
||||
session.engine.logger.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
|
||||
return n, err
|
||||
}
|
||||
|
||||
if err := convert.AssignValue(*aiValue, id); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (session *Session) getUniqueColumns(colNames []string, args []interface{}) (uniqueColValMap map[string]interface{}, err error) {
|
||||
uniqueColValMap = make(map[string]interface{})
|
||||
table := session.statement.RefTable
|
||||
if len(table.Indexes) == 0 {
|
||||
return nil, fmt.Errorf("provided bean has no unique constraints")
|
||||
}
|
||||
|
||||
// Iterate across the indexes in the provided table
|
||||
for _, index := range table.Indexes {
|
||||
if index.Type != schemas.UniqueType {
|
||||
continue
|
||||
}
|
||||
|
||||
// index is a Unique constraint
|
||||
indexCol:
|
||||
for _, indexColumnName := range index.Cols {
|
||||
if _, has := uniqueColValMap[indexColumnName]; has {
|
||||
// column is already included in uniqueCols so we don't need to add it again
|
||||
continue indexCol
|
||||
}
|
||||
|
||||
// Now iterate across colNames and add to the uniqueCols
|
||||
for i, col := range colNames {
|
||||
if col == indexColumnName {
|
||||
uniqueColValMap[col] = args[i]
|
||||
continue indexCol
|
||||
}
|
||||
}
|
||||
|
||||
indexColumn := table.GetColumn(indexColumnName)
|
||||
if !indexColumn.DefaultIsEmpty {
|
||||
uniqueColValMap[indexColumnName] = indexColumn.Default
|
||||
}
|
||||
|
||||
if indexColumn.MapType == schemas.ONLYFROMDB || indexColumn.IsAutoIncrement {
|
||||
continue
|
||||
}
|
||||
if session.statement.OmitColumnMap.Contain(indexColumn.Name) {
|
||||
continue
|
||||
}
|
||||
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) {
|
||||
continue
|
||||
}
|
||||
// FIXME: what do we do here?!
|
||||
if session.statement.IncrColumns.IsColExist(indexColumn.Name) {
|
||||
continue
|
||||
} else if session.statement.DecrColumns.IsColExist(indexColumn.Name) {
|
||||
continue
|
||||
} else if session.statement.ExprColumns.IsColExist(indexColumn.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
// FIXME: not sure if there's anything else we can do
|
||||
return nil, fmt.Errorf("provided bean does not provide a value for unique constraint %s field %s which has no default", index.Name, indexColumnName)
|
||||
}
|
||||
}
|
||||
return uniqueColValMap, nil
|
||||
}
|
Loading…
Reference in New Issue