diff --git a/.gitea/workflows/test-cockroach.yml b/.gitea/workflows/test-cockroach.yml index 0ca18861..ba966dc9 100644 --- a/.gitea/workflows/test-cockroach.yml +++ b/.gitea/workflows/test-cockroach.yml @@ -43,6 +43,7 @@ jobs: TEST_COCKROACH_DBNAME: xorm_test TEST_COCKROACH_USERNAME: root TEST_COCKROACH_PASSWORD: + IGNORE_TEST_DELETE_LIMIT: true run: sleep 20 && make test-cockroach services: diff --git a/convert/int.go b/convert/int.go index af8d4f75..03994773 100644 --- a/convert/int.go +++ b/convert/int.go @@ -35,6 +35,56 @@ func AsInt64(src interface{}) (int64, error) { return int64(v), nil case uint64: return int64(v), nil + case *int: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *int16: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *int32: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *int8: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *int64: + if v == nil { + return 0, nil + } + return *v, nil + case *uint: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *uint8: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *uint16: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *uint32: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *uint64: + if v == nil { + return 0, nil + } + return int64(*v), nil case []byte: return strconv.ParseInt(string(v), 10, 64) case string: @@ -110,9 +160,7 @@ func AsUint64(src interface{}) (uint64, error) { return 0, fmt.Errorf("unsupported value %T as uint64", src) } -var ( - _ sql.Scanner = &NullUint64{} -) +var _ sql.Scanner = &NullUint64{} // NullUint64 represents an uint64 that may be null. // NullUint64 implements the Scanner interface so @@ -142,9 +190,7 @@ func (n NullUint64) Value() (driver.Value, error) { return n.Uint64, nil } -var ( - _ sql.Scanner = &NullUint32{} -) +var _ sql.Scanner = &NullUint32{} // NullUint32 represents an uint32 that may be null. // NullUint32 implements the Scanner interface so diff --git a/convert/time.go b/convert/time.go index 544eca6f..fa9d832a 100644 --- a/convert/time.go +++ b/convert/time.go @@ -15,6 +15,7 @@ import ( ) // String2Time converts a string to time with original location +// be aware for time strings (HH:mm:ss) returns zero year (LMT) for converted location func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) { if len(s) == 19 { if s == utils.ZeroTime0 || s == utils.ZeroTime1 { @@ -62,7 +63,7 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t strings.HasPrefix(s, "0001-01-01T00:00:00."+strings.Repeat("0", len(s)-20)) { return &time.Time{}, nil } - var layout = "2006-01-02 15:04:05." + strings.Repeat("0", len(s)-20) + layout := "2006-01-02 15:04:05." + strings.Repeat("0", len(s)-20) dt, err := time.ParseInLocation(layout, s, originalLocation) if err != nil { return nil, err @@ -79,6 +80,18 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t } dt = dt.In(convertedLocation) return &dt, nil + } else if len(s) == 8 && s[2] == ':' && s[5] == ':' { + currentDate := time.Now() + dt, err := time.ParseInLocation("15:04:05", s, originalLocation) + if err != nil { + return nil, err + } + // add current date for correct time locations + dt = dt.AddDate(currentDate.Year(), int(currentDate.Month()), currentDate.Day()) + dt = dt.In(convertedLocation) + // back to zero year + dt = dt.AddDate(-currentDate.Year(), int(-currentDate.Month()), -currentDate.Day()) + return &dt, nil } else { i, err := strconv.ParseInt(s, 10, 64) if err == nil { diff --git a/convert/time_test.go b/convert/time_test.go index 4b1c2279..d7a9d5ad 100644 --- a/convert/time_test.go +++ b/convert/time_test.go @@ -15,7 +15,7 @@ func TestString2Time(t *testing.T) { expectedLoc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) - var kases = map[string]time.Time{ + cases := map[string]time.Time{ "2021-08-10": time.Date(2021, 8, 10, 8, 0, 0, 0, expectedLoc), "2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc), "2021-07-11 10:44:00.999": time.Date(2021, 7, 11, 18, 44, 0, 999000000, expectedLoc), @@ -25,12 +25,13 @@ func TestString2Time(t *testing.T) { "2021-06-06T22:58:20.999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999000000, expectedLoc), "2021-06-06T22:58:20.999999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999999000, expectedLoc), "2021-06-06T22:58:20.999999999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999999999, expectedLoc), - "2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 04, 0, expectedLoc), - "2021-08-10T10:33:04.999Z": time.Date(2021, 8, 10, 18, 33, 04, 999000000, expectedLoc), - "2021-08-10T10:33:04.999999Z": time.Date(2021, 8, 10, 18, 33, 04, 999999000, expectedLoc), - "2021-08-10T10:33:04.999999999Z": time.Date(2021, 8, 10, 18, 33, 04, 999999999, expectedLoc), + "2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 0o4, 0, expectedLoc), + "2021-08-10T10:33:04.999Z": time.Date(2021, 8, 10, 18, 33, 0o4, 999000000, expectedLoc), + "2021-08-10T10:33:04.999999Z": time.Date(2021, 8, 10, 18, 33, 0o4, 999999000, expectedLoc), + "2021-08-10T10:33:04.999999999Z": time.Date(2021, 8, 10, 18, 33, 0o4, 999999999, expectedLoc), + "10:22:33": time.Date(0, 1, 1, 18, 22, 33, 0, expectedLoc), } - for layout, tm := range kases { + for layout, tm := range cases { t.Run(layout, func(t *testing.T) { target, err := String2Time(layout, time.UTC, expectedLoc) assert.NoError(t, err) diff --git a/dialects/time_test.go b/dialects/time_test.go new file mode 100644 index 00000000..670207c6 --- /dev/null +++ b/dialects/time_test.go @@ -0,0 +1,190 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "xorm.io/xorm/schemas" +) + +type dialect struct { + Dialect + dbType schemas.DBType +} + +func (d dialect) URI() *URI { + return &URI{ + DBType: d.dbType, + } +} + +func TestFormatColumnTime(t *testing.T) { + date := time.Date(2020, 10, 23, 10, 14, 15, 123456, time.Local) + tests := []struct { + name string + dialect Dialect + location *time.Location + column *schemas.Column + time time.Time + wantRes interface{} + wantErr error + }{ + { + name: "nullable", + dialect: nil, + location: nil, + column: &schemas.Column{Nullable: true}, + time: time.Time{}, + wantRes: nil, + wantErr: nil, + }, + { + name: "invalid sqltype", + dialect: nil, + location: nil, + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Bit}}, + time: time.Time{}, + wantRes: 0, + wantErr: nil, + }, + { + name: "return default", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Bit}}, + time: date, + wantRes: date, + wantErr: nil, + }, + { + name: "return default (set timezone)", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Bit}, TimeZone: time.UTC}, + time: date, + wantRes: date.In(time.UTC), + wantErr: nil, + }, + { + name: "format date", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Date}}, + time: date, + wantRes: date.Format("2006-01-02"), + wantErr: nil, + }, + { + name: "format time", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Time}}, + time: date, + wantRes: date.Format("15:04:05"), + wantErr: nil, + }, + { + name: "format time (set length)", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Time}, Length: 64}, + time: date, + wantRes: date.Format("15:04:05.999999999"), + wantErr: nil, + }, + { + name: "format datetime", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.DateTime}}, + time: date, + wantRes: date.Format("2006-01-02 15:04:05"), + wantErr: nil, + }, + { + name: "format datetime (set length)", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.DateTime}, Length: 64}, + time: date, + wantRes: date.Format("2006-01-02 15:04:05.999999999"), + wantErr: nil, + }, + { + name: "format timestamp", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.TimeStamp}}, + time: date, + wantRes: date.Format("2006-01-02 15:04:05"), + wantErr: nil, + }, + { + name: "format timestamp (set length)", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.TimeStamp}, Length: 64}, + time: date, + wantRes: date.Format("2006-01-02 15:04:05.999999999"), + wantErr: nil, + }, + { + name: "format varchar", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Varchar}}, + time: date, + wantRes: date.Format("2006-01-02 15:04:05"), + wantErr: nil, + }, + { + name: "format timestampz", + dialect: dialect{}, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.TimeStampz}}, + time: date, + wantRes: date.Format(time.RFC3339Nano), + wantErr: nil, + }, + { + name: "format timestampz (mssql)", + dialect: dialect{dbType: schemas.MSSQL}, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.TimeStampz}}, + time: date, + wantRes: date.Format("2006-01-02T15:04:05.9999999Z07:00"), + wantErr: nil, + }, + { + name: "format int", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Int}}, + time: date, + wantRes: date.Unix(), + wantErr: nil, + }, + { + name: "format bigint", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.BigInt}}, + time: date, + wantRes: date.Unix(), + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := FormatColumnTime(tt.dialect, tt.location, tt.column, tt.time) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.wantRes, got) + }) + } +} diff --git a/go.mod b/go.mod index 5f2f18e0..b48b3bbc 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module xorm.io/xorm -go 1.13 +go 1.16 require ( gitee.com/travelliu/dm v1.8.11192 diff --git a/go.sum b/go.sum index 9f197fae..5c780ec6 100644 --- a/go.sum +++ b/go.sum @@ -24,7 +24,6 @@ github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDror github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= @@ -39,7 +38,6 @@ github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZ github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= -github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= @@ -231,7 +229,6 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/integrations/session_count_test.go b/integrations/session_count_test.go index 079602c3..c6e64e76 100644 --- a/integrations/session_count_test.go +++ b/integrations/session_count_test.go @@ -125,6 +125,11 @@ func TestWithTableName(t *testing.T) { total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{}) assert.NoError(t, err) assert.EqualValues(t, 2, total) + + // the orderby will be ignored by count because some databases will return errors if the orderby columns not in group by + total, err = testEngine.OrderBy("`name`").Count(CountWithTableName{}) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) } func TestCountWithSelectCols(t *testing.T) { diff --git a/integrations/session_delete_test.go b/integrations/session_delete_test.go index 680c3215..1ed3e706 100644 --- a/integrations/session_delete_test.go +++ b/integrations/session_delete_test.go @@ -5,6 +5,7 @@ package integrations import ( + "os" "testing" "time" @@ -70,6 +71,63 @@ func TestDelete(t *testing.T) { assert.False(t, has) } +func TestDeleteLimit(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + if testEngine.Dialect().URI().DBType == schemas.MSSQL || os.Getenv("IGNORE_TEST_DELETE_LIMIT") == "true" { + t.Skip() + return + } + + type UserinfoDeleteLimit struct { + Uid int64 `xorm:"id pk not null autoincr"` + IsMan bool + } + + assert.NoError(t, testEngine.Sync2(new(UserinfoDeleteLimit))) + + session := testEngine.NewSession() + defer session.Close() + + var err error + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT userinfo_delete_limit ON") + assert.NoError(t, err) + } + + user := UserinfoDeleteLimit{Uid: 1, IsMan: true} + cnt, err := session.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + user2 := UserinfoDeleteLimit{Uid: 2} + cnt, err = session.Insert(&user2) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Commit() + assert.NoError(t, err) + } + + cnt, err = testEngine.Limit(1, 1).Delete(&UserinfoDeleteLimit{}) + assert.Error(t, err) + assert.EqualValues(t, 0, cnt) + + cnt, err = testEngine.Limit(1).Desc("id").Delete(&UserinfoDeleteLimit{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var users []UserinfoDeleteLimit + err = testEngine.Find(&users) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) + assert.EqualValues(t, 1, users[0].Uid) + assert.EqualValues(t, true, users[0].IsMan) +} + func TestDeleted(t *testing.T) { assert.NoError(t, PrepareEngine()) diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 841ec709..d3403814 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -1012,4 +1012,12 @@ func TestGetBytesVars(t *testing.T) { assert.True(t, has) assert.EqualValues(t, []byte("bytes1-1"), gbv.Bytes1) assert.EqualValues(t, []byte("bytes2-2"), gbv.Bytes2) + + type MyID int64 + var myID MyID + + has, err = testEngine.Table("get_bytes_vars").Select("id").Desc("id").Get(&myID) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, gbv.Id, myID) } diff --git a/integrations/session_query_test.go b/integrations/session_query_test.go index 00b7d7a6..ff62f25d 100644 --- a/integrations/session_query_test.go +++ b/integrations/session_query_test.go @@ -30,7 +30,7 @@ func TestQueryString(t *testing.T) { assert.NoError(t, testEngine.Sync(new(GetVar2))) - var data = GetVar2{ + data := GetVar2{ Msg: "hi", Age: 28, Money: 1.5, @@ -58,7 +58,7 @@ func TestQueryString2(t *testing.T) { assert.NoError(t, testEngine.Sync(new(GetVar3))) - var data = GetVar3{ + data := GetVar3{ Msg: false, } _, err := testEngine.Insert(data) @@ -95,7 +95,7 @@ func TestQueryInterface(t *testing.T) { assert.NoError(t, testEngine.Sync(new(GetVarInterface))) - var data = GetVarInterface{ + data := GetVarInterface{ Msg: "hi", Age: 28, Money: 1.5, @@ -128,7 +128,7 @@ func TestQueryNoParams(t *testing.T) { assert.NoError(t, testEngine.Sync(new(QueryNoParams))) - var q = QueryNoParams{ + q := QueryNoParams{ Msg: "message", Age: 20, Money: 3000, @@ -172,7 +172,7 @@ func TestQueryStringNoParam(t *testing.T) { assert.NoError(t, testEngine.Sync(new(GetVar4))) - var data = GetVar4{ + data := GetVar4{ Msg: false, } _, err := testEngine.Insert(data) @@ -209,7 +209,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { assert.NoError(t, testEngine.Sync(new(GetVar6))) - var data = GetVar6{ + data := GetVar6{ Msg: false, } _, err := testEngine.Insert(data) @@ -246,7 +246,7 @@ func TestQueryInterfaceNoParam(t *testing.T) { assert.NoError(t, testEngine.Sync(new(GetVar5))) - var data = GetVar5{ + data := GetVar5{ Msg: false, } _, err := testEngine.Insert(data) @@ -280,7 +280,7 @@ func TestQueryWithBuilder(t *testing.T) { assert.NoError(t, testEngine.Sync(new(QueryWithBuilder))) - var q = QueryWithBuilder{ + q := QueryWithBuilder{ Msg: "message", Age: 20, Money: 3000, @@ -329,14 +329,14 @@ func TestJoinWithSubQuery(t *testing.T) { assert.NoError(t, testEngine.Sync(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart))) - var depart = JoinWithSubQueryDepart{ + depart := JoinWithSubQueryDepart{ Name: "depart1", } cnt, err := testEngine.Insert(&depart) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var q = JoinWithSubQuery1{ + q := JoinWithSubQuery1{ Msg: "message", DepartId: depart.Id, Money: 3000, @@ -401,7 +401,7 @@ func TestQueryBLOBInMySQL(t *testing.T) { } const N = 10 - var data = []Avatar{} + data := []Avatar{} for i := 0; i < N; i++ { // allocate a []byte that is as twice big as the last one // so that the underlying buffer will need to reallocate when querying @@ -448,3 +448,54 @@ func TestQueryBLOBInMySQL(t *testing.T) { } } } + +func TestRowsReset(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type RowsReset1 struct { + Id int64 + Name string + } + + type RowsReset2 struct { + Id int64 + Name string + } + + assert.NoError(t, testEngine.Sync(new(RowsReset1), new(RowsReset2))) + + data := []RowsReset1{ + {0, "1"}, + {0, "2"}, + {0, "3"}, + } + _, err := testEngine.Insert(data) + assert.NoError(t, err) + + data2 := []RowsReset2{ + {0, "4"}, + {0, "5"}, + {0, "6"}, + } + _, err = testEngine.Insert(data2) + assert.NoError(t, err) + + sess := testEngine.NewSession() + defer sess.Close() + + rows, err := sess.Rows(new(RowsReset1)) + assert.NoError(t, err) + for rows.Next() { + var data1 RowsReset1 + assert.NoError(t, rows.Scan(&data1)) + } + rows.Close() + + var rrs []RowsReset2 + assert.NoError(t, sess.Find(&rrs)) + + assert.Len(t, rrs, 3) + assert.EqualValues(t, "4", rrs[0].Name) + assert.EqualValues(t, "5", rrs[1].Name) + assert.EqualValues(t, "6", rrs[2].Name) +} diff --git a/integrations/session_update_test.go b/integrations/session_update_test.go index 45338cad..2a8f8187 100644 --- a/integrations/session_update_test.go +++ b/integrations/session_update_test.go @@ -28,7 +28,7 @@ func TestUpdateMap(t *testing.T) { } assert.NoError(t, testEngine.Sync(new(UpdateTable))) - var tb = UpdateTable{ + tb := UpdateTable{ Name: "test", Age: 35, } @@ -79,7 +79,7 @@ func TestUpdateLimit(t *testing.T) { } assert.NoError(t, testEngine.Sync(new(UpdateTable2))) - var tb = UpdateTable2{ + tb := UpdateTable2{ Name: "test1", Age: 35, } @@ -400,7 +400,7 @@ func TestUpdate1(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var s = "test" + s := "test" col1 := &UpdateAllCols{Ptr: &s} err = testEngine.Sync(col1) @@ -864,7 +864,7 @@ func TestCreatedUpdated2(t *testing.T) { assertSync(t, new(CreatedUpdatedStruct)) - var s = CreatedUpdatedStruct{ + s := CreatedUpdatedStruct{ Name: "test", } cnt, err := testEngine.Insert(&s) @@ -874,7 +874,7 @@ func TestCreatedUpdated2(t *testing.T) { time.Sleep(time.Second) - var s1 = CreatedUpdatedStruct{ + s1 := CreatedUpdatedStruct{ Name: "test1", CreateAt: s.CreateAt, UpdateAt: s.UpdateAt, @@ -907,7 +907,7 @@ func TestDeletedUpdate(t *testing.T) { assertSync(t, new(DeletedUpdatedStruct)) - var s = DeletedUpdatedStruct{ + s := DeletedUpdatedStruct{ Name: "test", } cnt, err := testEngine.Insert(&s) @@ -956,7 +956,7 @@ func TestUpdateMapCondition(t *testing.T) { assertSync(t, new(UpdateMapCondition)) - var c = UpdateMapCondition{ + c := UpdateMapCondition{ String: "string", } _, err := testEngine.Insert(&c) @@ -990,7 +990,7 @@ func TestUpdateMapContent(t *testing.T) { assertSync(t, new(UpdateMapContent)) - var c = UpdateMapContent{ + c := UpdateMapContent{ Name: "lunny", IsMan: true, Gender: 1, @@ -1126,7 +1126,7 @@ func TestUpdateDeleted(t *testing.T) { assertSync(t, new(UpdateDeletedStruct)) - var s = UpdateDeletedStruct{ + s := UpdateDeletedStruct{ Name: "test", } cnt, err := testEngine.Insert(&s) @@ -1232,7 +1232,7 @@ func TestUpdateExprs2(t *testing.T) { assertSync(t, new(UpdateExprsRelease)) - var uer = UpdateExprsRelease{ + uer := UpdateExprsRelease{ RepoId: 1, IsTag: false, IsDraft: false, @@ -1407,7 +1407,7 @@ func TestNilFromDB(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(TestTable1)) - var tt0 = TestTable1{ + tt0 := TestTable1{ Field1: &TestFieldType1{ cb: []byte("string"), }, @@ -1437,7 +1437,7 @@ func TestNilFromDB(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var tt = TestTable1{ + tt := TestTable1{ UpdateTime: time.Now(), Field1: &TestFieldType1{ cb: nil, @@ -1453,7 +1453,7 @@ func TestNilFromDB(t *testing.T) { assert.True(t, has) assert.Nil(t, tt2.Field1) - var tt3 = TestTable1{ + tt3 := TestTable1{ UpdateTime: time.Now(), Field1: &TestFieldType1{ cb: []byte{}, @@ -1470,3 +1470,34 @@ func TestNilFromDB(t *testing.T) { assert.NotNil(t, tt4.Field1) assert.NotNil(t, tt4.Field1.cb) } + +/* +func TestUpdateWithJoin(t *testing.T) { + type TestUpdateWithJoin struct { + Id int64 + ExtId int64 + Name string + } + + type TestUpdateWithJoin2 struct { + Id int64 + Name string + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestUpdateWithJoin), new(TestUpdateWithJoin2)) + + b := TestUpdateWithJoin2{Name: "test"} + _, err := testEngine.Insert(&b) + assert.NoError(t, err) + + _, err = testEngine.Insert(&TestUpdateWithJoin{ExtId: b.Id, Name: "test"}) + assert.NoError(t, err) + + _, err = testEngine.Table("test_update_with_join"). + Join("INNER", "test_update_with_join2", "test_update_with_join.ext_id = test_update_with_join2.id"). + Where("test_update_with_join2.name = ?", "test"). + Update(&TestUpdateWithJoin{Name: "test2"}) + assert.NoError(t, err) +} +*/ diff --git a/internal/statements/delete.go b/internal/statements/delete.go new file mode 100644 index 00000000..6e859399 --- /dev/null +++ b/internal/statements/delete.go @@ -0,0 +1,125 @@ +// Copyright 2023 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "errors" + "fmt" + "time" + + "xorm.io/builder" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +func (statement *Statement) writeDeleteOrder(w *builder.BytesWriter) error { + if err := statement.writeOrderBys(w); err != nil { + return err + } + + if statement.LimitN != nil && *statement.LimitN > 0 { + if statement.Start > 0 { + return fmt.Errorf("Delete with Limit start is unsupported") + } + limitNValue := *statement.LimitN + if _, err := fmt.Fprintf(w, " LIMIT %d", limitNValue); err != nil { + return err + } + } + + return nil +} + +// ErrNotImplemented not implemented +var ErrNotImplemented = errors.New("Not implemented") + +func (statement *Statement) writeOrderCond(orderCondWriter *builder.BytesWriter, tableName string) error { + orderSQLWriter := builder.NewWriter() + if err := statement.writeDeleteOrder(orderSQLWriter); err != nil { + return err + } + + if orderSQLWriter.Len() == 0 { + return nil + } + + switch statement.dialect.URI().DBType { + case schemas.POSTGRES: + if statement.cond.IsValid() { + if _, err := fmt.Fprint(orderCondWriter, " AND "); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(orderCondWriter, " WHERE "); err != nil { + return err + } + } + if _, err := fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String()); err != nil { + return err + } + orderCondWriter.Append(orderSQLWriter.Args()...) + return nil + case schemas.SQLITE: + if statement.cond.IsValid() { + if _, err := fmt.Fprint(orderCondWriter, " AND "); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(orderCondWriter, " WHERE "); err != nil { + return err + } + } + if _, err := fmt.Fprintf(orderCondWriter, "rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQLWriter.String()); err != nil { + return err + } + orderCondWriter.Append(orderSQLWriter.Args()...) + return nil + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: + return ErrNotImplemented + default: + return utils.WriteBuilder(orderCondWriter, orderSQLWriter) + } +} + +func (statement *Statement) WriteDelete(realSQLWriter, deleteSQLWriter *builder.BytesWriter, nowTime func(*schemas.Column) (interface{}, time.Time, error)) error { + tableNameNoQuote := statement.TableName() + tableName := statement.dialect.Quoter().Quote(tableNameNoQuote) + table := statement.RefTable + if _, err := fmt.Fprint(deleteSQLWriter, "DELETE FROM ", tableName); err != nil { + return err + } + if err := statement.writeWhere(deleteSQLWriter); err != nil { + return err + } + + orderCondWriter := builder.NewWriter() + if err := statement.writeOrderCond(orderCondWriter, tableName); err != nil { + return err + } + + if statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled + return utils.WriteBuilder(realSQLWriter, deleteSQLWriter, orderCondWriter) + } + + deletedColumn := table.DeletedColumn() + if _, err := fmt.Fprintf(realSQLWriter, "UPDATE %v SET %v = ?", + statement.dialect.Quoter().Quote(statement.TableName()), + statement.dialect.Quoter().Quote(deletedColumn.Name)); err != nil { + return err + } + + val, _, err := nowTime(deletedColumn) + if err != nil { + return err + } + realSQLWriter.Append(val) + + if err := statement.writeWhere(realSQLWriter); err != nil { + return err + } + + return utils.WriteBuilder(realSQLWriter, orderCondWriter) +} diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 187b94a3..9370c984 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -293,3 +293,99 @@ func (statement *Statement) GenInsertMultipleMapSQL(columns []string, argss [][] return buf.String(), buf.Args(), nil } + +func (statement *Statement) writeColumns(w *builder.BytesWriter, slice []string) error { + for i, s := range slice { + if i > 0 { + if _, err := fmt.Fprint(w, ","); err != nil { + return err + } + } + if err := statement.dialect.Quoter().QuoteTo(w.Builder, s); err != nil { + return err + } + } + return nil +} + +func (statement *Statement) writeQuestions(w *builder.BytesWriter, length int) error { + for i := 0; i < length; i++ { + if i > 0 { + if _, err := fmt.Fprint(w, ","); err != nil { + return err + } + } + if _, err := fmt.Fprint(w, "?"); err != nil { + return err + } + } + return nil +} + +func (statement *Statement) oracleWriteInsertMultiple(w *builder.BytesWriter, tableName string, colNames []string, colMultiPlaces []string) error { + if _, err := fmt.Fprint(w, "INSERT ALL"); err != nil { + return err + } + + for _, cols := range colMultiPlaces { + if _, err := fmt.Fprint(w, " INTO "); err != nil { + return err + } + if err := statement.dialect.Quoter().QuoteTo(w.Builder, tableName); err != nil { + return err + } + if _, err := fmt.Fprint(w, " ("); err != nil { + return err + } + if err := statement.writeColumns(w, colNames); err != nil { + return err + } + if _, err := fmt.Fprint(w, ") VALUES ("); err != nil { + return err + } + if _, err := fmt.Fprintf(w, cols, ")"); err != nil { + return err + } + } + + if _, err := fmt.Fprint(w, " SELECT 1 FROM DUAL"); err != nil { + return err + } + return nil +} + +func (statement *Statement) WriteInsertMultiple(w *builder.BytesWriter, tableName string, colNames []string, colMultiPlaces []string) error { + if statement.dialect.URI().DBType == schemas.ORACLE { + return statement.oracleWriteInsertMultiple(w, tableName, colNames, colMultiPlaces) + } + return statement.plainWriteInsertMultiple(w, tableName, colNames, colMultiPlaces) +} + +func (statement *Statement) plainWriteInsertMultiple(w *builder.BytesWriter, tableName string, colNames []string, colMultiPlaces []string) error { + if _, err := fmt.Fprint(w, "INSERT INTO "); err != nil { + return err + } + if err := statement.dialect.Quoter().QuoteTo(w.Builder, tableName); err != nil { + return err + } + if _, err := fmt.Fprint(w, " ("); err != nil { + return err + } + if err := statement.writeColumns(w, colNames); err != nil { + return err + } + if _, err := fmt.Fprint(w, ") VALUES ("); err != nil { + return err + } + for i, cols := range colMultiPlaces { + if _, err := fmt.Fprint(w, cols, ")"); err != nil { + return err + } + if i < len(colMultiPlaces)-1 { + if _, err := fmt.Fprint(w, ",("); err != nil { + return err + } + } + } + return nil +} diff --git a/internal/statements/join.go b/internal/statements/join.go index 61a1b4de..740b31de 100644 --- a/internal/statements/join.go +++ b/internal/statements/join.go @@ -36,7 +36,7 @@ func (statement *Statement) writeJoins(w *builder.BytesWriter) error { func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error { // write join operator - if _, err := fmt.Fprintf(buf, " %v JOIN", join.op); err != nil { + if _, err := fmt.Fprint(buf, " ", join.op, " JOIN"); err != nil { return err } diff --git a/internal/statements/order_by.go b/internal/statements/order_by.go index 08a8263b..54a3c6e0 100644 --- a/internal/statements/order_by.go +++ b/internal/statements/order_by.go @@ -5,86 +5,138 @@ package statements import ( + "errors" "fmt" - "strings" "xorm.io/builder" ) +type orderBy struct { + orderStr interface{} + orderArgs []interface{} + direction string // ASC, DESC or "", "" means raw orderStr +} + +func (ob orderBy) CheckValid() error { + if ob.orderStr == nil { + return fmt.Errorf("order by string is nil") + } + switch t := ob.orderStr.(type) { + case string: + if t == "" { + return fmt.Errorf("order by string is empty") + } + return nil + case *builder.Expression: + if t.Content() == "" { + return fmt.Errorf("order by string is empty") + } + return nil + default: + return fmt.Errorf("order by string is not string or builder.Expression") + } +} + func (statement *Statement) HasOrderBy() bool { - return statement.orderStr != "" + return len(statement.orderBy) > 0 } // ResetOrderBy reset ordery conditions func (statement *Statement) ResetOrderBy() { - statement.orderStr = "" - statement.orderArgs = nil + statement.orderBy = []orderBy{} +} + +var ErrNoColumnName = errors.New("no column name") + +func (statement *Statement) writeOrderBy(w *builder.BytesWriter, orderBy orderBy) error { + switch t := orderBy.orderStr.(type) { + case (*builder.Expression): + if _, err := fmt.Fprint(w.Builder, statement.dialect.Quoter().Replace(t.Content())); err != nil { + return err + } + w.Append(t.Args()...) + return nil + case string: + if orderBy.direction == "" { + if _, err := fmt.Fprint(w.Builder, statement.dialect.Quoter().Replace(t)); err != nil { + return err + } + w.Append(orderBy.orderArgs...) + return nil + } + if err := statement.dialect.Quoter().QuoteTo(w.Builder, t); err != nil { + return err + } + _, err := fmt.Fprint(w, " ", orderBy.direction) + return err + default: + return ErrUnSupportedSQLType + } } // WriteOrderBy write order by to writer -func (statement *Statement) WriteOrderBy(w builder.Writer) error { - if len(statement.orderStr) > 0 { - if _, err := fmt.Fprintf(w, " ORDER BY %s", statement.orderStr); err != nil { +func (statement *Statement) writeOrderBys(w *builder.BytesWriter) error { + if len(statement.orderBy) == 0 { + return nil + } + + if _, err := fmt.Fprint(w, " ORDER BY "); err != nil { + return err + } + for i, ob := range statement.orderBy { + if err := statement.writeOrderBy(w, ob); err != nil { return err } - w.Append(statement.orderArgs...) + if i < len(statement.orderBy)-1 { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } } return nil } // OrderBy generate "Order By order" statement func (statement *Statement) OrderBy(order interface{}, args ...interface{}) *Statement { - if len(statement.orderStr) > 0 { - statement.orderStr += ", " - } - var rawOrder string - switch t := order.(type) { - case (*builder.Expression): - rawOrder = t.Content() - args = t.Args() - case string: - rawOrder = t - default: - statement.LastError = ErrUnSupportedSQLType + ob := orderBy{order, args, ""} + if err := ob.CheckValid(); err != nil { + statement.LastError = err return statement } - statement.orderStr += statement.ReplaceQuote(rawOrder) - if len(args) > 0 { - statement.orderArgs = append(statement.orderArgs, args...) - } + statement.orderBy = append(statement.orderBy, ob) return statement } // Desc generate `ORDER BY xx DESC` func (statement *Statement) Desc(colNames ...string) *Statement { - var buf strings.Builder - if len(statement.orderStr) > 0 { - fmt.Fprint(&buf, statement.orderStr, ", ") + if len(colNames) == 0 { + statement.LastError = ErrNoColumnName + return statement } - for i, col := range colNames { - if i > 0 { - fmt.Fprint(&buf, ", ") + for _, colName := range colNames { + ob := orderBy{colName, nil, "DESC"} + statement.orderBy = append(statement.orderBy, ob) + if err := ob.CheckValid(); err != nil { + statement.LastError = err + return statement } - _ = statement.dialect.Quoter().QuoteTo(&buf, col) - fmt.Fprint(&buf, " DESC") } - statement.orderStr = buf.String() return statement } // Asc provide asc order by query condition, the input parameters are columns. func (statement *Statement) Asc(colNames ...string) *Statement { - var buf strings.Builder - if len(statement.orderStr) > 0 { - fmt.Fprint(&buf, statement.orderStr, ", ") + if len(colNames) == 0 { + statement.LastError = ErrNoColumnName + return statement } - for i, col := range colNames { - if i > 0 { - fmt.Fprint(&buf, ", ") + for _, colName := range colNames { + ob := orderBy{colName, nil, "ASC"} + statement.orderBy = append(statement.orderBy, ob) + if err := ob.CheckValid(); err != nil { + statement.LastError = err + return statement } - _ = statement.dialect.Quoter().QuoteTo(&buf, col) - fmt.Fprint(&buf, " ASC") } - statement.orderStr = buf.String() return statement } diff --git a/internal/statements/query.go b/internal/statements/query.go index 2e38f0fe..216a2028 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -35,7 +35,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int } buf := builder.NewWriter() - if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -66,7 +66,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri } buf := builder.NewWriter() - if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true); err != nil { + if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true, true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -122,7 +122,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, } buf := builder.NewWriter() - if err := statement.writeSelect(buf, columnStr, true); err != nil { + if err := statement.writeSelect(buf, columnStr, true, true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -153,12 +153,6 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa selectSQL = "count(*)" } } - var subQuerySelect string - if statement.GroupByStr != "" { - subQuerySelect = statement.GroupByStr - } else { - subQuerySelect = selectSQL - } buf := builder.NewWriter() if statement.GroupByStr != "" { @@ -167,7 +161,14 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa } } - if err := statement.writeSelect(buf, subQuerySelect, false); err != nil { + var subQuerySelect string + if statement.GroupByStr != "" { + subQuerySelect = statement.GroupByStr + } else { + subQuerySelect = selectSQL + } + + if err := statement.writeSelect(buf, subQuerySelect, false, false); err != nil { return "", nil, err } @@ -230,7 +231,7 @@ func (statement *Statement) writeDistinct(w builder.Writer) error { } func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr string) error { - if _, err := fmt.Fprintf(w, "SELECT "); err != nil { + if _, err := fmt.Fprintf(w, "SELECT"); err != nil { return err } if err := statement.writeDistinct(w); err != nil { @@ -243,14 +244,19 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr return err } -func (statement *Statement) writeWhere(w *builder.BytesWriter) error { - if !statement.cond.IsValid() { +func (statement *Statement) writeWhereCond(w *builder.BytesWriter, cond builder.Cond) error { + if !cond.IsValid() { return nil } + if _, err := fmt.Fprint(w, " WHERE "); err != nil { return err } - return statement.cond.WriteTo(statement.QuoteReplacer(w)) + return cond.WriteTo(statement.QuoteReplacer(w)) +} + +func (statement *Statement) writeWhere(w *builder.BytesWriter) error { + return statement.writeWhereCond(w, statement.cond) } func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error { @@ -320,7 +326,7 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err if err := statement.writeWhere(subWriter); err != nil { return err } - if err := statement.WriteOrderBy(subWriter); err != nil { + if err := statement.writeOrderBys(subWriter); err != nil { return err } if err := statement.writeGroupBy(subWriter); err != nil { @@ -359,7 +365,7 @@ func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr s return err } -func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error { +func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit, needOrderBy bool) error { if err := statement.writeSelectColumns(buf, columnStr); err != nil { return err } @@ -375,8 +381,10 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri if err := statement.writeHaving(buf); err != nil { return err } - if err := statement.WriteOrderBy(buf); err != nil { - return err + if needOrderBy { + if err := statement.writeOrderBys(buf); err != nil { + return err + } } dialect := statement.dialect @@ -514,7 +522,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa statement.cond = statement.cond.And(autoCond) buf := builder.NewWriter() - if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 61488ff7..7ad735f5 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -50,8 +50,7 @@ type Statement struct { Start int LimitN *int idParam schemas.PK - orderStr string - orderArgs []interface{} + orderBy []orderBy joins []join GroupByStr string HavingStr string @@ -163,15 +162,15 @@ func (statement *Statement) Reset() { // SQL adds raw sql statement func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { - switch query.(type) { + switch t := query.(type) { case (*builder.Builder): var err error - statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL() + statement.RawSQL, statement.RawParams, err = t.ToSQL() if err != nil { statement.LastError = err } case string: - statement.RawSQL = query.(string) + statement.RawSQL = t statement.RawParams = args default: statement.LastError = ErrUnSupportedSQLType diff --git a/internal/statements/update.go b/internal/statements/update.go index 16ab5676..5d71f34d 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "reflect" - "strings" "time" "xorm.io/builder" @@ -311,84 +310,328 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, return colNames, args, nil } -func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string) error { - whereWriter := builder.NewWriter() - if cond.IsValid() { - fmt.Fprint(whereWriter, "WHERE ") +func (statement *Statement) writeUpdateTop(updateWriter *builder.BytesWriter) error { + if statement.dialect.URI().DBType != schemas.MSSQL || statement.LimitN == nil { + return nil } - if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { + + table := statement.RefTable + if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { + return nil + } + + _, err := fmt.Fprintf(updateWriter, " TOP (%d)", *statement.LimitN) + return err +} + +func (statement *Statement) writeUpdateTableName(updateWriter *builder.BytesWriter) error { + tableName := statement.quote(statement.TableName()) + if statement.TableAlias == "" { + _, err := fmt.Fprint(updateWriter, " ", tableName) return err } - if err := statement.WriteOrderBy(whereWriter); err != nil { + + switch statement.dialect.URI().DBType { + case schemas.MSSQL: + _, err := fmt.Fprint(updateWriter, " ", statement.TableAlias) return err + default: + _, err := fmt.Fprint(updateWriter, " ", tableName, " AS ", statement.TableAlias) + return err + } +} + +func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) error { + if statement.dialect.URI().DBType != schemas.MSSQL || statement.TableAlias == "" { + return nil + } + + _, err := fmt.Fprint(updateWriter, " FROM ", statement.quote(statement.TableName()), " ", statement.TableAlias) + return err +} + +func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, cond builder.Cond) error { + if statement.LimitN == nil { + return nil } table := statement.RefTable tableName := statement.TableName() - // TODO: Oracle support needed - var top string - if statement.LimitN != nil { - limitValue := *statement.LimitN - switch statement.dialect.URI().DBType { - case schemas.MYSQL: - fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) - case schemas.SQLITE: - fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) - cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", - statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)) - - whereWriter = builder.NewWriter() - fmt.Fprint(whereWriter, "WHERE ") - if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { + limitValue := *statement.LimitN + switch statement.dialect.URI().DBType { + case schemas.MYSQL: + _, err := fmt.Fprintf(updateWriter, " LIMIT %d", limitValue) + return err + case schemas.SQLITE: + if cond.IsValid() { + if _, err := fmt.Fprint(updateWriter, " AND "); err != nil { return err } - case schemas.POSTGRES: - fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) - - cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", - statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)) - - whereWriter = builder.NewWriter() - fmt.Fprint(whereWriter, "WHERE ") - if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { + } else { + if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil { return err } - case schemas.MSSQL: - if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { - cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", - table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], - statement.quote(tableName), whereWriter.String()), whereWriter.Args()...) - - whereWriter = builder.NewWriter() - fmt.Fprint(whereWriter, "WHERE ") - if err := cond.WriteTo(whereWriter); err != nil { - return err - } - } else { - top = fmt.Sprintf("TOP (%d) ", limitValue) + } + if _, err := fmt.Fprint(updateWriter, "rowid IN (SELECT rowid FROM ", statement.quote(tableName)); err != nil { + return err + } + if err := statement.writeWhereCond(updateWriter, cond); err != nil { + return err + } + if err := statement.writeOrderBys(updateWriter); err != nil { + return err + } + _, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue) + return err + case schemas.POSTGRES: + if cond.IsValid() { + if _, err := fmt.Fprint(updateWriter, " AND "); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil { + return err } } - } - - tableAlias := statement.quote(tableName) - var fromSQL string - if statement.TableAlias != "" { - switch statement.dialect.URI().DBType { - case schemas.MSSQL: - fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, statement.TableAlias) - tableAlias = statement.TableAlias - default: - tableAlias = fmt.Sprintf("%s AS %s", tableAlias, statement.TableAlias) + if _, err := fmt.Fprint(updateWriter, "CTID IN (SELECT CTID FROM ", statement.quote(tableName)); err != nil { + return err } + if err := statement.writeWhereCond(updateWriter, cond); err != nil { + return err + } + if err := statement.writeOrderBys(updateWriter); err != nil { + return err + } + _, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue) + return err + case schemas.MSSQL: + if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { + if _, err := fmt.Fprintf(updateWriter, " WHERE %s IN (SELECT TOP (%d) %s FROM %v", + table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], + statement.quote(tableName)); err != nil { + return err + } + if err := statement.writeWhereCond(updateWriter, cond); err != nil { + return err + } + if err := statement.writeOrderBys(updateWriter); err != nil { + return err + } + _, err := fmt.Fprint(updateWriter, ")") + return err + } + return nil + default: // TODO: Oracle support needed + return fmt.Errorf("not implemented") + } +} + +func (statement *Statement) GenConditionsFromMap(m interface{}) ([]builder.Cond, error) { + switch t := m.(type) { + case map[string]interface{}: + conds := []builder.Cond{} + for k, v := range t { + conds = append(conds, builder.Eq{k: v}) + } + return conds, nil + case map[string]string: + conds := []builder.Cond{} + for k, v := range t { + conds = append(conds, builder.Eq{k: v}) + } + return conds, nil + default: + return nil, fmt.Errorf("unsupported condition map type %v", t) + } +} + +func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Value, hasPreviousSet bool) error { + if v.Type().Kind() != reflect.Struct { + return nil } - if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v", - top, - tableAlias, - strings.Join(colNames, ", "), - fromSQL); err != nil { + table := statement.RefTable + if !(statement.RefTable != nil && table.Version != "" && statement.CheckVersion) { + return nil + } + + verValue, err := table.VersionColumn().ValueOfV(&v) + if err != nil { return err } - return utils.WriteBuilder(updateWriter, whereWriter) + + if verValue == nil { + return nil + } + + if hasPreviousSet { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + + if _, err := fmt.Fprint(w, statement.quote(table.Version), " = ", statement.quote(table.Version), " + 1"); err != nil { + return err + } + return nil +} + +func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error { + for i, expr := range statement.IncrColumns { + if i > 0 || hasPreviousSet { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " + ?"); err != nil { + return err + } + w.Append(expr.Arg) + } + return nil +} + +func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error { + // for update action to like "column = column - ?" + for i, expr := range statement.DecrColumns { + if i > 0 || hasPreviousSet { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " - ?"); err != nil { + return err + } + w.Append(expr.Arg) + } + return nil +} + +func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error { + // for update action to like "column = expression" + for i, expr := range statement.ExprColumns { + if i > 0 || hasPreviousSet { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + switch tp := expr.Arg.(type) { + case string: + if len(tp) == 0 { + tp = "''" + } + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", tp); err != nil { + return err + } + case *builder.Builder: + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ("); err != nil { + return err + } + if err := tp.WriteTo(statement.QuoteReplacer(w)); err != nil { + return err + } + if _, err := fmt.Fprint(w, ")"); err != nil { + return err + } + default: + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ?"); err != nil { + return err + } + w.Append(expr.Arg) + } + } + return nil +} + +func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error { + previousLen := w.Len() + for i, colName := range colNames { + if i > 0 { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + if _, err := fmt.Fprint(w, colName); err != nil { + return err + } + } + w.Append(args...) + + if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil { + return err + } + + if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil { + return err + } + + if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil { + return err + } + + if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil { + return err + } + return nil +} + +var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") + +func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error { + if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil { + return err + } + + if err := statement.writeUpdateTop(updateWriter); err != nil { + return err + } + + if err := statement.writeUpdateTableName(updateWriter); err != nil { + return err + } + + // write set + if _, err := fmt.Fprint(updateWriter, " SET "); err != nil { + return err + } + previousLen := updateWriter.Len() + + if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil { + return err + } + + // if no columns to be updated, return error + if previousLen == updateWriter.Len() { + return ErrNoColumnsTobeUpdated + } + + // write from + if err := statement.writeUpdateFrom(updateWriter); err != nil { + return err + } + + if statement.dialect.URI().DBType == schemas.MSSQL { + table := statement.RefTable + if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { + } else { + // write where + if err := statement.writeWhereCond(updateWriter, cond); err != nil { + return err + } + } + } else { + // write where + if err := statement.writeWhereCond(updateWriter, cond); err != nil { + return err + } + } + + if statement.dialect.URI().DBType == schemas.MYSQL { + if err := statement.writeOrderBys(updateWriter); err != nil { + return err + } + } + + return statement.writeUpdateLimit(updateWriter, cond) } diff --git a/rows.go b/rows.go index a42eedb9..c539410e 100644 --- a/rows.go +++ b/rows.go @@ -144,6 +144,8 @@ func (rows *Rows) Close() error { defer rows.session.Close() } + defer rows.session.resetStatement() + if rows.rows != nil { return rows.rows.Close() } diff --git a/session_delete.go b/session_delete.go index 6c63d9b0..7336040f 100644 --- a/session_delete.go +++ b/session_delete.go @@ -6,22 +6,15 @@ package xorm import ( "errors" - "fmt" "strconv" "xorm.io/builder" "xorm.io/xorm/caches" - "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) -var ( - // ErrNeedDeletedCond delete needs less one condition error - ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") - - // ErrNotImplemented not implemented - ErrNotImplemented = errors.New("Not implemented") -) +// ErrNeedDeletedCond delete needs less one condition error +var ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error { if table == nil || @@ -112,9 +105,8 @@ func (session *Session) delete(beans []interface{}, mustHaveConditions bool) (in } var ( - condWriter = builder.NewWriter() - err error - bean interface{} + err error + bean interface{} ) if len(beans) > 0 { bean = beans[0] @@ -133,90 +125,27 @@ func (session *Session) delete(beans []interface{}, mustHaveConditions bool) (in } } - if err = session.statement.Conds().WriteTo(session.statement.QuoteReplacer(condWriter)); err != nil { - return 0, err - } - pLimitN := session.statement.LimitN - if mustHaveConditions && condWriter.Len() == 0 && (pLimitN == nil || *pLimitN == 0) { + if mustHaveConditions && !session.statement.Conds().IsValid() && (pLimitN == nil || *pLimitN == 0) { return 0, ErrNeedDeletedCond } tableNameNoQuote := session.statement.TableName() - tableName := session.engine.Quote(tableNameNoQuote) table := session.statement.RefTable - deleteSQLWriter := builder.NewWriter() - fmt.Fprintf(deleteSQLWriter, "DELETE FROM %v", tableName) - if condWriter.Len() > 0 { - fmt.Fprintf(deleteSQLWriter, " WHERE %v", condWriter.String()) - deleteSQLWriter.Append(condWriter.Args()...) - } - orderSQLWriter := builder.NewWriter() - if err := session.statement.WriteOrderBy(orderSQLWriter); err != nil { + realSQLWriter := builder.NewWriter() + deleteSQLWriter := builder.NewWriter() + if err := session.statement.WriteDelete(realSQLWriter, deleteSQLWriter, session.engine.nowTime); err != nil { return 0, err } - if pLimitN != nil && *pLimitN > 0 { - limitNValue := *pLimitN - if _, err := fmt.Fprintf(orderSQLWriter, " LIMIT %d", limitNValue); err != nil { - return 0, err - } - } - - orderCondWriter := builder.NewWriter() - if orderSQLWriter.Len() > 0 { - switch session.engine.dialect.URI().DBType { - case schemas.POSTGRES: - if condWriter.Len() > 0 { - fmt.Fprintf(orderCondWriter, " AND ") - } else { - fmt.Fprintf(orderCondWriter, " WHERE ") - } - fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String()) - orderCondWriter.Append(orderSQLWriter.Args()...) - case schemas.SQLITE: - if condWriter.Len() > 0 { - fmt.Fprintf(orderCondWriter, " AND ") - } else { - fmt.Fprintf(orderCondWriter, " WHERE ") - } - fmt.Fprintf(orderCondWriter, "rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQLWriter.String()) - // TODO: how to handle delete limit on mssql? - case schemas.MSSQL: - return 0, ErrNotImplemented - default: - fmt.Fprint(orderCondWriter, orderSQLWriter.String()) - orderCondWriter.Append(orderSQLWriter.Args()...) - } - } - - realSQLWriter := builder.NewWriter() - argsForCache := make([]interface{}, 0, len(deleteSQLWriter.Args())*2) - copy(argsForCache, deleteSQLWriter.Args()) - argsForCache = append(deleteSQLWriter.Args(), argsForCache...) if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled - if err := utils.WriteBuilder(realSQLWriter, deleteSQLWriter, orderCondWriter); err != nil { - return 0, err - } } else { deletedColumn := table.DeletedColumn() - if _, err := fmt.Fprintf(realSQLWriter, "UPDATE %v SET %v = ? WHERE %v", - session.engine.Quote(session.statement.TableName()), - session.engine.Quote(deletedColumn.Name), - condWriter.String()); err != nil { - return 0, err - } - val, t, err := session.engine.nowTime(deletedColumn) + _, t, err := session.engine.nowTime(deletedColumn) if err != nil { return 0, err } - realSQLWriter.Append(val) - realSQLWriter.Append(condWriter.Args()...) - - if err := utils.WriteBuilder(realSQLWriter, orderCondWriter); err != nil { - return 0, err - } colName := deletedColumn.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { @@ -225,6 +154,10 @@ func (session *Session) delete(beans []interface{}, mustHaveConditions bool) (in }) } + argsForCache := make([]interface{}, 0, len(deleteSQLWriter.Args())*2) + copy(argsForCache, deleteSQLWriter.Args()) + argsForCache = append(deleteSQLWriter.Args(), argsForCache...) + if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { _ = session.cacheDelete(table, tableNameNoQuote, deleteSQLWriter.String(), argsForCache...) } diff --git a/session_insert.go b/session_insert.go index 4067c3a8..7cc15241 100644 --- a/session_insert.go +++ b/session_insert.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "xorm.io/builder" "xorm.io/xorm/convert" "xorm.io/xorm/dialects" "xorm.io/xorm/internal/utils" @@ -156,14 +157,14 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, e } args = append(args, val) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t) }) } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnInt(bean, col, 1) @@ -186,24 +187,12 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, e } cleanupProcessorsClosures(&session.beforeClosures) - quoter := session.engine.dialect.Quoter() - var sql string - colStr := quoter.Join(colNames, ",") - if session.engine.dialect.URI().DBType == schemas.ORACLE { - temp := fmt.Sprintf(") INTO %s (%v) VALUES (", - quoter.Quote(tableName), - colStr) - sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL", - quoter.Quote(tableName), - colStr, - strings.Join(colMultiPlaces, temp)) - } else { - sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)", - quoter.Quote(tableName), - colStr, - strings.Join(colMultiPlaces, "),(")) + w := builder.NewWriter() + if err := session.statement.WriteInsertMultiple(w, tableName, colNames, colMultiPlaces); err != nil { + return 0, err } - res, err := session.exec(sql, args...) + + res, err := session.exec(w.String(), args...) if err != nil { return 0, err } @@ -276,7 +265,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { processor.BeforeInsert() } - var tableName = session.statement.TableName() + tableName := session.statement.TableName() table := session.statement.RefTable colNames, args, err := session.genInsertColumns(bean) @@ -518,7 +507,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } args = append(args, val) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t) @@ -548,7 +537,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err return 0, ErrTableNotFound } - var columns = make([]string, 0, len(m)) + columns := make([]string, 0, len(m)) exprs := session.statement.ExprColumns for k := range m { if !exprs.IsColExist(k) { @@ -557,7 +546,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err } sort.Strings(columns) - var args = make([]interface{}, 0, len(m)) + args := make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } @@ -575,7 +564,7 @@ func (session *Session) insertMultipleMapInterface(maps []map[string]interface{} return 0, ErrTableNotFound } - var columns = make([]string, 0, len(maps[0])) + columns := make([]string, 0, len(maps[0])) exprs := session.statement.ExprColumns for k := range maps[0] { if !exprs.IsColExist(k) { @@ -584,9 +573,9 @@ func (session *Session) insertMultipleMapInterface(maps []map[string]interface{} } sort.Strings(columns) - var argss = make([][]interface{}, 0, len(maps)) + argss := make([][]interface{}, 0, len(maps)) for _, m := range maps { - var args = make([]interface{}, 0, len(m)) + args := make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } @@ -606,7 +595,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { return 0, ErrTableNotFound } - var columns = make([]string, 0, len(m)) + columns := make([]string, 0, len(m)) exprs := session.statement.ExprColumns for k := range m { if !exprs.IsColExist(k) { @@ -616,7 +605,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { sort.Strings(columns) - var args = make([]interface{}, 0, len(m)) + args := make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } @@ -634,7 +623,7 @@ func (session *Session) insertMultipleMapString(maps []map[string]string) (int64 return 0, ErrTableNotFound } - var columns = make([]string, 0, len(maps[0])) + columns := make([]string, 0, len(maps[0])) exprs := session.statement.ExprColumns for k := range maps[0] { if !exprs.IsColExist(k) { @@ -643,9 +632,9 @@ func (session *Session) insertMultipleMapString(maps []map[string]string) (int64 } sort.Strings(columns) - var argss = make([][]interface{}, 0, len(maps)) + argss := make([][]interface{}, 0, len(maps)) for _, m := range maps { - var args = make([]interface{}, 0, len(m)) + args := make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } diff --git a/session_update.go b/session_update.go index 9a6964f1..b3640ad2 100644 --- a/session_update.go +++ b/session_update.go @@ -5,17 +5,17 @@ package xorm import ( - "errors" "reflect" "xorm.io/builder" + "xorm.io/xorm/internal/statements" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) // enumerated all errors var ( - ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") + ErrNoColumnsTobeUpdated = statements.ErrNoColumnsTobeUpdated ) func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) { @@ -74,9 +74,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 v := utils.ReflectValue(bean) t := v.Type() - var colNames []string - var args []interface{} - // handle before update processors for _, closure := range session.beforeClosures { closure(bean) @@ -87,6 +84,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } // -- + var colNames []string + var args []interface{} var err error isMap := t.Kind() == reflect.Map isStruct := t.Kind() == reflect.Struct @@ -148,41 +147,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - // for update action to like "column = column + ?" - incColumns := session.statement.IncrColumns - for _, expr := range incColumns { - colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?") - args = append(args, expr.Arg) - } - // for update action to like "column = column - ?" - decColumns := session.statement.DecrColumns - for _, expr := range decColumns { - colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?") - args = append(args, expr.Arg) - } - // for update action to like "column = expression" - exprColumns := session.statement.ExprColumns - for _, expr := range exprColumns { - switch tp := expr.Arg.(type) { - case string: - if len(tp) == 0 { - tp = "''" - } - colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp) - case *builder.Builder: - subQuery, subArgs, err := builder.ToSQL(tp) - if err != nil { - return 0, err - } - subQuery = session.statement.ReplaceQuote(subQuery) - colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")") - args = append(args, subArgs...) - default: - colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?") - args = append(args, expr.Arg) - } - } - if err = session.statement.ProcessIDParam(); err != nil { return 0, err } @@ -211,30 +175,25 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 verValue *reflect.Value ) if doIncVer { - verValue, err = table.VersionColumn().ValueOf(bean) + verValue, err = table.VersionColumn().ValueOfV(&v) if err != nil { return 0, err } if verValue != nil { cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()}) - colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1") } } - if len(colNames) == 0 { - return 0, ErrNoColumnsTobeUpdated - } - updateWriter := builder.NewWriter() - if err := session.statement.WriteUpdate(updateWriter, cond, colNames); err != nil { + if err := session.statement.WriteUpdate(updateWriter, cond, v, colNames, args); err != nil { return 0, err } tableName := session.statement.TableName() // table name must been get before exec because statement will be reset useCache := session.statement.UseCache - res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...) + res, err := session.exec(updateWriter.String(), updateWriter.Args()...) if err != nil { return 0, err } else if doIncVer {