Merge branch 'master' into lunny/add_test

This commit is contained in:
Lunny Xiao 2023-07-26 09:15:13 +08:00
commit aa6e8b6c14
11 changed files with 710 additions and 160 deletions

View File

@ -15,6 +15,7 @@ import (
) )
// String2Time converts a string to time with original location // String2Time converts a string to time with original location
// be aware for time strings (HH:mm:ss) returns zero year (LMT) for converted location
func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) { func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) {
if len(s) == 19 { if len(s) == 19 {
if s == utils.ZeroTime0 || s == utils.ZeroTime1 { if s == utils.ZeroTime0 || s == utils.ZeroTime1 {
@ -32,6 +33,7 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t
return nil, err return nil, err
} }
dt = dt.In(convertedLocation) dt = dt.In(convertedLocation)
dt.IsZero()
return &dt, nil return &dt, nil
} else if len(s) == 25 && s[10] == 'T' && s[19] == '+' && s[22] == ':' { } else if len(s) == 25 && s[10] == 'T' && s[19] == '+' && s[22] == ':' {
dt, err := time.Parse(time.RFC3339, s) dt, err := time.Parse(time.RFC3339, s)
@ -48,7 +50,7 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t
dt = dt.In(convertedLocation) dt = dt.In(convertedLocation)
return &dt, nil return &dt, nil
} else if len(s) >= 21 && s[19] == '.' { } else if len(s) >= 21 && s[19] == '.' {
var layout = "2006-01-02 15:04:05." + strings.Repeat("0", len(s)-20) layout := "2006-01-02 15:04:05." + strings.Repeat("0", len(s)-20)
dt, err := time.ParseInLocation(layout, s, originalLocation) dt, err := time.ParseInLocation(layout, s, originalLocation)
if err != nil { if err != nil {
return nil, err return nil, err
@ -65,6 +67,18 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t
} }
dt = dt.In(convertedLocation) dt = dt.In(convertedLocation)
return &dt, nil return &dt, nil
} else if len(s) == 8 && s[2] == ':' && s[5] == ':' {
currentDate := time.Now()
dt, err := time.ParseInLocation("15:04:05", s, originalLocation)
if err != nil {
return nil, err
}
// add current date for correct time locations
dt = dt.AddDate(currentDate.Year(), int(currentDate.Month()), currentDate.Day())
dt = dt.In(convertedLocation)
// back to zero year
dt = dt.AddDate(-currentDate.Year(), int(-currentDate.Month()), -currentDate.Day())
return &dt, nil
} else { } else {
i, err := strconv.ParseInt(s, 10, 64) i, err := strconv.ParseInt(s, 10, 64)
if err == nil { if err == nil {

View File

@ -15,7 +15,7 @@ func TestString2Time(t *testing.T) {
expectedLoc, err := time.LoadLocation("Asia/Shanghai") expectedLoc, err := time.LoadLocation("Asia/Shanghai")
assert.NoError(t, err) assert.NoError(t, err)
var kases = map[string]time.Time{ cases := map[string]time.Time{
"2021-08-10": time.Date(2021, 8, 10, 8, 0, 0, 0, expectedLoc), "2021-08-10": time.Date(2021, 8, 10, 8, 0, 0, 0, expectedLoc),
"2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc), "2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc),
"2021-07-11 10:44:00.999": time.Date(2021, 7, 11, 18, 44, 0, 999000000, expectedLoc), "2021-07-11 10:44:00.999": time.Date(2021, 7, 11, 18, 44, 0, 999000000, expectedLoc),
@ -25,12 +25,13 @@ func TestString2Time(t *testing.T) {
"2021-06-06T22:58:20.999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999000000, expectedLoc), "2021-06-06T22:58:20.999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999000000, expectedLoc),
"2021-06-06T22:58:20.999999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999999000, expectedLoc), "2021-06-06T22:58:20.999999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999999000, expectedLoc),
"2021-06-06T22:58:20.999999999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999999999, expectedLoc), "2021-06-06T22:58:20.999999999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999999999, expectedLoc),
"2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 04, 0, expectedLoc), "2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 0o4, 0, expectedLoc),
"2021-08-10T10:33:04.999Z": time.Date(2021, 8, 10, 18, 33, 04, 999000000, expectedLoc), "2021-08-10T10:33:04.999Z": time.Date(2021, 8, 10, 18, 33, 0o4, 999000000, expectedLoc),
"2021-08-10T10:33:04.999999Z": time.Date(2021, 8, 10, 18, 33, 04, 999999000, expectedLoc), "2021-08-10T10:33:04.999999Z": time.Date(2021, 8, 10, 18, 33, 0o4, 999999000, expectedLoc),
"2021-08-10T10:33:04.999999999Z": time.Date(2021, 8, 10, 18, 33, 04, 999999999, expectedLoc), "2021-08-10T10:33:04.999999999Z": time.Date(2021, 8, 10, 18, 33, 0o4, 999999999, expectedLoc),
"10:22:33": time.Date(0, 1, 1, 18, 22, 33, 0, expectedLoc),
} }
for layout, tm := range kases { for layout, tm := range cases {
t.Run(layout, func(t *testing.T) { t.Run(layout, func(t *testing.T) {
target, err := String2Time(layout, time.UTC, expectedLoc) target, err := String2Time(layout, time.UTC, expectedLoc)
assert.NoError(t, err) assert.NoError(t, err)

190
dialects/time_test.go Normal file
View File

@ -0,0 +1,190 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"xorm.io/xorm/schemas"
)
type dialect struct {
Dialect
dbType schemas.DBType
}
func (d dialect) URI() *URI {
return &URI{
DBType: d.dbType,
}
}
func TestFormatColumnTime(t *testing.T) {
date := time.Date(2020, 10, 23, 10, 14, 15, 123456, time.Local)
tests := []struct {
name string
dialect Dialect
location *time.Location
column *schemas.Column
time time.Time
wantRes interface{}
wantErr error
}{
{
name: "nullable",
dialect: nil,
location: nil,
column: &schemas.Column{Nullable: true},
time: time.Time{},
wantRes: nil,
wantErr: nil,
},
{
name: "invalid sqltype",
dialect: nil,
location: nil,
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Bit}},
time: time.Time{},
wantRes: 0,
wantErr: nil,
},
{
name: "return default",
dialect: nil,
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Bit}},
time: date,
wantRes: date,
wantErr: nil,
},
{
name: "return default (set timezone)",
dialect: nil,
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Bit}, TimeZone: time.UTC},
time: date,
wantRes: date.In(time.UTC),
wantErr: nil,
},
{
name: "format date",
dialect: nil,
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Date}},
time: date,
wantRes: date.Format("2006-01-02"),
wantErr: nil,
},
{
name: "format time",
dialect: nil,
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Time}},
time: date,
wantRes: date.Format("15:04:05"),
wantErr: nil,
},
{
name: "format time (set length)",
dialect: nil,
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Time}, Length: 64},
time: date,
wantRes: date.Format("15:04:05.999999999"),
wantErr: nil,
},
{
name: "format datetime",
dialect: nil,
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.DateTime}},
time: date,
wantRes: date.Format("2006-01-02 15:04:05"),
wantErr: nil,
},
{
name: "format datetime (set length)",
dialect: nil,
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.DateTime}, Length: 64},
time: date,
wantRes: date.Format("2006-01-02 15:04:05.999999999"),
wantErr: nil,
},
{
name: "format timestamp",
dialect: nil,
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.TimeStamp}},
time: date,
wantRes: date.Format("2006-01-02 15:04:05"),
wantErr: nil,
},
{
name: "format timestamp (set length)",
dialect: nil,
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.TimeStamp}, Length: 64},
time: date,
wantRes: date.Format("2006-01-02 15:04:05.999999999"),
wantErr: nil,
},
{
name: "format varchar",
dialect: nil,
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Varchar}},
time: date,
wantRes: date.Format("2006-01-02 15:04:05"),
wantErr: nil,
},
{
name: "format timestampz",
dialect: dialect{},
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.TimeStampz}},
time: date,
wantRes: date.Format(time.RFC3339Nano),
wantErr: nil,
},
{
name: "format timestampz (mssql)",
dialect: dialect{dbType: schemas.MSSQL},
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.TimeStampz}},
time: date,
wantRes: date.Format("2006-01-02T15:04:05.9999999Z07:00"),
wantErr: nil,
},
{
name: "format int",
dialect: nil,
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Int}},
time: date,
wantRes: date.Unix(),
wantErr: nil,
},
{
name: "format bigint",
dialect: nil,
location: date.Location(),
column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.BigInt}},
time: date,
wantRes: date.Unix(),
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := FormatColumnTime(tt.dialect, tt.location, tt.column, tt.time)
assert.Equal(t, tt.wantErr, err)
assert.Equal(t, tt.wantRes, got)
})
}
}

View File

@ -125,6 +125,11 @@ func TestWithTableName(t *testing.T) {
total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{}) total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, total) assert.EqualValues(t, 2, total)
// the orderby will be ignored by count because some databases will return errors if the orderby columns not in group by
total, err = testEngine.OrderBy("`name`").Count(CountWithTableName{})
assert.NoError(t, err)
assert.EqualValues(t, 2, total)
} }
func TestCountWithSelectCols(t *testing.T) { func TestCountWithSelectCols(t *testing.T) {

View File

@ -30,7 +30,7 @@ func TestQueryString(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVar2))) assert.NoError(t, testEngine.Sync(new(GetVar2)))
var data = GetVar2{ data := GetVar2{
Msg: "hi", Msg: "hi",
Age: 28, Age: 28,
Money: 1.5, Money: 1.5,
@ -58,7 +58,7 @@ func TestQueryString2(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVar3))) assert.NoError(t, testEngine.Sync(new(GetVar3)))
var data = GetVar3{ data := GetVar3{
Msg: false, Msg: false,
} }
_, err := testEngine.Insert(data) _, err := testEngine.Insert(data)
@ -95,7 +95,7 @@ func TestQueryInterface(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVarInterface))) assert.NoError(t, testEngine.Sync(new(GetVarInterface)))
var data = GetVarInterface{ data := GetVarInterface{
Msg: "hi", Msg: "hi",
Age: 28, Age: 28,
Money: 1.5, Money: 1.5,
@ -128,7 +128,7 @@ func TestQueryNoParams(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(QueryNoParams))) assert.NoError(t, testEngine.Sync(new(QueryNoParams)))
var q = QueryNoParams{ q := QueryNoParams{
Msg: "message", Msg: "message",
Age: 20, Age: 20,
Money: 3000, Money: 3000,
@ -172,7 +172,7 @@ func TestQueryStringNoParam(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVar4))) assert.NoError(t, testEngine.Sync(new(GetVar4)))
var data = GetVar4{ data := GetVar4{
Msg: false, Msg: false,
} }
_, err := testEngine.Insert(data) _, err := testEngine.Insert(data)
@ -209,7 +209,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVar6))) assert.NoError(t, testEngine.Sync(new(GetVar6)))
var data = GetVar6{ data := GetVar6{
Msg: false, Msg: false,
} }
_, err := testEngine.Insert(data) _, err := testEngine.Insert(data)
@ -246,7 +246,7 @@ func TestQueryInterfaceNoParam(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVar5))) assert.NoError(t, testEngine.Sync(new(GetVar5)))
var data = GetVar5{ data := GetVar5{
Msg: false, Msg: false,
} }
_, err := testEngine.Insert(data) _, err := testEngine.Insert(data)
@ -280,7 +280,7 @@ func TestQueryWithBuilder(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(QueryWithBuilder))) assert.NoError(t, testEngine.Sync(new(QueryWithBuilder)))
var q = QueryWithBuilder{ q := QueryWithBuilder{
Msg: "message", Msg: "message",
Age: 20, Age: 20,
Money: 3000, Money: 3000,
@ -329,14 +329,14 @@ func TestJoinWithSubQuery(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart))) assert.NoError(t, testEngine.Sync(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart)))
var depart = JoinWithSubQueryDepart{ depart := JoinWithSubQueryDepart{
Name: "depart1", Name: "depart1",
} }
cnt, err := testEngine.Insert(&depart) cnt, err := testEngine.Insert(&depart)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var q = JoinWithSubQuery1{ q := JoinWithSubQuery1{
Msg: "message", Msg: "message",
DepartId: depart.Id, DepartId: depart.Id,
Money: 3000, Money: 3000,
@ -401,7 +401,7 @@ func TestQueryBLOBInMySQL(t *testing.T) {
} }
const N = 10 const N = 10
var data = []Avatar{} data := []Avatar{}
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
// allocate a []byte that is as twice big as the last one // allocate a []byte that is as twice big as the last one
// so that the underlying buffer will need to reallocate when querying // so that the underlying buffer will need to reallocate when querying
@ -448,3 +448,54 @@ func TestQueryBLOBInMySQL(t *testing.T) {
} }
} }
} }
func TestRowsReset(t *testing.T) {
assert.NoError(t, PrepareEngine())
type RowsReset1 struct {
Id int64
Name string
}
type RowsReset2 struct {
Id int64
Name string
}
assert.NoError(t, testEngine.Sync(new(RowsReset1), new(RowsReset2)))
data := []RowsReset1{
{0, "1"},
{0, "2"},
{0, "3"},
}
_, err := testEngine.Insert(data)
assert.NoError(t, err)
data2 := []RowsReset2{
{0, "4"},
{0, "5"},
{0, "6"},
}
_, err = testEngine.Insert(data2)
assert.NoError(t, err)
sess := testEngine.NewSession()
defer sess.Close()
rows, err := sess.Rows(new(RowsReset1))
assert.NoError(t, err)
for rows.Next() {
var data1 RowsReset1
assert.NoError(t, rows.Scan(&data1))
}
rows.Close()
var rrs []RowsReset2
assert.NoError(t, sess.Find(&rrs))
assert.Len(t, rrs, 3)
assert.EqualValues(t, "4", rrs[0].Name)
assert.EqualValues(t, "5", rrs[1].Name)
assert.EqualValues(t, "6", rrs[2].Name)
}

View File

@ -28,7 +28,7 @@ func TestUpdateMap(t *testing.T) {
} }
assert.NoError(t, testEngine.Sync(new(UpdateTable))) assert.NoError(t, testEngine.Sync(new(UpdateTable)))
var tb = UpdateTable{ tb := UpdateTable{
Name: "test", Name: "test",
Age: 35, Age: 35,
} }
@ -79,7 +79,7 @@ func TestUpdateLimit(t *testing.T) {
} }
assert.NoError(t, testEngine.Sync(new(UpdateTable2))) assert.NoError(t, testEngine.Sync(new(UpdateTable2)))
var tb = UpdateTable2{ tb := UpdateTable2{
Name: "test1", Name: "test1",
Age: 35, Age: 35,
} }
@ -400,7 +400,7 @@ func TestUpdate1(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var s = "test" s := "test"
col1 := &UpdateAllCols{Ptr: &s} col1 := &UpdateAllCols{Ptr: &s}
err = testEngine.Sync(col1) err = testEngine.Sync(col1)
@ -864,7 +864,7 @@ func TestCreatedUpdated2(t *testing.T) {
assertSync(t, new(CreatedUpdatedStruct)) assertSync(t, new(CreatedUpdatedStruct))
var s = CreatedUpdatedStruct{ s := CreatedUpdatedStruct{
Name: "test", Name: "test",
} }
cnt, err := testEngine.Insert(&s) cnt, err := testEngine.Insert(&s)
@ -874,7 +874,7 @@ func TestCreatedUpdated2(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
var s1 = CreatedUpdatedStruct{ s1 := CreatedUpdatedStruct{
Name: "test1", Name: "test1",
CreateAt: s.CreateAt, CreateAt: s.CreateAt,
UpdateAt: s.UpdateAt, UpdateAt: s.UpdateAt,
@ -907,7 +907,7 @@ func TestDeletedUpdate(t *testing.T) {
assertSync(t, new(DeletedUpdatedStruct)) assertSync(t, new(DeletedUpdatedStruct))
var s = DeletedUpdatedStruct{ s := DeletedUpdatedStruct{
Name: "test", Name: "test",
} }
cnt, err := testEngine.Insert(&s) cnt, err := testEngine.Insert(&s)
@ -956,7 +956,7 @@ func TestUpdateMapCondition(t *testing.T) {
assertSync(t, new(UpdateMapCondition)) assertSync(t, new(UpdateMapCondition))
var c = UpdateMapCondition{ c := UpdateMapCondition{
String: "string", String: "string",
} }
_, err := testEngine.Insert(&c) _, err := testEngine.Insert(&c)
@ -990,7 +990,7 @@ func TestUpdateMapContent(t *testing.T) {
assertSync(t, new(UpdateMapContent)) assertSync(t, new(UpdateMapContent))
var c = UpdateMapContent{ c := UpdateMapContent{
Name: "lunny", Name: "lunny",
IsMan: true, IsMan: true,
Gender: 1, Gender: 1,
@ -1126,7 +1126,7 @@ func TestUpdateDeleted(t *testing.T) {
assertSync(t, new(UpdateDeletedStruct)) assertSync(t, new(UpdateDeletedStruct))
var s = UpdateDeletedStruct{ s := UpdateDeletedStruct{
Name: "test", Name: "test",
} }
cnt, err := testEngine.Insert(&s) cnt, err := testEngine.Insert(&s)
@ -1232,7 +1232,7 @@ func TestUpdateExprs2(t *testing.T) {
assertSync(t, new(UpdateExprsRelease)) assertSync(t, new(UpdateExprsRelease))
var uer = UpdateExprsRelease{ uer := UpdateExprsRelease{
RepoId: 1, RepoId: 1,
IsTag: false, IsTag: false,
IsDraft: false, IsDraft: false,
@ -1407,7 +1407,7 @@ func TestNilFromDB(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
assertSync(t, new(TestTable1)) assertSync(t, new(TestTable1))
var tt0 = TestTable1{ tt0 := TestTable1{
Field1: &TestFieldType1{ Field1: &TestFieldType1{
cb: []byte("string"), cb: []byte("string"),
}, },
@ -1437,7 +1437,7 @@ func TestNilFromDB(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var tt = TestTable1{ tt := TestTable1{
UpdateTime: time.Now(), UpdateTime: time.Now(),
Field1: &TestFieldType1{ Field1: &TestFieldType1{
cb: nil, cb: nil,
@ -1453,7 +1453,7 @@ func TestNilFromDB(t *testing.T) {
assert.True(t, has) assert.True(t, has)
assert.Nil(t, tt2.Field1) assert.Nil(t, tt2.Field1)
var tt3 = TestTable1{ tt3 := TestTable1{
UpdateTime: time.Now(), UpdateTime: time.Now(),
Field1: &TestFieldType1{ Field1: &TestFieldType1{
cb: []byte{}, cb: []byte{},
@ -1470,3 +1470,34 @@ func TestNilFromDB(t *testing.T) {
assert.NotNil(t, tt4.Field1) assert.NotNil(t, tt4.Field1)
assert.NotNil(t, tt4.Field1.cb) assert.NotNil(t, tt4.Field1.cb)
} }
/*
func TestUpdateWithJoin(t *testing.T) {
type TestUpdateWithJoin struct {
Id int64
ExtId int64
Name string
}
type TestUpdateWithJoin2 struct {
Id int64
Name string
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(TestUpdateWithJoin), new(TestUpdateWithJoin2))
b := TestUpdateWithJoin2{Name: "test"}
_, err := testEngine.Insert(&b)
assert.NoError(t, err)
_, err = testEngine.Insert(&TestUpdateWithJoin{ExtId: b.Id, Name: "test"})
assert.NoError(t, err)
_, err = testEngine.Table("test_update_with_join").
Join("INNER", "test_update_with_join2", "test_update_with_join.ext_id = test_update_with_join2.id").
Where("test_update_with_join2.name = ?", "test").
Update(&TestUpdateWithJoin{Name: "test2"})
assert.NoError(t, err)
}
*/

View File

@ -5,6 +5,7 @@
package statements package statements
import ( import (
"errors"
"fmt" "fmt"
"xorm.io/builder" "xorm.io/builder"
@ -16,6 +17,26 @@ type orderBy struct {
direction string // ASC, DESC or "", "" means raw orderStr direction string // ASC, DESC or "", "" means raw orderStr
} }
func (ob orderBy) CheckValid() error {
if ob.orderStr == nil {
return fmt.Errorf("order by string is nil")
}
switch t := ob.orderStr.(type) {
case string:
if t == "" {
return fmt.Errorf("order by string is empty")
}
return nil
case *builder.Expression:
if t.Content() == "" {
return fmt.Errorf("order by string is empty")
}
return nil
default:
return fmt.Errorf("order by string is not string or builder.Expression")
}
}
func (statement *Statement) HasOrderBy() bool { func (statement *Statement) HasOrderBy() bool {
return len(statement.orderBy) > 0 return len(statement.orderBy) > 0
} }
@ -25,6 +46,8 @@ func (statement *Statement) ResetOrderBy() {
statement.orderBy = []orderBy{} statement.orderBy = []orderBy{}
} }
var ErrNoColumnName = errors.New("no column name")
func (statement *Statement) writeOrderBy(w *builder.BytesWriter, orderBy orderBy) error { func (statement *Statement) writeOrderBy(w *builder.BytesWriter, orderBy orderBy) error {
switch t := orderBy.orderStr.(type) { switch t := orderBy.orderStr.(type) {
case (*builder.Expression): case (*builder.Expression):
@ -75,22 +98,45 @@ func (statement *Statement) writeOrderBys(w *builder.BytesWriter) error {
// OrderBy generate "Order By order" statement // OrderBy generate "Order By order" statement
func (statement *Statement) OrderBy(order interface{}, args ...interface{}) *Statement { func (statement *Statement) OrderBy(order interface{}, args ...interface{}) *Statement {
statement.orderBy = append(statement.orderBy, orderBy{order, args, ""}) ob := orderBy{order, args, ""}
if err := ob.CheckValid(); err != nil {
statement.LastError = err
return statement
}
statement.orderBy = append(statement.orderBy, ob)
return statement return statement
} }
// Desc generate `ORDER BY xx DESC` // Desc generate `ORDER BY xx DESC`
func (statement *Statement) Desc(colNames ...string) *Statement { func (statement *Statement) Desc(colNames ...string) *Statement {
if len(colNames) == 0 {
statement.LastError = ErrNoColumnName
return statement
}
for _, colName := range colNames { for _, colName := range colNames {
statement.orderBy = append(statement.orderBy, orderBy{colName, nil, "DESC"}) ob := orderBy{colName, nil, "DESC"}
statement.orderBy = append(statement.orderBy, ob)
if err := ob.CheckValid(); err != nil {
statement.LastError = err
return statement
}
} }
return statement return statement
} }
// Asc provide asc order by query condition, the input parameters are columns. // Asc provide asc order by query condition, the input parameters are columns.
func (statement *Statement) Asc(colNames ...string) *Statement { func (statement *Statement) Asc(colNames ...string) *Statement {
if len(colNames) == 0 {
statement.LastError = ErrNoColumnName
return statement
}
for _, colName := range colNames { for _, colName := range colNames {
statement.orderBy = append(statement.orderBy, orderBy{colName, nil, "ASC"}) ob := orderBy{colName, nil, "ASC"}
statement.orderBy = append(statement.orderBy, ob)
if err := ob.CheckValid(); err != nil {
statement.LastError = err
return statement
}
} }
return statement return statement
} }

View File

@ -35,7 +35,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
} }
buf := builder.NewWriter() buf := builder.NewWriter()
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil {
return "", nil, err return "", nil, err
} }
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil
@ -66,7 +66,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
} }
buf := builder.NewWriter() buf := builder.NewWriter()
if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true); err != nil { if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true, true); err != nil {
return "", nil, err return "", nil, err
} }
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil
@ -122,7 +122,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
} }
buf := builder.NewWriter() buf := builder.NewWriter()
if err := statement.writeSelect(buf, columnStr, true); err != nil { if err := statement.writeSelect(buf, columnStr, true, true); err != nil {
return "", nil, err return "", nil, err
} }
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil
@ -153,12 +153,6 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
selectSQL = "count(*)" selectSQL = "count(*)"
} }
} }
var subQuerySelect string
if statement.GroupByStr != "" {
subQuerySelect = statement.GroupByStr
} else {
subQuerySelect = selectSQL
}
buf := builder.NewWriter() buf := builder.NewWriter()
if statement.GroupByStr != "" { if statement.GroupByStr != "" {
@ -167,7 +161,14 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
} }
} }
if err := statement.writeSelect(buf, subQuerySelect, false); err != nil { var subQuerySelect string
if statement.GroupByStr != "" {
subQuerySelect = statement.GroupByStr
} else {
subQuerySelect = selectSQL
}
if err := statement.writeSelect(buf, subQuerySelect, false, false); err != nil {
return "", nil, err return "", nil, err
} }
@ -243,14 +244,19 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr
return err return err
} }
func (statement *Statement) writeWhere(w *builder.BytesWriter) error { func (statement *Statement) writeWhereCond(w *builder.BytesWriter, cond builder.Cond) error {
if !statement.cond.IsValid() { if !cond.IsValid() {
return nil return nil
} }
if _, err := fmt.Fprint(w, " WHERE "); err != nil { if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err return err
} }
return statement.cond.WriteTo(statement.QuoteReplacer(w)) return cond.WriteTo(statement.QuoteReplacer(w))
}
func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
return statement.writeWhereCond(w, statement.cond)
} }
func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error { func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error {
@ -359,7 +365,7 @@ func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr s
return err return err
} }
func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error { func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit, needOrderBy bool) error {
if err := statement.writeSelectColumns(buf, columnStr); err != nil { if err := statement.writeSelectColumns(buf, columnStr); err != nil {
return err return err
} }
@ -375,8 +381,10 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri
if err := statement.writeHaving(buf); err != nil { if err := statement.writeHaving(buf); err != nil {
return err return err
} }
if err := statement.writeOrderBys(buf); err != nil { if needOrderBy {
return err if err := statement.writeOrderBys(buf); err != nil {
return err
}
} }
dialect := statement.dialect dialect := statement.dialect
@ -514,7 +522,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
statement.cond = statement.cond.And(autoCond) statement.cond = statement.cond.And(autoCond)
buf := builder.NewWriter() buf := builder.NewWriter()
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil {
return "", nil, err return "", nil, err
} }
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil

View File

@ -9,7 +9,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strings"
"time" "time"
"xorm.io/builder" "xorm.io/builder"
@ -311,84 +310,328 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value,
return colNames, args, nil return colNames, args, nil
} }
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string) error { func (statement *Statement) writeUpdateTop(updateWriter *builder.BytesWriter) error {
whereWriter := builder.NewWriter() if statement.dialect.URI().DBType != schemas.MSSQL || statement.LimitN == nil {
if cond.IsValid() { return nil
fmt.Fprint(whereWriter, "WHERE ")
} }
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
table := statement.RefTable
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
return nil
}
_, err := fmt.Fprintf(updateWriter, " TOP (%d)", *statement.LimitN)
return err
}
func (statement *Statement) writeUpdateTableName(updateWriter *builder.BytesWriter) error {
tableName := statement.quote(statement.TableName())
if statement.TableAlias == "" {
_, err := fmt.Fprint(updateWriter, " ", tableName)
return err return err
} }
if err := statement.writeOrderBys(whereWriter); err != nil {
switch statement.dialect.URI().DBType {
case schemas.MSSQL:
_, err := fmt.Fprint(updateWriter, " ", statement.TableAlias)
return err return err
default:
_, err := fmt.Fprint(updateWriter, " ", tableName, " AS ", statement.TableAlias)
return err
}
}
func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) error {
if statement.dialect.URI().DBType != schemas.MSSQL || statement.TableAlias == "" {
return nil
}
_, err := fmt.Fprint(updateWriter, " FROM ", statement.quote(statement.TableName()), " ", statement.TableAlias)
return err
}
func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, cond builder.Cond) error {
if statement.LimitN == nil {
return nil
} }
table := statement.RefTable table := statement.RefTable
tableName := statement.TableName() tableName := statement.TableName()
// TODO: Oracle support needed
var top string
if statement.LimitN != nil {
limitValue := *statement.LimitN
switch statement.dialect.URI().DBType {
case schemas.MYSQL:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
case schemas.SQLITE:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", limitValue := *statement.LimitN
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)) switch statement.dialect.URI().DBType {
case schemas.MYSQL:
whereWriter = builder.NewWriter() _, err := fmt.Fprintf(updateWriter, " LIMIT %d", limitValue)
fmt.Fprint(whereWriter, "WHERE ") return err
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { case schemas.SQLITE:
if cond.IsValid() {
if _, err := fmt.Fprint(updateWriter, " AND "); err != nil {
return err return err
} }
case schemas.POSTGRES: } else {
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil {
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
return err return err
} }
case schemas.MSSQL: }
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { if _, err := fmt.Fprint(updateWriter, "rowid IN (SELECT rowid FROM ", statement.quote(tableName)); err != nil {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", return err
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], }
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...) if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
whereWriter = builder.NewWriter() }
fmt.Fprint(whereWriter, "WHERE ") if err := statement.writeOrderBys(updateWriter); err != nil {
if err := cond.WriteTo(whereWriter); err != nil { return err
return err }
} _, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue)
} else { return err
top = fmt.Sprintf("TOP (%d) ", limitValue) case schemas.POSTGRES:
if cond.IsValid() {
if _, err := fmt.Fprint(updateWriter, " AND "); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil {
return err
} }
} }
} if _, err := fmt.Fprint(updateWriter, "CTID IN (SELECT CTID FROM ", statement.quote(tableName)); err != nil {
return err
tableAlias := statement.quote(tableName)
var fromSQL string
if statement.TableAlias != "" {
switch statement.dialect.URI().DBType {
case schemas.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, statement.TableAlias)
tableAlias = statement.TableAlias
default:
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, statement.TableAlias)
} }
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
}
if err := statement.writeOrderBys(updateWriter); err != nil {
return err
}
_, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue)
return err
case schemas.MSSQL:
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
if _, err := fmt.Fprintf(updateWriter, " WHERE %s IN (SELECT TOP (%d) %s FROM %v",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
statement.quote(tableName)); err != nil {
return err
}
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
}
if err := statement.writeOrderBys(updateWriter); err != nil {
return err
}
_, err := fmt.Fprint(updateWriter, ")")
return err
}
return nil
default: // TODO: Oracle support needed
return fmt.Errorf("not implemented")
}
}
func (statement *Statement) GenConditionsFromMap(m interface{}) ([]builder.Cond, error) {
switch t := m.(type) {
case map[string]interface{}:
conds := []builder.Cond{}
for k, v := range t {
conds = append(conds, builder.Eq{k: v})
}
return conds, nil
case map[string]string:
conds := []builder.Cond{}
for k, v := range t {
conds = append(conds, builder.Eq{k: v})
}
return conds, nil
default:
return nil, fmt.Errorf("unsupported condition map type %v", t)
}
}
func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Value, hasPreviousSet bool) error {
if v.Type().Kind() != reflect.Struct {
return nil
} }
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v", table := statement.RefTable
top, if !(statement.RefTable != nil && table.Version != "" && statement.CheckVersion) {
tableAlias, return nil
strings.Join(colNames, ", "), }
fromSQL); err != nil {
verValue, err := table.VersionColumn().ValueOfV(&v)
if err != nil {
return err return err
} }
return utils.WriteBuilder(updateWriter, whereWriter)
if verValue == nil {
return nil
}
if hasPreviousSet {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, statement.quote(table.Version), " = ", statement.quote(table.Version), " + 1"); err != nil {
return err
}
return nil
}
func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error {
for i, expr := range statement.IncrColumns {
if i > 0 || hasPreviousSet {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " + ?"); err != nil {
return err
}
w.Append(expr.Arg)
}
return nil
}
func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error {
// for update action to like "column = column - ?"
for i, expr := range statement.DecrColumns {
if i > 0 || hasPreviousSet {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " - ?"); err != nil {
return err
}
w.Append(expr.Arg)
}
return nil
}
func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error {
// for update action to like "column = expression"
for i, expr := range statement.ExprColumns {
if i > 0 || hasPreviousSet {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
switch tp := expr.Arg.(type) {
case string:
if len(tp) == 0 {
tp = "''"
}
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", tp); err != nil {
return err
}
case *builder.Builder:
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ("); err != nil {
return err
}
if err := tp.WriteTo(statement.QuoteReplacer(w)); err != nil {
return err
}
if _, err := fmt.Fprint(w, ")"); err != nil {
return err
}
default:
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ?"); err != nil {
return err
}
w.Append(expr.Arg)
}
}
return nil
}
func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error {
previousLen := w.Len()
for i, colName := range colNames {
if i > 0 {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, colName); err != nil {
return err
}
}
w.Append(args...)
if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil {
return err
}
if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil {
return err
}
if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil {
return err
}
if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil {
return err
}
return nil
}
var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error {
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
return err
}
if err := statement.writeUpdateTop(updateWriter); err != nil {
return err
}
if err := statement.writeUpdateTableName(updateWriter); err != nil {
return err
}
// write set
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
return err
}
previousLen := updateWriter.Len()
if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
return err
}
// if no columns to be updated, return error
if previousLen == updateWriter.Len() {
return ErrNoColumnsTobeUpdated
}
// write from
if err := statement.writeUpdateFrom(updateWriter); err != nil {
return err
}
if statement.dialect.URI().DBType == schemas.MSSQL {
table := statement.RefTable
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
} else {
// write where
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
}
}
} else {
// write where
if err := statement.writeWhereCond(updateWriter, cond); err != nil {
return err
}
}
if statement.dialect.URI().DBType == schemas.MYSQL {
if err := statement.writeOrderBys(updateWriter); err != nil {
return err
}
}
return statement.writeUpdateLimit(updateWriter, cond)
} }

View File

@ -144,6 +144,8 @@ func (rows *Rows) Close() error {
defer rows.session.Close() defer rows.session.Close()
} }
defer rows.session.resetStatement()
if rows.rows != nil { if rows.rows != nil {
return rows.rows.Close() return rows.rows.Close()
} }

View File

@ -5,17 +5,17 @@
package xorm package xorm
import ( import (
"errors"
"reflect" "reflect"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/internal/statements"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
// enumerated all errors // enumerated all errors
var ( var (
ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") ErrNoColumnsTobeUpdated = statements.ErrNoColumnsTobeUpdated
) )
func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) { func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
@ -74,9 +74,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
v := utils.ReflectValue(bean) v := utils.ReflectValue(bean)
t := v.Type() t := v.Type()
var colNames []string
var args []interface{}
// handle before update processors // handle before update processors
for _, closure := range session.beforeClosures { for _, closure := range session.beforeClosures {
closure(bean) closure(bean)
@ -87,6 +84,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
// -- // --
var colNames []string
var args []interface{}
var err error var err error
isMap := t.Kind() == reflect.Map isMap := t.Kind() == reflect.Map
isStruct := t.Kind() == reflect.Struct isStruct := t.Kind() == reflect.Struct
@ -148,41 +147,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
// for update action to like "column = column + ?"
incColumns := session.statement.IncrColumns
for _, expr := range incColumns {
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?")
args = append(args, expr.Arg)
}
// for update action to like "column = column - ?"
decColumns := session.statement.DecrColumns
for _, expr := range decColumns {
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?")
args = append(args, expr.Arg)
}
// for update action to like "column = expression"
exprColumns := session.statement.ExprColumns
for _, expr := range exprColumns {
switch tp := expr.Arg.(type) {
case string:
if len(tp) == 0 {
tp = "''"
}
colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp)
case *builder.Builder:
subQuery, subArgs, err := builder.ToSQL(tp)
if err != nil {
return 0, err
}
subQuery = session.statement.ReplaceQuote(subQuery)
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")")
args = append(args, subArgs...)
default:
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?")
args = append(args, expr.Arg)
}
}
if err = session.statement.ProcessIDParam(); err != nil { if err = session.statement.ProcessIDParam(); err != nil {
return 0, err return 0, err
} }
@ -211,30 +175,25 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
verValue *reflect.Value verValue *reflect.Value
) )
if doIncVer { if doIncVer {
verValue, err = table.VersionColumn().ValueOf(bean) verValue, err = table.VersionColumn().ValueOfV(&v)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if verValue != nil { if verValue != nil {
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()}) cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
} }
} }
if len(colNames) == 0 {
return 0, ErrNoColumnsTobeUpdated
}
updateWriter := builder.NewWriter() updateWriter := builder.NewWriter()
if err := session.statement.WriteUpdate(updateWriter, cond, colNames); err != nil { if err := session.statement.WriteUpdate(updateWriter, cond, v, colNames, args); err != nil {
return 0, err return 0, err
} }
tableName := session.statement.TableName() // table name must been get before exec because statement will be reset tableName := session.statement.TableName() // table name must been get before exec because statement will be reset
useCache := session.statement.UseCache useCache := session.statement.UseCache
res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...) res, err := session.exec(updateWriter.String(), updateWriter.Args()...)
if err != nil { if err != nil {
return 0, err return 0, err
} else if doIncVer { } else if doIncVer {