refactor write update sql
This commit is contained in:
parent
3626de1459
commit
01b903f996
|
@ -28,7 +28,7 @@ func TestUpdateMap(t *testing.T) {
|
|||
}
|
||||
|
||||
assert.NoError(t, testEngine.Sync(new(UpdateTable)))
|
||||
var tb = UpdateTable{
|
||||
tb := UpdateTable{
|
||||
Name: "test",
|
||||
Age: 35,
|
||||
}
|
||||
|
@ -79,7 +79,7 @@ func TestUpdateLimit(t *testing.T) {
|
|||
}
|
||||
|
||||
assert.NoError(t, testEngine.Sync(new(UpdateTable2)))
|
||||
var tb = UpdateTable2{
|
||||
tb := UpdateTable2{
|
||||
Name: "test1",
|
||||
Age: 35,
|
||||
}
|
||||
|
@ -400,7 +400,7 @@ func TestUpdate1(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, cnt)
|
||||
|
||||
var s = "test"
|
||||
s := "test"
|
||||
|
||||
col1 := &UpdateAllCols{Ptr: &s}
|
||||
err = testEngine.Sync(col1)
|
||||
|
@ -864,7 +864,7 @@ func TestCreatedUpdated2(t *testing.T) {
|
|||
|
||||
assertSync(t, new(CreatedUpdatedStruct))
|
||||
|
||||
var s = CreatedUpdatedStruct{
|
||||
s := CreatedUpdatedStruct{
|
||||
Name: "test",
|
||||
}
|
||||
cnt, err := testEngine.Insert(&s)
|
||||
|
@ -874,7 +874,7 @@ func TestCreatedUpdated2(t *testing.T) {
|
|||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
var s1 = CreatedUpdatedStruct{
|
||||
s1 := CreatedUpdatedStruct{
|
||||
Name: "test1",
|
||||
CreateAt: s.CreateAt,
|
||||
UpdateAt: s.UpdateAt,
|
||||
|
@ -907,7 +907,7 @@ func TestDeletedUpdate(t *testing.T) {
|
|||
|
||||
assertSync(t, new(DeletedUpdatedStruct))
|
||||
|
||||
var s = DeletedUpdatedStruct{
|
||||
s := DeletedUpdatedStruct{
|
||||
Name: "test",
|
||||
}
|
||||
cnt, err := testEngine.Insert(&s)
|
||||
|
@ -956,7 +956,7 @@ func TestUpdateMapCondition(t *testing.T) {
|
|||
|
||||
assertSync(t, new(UpdateMapCondition))
|
||||
|
||||
var c = UpdateMapCondition{
|
||||
c := UpdateMapCondition{
|
||||
String: "string",
|
||||
}
|
||||
_, err := testEngine.Insert(&c)
|
||||
|
@ -990,7 +990,7 @@ func TestUpdateMapContent(t *testing.T) {
|
|||
|
||||
assertSync(t, new(UpdateMapContent))
|
||||
|
||||
var c = UpdateMapContent{
|
||||
c := UpdateMapContent{
|
||||
Name: "lunny",
|
||||
IsMan: true,
|
||||
Gender: 1,
|
||||
|
@ -1126,7 +1126,7 @@ func TestUpdateDeleted(t *testing.T) {
|
|||
|
||||
assertSync(t, new(UpdateDeletedStruct))
|
||||
|
||||
var s = UpdateDeletedStruct{
|
||||
s := UpdateDeletedStruct{
|
||||
Name: "test",
|
||||
}
|
||||
cnt, err := testEngine.Insert(&s)
|
||||
|
@ -1232,7 +1232,7 @@ func TestUpdateExprs2(t *testing.T) {
|
|||
|
||||
assertSync(t, new(UpdateExprsRelease))
|
||||
|
||||
var uer = UpdateExprsRelease{
|
||||
uer := UpdateExprsRelease{
|
||||
RepoId: 1,
|
||||
IsTag: false,
|
||||
IsDraft: false,
|
||||
|
@ -1407,7 +1407,7 @@ func TestNilFromDB(t *testing.T) {
|
|||
assert.NoError(t, PrepareEngine())
|
||||
assertSync(t, new(TestTable1))
|
||||
|
||||
var tt0 = TestTable1{
|
||||
tt0 := TestTable1{
|
||||
Field1: &TestFieldType1{
|
||||
cb: []byte("string"),
|
||||
},
|
||||
|
@ -1437,7 +1437,7 @@ func TestNilFromDB(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, cnt)
|
||||
|
||||
var tt = TestTable1{
|
||||
tt := TestTable1{
|
||||
UpdateTime: time.Now(),
|
||||
Field1: &TestFieldType1{
|
||||
cb: nil,
|
||||
|
@ -1453,7 +1453,7 @@ func TestNilFromDB(t *testing.T) {
|
|||
assert.True(t, has)
|
||||
assert.Nil(t, tt2.Field1)
|
||||
|
||||
var tt3 = TestTable1{
|
||||
tt3 := TestTable1{
|
||||
UpdateTime: time.Now(),
|
||||
Field1: &TestFieldType1{
|
||||
cb: []byte{},
|
||||
|
@ -1470,3 +1470,34 @@ func TestNilFromDB(t *testing.T) {
|
|||
assert.NotNil(t, tt4.Field1)
|
||||
assert.NotNil(t, tt4.Field1.cb)
|
||||
}
|
||||
|
||||
/*
|
||||
func TestUpdateWithJoin(t *testing.T) {
|
||||
type TestUpdateWithJoin struct {
|
||||
Id int64
|
||||
ExtId int64
|
||||
Name string
|
||||
}
|
||||
|
||||
type TestUpdateWithJoin2 struct {
|
||||
Id int64
|
||||
Name string
|
||||
}
|
||||
|
||||
assert.NoError(t, PrepareEngine())
|
||||
assertSync(t, new(TestUpdateWithJoin), new(TestUpdateWithJoin2))
|
||||
|
||||
b := TestUpdateWithJoin2{Name: "test"}
|
||||
_, err := testEngine.Insert(&b)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = testEngine.Insert(&TestUpdateWithJoin{ExtId: b.Id, Name: "test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = testEngine.Table("test_update_with_join").
|
||||
Join("INNER", "test_update_with_join2", "test_update_with_join.ext_id = test_update_with_join2.id").
|
||||
Where("test_update_with_join2.name = ?", "test").
|
||||
Update(&TestUpdateWithJoin{Name: "test2"})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -243,14 +243,19 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr
|
|||
return err
|
||||
}
|
||||
|
||||
func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
|
||||
func (statement *Statement) writeWhereCond(w *builder.BytesWriter, cond builder.Cond) error {
|
||||
if !statement.cond.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
|
||||
return err
|
||||
}
|
||||
return statement.cond.WriteTo(statement.QuoteReplacer(w))
|
||||
return cond.WriteTo(statement.QuoteReplacer(w))
|
||||
}
|
||||
|
||||
func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
|
||||
return statement.writeWhereCond(w, statement.cond)
|
||||
}
|
||||
|
||||
func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error {
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"xorm.io/builder"
|
||||
|
@ -311,84 +310,177 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value,
|
|||
return colNames, args, nil
|
||||
}
|
||||
|
||||
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string) error {
|
||||
whereWriter := builder.NewWriter()
|
||||
if cond.IsValid() {
|
||||
fmt.Fprint(whereWriter, "WHERE ")
|
||||
func (statement *Statement) writeUpdateTop(updateWriter *builder.BytesWriter) error {
|
||||
if statement.dialect.URI().DBType != schemas.MSSQL || statement.LimitN == nil {
|
||||
return nil
|
||||
}
|
||||
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
|
||||
|
||||
table := statement.RefTable
|
||||
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := fmt.Fprintf(updateWriter, " TOP (%d)", *statement.LimitN)
|
||||
return err
|
||||
}
|
||||
|
||||
func (statement *Statement) writeUpdateTableName(updateWriter *builder.BytesWriter) error {
|
||||
tableName := statement.quote(statement.TableName())
|
||||
if statement.TableAlias == "" {
|
||||
_, err := fmt.Fprint(updateWriter, " ", tableName)
|
||||
return err
|
||||
}
|
||||
if err := statement.writeOrderBys(whereWriter); err != nil {
|
||||
|
||||
switch statement.dialect.URI().DBType {
|
||||
case schemas.MSSQL:
|
||||
_, err := fmt.Fprint(updateWriter, " ", statement.TableAlias)
|
||||
return err
|
||||
default:
|
||||
_, err := fmt.Fprint(updateWriter, " ", tableName, " AS ", statement.TableAlias)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) error {
|
||||
if statement.dialect.URI().DBType != schemas.MSSQL || statement.TableAlias == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := fmt.Fprint(updateWriter, " FROM ", statement.quote(statement.TableName()), " ", statement.TableAlias)
|
||||
return err
|
||||
}
|
||||
|
||||
func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, cond builder.Cond) error {
|
||||
if statement.LimitN == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
table := statement.RefTable
|
||||
tableName := statement.TableName()
|
||||
// TODO: Oracle support needed
|
||||
var top string
|
||||
if statement.LimitN != nil {
|
||||
|
||||
limitValue := *statement.LimitN
|
||||
switch statement.dialect.URI().DBType {
|
||||
case schemas.MYSQL:
|
||||
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
||||
_, err := fmt.Fprintf(updateWriter, " LIMIT %d", limitValue)
|
||||
return err
|
||||
case schemas.SQLITE:
|
||||
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
||||
|
||||
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
||||
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
|
||||
|
||||
whereWriter = builder.NewWriter()
|
||||
fmt.Fprint(whereWriter, "WHERE ")
|
||||
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
|
||||
return err
|
||||
}
|
||||
case schemas.POSTGRES:
|
||||
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
||||
|
||||
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
||||
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
|
||||
|
||||
whereWriter = builder.NewWriter()
|
||||
fmt.Fprint(whereWriter, "WHERE ")
|
||||
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
|
||||
return err
|
||||
}
|
||||
case schemas.MSSQL:
|
||||
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
|
||||
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
|
||||
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
|
||||
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)
|
||||
|
||||
whereWriter = builder.NewWriter()
|
||||
fmt.Fprint(whereWriter, "WHERE ")
|
||||
if err := cond.WriteTo(whereWriter); err != nil {
|
||||
if cond.IsValid() {
|
||||
if _, err := fmt.Fprint(updateWriter, " AND "); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
top = fmt.Sprintf("TOP (%d) ", limitValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tableAlias := statement.quote(tableName)
|
||||
var fromSQL string
|
||||
if statement.TableAlias != "" {
|
||||
switch statement.dialect.URI().DBType {
|
||||
case schemas.MSSQL:
|
||||
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, statement.TableAlias)
|
||||
tableAlias = statement.TableAlias
|
||||
default:
|
||||
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, statement.TableAlias)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v",
|
||||
top,
|
||||
tableAlias,
|
||||
strings.Join(colNames, ", "),
|
||||
fromSQL); err != nil {
|
||||
if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil {
|
||||
return err
|
||||
}
|
||||
return utils.WriteBuilder(updateWriter, whereWriter)
|
||||
}
|
||||
if _, err := fmt.Fprint(updateWriter, "rowid IN (SELECT rowid FROM ", statement.quote(tableName)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeOrderBys(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue)
|
||||
return err
|
||||
case schemas.POSTGRES:
|
||||
if cond.IsValid() {
|
||||
if _, err := fmt.Fprint(updateWriter, " AND "); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprint(updateWriter, "CTID IN (SELECT CTID FROM ", statement.quote(tableName)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeOrderBys(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue)
|
||||
return err
|
||||
case schemas.MSSQL:
|
||||
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
|
||||
if _, err := fmt.Fprintf(updateWriter, " WHERE %s IN (SELECT TOP (%d) %s FROM %v",
|
||||
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
|
||||
statement.quote(tableName)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeOrderBys(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := fmt.Fprint(updateWriter, ")")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
default: // TODO: Oracle support needed
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string) error {
|
||||
if _, err := fmt.Fprintf(updateWriter, "UPDATE "); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.writeUpdateTop(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.writeUpdateTableName(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// write set
|
||||
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
|
||||
return err
|
||||
}
|
||||
for i, colName := range colNames {
|
||||
if i > 0 {
|
||||
if _, err := fmt.Fprint(updateWriter, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprint(updateWriter, colName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// write from
|
||||
if err := statement.writeUpdateFrom(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if statement.dialect.URI().DBType == schemas.MSSQL {
|
||||
table := statement.RefTable
|
||||
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
|
||||
} else {
|
||||
// write where
|
||||
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// write where
|
||||
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if statement.dialect.URI().DBType == schemas.MYSQL {
|
||||
if err := statement.writeOrderBys(updateWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return statement.writeUpdateLimit(updateWriter, cond)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue