refactor write update sql (#2304)

Reviewed-on: https://gitea.com/xorm/xorm/pulls/2304
This commit is contained in:
Lunny Xiao 2023-07-24 07:57:05 +00:00
parent 24a672be3c
commit a13564976c
4 changed files with 208 additions and 79 deletions

View File

@ -28,7 +28,7 @@ func TestUpdateMap(t *testing.T) {
}
assert.NoError(t, testEngine.Sync(new(UpdateTable)))
var tb = UpdateTable{
tb := UpdateTable{
Name: "test",
Age: 35,
}
@ -79,7 +79,7 @@ func TestUpdateLimit(t *testing.T) {
}
assert.NoError(t, testEngine.Sync(new(UpdateTable2)))
var tb = UpdateTable2{
tb := UpdateTable2{
Name: "test1",
Age: 35,
}
@ -400,7 +400,7 @@ func TestUpdate1(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var s = "test"
s := "test"
col1 := &UpdateAllCols{Ptr: &s}
err = testEngine.Sync(col1)
@ -864,7 +864,7 @@ func TestCreatedUpdated2(t *testing.T) {
assertSync(t, new(CreatedUpdatedStruct))
var s = CreatedUpdatedStruct{
s := CreatedUpdatedStruct{
Name: "test",
}
cnt, err := testEngine.Insert(&s)
@ -874,7 +874,7 @@ func TestCreatedUpdated2(t *testing.T) {
time.Sleep(time.Second)
var s1 = CreatedUpdatedStruct{
s1 := CreatedUpdatedStruct{
Name: "test1",
CreateAt: s.CreateAt,
UpdateAt: s.UpdateAt,
@ -907,7 +907,7 @@ func TestDeletedUpdate(t *testing.T) {
assertSync(t, new(DeletedUpdatedStruct))
var s = DeletedUpdatedStruct{
s := DeletedUpdatedStruct{
Name: "test",
}
cnt, err := testEngine.Insert(&s)
@ -956,7 +956,7 @@ func TestUpdateMapCondition(t *testing.T) {
assertSync(t, new(UpdateMapCondition))
var c = UpdateMapCondition{
c := UpdateMapCondition{
String: "string",
}
_, err := testEngine.Insert(&c)
@ -990,7 +990,7 @@ func TestUpdateMapContent(t *testing.T) {
assertSync(t, new(UpdateMapContent))
var c = UpdateMapContent{
c := UpdateMapContent{
Name: "lunny",
IsMan: true,
Gender: 1,
@ -1126,7 +1126,7 @@ func TestUpdateDeleted(t *testing.T) {
assertSync(t, new(UpdateDeletedStruct))
var s = UpdateDeletedStruct{
s := UpdateDeletedStruct{
Name: "test",
}
cnt, err := testEngine.Insert(&s)
@ -1232,7 +1232,7 @@ func TestUpdateExprs2(t *testing.T) {
assertSync(t, new(UpdateExprsRelease))
var uer = UpdateExprsRelease{
uer := UpdateExprsRelease{
RepoId: 1,
IsTag: false,
IsDraft: false,
@ -1407,7 +1407,7 @@ func TestNilFromDB(t *testing.T) {
assert.NoError(t, PrepareEngine())
assertSync(t, new(TestTable1))
var tt0 = TestTable1{
tt0 := TestTable1{
Field1: &TestFieldType1{
cb: []byte("string"),
},
@ -1437,7 +1437,7 @@ func TestNilFromDB(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var tt = TestTable1{
tt := TestTable1{
UpdateTime: time.Now(),
Field1: &TestFieldType1{
cb: nil,
@ -1453,7 +1453,7 @@ func TestNilFromDB(t *testing.T) {
assert.True(t, has)
assert.Nil(t, tt2.Field1)
var tt3 = TestTable1{
tt3 := TestTable1{
UpdateTime: time.Now(),
Field1: &TestFieldType1{
cb: []byte{},
@ -1470,3 +1470,34 @@ func TestNilFromDB(t *testing.T) {
assert.NotNil(t, tt4.Field1)
assert.NotNil(t, tt4.Field1.cb)
}
/*
func TestUpdateWithJoin(t *testing.T) {
type TestUpdateWithJoin struct {
Id int64
ExtId int64
Name string
}
type TestUpdateWithJoin2 struct {
Id int64
Name string
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(TestUpdateWithJoin), new(TestUpdateWithJoin2))
b := TestUpdateWithJoin2{Name: "test"}
_, err := testEngine.Insert(&b)
assert.NoError(t, err)
_, err = testEngine.Insert(&TestUpdateWithJoin{ExtId: b.Id, Name: "test"})
assert.NoError(t, err)
_, err = testEngine.Table("test_update_with_join").
Join("INNER", "test_update_with_join2", "test_update_with_join.ext_id = test_update_with_join2.id").
Where("test_update_with_join2.name = ?", "test").
Update(&TestUpdateWithJoin{Name: "test2"})
assert.NoError(t, err)
}
*/

View File

@ -243,14 +243,19 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr
return err
}
func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
if !statement.cond.IsValid() {
func (statement *Statement) writeWhereCond(w *builder.BytesWriter, cond builder.Cond) error {
if !cond.IsValid() {
return nil
}
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
return statement.cond.WriteTo(statement.QuoteReplacer(w))
return cond.WriteTo(statement.QuoteReplacer(w))
}
func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
return statement.writeWhereCond(w, statement.cond)
}
func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error {

View File

@ -9,7 +9,6 @@ import (
"errors"
"fmt"
"reflect"
"strings"
"time"
"xorm.io/builder"
@ -311,84 +310,178 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value,
return colNames, args, nil
}
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string) error {
whereWriter := builder.NewWriter()
if cond.IsValid() {
fmt.Fprint(whereWriter, "WHERE ")
func (statement *Statement) writeUpdateTop(updateWriter *builder.BytesWriter) error {
if statement.dialect.URI().DBType != schemas.MSSQL || statement.LimitN == nil {
return nil
}
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
table := statement.RefTable
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
return nil
}
_, err := fmt.Fprintf(updateWriter, " TOP (%d)", *statement.LimitN)
return err
}
func (statement *Statement) writeUpdateTableName(updateWriter *builder.BytesWriter) error {
tableName := statement.quote(statement.TableName())
if statement.TableAlias == "" {
_, err := fmt.Fprint(updateWriter, " ", tableName)
return err
}
if err := statement.writeOrderBys(whereWriter); err != nil {
switch statement.dialect.URI().DBType {
case schemas.MSSQL:
_, err := fmt.Fprint(updateWriter, " ", statement.TableAlias)
return err
default:
_, err := fmt.Fprint(updateWriter, " ", tableName, " AS ", statement.TableAlias)
return err
}
}
func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) error {
if statement.dialect.URI().DBType != schemas.MSSQL || statement.TableAlias == "" {
return nil
}
_, err := fmt.Fprint(updateWriter, " FROM ", statement.quote(statement.TableName()), " ", statement.TableAlias)
return err
}
func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, cond builder.Cond) error {
if statement.LimitN == nil {
return nil
}
table := statement.RefTable
tableName := statement.TableName()
// TODO: Oracle support needed
var top string
if statement.LimitN != nil {
limitValue := *statement.LimitN
switch statement.dialect.URI().DBType {
case schemas.MYSQL:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
_, err := fmt.Fprintf(updateWriter, " LIMIT %d", limitValue)
return err
case schemas.SQLITE:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
return err
}
case schemas.POSTGRES:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
return err
}
case schemas.MSSQL:
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
if cond.IsValid() {
if _, err := fmt.Fprint(updateWriter, " AND "); err != nil {
return err
}
} else {
top = fmt.Sprintf("TOP (%d) ", limitValue)
}
}
}
tableAlias := statement.quote(tableName)
var fromSQL string
if statement.TableAlias != "" {
switch statement.dialect.URI().DBType {
case schemas.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, statement.TableAlias)
tableAlias = statement.TableAlias
default:
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, statement.TableAlias)
}
}
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v",
top,
tableAlias,
strings.Join(colNames, ", "),
fromSQL); err != nil {
if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil {
return err
}
return utils.WriteBuilder(updateWriter, whereWriter)
}
if _, err := fmt.Fprint(updateWriter, "rowid IN (SELECT rowid FROM ", statement.quote(tableName)); err != nil {
return err
}
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
}
if err := statement.writeOrderBys(updateWriter); err != nil {
return err
}
_, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue)
return err
case schemas.POSTGRES:
if cond.IsValid() {
if _, err := fmt.Fprint(updateWriter, " AND "); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil {
return err
}
}
if _, err := fmt.Fprint(updateWriter, "CTID IN (SELECT CTID FROM ", statement.quote(tableName)); err != nil {
return err
}
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
}
if err := statement.writeOrderBys(updateWriter); err != nil {
return err
}
_, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue)
return err
case schemas.MSSQL:
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
if _, err := fmt.Fprintf(updateWriter, " WHERE %s IN (SELECT TOP (%d) %s FROM %v",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
statement.quote(tableName)); err != nil {
return err
}
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
}
if err := statement.writeOrderBys(updateWriter); err != nil {
return err
}
_, err := fmt.Fprint(updateWriter, ")")
return err
}
return nil
default: // TODO: Oracle support needed
return fmt.Errorf("not implemented")
}
}
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string, args []interface{}) error {
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
return err
}
if err := statement.writeUpdateTop(updateWriter); err != nil {
return err
}
if err := statement.writeUpdateTableName(updateWriter); err != nil {
return err
}
// write set
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
return err
}
for i, colName := range colNames {
if i > 0 {
if _, err := fmt.Fprint(updateWriter, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(updateWriter, colName); err != nil {
return err
}
}
updateWriter.Append(args...)
// write from
if err := statement.writeUpdateFrom(updateWriter); err != nil {
return err
}
if statement.dialect.URI().DBType == schemas.MSSQL {
table := statement.RefTable
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
} else {
// write where
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
}
}
} else {
// write where
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
}
}
if statement.dialect.URI().DBType == schemas.MYSQL {
if err := statement.writeOrderBys(updateWriter); err != nil {
return err
}
}
return statement.writeUpdateLimit(updateWriter, cond)
}

View File

@ -227,14 +227,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
updateWriter := builder.NewWriter()
if err := session.statement.WriteUpdate(updateWriter, cond, colNames); err != nil {
if err := session.statement.WriteUpdate(updateWriter, cond, colNames, args); err != nil {
return 0, err
}
tableName := session.statement.TableName() // table name must been get before exec because statement will be reset
useCache := session.statement.UseCache
res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...)
res, err := session.exec(updateWriter.String(), updateWriter.Args()...)
if err != nil {
return 0, err
} else if doIncVer {