diff --git a/internal/statements/upsert.go b/internal/statements/upsert.go index 21ad4154..637a10b2 100644 --- a/internal/statements/upsert.go +++ b/internal/statements/upsert.go @@ -131,13 +131,16 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args [] } var updateColumns []string + var updateArgs []interface{} if doUpdate { updateColumns = make([]string, 0, len(columns)) - for _, column := range columns { + updateArgs = make([]interface{}, 0, len(columns)) + for i, column := range columns { if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has { continue } updateColumns = append(updateColumns, quote(column)) + updateArgs = append(updateArgs, args[i]) } doUpdate = doUpdate && (len(updateColumns) > 0) } @@ -163,15 +166,39 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args [] write(")") if doUpdate { write(" WHEN MATCHED THEN UPDATE SET ") - write("src.", quote(updateColumns[0]), "= target.", quote(updateColumns[0])) - for _, col := range updateColumns[1:] { - write(", src.", quote(col), "= target.", quote(col)) + write("target.", quote(updateColumns[0]), " = ?") + buf.Append(updateArgs[0]) + for i, col := range updateColumns[1:] { + write(", target.", quote(col), " = ?") + buf.Append(updateArgs[i+1]) } } write(" WHEN NOT MATCHED THEN INSERT ") - if err := statement.genInsertValues(buf, columns, args); err != nil { + includeAutoIncrement := statement.includeAutoIncrement(columns) + if len(columns) == 0 { + write(" DEFAULT VALUES ") + } else { + // We have some values - Write the column names we need to insert: + write(" (") + if includeAutoIncrement { + columns = append(columns, table.AutoIncrement) + } + + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, statement.ExprColumns.ColNames()...), ","); err != nil { + return "", nil, err + } + + write(")") + if err := statement.genInsertValuesValues(buf, includeAutoIncrement, columns, args); err != nil { + return "", nil, err + } + + } + if err := statement.writeInsertOutput(buf.Builder, table); err != nil { return "", nil, err } + + write(";") return buf.String(), buf.Args(), nil } @@ -294,13 +321,22 @@ func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args } } var updateColumns []string + var updateArgs []interface{} if doUpdate { updateColumns = make([]string, 0, len(columns)) - for _, column := range append(columns, exprs.ColNames()...) { + for _, expr := range exprs { + if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(expr.ColName)]; has { + continue + } + updateColumns = append(updateColumns, quote(expr.ColName)) + updateArgs = append(updateArgs, expr.Arg) + } + for i, column := range columns { if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has { continue } updateColumns = append(updateColumns, quote(column)) + updateArgs = append(updateArgs, args[i]) } doUpdate = doUpdate && (len(updateColumns) > 0) } @@ -317,20 +353,21 @@ func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args } countUniques++ write("(") - write("src.", quote(index.Cols[0]), "= target.", quote(index.Cols[0])) + write("src.", quote(index.Cols[0]), " = target.", quote(index.Cols[0])) for _, col := range index.Cols[1:] { - write(" AND src.", quote(col), "= target.", quote(col)) + write(" AND src.", quote(col), " = target.", quote(col)) } write(")") } write(")") if doUpdate { write(" WHEN MATCHED THEN UPDATE SET ") - write("src.", quote(updateColumns[0]), "= target.", quote(updateColumns[0])) - for _, col := range updateColumns[1:] { - write(", src.", quote(col), "= target.", quote(col)) + write("target.", quote(updateColumns[0]), " = ?") + buf.Append(updateArgs[0]) + for i, col := range updateColumns[1:] { + write(", target.", quote(col), " = ?") + buf.Append(updateArgs[i+1]) } - } write(" WHEN NOT MATCHED THEN INSERT ") if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil { @@ -341,6 +378,7 @@ func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args if err := statement.genInsertValuesValues(buf, false, columns, args); err != nil { return "", nil, err } + write(";") return buf.String(), buf.Args(), nil }