diff --git a/internal/statements/upsert.go b/internal/statements/upsert.go index abcdaa14..01d7fe52 100644 --- a/internal/statements/upsert.go +++ b/internal/statements/upsert.go @@ -24,15 +24,29 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [ table = statement.RefTable tableName = statement.TableName() ) + quote := statement.dialect.Quoter().Quote + write := func(args ...string) { + for _, arg := range args { + _, _ = buf.WriteString(arg) + } + } + + var updateColumns []string + if doUpdate { + updateColumns = make([]string, 0, len(columns)) + for _, column := range columns { + if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has { + continue + } + updateColumns = append(updateColumns, quote(column)) + } + doUpdate = doUpdate && (len(updateColumns) > 0) + } if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate { - if _, err := buf.WriteString("INSERT IGNORE INTO "); err != nil { - return "", nil, err - } + write("INSERT IGNORE INTO ") } else { - if _, err := buf.WriteString("INSERT INTO "); err != nil { - return "", nil, err - } + write("INSERT INTO ") } if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { @@ -45,28 +59,31 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [ switch statement.dialect.URI().DBType { case schemas.SQLITE, schemas.POSTGRES: - if _, err := buf.WriteString(" ON CONFLICT DO "); err != nil { - return "", nil, err - } + write(" ON CONFLICT DO ") if doUpdate { - return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT - } else { - if _, err := buf.WriteString("NOTHING"); err != nil { - return "", nil, err + write("UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0]) + for _, column := range updateColumns[1:] { + write(", ", column, " = excluded.", column) } + } else { + write("NOTHING") } case schemas.MYSQL: if doUpdate { - return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT + // FIXME: mysql >= 8.0.19 should use table alias + write(" ON DUPLICATE KEY ") + write("UPDATE SET ", updateColumns[0], " = VALUES(", updateColumns[0], ")") + for _, column := range updateColumns[1:] { + write(", ", column, " = VALUES(", column, ")") + } + } default: return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT } - if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { - if _, err := buf.WriteString(" RETURNING "); err != nil { - return "", nil, err - } + if len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.POSTGRES || statement.dialect.URI().DBType == schemas.SQLITE) { + write(" RETURNING ") if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { return "", nil, err } @@ -91,9 +108,9 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args [] write("MERGE INTO ", quote(tableName)) if statement.dialect.URI().DBType == schemas.MSSQL { - write("WITH (HOLDLOCK) AS target ") + write(" WITH (HOLDLOCK)") } - write("USING (SELECT ") + write(" AS target USING (SELECT ") uniqueCols := make([]string, 0, len(uniqueColValMap)) for colName := range uniqueColValMap { @@ -108,6 +125,19 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args [] write(", ") } } + + var updateColumns []string + if doUpdate { + updateColumns = make([]string, 0, len(columns)) + for _, column := range columns { + if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has { + continue + } + updateColumns = append(updateColumns, quote(column)) + } + doUpdate = doUpdate && (len(updateColumns) > 0) + } + write(") AS src ON (") countUniques := 0 @@ -126,10 +156,15 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args [] } write(")") } + write(")") if doUpdate { - return "", nil, fmt.Errorf("unimplemented") + write(" WHEN MATCHED THEN UPDATE SET ") + write("src.", quote(updateColumns[0]), "= target.", quote(updateColumns[0])) + for _, col := range updateColumns[1:] { + write(", src.", quote(col), "= target.", quote(col)) + } } - write(") WHEN NOT MATCHED THEN INSERT") + write(" WHEN NOT MATCHED THEN INSERT ") if err := statement.genInsertValues(buf, columns, args); err != nil { return "", nil, err } @@ -155,6 +190,17 @@ func (statement *Statement) GenUpsertMapSQL(doUpdate bool, columns []string, arg _, _ = buf.WriteString(arg) } } + var updateColumns []string + if doUpdate { + updateColumns = make([]string, 0, len(columns)) + for _, column := range append(columns, exprs.ColNames()...) { + if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has { + continue + } + updateColumns = append(updateColumns, quote(column)) + } + doUpdate = doUpdate && (len(updateColumns) > 0) + } if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate { write("INSERT IGNORE INTO ", quote(tableName), " (") @@ -172,25 +218,29 @@ func (statement *Statement) GenUpsertMapSQL(doUpdate bool, columns []string, arg switch statement.dialect.URI().DBType { case schemas.SQLITE, schemas.POSTGRES: - if _, err := buf.WriteString(" ON CONFLICT DO "); err != nil { - return "", nil, err - } + write(" ON CONFLICT DO ") if doUpdate { - return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT - } else { - if _, err := buf.WriteString("NOTHING"); err != nil { - return "", nil, err + write("UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0]) + for _, column := range updateColumns[1:] { + write(", ", column, " = excluded.", column) } + } else { + write("NOTHING") } case schemas.MYSQL: if doUpdate { - return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT + // FIXME: mysql >= 8.0.19 should use table alias + write(" ON DUPLICATE KEY ") + write("UPDATE SET ", updateColumns[0], " = VALUES(", updateColumns[0], ")") + for _, column := range updateColumns[1:] { + write(", ", column, " = VALUES(", column, ")") + } } default: return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT } - if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { + if len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.POSTGRES || statement.dialect.URI().DBType == schemas.SQLITE) { if _, err := buf.WriteString(" RETURNING "); err != nil { return "", nil, err } @@ -219,9 +269,9 @@ func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args write("MERGE INTO ", quote(tableName)) if statement.dialect.URI().DBType == schemas.MSSQL { - write("WITH (HOLDLOCK) AS target ") + write(" WITH (HOLDLOCK)") } - write("USING (SELECT ") + write(" AS target USING (SELECT ") uniqueCols := make([]string, 0, len(uniqueColValMap)) for colName := range uniqueColValMap { @@ -236,6 +286,18 @@ func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args write(", ") } } + var updateColumns []string + if doUpdate { + updateColumns = make([]string, 0, len(columns)) + for _, column := range append(columns, exprs.ColNames()...) { + if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has { + continue + } + updateColumns = append(updateColumns, quote(column)) + } + doUpdate = doUpdate && (len(updateColumns) > 0) + } + write(") AS src ON (") countUniques := 0 @@ -254,10 +316,16 @@ func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args } write(")") } + write(")") if doUpdate { - return "", nil, fmt.Errorf("unimplemented") + write(" WHEN MATCHED THEN UPDATE SET ") + write("src.", quote(updateColumns[0]), "= target.", quote(updateColumns[0])) + for _, col := range updateColumns[1:] { + write(", src.", quote(col), "= target.", quote(col)) + } + } - write(") WHEN NOT MATCHED THEN INSERT") + write(" WHEN NOT MATCHED THEN INSERT ") if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil { return "", nil, err } diff --git a/session_upsert.go b/session_upsert.go index 695b7b69..e53d052c 100644 --- a/session_upsert.go +++ b/session_upsert.go @@ -37,7 +37,7 @@ func (session *Session) upsert(doUpdate bool, beans ...interface{}) (int64, erro switch v := bean.(type) { case map[string]interface{}: cnt, err = session.upsertMapInterface(doUpdate, v) - case []map[string]interface{}: + case []map[string]interface{}: // FIXME: handle multiple for _, m := range v { cnt, err := session.upsertMapInterface(doUpdate, m) if err != nil { @@ -47,7 +47,7 @@ func (session *Session) upsert(doUpdate bool, beans ...interface{}) (int64, erro } case map[string]string: cnt, err = session.upsertMapString(doUpdate, v) - case []map[string]string: + case []map[string]string: // FIXME: handle multiple for _, m := range v { cnt, err := session.upsertMapString(doUpdate, m) if err != nil { @@ -57,7 +57,7 @@ func (session *Session) upsert(doUpdate bool, beans ...interface{}) (int64, erro } default: sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - if sliceValue.Kind() == reflect.Slice { + if sliceValue.Kind() == reflect.Slice { // FIXME: handle multiple if sliceValue.Len() <= 0 { return 0, ErrNoElementsOnSlice } @@ -140,7 +140,7 @@ func (session *Session) upsertMap(doUpdate bool, columns []string, args []interf if err != nil { return 0, err } - return affected, fmt.Errorf("unimplemented") + return affected, nil } func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, error) { @@ -273,21 +273,37 @@ func (session *Session) getUniqueColumns(colNames []string, args []interface{}) } if indexColumn.MapType == schemas.ONLYFROMDB || indexColumn.IsAutoIncrement { - continue + continue indexCol } + // FIXME: what do we do here?! if session.statement.OmitColumnMap.Contain(indexColumn.Name) { - continue + continue indexCol } + // FIXME: what do we do here?! if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) { - continue + continue indexCol } // FIXME: what do we do here?! if session.statement.IncrColumns.IsColExist(indexColumn.Name) { - continue + for _, exprCol := range session.statement.IncrColumns { + if exprCol.ColName == indexColumn.Name { + uniqueColValMap[indexColumnName] = exprCol.Arg + } + } + continue indexCol } else if session.statement.DecrColumns.IsColExist(indexColumn.Name) { - continue + for _, exprCol := range session.statement.DecrColumns { + if exprCol.ColName == indexColumn.Name { + uniqueColValMap[indexColumnName] = exprCol.Arg + } + } + continue indexCol } else if session.statement.ExprColumns.IsColExist(indexColumn.Name) { - continue + for _, exprCol := range session.statement.ExprColumns { + if exprCol.ColName == indexColumn.Name { + uniqueColValMap[indexColumnName] = exprCol.Arg + } + } } // FIXME: not sure if there's anything else we can do