diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 8475fc57..321d8395 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -59,16 +59,16 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) } func (statement *Statement) includeAutoIncrement(colNames []string) bool { - includesAutoIncrement := len(statement.RefTable.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) - if includesAutoIncrement { + needToIncludeAutoIncrement := len(statement.RefTable.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) + if needToIncludeAutoIncrement { for _, col := range colNames { if strings.EqualFold(col, statement.RefTable.AutoIncrement) { - includesAutoIncrement = false + needToIncludeAutoIncrement = false break } } } - return includesAutoIncrement + return needToIncludeAutoIncrement } func (statement *Statement) genInsertValues(buf *builder.BytesWriter, colNames []string, args []interface{}) error { diff --git a/internal/statements/upsert.go b/internal/statements/upsert.go index 0c75eb1b..098a11fb 100644 --- a/internal/statements/upsert.go +++ b/internal/statements/upsert.go @@ -12,11 +12,11 @@ import ( ) // GenUpsertSQL generates upsert beans SQL -func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) { +func (statement *Statement) GenUpsertSQL(doUpdate bool, addOuput bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}, uniqueConstraints [][]string) (string, []interface{}, error) { if statement.dialect.URI().DBType == schemas.MSSQL || statement.dialect.URI().DBType == schemas.DAMENG || statement.dialect.URI().DBType == schemas.ORACLE { - return statement.genMergeSQL(doUpdate, columns, args, uniqueColValMap) + return statement.genMergeSQL(doUpdate, addOuput, columns, args, uniqueColValMap, uniqueConstraints) } var ( @@ -70,35 +70,14 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [ } case schemas.POSTGRES: if doUpdate { - primaryColumnIncluded := false - for _, primaryKeyColumn := range table.PrimaryKeys { - if _, has := uniqueColValMap[primaryKeyColumn]; !has { - continue - } - primaryColumnIncluded = true + // In doUpdate we know that uniqueConstraints has to be length 1 + write(" ON CONFLICT (", quote(uniqueConstraints[0][0])) + for _, uniqueColumn := range uniqueConstraints[0][1:] { + write(", ", uniqueColumn) } - if primaryColumnIncluded { - write(" ON CONFLICT (", quote(table.PrimaryKeys[0])) - for _, col := range table.PrimaryKeys[1:] { - write(", ", quote(col)) - } - write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0]) - for _, column := range updateColumns[1:] { - write(", ", column, " = excluded.", column) - } - } - for _, index := range table.Indexes { - if index.Type != schemas.UniqueType { - continue - } - write(" ON CONFLICT (", quote(index.Cols[0])) - for _, col := range index.Cols[1:] { - write(", ", quote(col)) - } - write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0]) - for _, column := range updateColumns[1:] { - write(", ", column, " = excluded.", column) - } + write(") DO UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0]) + for _, column := range updateColumns[1:] { + write(", ", column, " = excluded.", column) } } else { write(" ON CONFLICT DO NOTHING") @@ -131,7 +110,7 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [ return buf.String(), buf.Args(), nil } -func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) { +func (statement *Statement) genMergeSQL(doUpdate bool, addOutput bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}, uniqueConstraints [][]string) (string, []interface{}, error) { var ( buf = builder.NewWriter() table = statement.RefTable @@ -151,18 +130,16 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args [] } write(" AS target USING (SELECT ") - uniqueCols := make([]string, 0, len(uniqueColValMap)) - for colName := range uniqueColValMap { - uniqueCols = append(uniqueCols, colName) - } - for i, colName := range uniqueCols { - if err := statement.WriteArg(buf, uniqueColValMap[colName]); err != nil { - return "", nil, err - } - write(" AS ", quote(colName)) - if i < len(uniqueCols)-1 { + uniqueColumnsCount := 0 + for uniqueColumn, uniqueValue := range uniqueColValMap { + if uniqueColumnsCount > 0 { write(", ") } + if err := statement.WriteArg(buf, uniqueValue); err != nil { + return "", nil, err + } + write(" AS ", quote(uniqueColumn)) + uniqueColumnsCount++ } var updateColumns []string @@ -181,37 +158,13 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args [] } write(") AS src ON (") - - countUniques := 0 - primaryColumnIncluded := false - for _, primaryKeyColumn := range table.PrimaryKeys { - if _, has := uniqueColValMap[primaryKeyColumn]; !has { - continue - } - if !primaryColumnIncluded { - write("(") - } else { - write(" AND ") - } - write("src.", quote(primaryKeyColumn), " = target.", quote(primaryKeyColumn)) - primaryColumnIncluded = true - } - if primaryColumnIncluded { - write(")") - countUniques++ - } - for _, index := range table.Indexes { - if index.Type != schemas.UniqueType { - continue - } - if countUniques > 0 { + for i, uniqueColumns := range uniqueConstraints { + if i > 0 { // if !doUpdate there may be more than one uniqueConstraint write(" OR ") } - countUniques++ - write("(") - write("src.", quote(index.Cols[0]), " = target.", quote(index.Cols[0])) - for _, col := range index.Cols[1:] { - write(" AND src.", quote(col), " = target.", quote(col)) + write("(src.", quote(uniqueColumns[0]), " = target.", quote(uniqueColumns[0])) + for _, uniqueColumn := range uniqueColumns[1:] { + write(" AND src.", quote(uniqueColumn), " = target.", quote(uniqueColumn)) } write(")") } @@ -228,10 +181,10 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args [] write(" WHEN NOT MATCHED THEN INSERT ") includeAutoIncrement := statement.includeAutoIncrement(columns) if len(columns) == 0 && statement.dialect.URI().DBType == schemas.MSSQL { - write(" DEFAULT VALUES ") + write("DEFAULT VALUES ") } else { // We have some values - Write the column names we need to insert: - write(" (") + write("(") if includeAutoIncrement { columns = append(columns, table.AutoIncrement) } @@ -246,191 +199,12 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args [] } } - if err := statement.writeInsertOutput(buf.Builder, table); err != nil { - return "", nil, err + if addOutput { + if err := statement.writeInsertOutput(buf.Builder, table); err != nil { + return "", nil, err + } } write(";") return buf.String(), buf.Args(), nil } - -// GenUpsertMapSQL generates insert map SQL -func (statement *Statement) GenUpsertMapSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) { - if statement.dialect.URI().DBType == schemas.MSSQL || - statement.dialect.URI().DBType == schemas.DAMENG || - statement.dialect.URI().DBType == schemas.ORACLE { - return statement.genMergeMapSQL(doUpdate, columns, args, uniqueColValMap) - } - var ( - buf = builder.NewWriter() - exprs = statement.ExprColumns - table = statement.RefTable - tableName = statement.TableName() - ) - quote := statement.dialect.Quoter().Quote - write := func(args ...string) { - for _, arg := range args { - _, _ = buf.WriteString(arg) - } - } - var updateColumns []string - if doUpdate { - updateColumns = make([]string, 0, len(columns)) - for _, column := range append(columns, exprs.ColNames()...) { - if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has { - continue - } - updateColumns = append(updateColumns, quote(column)) - } - doUpdate = doUpdate && (len(updateColumns) > 0) - } - - if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate { - write("INSERT IGNORE INTO ", quote(tableName), " (") - } else { - write("INSERT INTO ", quote(tableName), " (") - } - if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil { - return "", nil, err - } - write(")") - - if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil { - return "", nil, err - } - - switch statement.dialect.URI().DBType { - case schemas.SQLITE, schemas.POSTGRES: - write(" ON CONFLICT DO ") - if doUpdate { - write("UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0]) - for _, column := range updateColumns[1:] { - write(", ", column, " = excluded.", column) - } - } else { - write("NOTHING") - } - case schemas.MYSQL: - if doUpdate { - // FIXME: mysql >= 8.0.19 should use table alias - write(" ON DUPLICATE KEY ") - write("UPDATE ", updateColumns[0], " = VALUES(", updateColumns[0], ")") - for _, column := range updateColumns[1:] { - write(", ", column, " = VALUES(", column, ")") - } - if len(table.AutoIncrement) > 0 { - write(", ", quote(table.AutoIncrement), " = LAST_INSERT_ID(", quote(table.AutoIncrement), ")") - } - } - default: - return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT - } - - if len(table.AutoIncrement) > 0 && - (statement.dialect.URI().DBType == schemas.POSTGRES || - statement.dialect.URI().DBType == schemas.SQLITE) { - write(" RETURNING ") - if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { - return "", nil, err - } - } - - return buf.String(), buf.Args(), nil -} - -func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args []interface{}, uniqueColValMap map[string]interface{}) (string, []interface{}, error) { - var ( - buf = builder.NewWriter() - table = statement.RefTable - exprs = statement.ExprColumns - tableName = statement.TableName() - ) - - quote := statement.dialect.Quoter().Quote - write := func(args ...string) { - for _, arg := range args { - _, _ = buf.WriteString(arg) - } - } - - write("MERGE INTO ", quote(tableName)) - if statement.dialect.URI().DBType == schemas.MSSQL { - write(" WITH (HOLDLOCK)") - } - write(" AS target USING (SELECT ") - - uniqueCols := make([]string, 0, len(uniqueColValMap)) - for colName := range uniqueColValMap { - uniqueCols = append(uniqueCols, colName) - } - for i, colName := range uniqueCols { - if err := statement.WriteArg(buf, uniqueColValMap[colName]); err != nil { - return "", nil, err - } - write(" AS ", quote(colName)) - if i < len(uniqueCols)-1 { - write(", ") - } - } - var updateColumns []string - var updateArgs []interface{} - if doUpdate { - updateColumns = make([]string, 0, len(columns)) - for _, expr := range exprs { - if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(expr.ColName)]; has { - continue - } - updateColumns = append(updateColumns, quote(expr.ColName)) - updateArgs = append(updateArgs, expr.Arg) - } - for i, column := range columns { - if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has { - continue - } - updateColumns = append(updateColumns, quote(column)) - updateArgs = append(updateArgs, args[i]) - } - doUpdate = doUpdate && (len(updateColumns) > 0) - } - - write(") AS src ON (") - - countUniques := 0 - for _, index := range table.Indexes { - if index.Type != schemas.UniqueType { - continue - } - if countUniques > 0 { - write(" OR ") - } - countUniques++ - write("(") - write("src.", quote(index.Cols[0]), " = target.", quote(index.Cols[0])) - for _, col := range index.Cols[1:] { - write(" AND src.", quote(col), " = target.", quote(col)) - } - write(")") - } - write(")") - if doUpdate { - write(" WHEN MATCHED THEN UPDATE SET ") - write("target.", quote(updateColumns[0]), " = ?") - buf.Append(updateArgs[0]) - for i, col := range updateColumns[1:] { - write(", target.", quote(col), " = ?") - buf.Append(updateArgs[i+1]) - } - } - write(" WHEN NOT MATCHED THEN INSERT ") - if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil { - return "", nil, err - } - write(")") - - if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil { - return "", nil, err - } - write(";") - - return buf.String(), buf.Args(), nil -} diff --git a/session_upsert.go b/session_upsert.go index c6ccba36..f4b896dc 100644 --- a/session_upsert.go +++ b/session_upsert.go @@ -117,15 +117,12 @@ func (session *Session) upsertMap(doUpdate bool, columns []string, args []interf return 0, ErrTableNotFound } - uniqueColValMap, numberOfUniqueConstraints, err := session.getUniqueColumns(columns, args) + uniqueColValMap, uniqueConstraints, err := session.getUniqueColumns(doUpdate, columns, args) if err != nil { return 0, err } - if doUpdate && numberOfUniqueConstraints > 1 { - return 0, fmt.Errorf("cannot upsert if there is more than one unique constraint") - } - sql, args, err := session.statement.GenUpsertMapSQL(doUpdate, columns, args, uniqueColValMap) + sql, args, err := session.statement.GenUpsertSQL(doUpdate, false, columns, args, uniqueColValMap, uniqueConstraints) if err != nil { return 0, err } @@ -172,15 +169,12 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er return 0, err } - uniqueColValMap, numberOfUniqueConstraints, err := session.getUniqueColumns(colNames, args) + uniqueColValMap, uniqueConstraints, err := session.getUniqueColumns(doUpdate, colNames, args) if err != nil { return 0, err } - if doUpdate && numberOfUniqueConstraints > 1 { - return 0, fmt.Errorf("cannot upsert if there is more than one unique constraint") - } - sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, colNames, args, uniqueColValMap) + sqlStr, args, err := session.statement.GenUpsertSQL(doUpdate, true, colNames, args, uniqueColValMap, uniqueConstraints) if err != nil { return 0, err } @@ -247,30 +241,60 @@ func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, er return n, err } -func (session *Session) getUniqueColumns(argColumns []string, args []interface{}) (uniqueColValMap map[string]interface{}, numberOfUniqueConstraints int, err error) { +var ( + ErrNoUniqueConstraints = fmt.Errorf("provided bean has no unique constraints") + ErrMultipleUniqueConstraints = fmt.Errorf("cannot upsert if there is more than one unique constraint tested") +) + +func (session *Session) getUniqueColumns(doUpdate bool, argColumns []string, args []interface{}) (uniqueColValMap map[string]interface{}, constraints [][]string, err error) { + // We need to collect the constraints that are being "tested" by argColumns as compared to the table + // + // There are two cases: + // + // 1. Insert on conflict do nothing + // 2. Upsert + // + // If we are an "Insert on conflict do nothing" then more than one "constraint" can be tested. + // If we are an "Upsert" only one "constraint" can be tested. + // + // In Xorm the only constraints we know of are "Unique Indices" and "Primary Keys". + // + // For unique indices - every column in the unique index is being tested. + // + // If the primary key is a single column and it is autoincrement then an empty PK column is not testing an unique constraint + // otherwise it does count. + uniqueColValMap = make(map[string]interface{}) table := session.statement.RefTable - if len(table.Indexes) == 0 && (len(table.PrimaryKeys) == 0 || (len(table.PrimaryKeys) == 1 && table.AutoIncrement == table.PrimaryKeys[0])) { - return nil, 0, fmt.Errorf("provided bean has no unique constraints") + // Shortcut when there are no indices and no primary keys + if len(table.Indexes) == 0 && len(table.PrimaryKeys) == 0 { + return nil, nil, ErrNoUniqueConstraints } - // Check the primary key - primaryColumnIncluded := false -primaryCol: - for _, primaryKeyColumn := range table.PrimaryKeys { - for i, column := range argColumns { - if column == primaryKeyColumn { - uniqueColValMap[column] = args[i] - primaryColumnIncluded = true - continue primaryCol - } + numberOfUniqueConstraints := 0 + + // Check the primary key: + switch len(table.PrimaryKeys) { + case 0: + // No primary keys - nothing to do + case 1: + // check if the pkColumn is included + value := session.getUniqueColumnValue(table.PrimaryKeys[0], argColumns, args) + if value != nil { + numberOfUniqueConstraints++ + uniqueColValMap[table.PrimaryKeys[0]] = value + constraints = append(constraints, table.PrimaryKeys) } - if primaryKeyColumn != table.AutoIncrement { - primaryColumnIncluded = true - } - } - if primaryColumnIncluded { + default: numberOfUniqueConstraints++ + constraints = append(constraints, table.PrimaryKeys) + for _, column := range table.PrimaryKeys { + value := session.getUniqueColumnValue(column, argColumns, args) + if value == nil { + value = "" // default to empty + } + uniqueColValMap[column] = value + } } // Iterate across the indexes in the provided table @@ -279,64 +303,84 @@ primaryCol: continue } numberOfUniqueConstraints++ + constraints = append(constraints, index.Cols) + // index is a Unique constraint - indexCol: - for _, indexColumnName := range index.Cols { - if _, has := uniqueColValMap[indexColumnName]; has { - // column is already included in uniqueCols so we don't need to add it again - continue indexCol + for _, column := range index.Cols { + if _, has := uniqueColValMap[column]; has { + continue } - // Now iterate across colNames and add to the uniqueCols - for i, col := range argColumns { - if col == indexColumnName { - uniqueColValMap[col] = args[i] - continue indexCol - } + value := session.getUniqueColumnValue(column, argColumns, args) + if value == nil { + value = "" // default to empty } - - indexColumn := table.GetColumn(indexColumnName) - if !indexColumn.DefaultIsEmpty { - uniqueColValMap[indexColumnName] = indexColumn.Default - } - - if indexColumn.MapType == schemas.ONLYFROMDB || indexColumn.IsAutoIncrement { - continue indexCol - } - // FIXME: what do we do here?! - if session.statement.OmitColumnMap.Contain(indexColumn.Name) { - continue indexCol - } - // FIXME: what do we do here?! - if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) { - continue indexCol - } - // FIXME: what do we do here?! - if session.statement.IncrColumns.IsColExist(indexColumn.Name) { - for _, exprCol := range session.statement.IncrColumns { - if exprCol.ColName == indexColumn.Name { - uniqueColValMap[indexColumnName] = exprCol.Arg - } - } - continue indexCol - } else if session.statement.DecrColumns.IsColExist(indexColumn.Name) { - for _, exprCol := range session.statement.DecrColumns { - if exprCol.ColName == indexColumn.Name { - uniqueColValMap[indexColumnName] = exprCol.Arg - } - } - continue indexCol - } else if session.statement.ExprColumns.IsColExist(indexColumn.Name) { - for _, exprCol := range session.statement.ExprColumns { - if exprCol.ColName == indexColumn.Name { - uniqueColValMap[indexColumnName] = exprCol.Arg - } - } - } - - // FIXME: not sure if there's anything else we can do - return nil, 0, fmt.Errorf("provided bean does not provide a value for unique constraint %s field %s which has no default", index.Name, indexColumnName) + uniqueColValMap[column] = value } } - return uniqueColValMap, numberOfUniqueConstraints, nil + if doUpdate && numberOfUniqueConstraints > 1 { + return nil, nil, ErrMultipleUniqueConstraints + } + if len(constraints) == 0 { + return nil, nil, ErrNoUniqueConstraints + } + + return uniqueColValMap, constraints, nil +} + +func (session *Session) getUniqueColumnValue(indexColumnName string, argColumns []string, args []interface{}) (value interface{}) { + table := session.statement.RefTable + + // Now iterate across colNames and add to the uniqueCols + for i, col := range argColumns { + if col == indexColumnName { + return args[i] + } + } + + indexColumn := table.GetColumn(indexColumnName) + if indexColumn.IsAutoIncrement { + return nil + } + + if !indexColumn.DefaultIsEmpty { + value = indexColumn.Default + } + + if indexColumn.MapType == schemas.ONLYFROMDB { + return value + } + // FIXME: what do we do here?! + if session.statement.OmitColumnMap.Contain(indexColumn.Name) { + return value + } + // FIXME: what do we do here?! + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) { + return value + } + // FIXME: what do we do here?! + if session.statement.IncrColumns.IsColExist(indexColumn.Name) { + for _, exprCol := range session.statement.IncrColumns { + if exprCol.ColName == indexColumn.Name { + return exprCol.Arg + } + } + return value + } else if session.statement.DecrColumns.IsColExist(indexColumn.Name) { + for _, exprCol := range session.statement.DecrColumns { + if exprCol.ColName == indexColumn.Name { + return exprCol.Arg + } + } + return value + } else if session.statement.ExprColumns.IsColExist(indexColumn.Name) { + for _, exprCol := range session.statement.ExprColumns { + if exprCol.ColName == indexColumn.Name { + return exprCol.Arg + } + } + } + + // FIXME: not sure if there's anything else we can do + return value }