More cleanly handle primary keys as unique constraints

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
Andrew Thornton 2023-03-15 11:47:44 +00:00
parent 4bf706dd0c
commit 625167ded5
No known key found for this signature in database
GPG Key ID: 3CDE74631F13A748
3 changed files with 159 additions and 341 deletions

View File

@ -59,16 +59,16 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
} }
func (statement *Statement) includeAutoIncrement(colNames []string) bool { func (statement *Statement) includeAutoIncrement(colNames []string) bool {
includesAutoIncrement := len(statement.RefTable.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) needToIncludeAutoIncrement := len(statement.RefTable.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
if includesAutoIncrement { if needToIncludeAutoIncrement {
for _, col := range colNames { for _, col := range colNames {
if strings.EqualFold(col, statement.RefTable.AutoIncrement) { if strings.EqualFold(col, statement.RefTable.AutoIncrement) {
includesAutoIncrement = false needToIncludeAutoIncrement = false
break break
} }
} }
} }
return includesAutoIncrement return needToIncludeAutoIncrement
} }
func (statement *Statement) genInsertValues(buf *builder.BytesWriter, colNames []string, args []interface{}) error { func (statement *Statement) genInsertValues(buf *builder.BytesWriter, colNames []string, args []interface{}) error {

View File

@ -12,11 +12,11 @@ import (
) )
// GenUpsertSQL generates upsert beans SQL // GenUpsertSQL generates upsert beans SQL
func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) { func (statement *Statement) GenUpsertSQL(doUpdate bool, addOuput bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}, uniqueConstraints [][]string) (string, []interface{}, error) {
if statement.dialect.URI().DBType == schemas.MSSQL || if statement.dialect.URI().DBType == schemas.MSSQL ||
statement.dialect.URI().DBType == schemas.DAMENG || statement.dialect.URI().DBType == schemas.DAMENG ||
statement.dialect.URI().DBType == schemas.ORACLE { statement.dialect.URI().DBType == schemas.ORACLE {
return statement.genMergeSQL(doUpdate, columns, args, uniqueColValMap) return statement.genMergeSQL(doUpdate, addOuput, columns, args, uniqueColValMap, uniqueConstraints)
} }
var ( var (
@ -70,35 +70,14 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [
} }
case schemas.POSTGRES: case schemas.POSTGRES:
if doUpdate { if doUpdate {
primaryColumnIncluded := false // In doUpdate we know that uniqueConstraints has to be length 1
for _, primaryKeyColumn := range table.PrimaryKeys { write(" ON CONFLICT (", quote(uniqueConstraints[0][0]))
if _, has := uniqueColValMap[primaryKeyColumn]; !has { for _, uniqueColumn := range uniqueConstraints[0][1:] {
continue write(", ", uniqueColumn)
}
primaryColumnIncluded = true
} }
if primaryColumnIncluded { write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
write(" ON CONFLICT (", quote(table.PrimaryKeys[0])) for _, column := range updateColumns[1:] {
for _, col := range table.PrimaryKeys[1:] { write(", ", column, " = excluded.", column)
write(", ", quote(col))
}
write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
for _, column := range updateColumns[1:] {
write(", ", column, " = excluded.", column)
}
}
for _, index := range table.Indexes {
if index.Type != schemas.UniqueType {
continue
}
write(" ON CONFLICT (", quote(index.Cols[0]))
for _, col := range index.Cols[1:] {
write(", ", quote(col))
}
write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
for _, column := range updateColumns[1:] {
write(", ", column, " = excluded.", column)
}
} }
} else { } else {
write(" ON CONFLICT DO NOTHING") write(" ON CONFLICT DO NOTHING")
@ -131,7 +110,7 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil
} }
func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) { func (statement *Statement) genMergeSQL(doUpdate bool, addOutput bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}, uniqueConstraints [][]string) (string, []interface{}, error) {
var ( var (
buf = builder.NewWriter() buf = builder.NewWriter()
table = statement.RefTable table = statement.RefTable
@ -151,18 +130,16 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
} }
write(" AS target USING (SELECT ") write(" AS target USING (SELECT ")
uniqueCols := make([]string, 0, len(uniqueColValMap)) uniqueColumnsCount := 0
for colName := range uniqueColValMap { for uniqueColumn, uniqueValue := range uniqueColValMap {
uniqueCols = append(uniqueCols, colName) if uniqueColumnsCount > 0 {
}
for i, colName := range uniqueCols {
if err := statement.WriteArg(buf, uniqueColValMap[colName]); err != nil {
return "", nil, err
}
write(" AS ", quote(colName))
if i < len(uniqueCols)-1 {
write(", ") write(", ")
} }
if err := statement.WriteArg(buf, uniqueValue); err != nil {
return "", nil, err
}
write(" AS ", quote(uniqueColumn))
uniqueColumnsCount++
} }
var updateColumns []string var updateColumns []string
@ -181,37 +158,13 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
} }
write(") AS src ON (") write(") AS src ON (")
for i, uniqueColumns := range uniqueConstraints {
countUniques := 0 if i > 0 { // if !doUpdate there may be more than one uniqueConstraint
primaryColumnIncluded := false
for _, primaryKeyColumn := range table.PrimaryKeys {
if _, has := uniqueColValMap[primaryKeyColumn]; !has {
continue
}
if !primaryColumnIncluded {
write("(")
} else {
write(" AND ")
}
write("src.", quote(primaryKeyColumn), " = target.", quote(primaryKeyColumn))
primaryColumnIncluded = true
}
if primaryColumnIncluded {
write(")")
countUniques++
}
for _, index := range table.Indexes {
if index.Type != schemas.UniqueType {
continue
}
if countUniques > 0 {
write(" OR ") write(" OR ")
} }
countUniques++ write("(src.", quote(uniqueColumns[0]), " = target.", quote(uniqueColumns[0]))
write("(") for _, uniqueColumn := range uniqueColumns[1:] {
write("src.", quote(index.Cols[0]), " = target.", quote(index.Cols[0])) write(" AND src.", quote(uniqueColumn), " = target.", quote(uniqueColumn))
for _, col := range index.Cols[1:] {
write(" AND src.", quote(col), " = target.", quote(col))
} }
write(")") write(")")
} }
@ -228,10 +181,10 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
write(" WHEN NOT MATCHED THEN INSERT ") write(" WHEN NOT MATCHED THEN INSERT ")
includeAutoIncrement := statement.includeAutoIncrement(columns) includeAutoIncrement := statement.includeAutoIncrement(columns)
if len(columns) == 0 && statement.dialect.URI().DBType == schemas.MSSQL { if len(columns) == 0 && statement.dialect.URI().DBType == schemas.MSSQL {
write(" DEFAULT VALUES ") write("DEFAULT VALUES ")
} else { } else {
// We have some values - Write the column names we need to insert: // We have some values - Write the column names we need to insert:
write(" (") write("(")
if includeAutoIncrement { if includeAutoIncrement {
columns = append(columns, table.AutoIncrement) columns = append(columns, table.AutoIncrement)
} }
@ -246,191 +199,12 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
} }
} }
if err := statement.writeInsertOutput(buf.Builder, table); err != nil { if addOutput {
return "", nil, err if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
return "", nil, err
}
} }
write(";") write(";")
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil
} }
// GenUpsertMapSQL generates insert map SQL
func (statement *Statement) GenUpsertMapSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
if statement.dialect.URI().DBType == schemas.MSSQL ||
statement.dialect.URI().DBType == schemas.DAMENG ||
statement.dialect.URI().DBType == schemas.ORACLE {
return statement.genMergeMapSQL(doUpdate, columns, args, uniqueColValMap)
}
var (
buf = builder.NewWriter()
exprs = statement.ExprColumns
table = statement.RefTable
tableName = statement.TableName()
)
quote := statement.dialect.Quoter().Quote
write := func(args ...string) {
for _, arg := range args {
_, _ = buf.WriteString(arg)
}
}
var updateColumns []string
if doUpdate {
updateColumns = make([]string, 0, len(columns))
for _, column := range append(columns, exprs.ColNames()...) {
if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has {
continue
}
updateColumns = append(updateColumns, quote(column))
}
doUpdate = doUpdate && (len(updateColumns) > 0)
}
if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate {
write("INSERT IGNORE INTO ", quote(tableName), " (")
} else {
write("INSERT INTO ", quote(tableName), " (")
}
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil {
return "", nil, err
}
write(")")
if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil {
return "", nil, err
}
switch statement.dialect.URI().DBType {
case schemas.SQLITE, schemas.POSTGRES:
write(" ON CONFLICT DO ")
if doUpdate {
write("UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
for _, column := range updateColumns[1:] {
write(", ", column, " = excluded.", column)
}
} else {
write("NOTHING")
}
case schemas.MYSQL:
if doUpdate {
// FIXME: mysql >= 8.0.19 should use table alias
write(" ON DUPLICATE KEY ")
write("UPDATE ", updateColumns[0], " = VALUES(", updateColumns[0], ")")
for _, column := range updateColumns[1:] {
write(", ", column, " = VALUES(", column, ")")
}
if len(table.AutoIncrement) > 0 {
write(", ", quote(table.AutoIncrement), " = LAST_INSERT_ID(", quote(table.AutoIncrement), ")")
}
}
default:
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
}
if len(table.AutoIncrement) > 0 &&
(statement.dialect.URI().DBType == schemas.POSTGRES ||
statement.dialect.URI().DBType == schemas.SQLITE) {
write(" RETURNING ")
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
return "", nil, err
}
}
return buf.String(), buf.Args(), nil
}
func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) {
var (
buf = builder.NewWriter()
table = statement.RefTable
exprs = statement.ExprColumns
tableName = statement.TableName()
)
quote := statement.dialect.Quoter().Quote
write := func(args ...string) {
for _, arg := range args {
_, _ = buf.WriteString(arg)
}
}
write("MERGE INTO ", quote(tableName))
if statement.dialect.URI().DBType == schemas.MSSQL {
write(" WITH (HOLDLOCK)")
}
write(" AS target USING (SELECT ")
uniqueCols := make([]string, 0, len(uniqueColValMap))
for colName := range uniqueColValMap {
uniqueCols = append(uniqueCols, colName)
}
for i, colName := range uniqueCols {
if err := statement.WriteArg(buf, uniqueColValMap[colName]); err != nil {
return "", nil, err
}
write(" AS ", quote(colName))
if i < len(uniqueCols)-1 {
write(", ")
}
}
var updateColumns []string
var updateArgs []interface{}
if doUpdate {
updateColumns = make([]string, 0, len(columns))
for _, expr := range exprs {
if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(expr.ColName)]; has {
continue
}
updateColumns = append(updateColumns, quote(expr.ColName))
updateArgs = append(updateArgs, expr.Arg)
}
for i, column := range columns {
if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has {
continue
}
updateColumns = append(updateColumns, quote(column))
updateArgs = append(updateArgs, args[i])
}
doUpdate = doUpdate && (len(updateColumns) > 0)
}
write(") AS src ON (")
countUniques := 0
for _, index := range table.Indexes {
if index.Type != schemas.UniqueType {
continue
}
if countUniques > 0 {
write(" OR ")
}
countUniques++
write("(")
write("src.", quote(index.Cols[0]), " = target.", quote(index.Cols[0]))
for _, col := range index.Cols[1:] {
write(" AND src.", quote(col), " = target.", quote(col))
}
write(")")
}
write(")")
if doUpdate {
write(" WHEN MATCHED THEN UPDATE SET ")
write("target.", quote(updateColumns[0]), " = ?")
buf.Append(updateArgs[0])
for i, col := range updateColumns[1:] {
write(", target.", quote(col), " = ?")
buf.Append(updateArgs[i+1])
}
}
write(" WHEN NOT MATCHED THEN INSERT ")
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil {
return "", nil, err
}
write(")")
if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil {
return "", nil, err
}
write(";")
return buf.String(), buf.Args(), nil
}

View File

@ -117,15 +117,12 @@ func (session *Session) upsertMap(doUpdate bool, columns []string, args []interf
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
uniqueColValMap, numberOfUniqueConstraints, err := session.getUniqueColumns(columns, args) uniqueColValMap, uniqueConstraints, err := session.getUniqueColumns(doUpdate, columns, args)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if doUpdate && numberOfUniqueConstraints > 1 {
return 0, fmt.Errorf("cannot upsert if there is more than one unique constraint")
}
sql, args, err := session.statement.GenUpsertMapSQL(doUpdate, columns, args, uniqueColValMap) sql, args, err := session.statement.GenUpsertSQL(doUpdate, false, columns, args, uniqueColValMap, uniqueConstraints)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -172,15 +169,12 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er
return 0, err return 0, err
} }
uniqueColValMap, numberOfUniqueConstraints, err := session.getUniqueColumns(colNames, args) uniqueColValMap, uniqueConstraints, err := session.getUniqueColumns(doUpdate, colNames, args)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if doUpdate && numberOfUniqueConstraints > 1 {
return 0, fmt.Errorf("cannot upsert if there is more than one unique constraint")
}
sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, colNames, args, uniqueColValMap) sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, true, colNames, args, uniqueColValMap, uniqueConstraints)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -247,30 +241,60 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er
return n, err return n, err
} }
func (session *Session) getUniqueColumns(argColumns []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) { var (
ErrNoUniqueConstraints = fmt.Errorf("provided bean has no unique constraints")
ErrMultipleUniqueConstraints = fmt.Errorf("cannot upsert if there is more than one unique constraint tested")
)
func (session *Session) getUniqueColumns(doUpdate bool, argColumns []string, args []interface{}) (uniqueColValMap map[string]interface{}, constraints [][]string, err error) {
// We need to collect the constraints that are being "tested" by argColumns as compared to the table
//
// There are two cases:
//
// 1. Insert on conflict do nothing
// 2. Upsert
//
// If we are an "Insert on conflict do nothing" then more than one "constraint" can be tested.
// If we are an "Upsert" only one "constraint" can be tested.
//
// In Xorm the only constraints we know of are "Unique Indices" and "Primary Keys".
//
// For unique indices - every column in the unique index is being tested.
//
// If the primary key is a single column and it is autoincrement then an empty PK column is not testing an unique constraint
// otherwise it does count.
uniqueColValMap = make(map[string]interface{}) uniqueColValMap = make(map[string]interface{})
table := session.statement.RefTable table := session.statement.RefTable
if len(table.Indexes) == 0 && (len(table.PrimaryKeys) == 0 || (len(table.PrimaryKeys) == 1 && table.AutoIncrement == table.PrimaryKeys[0])) { // Shortcut when there are no indices and no primary keys
return nil, 0, fmt.Errorf("provided bean has no unique constraints") if len(table.Indexes) == 0 && len(table.PrimaryKeys) == 0 {
return nil, nil, ErrNoUniqueConstraints
} }
// Check the primary key numberOfUniqueConstraints := 0
primaryColumnIncluded := false
primaryCol: // Check the primary key:
for _, primaryKeyColumn := range table.PrimaryKeys { switch len(table.PrimaryKeys) {
for i, column := range argColumns { case 0:
if column == primaryKeyColumn { // No primary keys - nothing to do
uniqueColValMap[column] = args[i] case 1:
primaryColumnIncluded = true // check if the pkColumn is included
continue primaryCol value := session.getUniqueColumnValue(table.PrimaryKeys[0], argColumns, args)
} if value != nil {
numberOfUniqueConstraints++
uniqueColValMap[table.PrimaryKeys[0]] = value
constraints = append(constraints, table.PrimaryKeys)
} }
if primaryKeyColumn != table.AutoIncrement { default:
primaryColumnIncluded = true
}
}
if primaryColumnIncluded {
numberOfUniqueConstraints++ numberOfUniqueConstraints++
constraints = append(constraints, table.PrimaryKeys)
for _, column := range table.PrimaryKeys {
value := session.getUniqueColumnValue(column, argColumns, args)
if value == nil {
value = "" // default to empty
}
uniqueColValMap[column] = value
}
} }
// Iterate across the indexes in the provided table // Iterate across the indexes in the provided table
@ -279,64 +303,84 @@ primaryCol:
continue continue
} }
numberOfUniqueConstraints++ numberOfUniqueConstraints++
constraints = append(constraints, index.Cols)
// index is a Unique constraint // index is a Unique constraint
indexCol: for _, column := range index.Cols {
for _, indexColumnName := range index.Cols { if _, has := uniqueColValMap[column]; has {
if _, has := uniqueColValMap[indexColumnName]; has { continue
// column is already included in uniqueCols so we don't need to add it again
continue indexCol
} }
// Now iterate across colNames and add to the uniqueCols value := session.getUniqueColumnValue(column, argColumns, args)
for i, col := range argColumns { if value == nil {
if col == indexColumnName { value = "" // default to empty
uniqueColValMap[col] = args[i]
continue indexCol
}
} }
uniqueColValMap[column] = value
indexColumn := table.GetColumn(indexColumnName)
if !indexColumn.DefaultIsEmpty {
uniqueColValMap[indexColumnName] = indexColumn.Default
}
if indexColumn.MapType == schemas.ONLYFROMDB || indexColumn.IsAutoIncrement {
continue indexCol
}
// FIXME: what do we do here?!
if session.statement.OmitColumnMap.Contain(indexColumn.Name) {
continue indexCol
}
// FIXME: what do we do here?!
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) {
continue indexCol
}
// FIXME: what do we do here?!
if session.statement.IncrColumns.IsColExist(indexColumn.Name) {
for _, exprCol := range session.statement.IncrColumns {
if exprCol.ColName == indexColumn.Name {
uniqueColValMap[indexColumnName] = exprCol.Arg
}
}
continue indexCol
} else if session.statement.DecrColumns.IsColExist(indexColumn.Name) {
for _, exprCol := range session.statement.DecrColumns {
if exprCol.ColName == indexColumn.Name {
uniqueColValMap[indexColumnName] = exprCol.Arg
}
}
continue indexCol
} else if session.statement.ExprColumns.IsColExist(indexColumn.Name) {
for _, exprCol := range session.statement.ExprColumns {
if exprCol.ColName == indexColumn.Name {
uniqueColValMap[indexColumnName] = exprCol.Arg
}
}
}
// FIXME: not sure if there's anything else we can do
return nil, 0, fmt.Errorf("provided bean does not provide a value for unique constraint %s field %s which has no default", index.Name, indexColumnName)
} }
} }
return uniqueColValMap, numberOfUniqueConstraints, nil if doUpdate && numberOfUniqueConstraints > 1 {
return nil, nil, ErrMultipleUniqueConstraints
}
if len(constraints) == 0 {
return nil, nil, ErrNoUniqueConstraints
}
return uniqueColValMap, constraints, nil
}
func (session *Session) getUniqueColumnValue(indexColumnName string, argColumns []string, args []interface{}) (value interface{}) {
table := session.statement.RefTable
// Now iterate across colNames and add to the uniqueCols
for i, col := range argColumns {
if col == indexColumnName {
return args[i]
}
}
indexColumn := table.GetColumn(indexColumnName)
if indexColumn.IsAutoIncrement {
return nil
}
if !indexColumn.DefaultIsEmpty {
value = indexColumn.Default
}
if indexColumn.MapType == schemas.ONLYFROMDB {
return value
}
// FIXME: what do we do here?!
if session.statement.OmitColumnMap.Contain(indexColumn.Name) {
return value
}
// FIXME: what do we do here?!
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) {
return value
}
// FIXME: what do we do here?!
if session.statement.IncrColumns.IsColExist(indexColumn.Name) {
for _, exprCol := range session.statement.IncrColumns {
if exprCol.ColName == indexColumn.Name {
return exprCol.Arg
}
}
return value
} else if session.statement.DecrColumns.IsColExist(indexColumn.Name) {
for _, exprCol := range session.statement.DecrColumns {
if exprCol.ColName == indexColumn.Name {
return exprCol.Arg
}
}
return value
} else if session.statement.ExprColumns.IsColExist(indexColumn.Name) {
for _, exprCol := range session.statement.ExprColumns {
if exprCol.ColName == indexColumn.Name {
return exprCol.Arg
}
}
}
// FIXME: not sure if there's anything else we can do
return value
} }