From 4d4311de20144f663df6b6f20e4d16ad974c5eff Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 30 Dec 2023 16:36:19 +0800 Subject: [PATCH] Fix test --- internal/statements/join.go | 32 ++++++-- internal/statements/update.go | 142 +++++++++++++++++++++++++--------- tests/session_update_test.go | 11 +-- 3 files changed, 137 insertions(+), 48 deletions(-) diff --git a/internal/statements/join.go b/internal/statements/join.go index 48e7403b..0a213c6b 100644 --- a/internal/statements/join.go +++ b/internal/statements/join.go @@ -34,13 +34,7 @@ func (statement *Statement) writeJoins(w *builder.BytesWriter) error { return nil } -func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error { - // write join operator - if _, err := fmt.Fprint(buf, " ", join.op, " JOIN"); err != nil { - return err - } - - // write join table or subquery +func (statement *Statement) writeJoinTable(buf *builder.BytesWriter, join join) error { switch tp := join.table.(type) { case builder.Builder: if _, err := fmt.Fprintf(buf, " ("); err != nil { @@ -87,6 +81,19 @@ func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error return err } } + return nil +} + +func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error { + // write join operator + if _, err := fmt.Fprint(buf, " ", join.op, " JOIN"); err != nil { + return err + } + + // write join table or subquery + if err := statement.writeJoinTable(buf, join); err != nil { + return err + } // write on condition if _, err := fmt.Fprint(buf, " ON "); err != nil { @@ -109,3 +116,14 @@ func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error return nil } + +func (statement *Statement) convertJoinCondition(join join) (builder.Cond, error) { + switch condTp := join.condition.(type) { + case string: + return builder.Expr(statement.ReplaceQuote(condTp), join.args...), nil + case builder.Cond: + return condTp, nil + default: + return nil, fmt.Errorf("unsupported join condition type: %v", condTp) + } +} diff --git a/internal/statements/update.go b/internal/statements/update.go index 673d3848..cc586de2 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -341,13 +341,41 @@ func (statement *Statement) writeUpdateTableName(updateWriter *builder.BytesWrit } } -func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) error { - if statement.dialect.URI().DBType != schemas.MSSQL || statement.TableAlias == "" { - return nil +func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) (builder.Cond, error) { + if _, err := fmt.Fprint(updateWriter, " FROM"); err != nil { + return nil, err } - _, err := fmt.Fprint(updateWriter, " FROM ", statement.quote(statement.TableName()), " ", statement.TableAlias) - return err + if statement.dialect.URI().DBType == schemas.MSSQL { + if _, err := fmt.Fprint(updateWriter, " ", statement.quote(statement.TableName())); err != nil { + return nil, err + } + if statement.TableAlias != "" { + if _, err := fmt.Fprint(updateWriter, " ", statement.TableAlias); err != nil { + return nil, err + } + } + } + + cond := builder.NewCond() + for i, join := range statement.joins { + if statement.dialect.URI().DBType == schemas.MSSQL || i > 0 { + if _, err := fmt.Fprint(updateWriter, ","); err != nil { + return nil, err + } + } + if err := statement.writeJoinTable(updateWriter, join); err != nil { + return nil, err + } + + joinCond, err := statement.convertJoinCondition(join) + if err != nil { + return nil, err + } + cond = cond.And(joinCond) + } + + return cond, nil } func (statement *Statement) writeWhereOrAnd(updateWriter *builder.BytesWriter, hasConditions bool) error { @@ -555,7 +583,7 @@ func (statement *Statement) writeSetColumns(colNames []string, args []any) func( return err } } - if len(statement.joins) > 0 { + if statement.dialect.URI().DBType == schemas.MSSQL && len(statement.joins) > 0 { if _, err := fmt.Fprint(w, statement.tableName, ".", colName); err != nil { return err } @@ -611,6 +639,40 @@ func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Va var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []any) error { + switch statement.dialect.URI().DBType { + case schemas.MYSQL: + return statement.writeUpdateMySQL(updateWriter, cond, v, colNames, args) + case schemas.MSSQL: + return statement.writeUpdateMSSQL(updateWriter, cond, v, colNames, args) + default: + return statement.writeUpdateCommon(updateWriter, cond, v, colNames, args) + } +} + +func (statement *Statement) writeUpdateMySQL(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []any) error { + if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil { + return err + } + if err := statement.writeUpdateTableName(updateWriter); err != nil { + return err + } + if err := statement.writeJoins(updateWriter); err != nil { + return err + } + if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil { + return err + } + // write where + if err := statement.writeWhereCond(updateWriter, cond); err != nil { + return err + } + if err := statement.writeOrderBys(updateWriter); err != nil { + return err + } + return statement.writeUpdateLimit(updateWriter, cond) +} + +func (statement *Statement) writeUpdateMSSQL(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []any) error { if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil { return err } @@ -623,48 +685,56 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond return err } - if statement.dialect.URI().DBType == schemas.MYSQL { - if err := statement.writeJoins(updateWriter); err != nil { + if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil { + return err + } + + // write from + joinConds, err := statement.writeUpdateFrom(updateWriter) + if err != nil { + return err + } + + table := statement.RefTable + if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { + } else { + // write where + if err := statement.writeWhereCond(updateWriter, cond.And(joinConds)); err != nil { return err } } + return statement.writeUpdateLimit(updateWriter, cond.And(joinConds)) +} + +// writeUpdateCommon write update sql for non mysql && non mssql +func (statement *Statement) writeUpdateCommon(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []any) error { + if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil { + return err + } + + if err := statement.writeUpdateTop(updateWriter); err != nil { + return err + } + + if err := statement.writeUpdateTableName(updateWriter); err != nil { + return err + } + if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil { return err } // write from - if err := statement.writeUpdateFrom(updateWriter); err != nil { + joinConds, err := statement.writeUpdateFrom(updateWriter) + if err != nil { return err } - if statement.dialect.URI().DBType != schemas.MYSQL { - if err := statement.writeJoins(updateWriter); err != nil { - return err - } + // write where + if err := statement.writeWhereCond(updateWriter, cond.And(joinConds)); err != nil { + return err } - if statement.dialect.URI().DBType == schemas.MSSQL { - table := statement.RefTable - if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { - } else { - // write where - if err := statement.writeWhereCond(updateWriter, cond); err != nil { - return err - } - } - } else { - // write where - if err := statement.writeWhereCond(updateWriter, cond); err != nil { - return err - } - } - - if statement.dialect.URI().DBType == schemas.MYSQL { - if err := statement.writeOrderBys(updateWriter); err != nil { - return err - } - } - - return statement.writeUpdateLimit(updateWriter, cond) + return statement.writeUpdateLimit(updateWriter, cond.And(joinConds)) } diff --git a/tests/session_update_test.go b/tests/session_update_test.go index 900f58b1..6b7bcd13 100644 --- a/tests/session_update_test.go +++ b/tests/session_update_test.go @@ -1473,11 +1473,6 @@ func TestNilFromDB(t *testing.T) { } func TestUpdateWithJoin(t *testing.T) { - if testEngine.Dialect().URI().DBType == schemas.SQLITE { - t.Skip() - return - } - type TestUpdateWithJoin struct { Id int64 ExtId int64 @@ -1499,6 +1494,12 @@ func TestUpdateWithJoin(t *testing.T) { _, err = testEngine.Insert(&TestUpdateWithJoin{ExtId: b.Id, Name: "test"}) assert.NoError(t, err) + _, err = testEngine.Table("test_update_with_join"). + Join("INNER", "test_update_with_join2 AS b", "test_update_with_join.ext_id = b.id"). + Where("b.`name` = ?", "test"). + Update(&TestUpdateWithJoin{Name: "test2"}) + assert.NoError(t, err) + _, err = testEngine.Table("test_update_with_join"). Join("INNER", "test_update_with_join2", "test_update_with_join.ext_id = test_update_with_join2.id"). Where("test_update_with_join2.`name` = ?", "test").