More cleanly handle primary keys as unique constraints
Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
parent
4bf706dd0c
commit
625167ded5
|
@ -59,16 +59,16 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) includeAutoIncrement(colNames []string) bool {
|
func (statement *Statement) includeAutoIncrement(colNames []string) bool {
|
||||||
includesAutoIncrement := len(statement.RefTable.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
|
needToIncludeAutoIncrement := len(statement.RefTable.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
|
||||||
if includesAutoIncrement {
|
if needToIncludeAutoIncrement {
|
||||||
for _, col := range colNames {
|
for _, col := range colNames {
|
||||||
if strings.EqualFold(col, statement.RefTable.AutoIncrement) {
|
if strings.EqualFold(col, statement.RefTable.AutoIncrement) {
|
||||||
includesAutoIncrement = false
|
needToIncludeAutoIncrement = false
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return includesAutoIncrement
|
return needToIncludeAutoIncrement
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) genInsertValues(buf *builder.BytesWriter, colNames []string, args []interface{}) error {
|
func (statement *Statement) genInsertValues(buf *builder.BytesWriter, colNames []string, args []interface{}) error {
|
||||||
|
|
|
@ -12,11 +12,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// GenUpsertSQL generates upsert beans SQL
|
// GenUpsertSQL generates upsert beans SQL
|
||||||
func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
|
func (statement *Statement) GenUpsertSQL(doUpdate bool, addOuput bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}, uniqueConstraints [][]string) (string, []interface{}, error) {
|
||||||
if statement.dialect.URI().DBType == schemas.MSSQL ||
|
if statement.dialect.URI().DBType == schemas.MSSQL ||
|
||||||
statement.dialect.URI().DBType == schemas.DAMENG ||
|
statement.dialect.URI().DBType == schemas.DAMENG ||
|
||||||
statement.dialect.URI().DBType == schemas.ORACLE {
|
statement.dialect.URI().DBType == schemas.ORACLE {
|
||||||
return statement.genMergeSQL(doUpdate, columns, args, uniqueColValMap)
|
return statement.genMergeSQL(doUpdate, addOuput, columns, args, uniqueColValMap, uniqueConstraints)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -70,35 +70,14 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [
|
||||||
}
|
}
|
||||||
case schemas.POSTGRES:
|
case schemas.POSTGRES:
|
||||||
if doUpdate {
|
if doUpdate {
|
||||||
primaryColumnIncluded := false
|
// In doUpdate we know that uniqueConstraints has to be length 1
|
||||||
for _, primaryKeyColumn := range table.PrimaryKeys {
|
write(" ON CONFLICT (", quote(uniqueConstraints[0][0]))
|
||||||
if _, has := uniqueColValMap[primaryKeyColumn]; !has {
|
for _, uniqueColumn := range uniqueConstraints[0][1:] {
|
||||||
continue
|
write(", ", uniqueColumn)
|
||||||
}
|
|
||||||
primaryColumnIncluded = true
|
|
||||||
}
|
}
|
||||||
if primaryColumnIncluded {
|
write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
|
||||||
write(" ON CONFLICT (", quote(table.PrimaryKeys[0]))
|
for _, column := range updateColumns[1:] {
|
||||||
for _, col := range table.PrimaryKeys[1:] {
|
write(", ", column, " = excluded.", column)
|
||||||
write(", ", quote(col))
|
|
||||||
}
|
|
||||||
write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
|
|
||||||
for _, column := range updateColumns[1:] {
|
|
||||||
write(", ", column, " = excluded.", column)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, index := range table.Indexes {
|
|
||||||
if index.Type != schemas.UniqueType {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
write(" ON CONFLICT (", quote(index.Cols[0]))
|
|
||||||
for _, col := range index.Cols[1:] {
|
|
||||||
write(", ", quote(col))
|
|
||||||
}
|
|
||||||
write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
|
|
||||||
for _, column := range updateColumns[1:] {
|
|
||||||
write(", ", column, " = excluded.", column)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
write(" ON CONFLICT DO NOTHING")
|
write(" ON CONFLICT DO NOTHING")
|
||||||
|
@ -131,7 +110,7 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [
|
||||||
return buf.String(), buf.Args(), nil
|
return buf.String(), buf.Args(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
|
func (statement *Statement) genMergeSQL(doUpdate bool, addOutput bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}, uniqueConstraints [][]string) (string, []interface{}, error) {
|
||||||
var (
|
var (
|
||||||
buf = builder.NewWriter()
|
buf = builder.NewWriter()
|
||||||
table = statement.RefTable
|
table = statement.RefTable
|
||||||
|
@ -151,18 +130,16 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
|
||||||
}
|
}
|
||||||
write(" AS target USING (SELECT ")
|
write(" AS target USING (SELECT ")
|
||||||
|
|
||||||
uniqueCols := make([]string, 0, len(uniqueColValMap))
|
uniqueColumnsCount := 0
|
||||||
for colName := range uniqueColValMap {
|
for uniqueColumn, uniqueValue := range uniqueColValMap {
|
||||||
uniqueCols = append(uniqueCols, colName)
|
if uniqueColumnsCount > 0 {
|
||||||
}
|
|
||||||
for i, colName := range uniqueCols {
|
|
||||||
if err := statement.WriteArg(buf, uniqueColValMap[colName]); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
write(" AS ", quote(colName))
|
|
||||||
if i < len(uniqueCols)-1 {
|
|
||||||
write(", ")
|
write(", ")
|
||||||
}
|
}
|
||||||
|
if err := statement.WriteArg(buf, uniqueValue); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
write(" AS ", quote(uniqueColumn))
|
||||||
|
uniqueColumnsCount++
|
||||||
}
|
}
|
||||||
|
|
||||||
var updateColumns []string
|
var updateColumns []string
|
||||||
|
@ -181,37 +158,13 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
|
||||||
}
|
}
|
||||||
|
|
||||||
write(") AS src ON (")
|
write(") AS src ON (")
|
||||||
|
for i, uniqueColumns := range uniqueConstraints {
|
||||||
countUniques := 0
|
if i > 0 { // if !doUpdate there may be more than one uniqueConstraint
|
||||||
primaryColumnIncluded := false
|
|
||||||
for _, primaryKeyColumn := range table.PrimaryKeys {
|
|
||||||
if _, has := uniqueColValMap[primaryKeyColumn]; !has {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !primaryColumnIncluded {
|
|
||||||
write("(")
|
|
||||||
} else {
|
|
||||||
write(" AND ")
|
|
||||||
}
|
|
||||||
write("src.", quote(primaryKeyColumn), " = target.", quote(primaryKeyColumn))
|
|
||||||
primaryColumnIncluded = true
|
|
||||||
}
|
|
||||||
if primaryColumnIncluded {
|
|
||||||
write(")")
|
|
||||||
countUniques++
|
|
||||||
}
|
|
||||||
for _, index := range table.Indexes {
|
|
||||||
if index.Type != schemas.UniqueType {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if countUniques > 0 {
|
|
||||||
write(" OR ")
|
write(" OR ")
|
||||||
}
|
}
|
||||||
countUniques++
|
write("(src.", quote(uniqueColumns[0]), " = target.", quote(uniqueColumns[0]))
|
||||||
write("(")
|
for _, uniqueColumn := range uniqueColumns[1:] {
|
||||||
write("src.", quote(index.Cols[0]), " = target.", quote(index.Cols[0]))
|
write(" AND src.", quote(uniqueColumn), " = target.", quote(uniqueColumn))
|
||||||
for _, col := range index.Cols[1:] {
|
|
||||||
write(" AND src.", quote(col), " = target.", quote(col))
|
|
||||||
}
|
}
|
||||||
write(")")
|
write(")")
|
||||||
}
|
}
|
||||||
|
@ -228,10 +181,10 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
|
||||||
write(" WHEN NOT MATCHED THEN INSERT ")
|
write(" WHEN NOT MATCHED THEN INSERT ")
|
||||||
includeAutoIncrement := statement.includeAutoIncrement(columns)
|
includeAutoIncrement := statement.includeAutoIncrement(columns)
|
||||||
if len(columns) == 0 && statement.dialect.URI().DBType == schemas.MSSQL {
|
if len(columns) == 0 && statement.dialect.URI().DBType == schemas.MSSQL {
|
||||||
write(" DEFAULT VALUES ")
|
write("DEFAULT VALUES ")
|
||||||
} else {
|
} else {
|
||||||
// We have some values - Write the column names we need to insert:
|
// We have some values - Write the column names we need to insert:
|
||||||
write(" (")
|
write("(")
|
||||||
if includeAutoIncrement {
|
if includeAutoIncrement {
|
||||||
columns = append(columns, table.AutoIncrement)
|
columns = append(columns, table.AutoIncrement)
|
||||||
}
|
}
|
||||||
|
@ -246,191 +199,12 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
|
if addOutput {
|
||||||
return "", nil, err
|
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
write(";")
|
write(";")
|
||||||
return buf.String(), buf.Args(), nil
|
return buf.String(), buf.Args(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenUpsertMapSQL generates insert map SQL
|
|
||||||
func (statement *Statement) GenUpsertMapSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
|
|
||||||
if statement.dialect.URI().DBType == schemas.MSSQL ||
|
|
||||||
statement.dialect.URI().DBType == schemas.DAMENG ||
|
|
||||||
statement.dialect.URI().DBType == schemas.ORACLE {
|
|
||||||
return statement.genMergeMapSQL(doUpdate, columns, args, uniqueColValMap)
|
|
||||||
}
|
|
||||||
var (
|
|
||||||
buf = builder.NewWriter()
|
|
||||||
exprs = statement.ExprColumns
|
|
||||||
table = statement.RefTable
|
|
||||||
tableName = statement.TableName()
|
|
||||||
)
|
|
||||||
quote := statement.dialect.Quoter().Quote
|
|
||||||
write := func(args ...string) {
|
|
||||||
for _, arg := range args {
|
|
||||||
_, _ = buf.WriteString(arg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var updateColumns []string
|
|
||||||
if doUpdate {
|
|
||||||
updateColumns = make([]string, 0, len(columns))
|
|
||||||
for _, column := range append(columns, exprs.ColNames()...) {
|
|
||||||
if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
updateColumns = append(updateColumns, quote(column))
|
|
||||||
}
|
|
||||||
doUpdate = doUpdate && (len(updateColumns) > 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate {
|
|
||||||
write("INSERT IGNORE INTO ", quote(tableName), " (")
|
|
||||||
} else {
|
|
||||||
write("INSERT INTO ", quote(tableName), " (")
|
|
||||||
}
|
|
||||||
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
write(")")
|
|
||||||
|
|
||||||
if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch statement.dialect.URI().DBType {
|
|
||||||
case schemas.SQLITE, schemas.POSTGRES:
|
|
||||||
write(" ON CONFLICT DO ")
|
|
||||||
if doUpdate {
|
|
||||||
write("UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
|
|
||||||
for _, column := range updateColumns[1:] {
|
|
||||||
write(", ", column, " = excluded.", column)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
write("NOTHING")
|
|
||||||
}
|
|
||||||
case schemas.MYSQL:
|
|
||||||
if doUpdate {
|
|
||||||
// FIXME: mysql >= 8.0.19 should use table alias
|
|
||||||
write(" ON DUPLICATE KEY ")
|
|
||||||
write("UPDATE ", updateColumns[0], " = VALUES(", updateColumns[0], ")")
|
|
||||||
for _, column := range updateColumns[1:] {
|
|
||||||
write(", ", column, " = VALUES(", column, ")")
|
|
||||||
}
|
|
||||||
if len(table.AutoIncrement) > 0 {
|
|
||||||
write(", ", quote(table.AutoIncrement), " = LAST_INSERT_ID(", quote(table.AutoIncrement), ")")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(table.AutoIncrement) > 0 &&
|
|
||||||
(statement.dialect.URI().DBType == schemas.POSTGRES ||
|
|
||||||
statement.dialect.URI().DBType == schemas.SQLITE) {
|
|
||||||
write(" RETURNING ")
|
|
||||||
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.String(), buf.Args(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
|
|
||||||
var (
|
|
||||||
buf = builder.NewWriter()
|
|
||||||
table = statement.RefTable
|
|
||||||
exprs = statement.ExprColumns
|
|
||||||
tableName = statement.TableName()
|
|
||||||
)
|
|
||||||
|
|
||||||
quote := statement.dialect.Quoter().Quote
|
|
||||||
write := func(args ...string) {
|
|
||||||
for _, arg := range args {
|
|
||||||
_, _ = buf.WriteString(arg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
write("MERGE INTO ", quote(tableName))
|
|
||||||
if statement.dialect.URI().DBType == schemas.MSSQL {
|
|
||||||
write(" WITH (HOLDLOCK)")
|
|
||||||
}
|
|
||||||
write(" AS target USING (SELECT ")
|
|
||||||
|
|
||||||
uniqueCols := make([]string, 0, len(uniqueColValMap))
|
|
||||||
for colName := range uniqueColValMap {
|
|
||||||
uniqueCols = append(uniqueCols, colName)
|
|
||||||
}
|
|
||||||
for i, colName := range uniqueCols {
|
|
||||||
if err := statement.WriteArg(buf, uniqueColValMap[colName]); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
write(" AS ", quote(colName))
|
|
||||||
if i < len(uniqueCols)-1 {
|
|
||||||
write(", ")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var updateColumns []string
|
|
||||||
var updateArgs []interface{}
|
|
||||||
if doUpdate {
|
|
||||||
updateColumns = make([]string, 0, len(columns))
|
|
||||||
for _, expr := range exprs {
|
|
||||||
if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(expr.ColName)]; has {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
updateColumns = append(updateColumns, quote(expr.ColName))
|
|
||||||
updateArgs = append(updateArgs, expr.Arg)
|
|
||||||
}
|
|
||||||
for i, column := range columns {
|
|
||||||
if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
updateColumns = append(updateColumns, quote(column))
|
|
||||||
updateArgs = append(updateArgs, args[i])
|
|
||||||
}
|
|
||||||
doUpdate = doUpdate && (len(updateColumns) > 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
write(") AS src ON (")
|
|
||||||
|
|
||||||
countUniques := 0
|
|
||||||
for _, index := range table.Indexes {
|
|
||||||
if index.Type != schemas.UniqueType {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if countUniques > 0 {
|
|
||||||
write(" OR ")
|
|
||||||
}
|
|
||||||
countUniques++
|
|
||||||
write("(")
|
|
||||||
write("src.", quote(index.Cols[0]), " = target.", quote(index.Cols[0]))
|
|
||||||
for _, col := range index.Cols[1:] {
|
|
||||||
write(" AND src.", quote(col), " = target.", quote(col))
|
|
||||||
}
|
|
||||||
write(")")
|
|
||||||
}
|
|
||||||
write(")")
|
|
||||||
if doUpdate {
|
|
||||||
write(" WHEN MATCHED THEN UPDATE SET ")
|
|
||||||
write("target.", quote(updateColumns[0]), " = ?")
|
|
||||||
buf.Append(updateArgs[0])
|
|
||||||
for i, col := range updateColumns[1:] {
|
|
||||||
write(", target.", quote(col), " = ?")
|
|
||||||
buf.Append(updateArgs[i+1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
write(" WHEN NOT MATCHED THEN INSERT ")
|
|
||||||
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
write(")")
|
|
||||||
|
|
||||||
if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
write(";")
|
|
||||||
|
|
||||||
return buf.String(), buf.Args(), nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -117,15 +117,12 @@ func (session *Session) upsertMap(doUpdate bool, columns []string, args []interf
|
||||||
return 0, ErrTableNotFound
|
return 0, ErrTableNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
uniqueColValMap, numberOfUniqueConstraints, err := session.getUniqueColumns(columns, args)
|
uniqueColValMap, uniqueConstraints, err := session.getUniqueColumns(doUpdate, columns, args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if doUpdate && numberOfUniqueConstraints > 1 {
|
|
||||||
return 0, fmt.Errorf("cannot upsert if there is more than one unique constraint")
|
|
||||||
}
|
|
||||||
|
|
||||||
sql, args, err := session.statement.GenUpsertMapSQL(doUpdate, columns, args, uniqueColValMap)
|
sql, args, err := session.statement.GenUpsertSQL(doUpdate, false, columns, args, uniqueColValMap, uniqueConstraints)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -172,15 +169,12 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
uniqueColValMap, numberOfUniqueConstraints, err := session.getUniqueColumns(colNames, args)
|
uniqueColValMap, uniqueConstraints, err := session.getUniqueColumns(doUpdate, colNames, args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if doUpdate && numberOfUniqueConstraints > 1 {
|
|
||||||
return 0, fmt.Errorf("cannot upsert if there is more than one unique constraint")
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, colNames, args, uniqueColValMap)
|
sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, true, colNames, args, uniqueColValMap, uniqueConstraints)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -247,30 +241,60 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *Session) getUniqueColumns(argColumns []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) {
|
var (
|
||||||
|
ErrNoUniqueConstraints = fmt.Errorf("provided bean has no unique constraints")
|
||||||
|
ErrMultipleUniqueConstraints = fmt.Errorf("cannot upsert if there is more than one unique constraint tested")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (session *Session) getUniqueColumns(doUpdate bool, argColumns []string, args []interface{}) (uniqueColValMap map[string]interface{}, constraints [][]string, err error) {
|
||||||
|
// We need to collect the constraints that are being "tested" by argColumns as compared to the table
|
||||||
|
//
|
||||||
|
// There are two cases:
|
||||||
|
//
|
||||||
|
// 1. Insert on conflict do nothing
|
||||||
|
// 2. Upsert
|
||||||
|
//
|
||||||
|
// If we are an "Insert on conflict do nothing" then more than one "constraint" can be tested.
|
||||||
|
// If we are an "Upsert" only one "constraint" can be tested.
|
||||||
|
//
|
||||||
|
// In Xorm the only constraints we know of are "Unique Indices" and "Primary Keys".
|
||||||
|
//
|
||||||
|
// For unique indices - every column in the unique index is being tested.
|
||||||
|
//
|
||||||
|
// If the primary key is a single column and it is autoincrement then an empty PK column is not testing an unique constraint
|
||||||
|
// otherwise it does count.
|
||||||
|
|
||||||
uniqueColValMap = make(map[string]interface{})
|
uniqueColValMap = make(map[string]interface{})
|
||||||
table := session.statement.RefTable
|
table := session.statement.RefTable
|
||||||
if len(table.Indexes) == 0 && (len(table.PrimaryKeys) == 0 || (len(table.PrimaryKeys) == 1 && table.AutoIncrement == table.PrimaryKeys[0])) {
|
// Shortcut when there are no indices and no primary keys
|
||||||
return nil, 0, fmt.Errorf("provided bean has no unique constraints")
|
if len(table.Indexes) == 0 && len(table.PrimaryKeys) == 0 {
|
||||||
|
return nil, nil, ErrNoUniqueConstraints
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the primary key
|
numberOfUniqueConstraints := 0
|
||||||
primaryColumnIncluded := false
|
|
||||||
primaryCol:
|
// Check the primary key:
|
||||||
for _, primaryKeyColumn := range table.PrimaryKeys {
|
switch len(table.PrimaryKeys) {
|
||||||
for i, column := range argColumns {
|
case 0:
|
||||||
if column == primaryKeyColumn {
|
// No primary keys - nothing to do
|
||||||
uniqueColValMap[column] = args[i]
|
case 1:
|
||||||
primaryColumnIncluded = true
|
// check if the pkColumn is included
|
||||||
continue primaryCol
|
value := session.getUniqueColumnValue(table.PrimaryKeys[0], argColumns, args)
|
||||||
}
|
if value != nil {
|
||||||
|
numberOfUniqueConstraints++
|
||||||
|
uniqueColValMap[table.PrimaryKeys[0]] = value
|
||||||
|
constraints = append(constraints, table.PrimaryKeys)
|
||||||
}
|
}
|
||||||
if primaryKeyColumn != table.AutoIncrement {
|
default:
|
||||||
primaryColumnIncluded = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if primaryColumnIncluded {
|
|
||||||
numberOfUniqueConstraints++
|
numberOfUniqueConstraints++
|
||||||
|
constraints = append(constraints, table.PrimaryKeys)
|
||||||
|
for _, column := range table.PrimaryKeys {
|
||||||
|
value := session.getUniqueColumnValue(column, argColumns, args)
|
||||||
|
if value == nil {
|
||||||
|
value = "" // default to empty
|
||||||
|
}
|
||||||
|
uniqueColValMap[column] = value
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterate across the indexes in the provided table
|
// Iterate across the indexes in the provided table
|
||||||
|
@ -279,64 +303,84 @@ primaryCol:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
numberOfUniqueConstraints++
|
numberOfUniqueConstraints++
|
||||||
|
constraints = append(constraints, index.Cols)
|
||||||
|
|
||||||
// index is a Unique constraint
|
// index is a Unique constraint
|
||||||
indexCol:
|
for _, column := range index.Cols {
|
||||||
for _, indexColumnName := range index.Cols {
|
if _, has := uniqueColValMap[column]; has {
|
||||||
if _, has := uniqueColValMap[indexColumnName]; has {
|
continue
|
||||||
// column is already included in uniqueCols so we don't need to add it again
|
|
||||||
continue indexCol
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now iterate across colNames and add to the uniqueCols
|
value := session.getUniqueColumnValue(column, argColumns, args)
|
||||||
for i, col := range argColumns {
|
if value == nil {
|
||||||
if col == indexColumnName {
|
value = "" // default to empty
|
||||||
uniqueColValMap[col] = args[i]
|
|
||||||
continue indexCol
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
uniqueColValMap[column] = value
|
||||||
indexColumn := table.GetColumn(indexColumnName)
|
|
||||||
if !indexColumn.DefaultIsEmpty {
|
|
||||||
uniqueColValMap[indexColumnName] = indexColumn.Default
|
|
||||||
}
|
|
||||||
|
|
||||||
if indexColumn.MapType == schemas.ONLYFROMDB || indexColumn.IsAutoIncrement {
|
|
||||||
continue indexCol
|
|
||||||
}
|
|
||||||
// FIXME: what do we do here?!
|
|
||||||
if session.statement.OmitColumnMap.Contain(indexColumn.Name) {
|
|
||||||
continue indexCol
|
|
||||||
}
|
|
||||||
// FIXME: what do we do here?!
|
|
||||||
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) {
|
|
||||||
continue indexCol
|
|
||||||
}
|
|
||||||
// FIXME: what do we do here?!
|
|
||||||
if session.statement.IncrColumns.IsColExist(indexColumn.Name) {
|
|
||||||
for _, exprCol := range session.statement.IncrColumns {
|
|
||||||
if exprCol.ColName == indexColumn.Name {
|
|
||||||
uniqueColValMap[indexColumnName] = exprCol.Arg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue indexCol
|
|
||||||
} else if session.statement.DecrColumns.IsColExist(indexColumn.Name) {
|
|
||||||
for _, exprCol := range session.statement.DecrColumns {
|
|
||||||
if exprCol.ColName == indexColumn.Name {
|
|
||||||
uniqueColValMap[indexColumnName] = exprCol.Arg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue indexCol
|
|
||||||
} else if session.statement.ExprColumns.IsColExist(indexColumn.Name) {
|
|
||||||
for _, exprCol := range session.statement.ExprColumns {
|
|
||||||
if exprCol.ColName == indexColumn.Name {
|
|
||||||
uniqueColValMap[indexColumnName] = exprCol.Arg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// FIXME: not sure if there's anything else we can do
|
|
||||||
return nil, 0, fmt.Errorf("provided bean does not provide a value for unique constraint %s field %s which has no default", index.Name, indexColumnName)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return uniqueColValMap, numberOfUniqueConstraints, nil
|
if doUpdate && numberOfUniqueConstraints > 1 {
|
||||||
|
return nil, nil, ErrMultipleUniqueConstraints
|
||||||
|
}
|
||||||
|
if len(constraints) == 0 {
|
||||||
|
return nil, nil, ErrNoUniqueConstraints
|
||||||
|
}
|
||||||
|
|
||||||
|
return uniqueColValMap, constraints, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) getUniqueColumnValue(indexColumnName string, argColumns []string, args []interface{}) (value interface{}) {
|
||||||
|
table := session.statement.RefTable
|
||||||
|
|
||||||
|
// Now iterate across colNames and add to the uniqueCols
|
||||||
|
for i, col := range argColumns {
|
||||||
|
if col == indexColumnName {
|
||||||
|
return args[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
indexColumn := table.GetColumn(indexColumnName)
|
||||||
|
if indexColumn.IsAutoIncrement {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !indexColumn.DefaultIsEmpty {
|
||||||
|
value = indexColumn.Default
|
||||||
|
}
|
||||||
|
|
||||||
|
if indexColumn.MapType == schemas.ONLYFROMDB {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
// FIXME: what do we do here?!
|
||||||
|
if session.statement.OmitColumnMap.Contain(indexColumn.Name) {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
// FIXME: what do we do here?!
|
||||||
|
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
// FIXME: what do we do here?!
|
||||||
|
if session.statement.IncrColumns.IsColExist(indexColumn.Name) {
|
||||||
|
for _, exprCol := range session.statement.IncrColumns {
|
||||||
|
if exprCol.ColName == indexColumn.Name {
|
||||||
|
return exprCol.Arg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
} else if session.statement.DecrColumns.IsColExist(indexColumn.Name) {
|
||||||
|
for _, exprCol := range session.statement.DecrColumns {
|
||||||
|
if exprCol.ColName == indexColumn.Name {
|
||||||
|
return exprCol.Arg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
} else if session.statement.ExprColumns.IsColExist(indexColumn.Name) {
|
||||||
|
for _, exprCol := range session.statement.ExprColumns {
|
||||||
|
if exprCol.ColName == indexColumn.Name {
|
||||||
|
return exprCol.Arg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME: not sure if there's anything else we can do
|
||||||
|
return value
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue