backport #2383 Reviewed-on: https://gitea.com/xorm/xorm/pulls/2383 Reviewed-on: https://gitea.com/xorm/xorm/pulls/2385
This commit is contained in:
parent
cc28d99161
commit
0398dee813
|
@ -34,13 +34,7 @@ func (statement *Statement) writeJoins(w *builder.BytesWriter) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error {
|
||||
// write join operator
|
||||
if _, err := fmt.Fprint(buf, " ", join.op, " JOIN"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// write join table or subquery
|
||||
func (statement *Statement) writeJoinTable(buf *builder.BytesWriter, join join) error {
|
||||
switch tp := join.table.(type) {
|
||||
case builder.Builder:
|
||||
if _, err := fmt.Fprintf(buf, " ("); err != nil {
|
||||
|
@ -87,6 +81,19 @@ func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error
|
|||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error {
|
||||
// write join operator
|
||||
if _, err := fmt.Fprint(buf, " ", join.op, " JOIN"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// write join table or subquery
|
||||
if err := statement.writeJoinTable(buf, join); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// write on condition
|
||||
if _, err := fmt.Fprint(buf, " ON "); err != nil {
|
||||
|
@ -109,3 +116,14 @@ func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) convertJoinCondition(join join) (builder.Cond, error) {
|
||||
switch condTp := join.condition.(type) {
|
||||
case string:
|
||||
return builder.Expr(statement.ReplaceQuote(condTp), join.args...), nil
|
||||
case builder.Cond:
|
||||
return condTp, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported join condition type: %v", condTp)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -341,13 +341,51 @@ func (statement *Statement) writeUpdateTableName(updateWriter *builder.BytesWrit
|
|||
}
|
||||
}
|
||||
|
||||
func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) error {
|
||||
if statement.dialect.URI().DBType != schemas.MSSQL || statement.TableAlias == "" {
|
||||
return nil
|
||||
func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) (builder.Cond, error) {
|
||||
if statement.dialect.URI().DBType == schemas.MSSQL {
|
||||
if _, err := fmt.Fprint(updateWriter, " FROM"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err := fmt.Fprint(updateWriter, " FROM ", statement.quote(statement.TableName()), " ", statement.TableAlias)
|
||||
return err
|
||||
if _, err := fmt.Fprint(updateWriter, " ", statement.quote(statement.TableName())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if statement.TableAlias != "" {
|
||||
if _, err := fmt.Fprint(updateWriter, " ", statement.TableAlias); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(statement.joins) == 0 {
|
||||
return builder.NewCond(), nil
|
||||
}
|
||||
|
||||
if statement.dialect.URI().DBType != schemas.MSSQL {
|
||||
if _, err := fmt.Fprint(updateWriter, " FROM"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cond := builder.NewCond()
|
||||
for i, join := range statement.joins {
|
||||
if statement.dialect.URI().DBType == schemas.MSSQL || i > 0 {
|
||||
if _, err := fmt.Fprint(updateWriter, ","); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := statement.writeJoinTable(updateWriter, join); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
joinCond, err := statement.convertJoinCondition(join)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cond = cond.And(joinCond)
|
||||
}
|
||||
|
||||
return cond, nil
|
||||
}
|
||||
|
||||
func (statement *Statement) writeWhereOrAnd(updateWriter *builder.BytesWriter, hasConditions bool) error {
|
||||
|
@ -555,16 +593,32 @@ func (statement *Statement) writeSetColumns(colNames []string, args []interface{
|
|||
return err
|
||||
}
|
||||
}
|
||||
if statement.dialect.URI().DBType != schemas.SQLITE && statement.dialect.URI().DBType != schemas.POSTGRES && len(statement.joins) > 0 {
|
||||
tbName := statement.TableAlias
|
||||
if tbName == "" {
|
||||
tbName = statement.TableName()
|
||||
}
|
||||
if _, err := fmt.Fprint(w, tbName, ".", colName); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := fmt.Fprint(w, colName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
w.Append(args...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error {
|
||||
// write set
|
||||
if _, err := fmt.Fprint(w, " SET "); err != nil {
|
||||
return err
|
||||
}
|
||||
previousLen := w.Len()
|
||||
|
||||
if err := statement.writeSetColumns(colNames, args)(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -588,12 +642,51 @@ func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Va
|
|||
if err := statement.writeVersionIncrSet(w, v, setNumber > 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if no columns to be updated, return error
|
||||
if previousLen == w.Len() {
|
||||
return ErrNoColumnsTobeUpdated
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
|
||||
|
||||
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error {
|
||||
switch statement.dialect.URI().DBType {
|
||||
case schemas.MYSQL:
|
||||
return statement.writeUpdateMySQL(updateWriter, cond, v, colNames, args)
|
||||
case schemas.MSSQL:
|
||||
return statement.writeUpdateMSSQL(updateWriter, cond, v, colNames, args)
|
||||
default:
|
||||
return statement.writeUpdateCommon(updateWriter, cond, v, colNames, args)
|
||||
}
|
||||
}
|
||||
|
||||
func (statement *Statement) writeUpdateMySQL(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error {
|
||||
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeUpdateTableName(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeJoins(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
|
||||
return err
|
||||
}
|
||||
// write where
|
||||
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeOrderBys(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
return statement.writeUpdateLimit(updateWriter, cond)
|
||||
}
|
||||
|
||||
func (statement *Statement) writeUpdateMSSQL(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error {
|
||||
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -606,47 +699,56 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond
|
|||
return err
|
||||
}
|
||||
|
||||
// write set
|
||||
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
|
||||
if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// write from
|
||||
joinConds, err := statement.writeUpdateFrom(updateWriter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
table := statement.RefTable
|
||||
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
|
||||
} else {
|
||||
// write where
|
||||
if err := statement.writeWhereCond(updateWriter, cond.And(joinConds)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return statement.writeUpdateLimit(updateWriter, cond.And(joinConds))
|
||||
}
|
||||
|
||||
// writeUpdateCommon write update sql for non mysql && non mssql
|
||||
func (statement *Statement) writeUpdateCommon(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error {
|
||||
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.writeUpdateTop(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.writeUpdateTableName(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
previousLen := updateWriter.Len()
|
||||
|
||||
if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if no columns to be updated, return error
|
||||
if previousLen == updateWriter.Len() {
|
||||
return ErrNoColumnsTobeUpdated
|
||||
}
|
||||
|
||||
// write from
|
||||
if err := statement.writeUpdateFrom(updateWriter); err != nil {
|
||||
joinConds, err := statement.writeUpdateFrom(updateWriter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if statement.dialect.URI().DBType == schemas.MSSQL {
|
||||
table := statement.RefTable
|
||||
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
|
||||
} else {
|
||||
// write where
|
||||
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
|
||||
if err := statement.writeWhereCond(updateWriter, cond.And(joinConds)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// write where
|
||||
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if statement.dialect.URI().DBType == schemas.MYSQL {
|
||||
if err := statement.writeOrderBys(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return statement.writeUpdateLimit(updateWriter, cond)
|
||||
return statement.writeUpdateLimit(updateWriter, cond.And(joinConds))
|
||||
}
|
||||
|
|
|
@ -1471,7 +1471,6 @@ func TestNilFromDB(t *testing.T) {
|
|||
assert.NotNil(t, tt4.Field1.cb)
|
||||
}
|
||||
|
||||
/*
|
||||
func TestUpdateWithJoin(t *testing.T) {
|
||||
type TestUpdateWithJoin struct {
|
||||
Id int64
|
||||
|
@ -1494,10 +1493,15 @@ func TestUpdateWithJoin(t *testing.T) {
|
|||
_, err = testEngine.Insert(&TestUpdateWithJoin{ExtId: b.Id, Name: "test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = testEngine.Table("test_update_with_join").
|
||||
Join("INNER", "test_update_with_join2 AS b", "test_update_with_join.ext_id = b.id").
|
||||
Where("b.`name` = ?", "test").
|
||||
Update(&TestUpdateWithJoin{Name: "test2"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = testEngine.Table("test_update_with_join").
|
||||
Join("INNER", "test_update_with_join2", "test_update_with_join.ext_id = test_update_with_join2.id").
|
||||
Where("test_update_with_join2.name = ?", "test").
|
||||
Where("test_update_with_join2.`name` = ?", "test").
|
||||
Update(&TestUpdateWithJoin{Name: "test2"})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
*/
|
||||
|
|
Loading…
Reference in New Issue