xorm/session_upsert.go

397 lines
10 KiB
Go

package xorm
import (
"database/sql"
"fmt"
"reflect"
"xorm.io/xorm/convert"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
func (session *Session) InsertOnConflictDoNothing(beans ...interface{}) (int64, error) {
return session.upsert(false, beans...)
}
func (session *Session) Upsert(beans ...interface{}) (int64, error) {
return session.upsert(true, beans...)
}
func (session *Session) upsert(doUpdate bool, beans ...interface{}) (int64, error) {
var affected int64
var err error
if session.isAutoClose {
defer session.Close()
}
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
fmt.Println(session.statement.TableName())
for _, bean := range beans {
var cnt int64
var err error
switch v := bean.(type) {
case map[string]interface{}:
cnt, err = session.upsertMapInterface(doUpdate, v)
case []map[string]interface{}: // FIXME: handle multiple?
for _, m := range v {
cnt, err := session.upsertMapInterface(doUpdate, m)
if err != nil {
return affected, err
}
affected += cnt
}
case map[string]string:
cnt, err = session.upsertMapString(doUpdate, v)
case []map[string]string: // FIXME: handle multiple?
for _, m := range v {
cnt, err := session.upsertMapString(doUpdate, m)
if err != nil {
return affected, err
}
affected += cnt
}
default:
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice { // FIXME: handle multiple?
if sliceValue.Len() <= 0 {
return 0, ErrNoElementsOnSlice
}
for i := 0; i < sliceValue.Len(); i++ {
v := sliceValue.Index(i)
bean := v.Interface()
cnt, err := session.upsertStruct(doUpdate, bean)
if err != nil {
return affected, err
}
affected += cnt
}
} else {
cnt, err = session.upsertStruct(doUpdate, bean)
}
}
if err != nil {
return affected, err
}
affected += cnt
}
return affected, err
}
func (session *Session) upsertMapInterface(doUpdate bool, m map[string]interface{}) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
}
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
columns, args := utils.MapToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
return session.upsertMap(doUpdate, columns, args)
}
func (session *Session) upsertMapString(doUpdate bool, m map[string]string) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
}
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
columns, args := utils.MapStringToSlices(m, session.statement.ExprColumns.ColNamesTrim(), schemas.CommonQuoter.Trim)
return session.upsertMap(doUpdate, columns, args)
}
func (session *Session) upsertMap(doUpdate bool, columns []string, args []interface{}) (int64, error) {
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
if session.statement.RefTable == nil {
return 0, ErrTableNotFound
}
uniqueColValMap, uniqueConstraints, err := session.getUniqueColumns(doUpdate, columns, args)
if err != nil {
return 0, err
}
sql, args, err := session.statement.GenUpsertSQL(doUpdate, false, columns, args, uniqueColValMap, uniqueConstraints)
if err != nil {
return 0, err
}
sql = session.engine.dialect.Quoter().Replace(sql)
if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
if doUpdate && session.engine.dialect.URI().DBType == schemas.MYSQL && affected == 2 {
// for MYSQL if INSERT ... ON CONFLICT RowsAffected == 2 means UPDATE
affected = 1
}
return affected, nil
}
func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, error) {
if err := session.statement.SetRefBean(bean); err != nil {
return 0, err
}
if len(session.statement.TableName()) == 0 {
return 0, ErrTableNotFound
}
// handle BeforeInsertProcessor
for _, closure := range session.beforeClosures {
closure(bean)
}
cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
processor.BeforeInsert()
}
tableName := session.statement.TableName()
table := session.statement.RefTable
colNames, args, err := session.genInsertColumns(bean)
if err != nil {
return 0, err
}
uniqueColValMap, uniqueConstraints, err := session.getUniqueColumns(doUpdate, colNames, args)
if err != nil {
return 0, err
}
sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, true, colNames, args, uniqueColValMap, uniqueConstraints)
if err != nil {
return 0, err
}
sqlStr = session.engine.dialect.Quoter().Replace(sqlStr)
// if there is auto increment column and driver doesn't support return it
if len(table.AutoIncrement) > 0 && (!session.engine.driver.Features().SupportReturnInsertedID || session.engine.dialect.URI().DBType == schemas.SQLITE) {
n, err := session.execInsertSqlNoAutoReturn(sqlStr, bean, colNames, args)
if err == sql.ErrNoRows {
return n, nil
}
return n, err
}
res, err := session.exec(sqlStr, args...)
if err != nil {
return 0, err
}
defer session.handleAfterInsertProcessorFunc(bean)
_ = session.cacheInsert(tableName)
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
}
n, err := res.RowsAffected()
if err != nil || n == 0 {
return 0, err
}
if session.engine.dialect.URI().DBType == schemas.MYSQL && n == 2 {
// for MYSQL if INSERT ... ON CONFLICT RowsAffected == 2 means UPDATE
n = 1
}
if table.AutoIncrement == "" {
return n, nil
}
id, err := res.LastInsertId()
if err != nil || id <= 0 {
return n, err
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return n, err
}
if err := convert.AssignValue(*aiValue, id); err != nil {
return 0, err
}
return n, err
}
var (
ErrNoUniqueConstraints = fmt.Errorf("provided bean has no unique constraints")
ErrMultipleUniqueConstraints = fmt.Errorf("cannot upsert if there is more than one unique constraint tested")
)
func (session *Session) getUniqueColumns(doUpdate bool, argColumns []string, args []interface{}) (uniqueColValMap map[string]interface{}, constraints [][]string, err error) {
// We need to collect the constraints that are being "tested" by argColumns as compared to the table
//
// There are two cases:
//
// 1. Insert on conflict do nothing
// 2. Upsert
//
// If we are an "Insert on conflict do nothing" then more than one "constraint" can be tested.
// If we are an "Upsert" only one "constraint" can be tested.
//
// In Xorm the only constraints we know of are "Unique Indices" and "Primary Keys".
//
// For unique indices - every column in the unique index is being tested.
//
// If the primary key is a single column and it is autoincrement then an empty PK column is not testing an unique constraint
// otherwise it does count.
uniqueColValMap = make(map[string]interface{})
table := session.statement.RefTable
// Shortcut when there are no indices and no primary keys
if len(table.Indexes) == 0 && len(table.PrimaryKeys) == 0 {
return nil, nil, ErrNoUniqueConstraints
}
numberOfUniqueConstraints := 0
// Check the primary key:
switch len(table.PrimaryKeys) {
case 0:
// No primary keys - nothing to do
case 1:
// check if the pkColumn is included
value := session.getUniqueColumnValue(table.PrimaryKeys[0], argColumns, args)
if value != nil {
numberOfUniqueConstraints++
uniqueColValMap[table.PrimaryKeys[0]] = value
constraints = append(constraints, table.PrimaryKeys)
}
default:
numberOfUniqueConstraints++
constraints = append(constraints, table.PrimaryKeys)
for _, column := range table.PrimaryKeys {
value := session.getUniqueColumnValue(column, argColumns, args)
if value == nil {
value = "" // default to empty
}
uniqueColValMap[column] = value
}
}
// Iterate across the indexes in the provided table
for _, index := range table.Indexes {
if index.Type != schemas.UniqueType {
continue
}
numberOfUniqueConstraints++
constraints = append(constraints, index.Cols)
// index is a Unique constraint
for _, column := range index.Cols {
if _, has := uniqueColValMap[column]; has {
continue
}
value := session.getUniqueColumnValue(column, argColumns, args)
if value == nil {
value = "" // default to empty
}
uniqueColValMap[column] = value
}
}
if doUpdate && numberOfUniqueConstraints > 1 {
return nil, nil, ErrMultipleUniqueConstraints
}
if len(constraints) == 0 {
return nil, nil, ErrNoUniqueConstraints
}
return uniqueColValMap, constraints, nil
}
func (session *Session) getUniqueColumnValue(indexColumnName string, argColumns []string, args []interface{}) (value interface{}) {
table := session.statement.RefTable
// Now iterate across colNames and add to the uniqueCols
for i, col := range argColumns {
if col == indexColumnName {
return args[i]
}
}
indexColumn := table.GetColumn(indexColumnName)
if indexColumn.IsAutoIncrement {
return nil
}
if !indexColumn.DefaultIsEmpty {
value = indexColumn.Default
}
if indexColumn.MapType == schemas.ONLYFROMDB {
return value
}
// FIXME: what do we do here?!
if session.statement.OmitColumnMap.Contain(indexColumn.Name) {
return value
}
// FIXME: what do we do here?!
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) {
return value
}
// FIXME: what do we do here?!
if session.statement.IncrColumns.IsColExist(indexColumn.Name) {
for _, exprCol := range session.statement.IncrColumns {
if exprCol.ColName == indexColumn.Name {
return exprCol.Arg
}
}
return value
} else if session.statement.DecrColumns.IsColExist(indexColumn.Name) {
for _, exprCol := range session.statement.DecrColumns {
if exprCol.ColName == indexColumn.Name {
return exprCol.Arg
}
}
return value
} else if session.statement.ExprColumns.IsColExist(indexColumn.Name) {
for _, exprCol := range session.statement.ExprColumns {
if exprCol.ColName == indexColumn.Name {
return exprCol.Arg
}
}
}
// FIXME: not sure if there's anything else we can do
return value
}