This commit is contained in:
Lunny Xiao 2023-12-30 16:36:19 +08:00
parent c3fb1bb5cb
commit 4d4311de20
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
3 changed files with 137 additions and 48 deletions

View File

@ -34,13 +34,7 @@ func (statement *Statement) writeJoins(w *builder.BytesWriter) error {
return nil
}
func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error {
// write join operator
if _, err := fmt.Fprint(buf, " ", join.op, " JOIN"); err != nil {
return err
}
// write join table or subquery
func (statement *Statement) writeJoinTable(buf *builder.BytesWriter, join join) error {
switch tp := join.table.(type) {
case builder.Builder:
if _, err := fmt.Fprintf(buf, " ("); err != nil {
@ -87,6 +81,19 @@ func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error
return err
}
}
return nil
}
func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error {
// write join operator
if _, err := fmt.Fprint(buf, " ", join.op, " JOIN"); err != nil {
return err
}
// write join table or subquery
if err := statement.writeJoinTable(buf, join); err != nil {
return err
}
// write on condition
if _, err := fmt.Fprint(buf, " ON "); err != nil {
@ -109,3 +116,14 @@ func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error
return nil
}
func (statement *Statement) convertJoinCondition(join join) (builder.Cond, error) {
switch condTp := join.condition.(type) {
case string:
return builder.Expr(statement.ReplaceQuote(condTp), join.args...), nil
case builder.Cond:
return condTp, nil
default:
return nil, fmt.Errorf("unsupported join condition type: %v", condTp)
}
}

View File

@ -341,13 +341,41 @@ func (statement *Statement) writeUpdateTableName(updateWriter *builder.BytesWrit
}
}
func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) error {
if statement.dialect.URI().DBType != schemas.MSSQL || statement.TableAlias == "" {
return nil
func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) (builder.Cond, error) {
if _, err := fmt.Fprint(updateWriter, " FROM"); err != nil {
return nil, err
}
_, err := fmt.Fprint(updateWriter, " FROM ", statement.quote(statement.TableName()), " ", statement.TableAlias)
return err
if statement.dialect.URI().DBType == schemas.MSSQL {
if _, err := fmt.Fprint(updateWriter, " ", statement.quote(statement.TableName())); err != nil {
return nil, err
}
if statement.TableAlias != "" {
if _, err := fmt.Fprint(updateWriter, " ", statement.TableAlias); err != nil {
return nil, err
}
}
}
cond := builder.NewCond()
for i, join := range statement.joins {
if statement.dialect.URI().DBType == schemas.MSSQL || i > 0 {
if _, err := fmt.Fprint(updateWriter, ","); err != nil {
return nil, err
}
}
if err := statement.writeJoinTable(updateWriter, join); err != nil {
return nil, err
}
joinCond, err := statement.convertJoinCondition(join)
if err != nil {
return nil, err
}
cond = cond.And(joinCond)
}
return cond, nil
}
func (statement *Statement) writeWhereOrAnd(updateWriter *builder.BytesWriter, hasConditions bool) error {
@ -555,7 +583,7 @@ func (statement *Statement) writeSetColumns(colNames []string, args []any) func(
return err
}
}
if len(statement.joins) > 0 {
if statement.dialect.URI().DBType == schemas.MSSQL && len(statement.joins) > 0 {
if _, err := fmt.Fprint(w, statement.tableName, ".", colName); err != nil {
return err
}
@ -611,6 +639,40 @@ func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Va
var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []any) error {
switch statement.dialect.URI().DBType {
case schemas.MYSQL:
return statement.writeUpdateMySQL(updateWriter, cond, v, colNames, args)
case schemas.MSSQL:
return statement.writeUpdateMSSQL(updateWriter, cond, v, colNames, args)
default:
return statement.writeUpdateCommon(updateWriter, cond, v, colNames, args)
}
}
func (statement *Statement) writeUpdateMySQL(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []any) error {
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
return err
}
if err := statement.writeUpdateTableName(updateWriter); err != nil {
return err
}
if err := statement.writeJoins(updateWriter); err != nil {
return err
}
if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
return err
}
// write where
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
}
if err := statement.writeOrderBys(updateWriter); err != nil {
return err
}
return statement.writeUpdateLimit(updateWriter, cond)
}
func (statement *Statement) writeUpdateMSSQL(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []any) error {
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
return err
}
@ -623,48 +685,56 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond
return err
}
if statement.dialect.URI().DBType == schemas.MYSQL {
if err := statement.writeJoins(updateWriter); err != nil {
if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
return err
}
// write from
joinConds, err := statement.writeUpdateFrom(updateWriter)
if err != nil {
return err
}
table := statement.RefTable
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
} else {
// write where
if err := statement.writeWhereCond(updateWriter, cond.And(joinConds)); err != nil {
return err
}
}
return statement.writeUpdateLimit(updateWriter, cond.And(joinConds))
}
// writeUpdateCommon write update sql for non mysql && non mssql
func (statement *Statement) writeUpdateCommon(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []any) error {
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
return err
}
if err := statement.writeUpdateTop(updateWriter); err != nil {
return err
}
if err := statement.writeUpdateTableName(updateWriter); err != nil {
return err
}
if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
return err
}
// write from
if err := statement.writeUpdateFrom(updateWriter); err != nil {
joinConds, err := statement.writeUpdateFrom(updateWriter)
if err != nil {
return err
}
if statement.dialect.URI().DBType != schemas.MYSQL {
if err := statement.writeJoins(updateWriter); err != nil {
return err
}
// write where
if err := statement.writeWhereCond(updateWriter, cond.And(joinConds)); err != nil {
return err
}
if statement.dialect.URI().DBType == schemas.MSSQL {
table := statement.RefTable
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
} else {
// write where
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
}
}
} else {
// write where
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
}
}
if statement.dialect.URI().DBType == schemas.MYSQL {
if err := statement.writeOrderBys(updateWriter); err != nil {
return err
}
}
return statement.writeUpdateLimit(updateWriter, cond)
return statement.writeUpdateLimit(updateWriter, cond.And(joinConds))
}

View File

@ -1473,11 +1473,6 @@ func TestNilFromDB(t *testing.T) {
}
func TestUpdateWithJoin(t *testing.T) {
if testEngine.Dialect().URI().DBType == schemas.SQLITE {
t.Skip()
return
}
type TestUpdateWithJoin struct {
Id int64
ExtId int64
@ -1499,6 +1494,12 @@ func TestUpdateWithJoin(t *testing.T) {
_, err = testEngine.Insert(&TestUpdateWithJoin{ExtId: b.Id, Name: "test"})
assert.NoError(t, err)
_, err = testEngine.Table("test_update_with_join").
Join("INNER", "test_update_with_join2 AS b", "test_update_with_join.ext_id = b.id").
Where("b.`name` = ?", "test").
Update(&TestUpdateWithJoin{Name: "test2"})
assert.NoError(t, err)
_, err = testEngine.Table("test_update_with_join").
Join("INNER", "test_update_with_join2", "test_update_with_join.ext_id = test_update_with_join2.id").
Where("test_update_with_join2.`name` = ?", "test").