diff --git a/.gitea/workflows/release-tag.yml b/.gitea/workflows/release-tag.yml index 10ed831e..3b788e76 100644 --- a/.gitea/workflows/release-tag.yml +++ b/.gitea/workflows/release-tag.yml @@ -13,11 +13,11 @@ jobs: with: fetch-depth: 0 - name: setup go - uses: https://github.com/actions/setup-go@v4 + uses: actions/setup-go@v4 with: go-version: '>=1.20.1' - name: Use Go Action id: use-go-action - uses: actions/release-action@main + uses: https://gitea.com/actions/release-action@main with: api_key: '${{secrets.RELEASE_TOKEN}}' \ No newline at end of file diff --git a/.gitea/workflows/test-cockroach.yml b/.gitea/workflows/test-cockroach.yml index 0ca18861..cfcda89d 100644 --- a/.gitea/workflows/test-cockroach.yml +++ b/.gitea/workflows/test-cockroach.yml @@ -36,13 +36,14 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test cockroach env: TEST_COCKROACH_HOST: "cockroach:26257" TEST_COCKROACH_DBNAME: xorm_test TEST_COCKROACH_USERNAME: root TEST_COCKROACH_PASSWORD: + IGNORE_TEST_DELETE_LIMIT: true run: sleep 20 && make test-cockroach services: diff --git a/.gitea/workflows/test-mariadb.yml b/.gitea/workflows/test-mariadb.yml index 466f3858..dbc819db 100644 --- a/.gitea/workflows/test-mariadb.yml +++ b/.gitea/workflows/test-mariadb.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test mariadb env: TEST_MYSQL_HOST: mariadb diff --git a/.gitea/workflows/test-mssql.yml b/.gitea/workflows/test-mssql.yml index d02e6956..04b8031a 100644 --- a/.gitea/workflows/test-mssql.yml +++ b/.gitea/workflows/test-mssql.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test mssql env: TEST_MSSQL_HOST: mssql diff --git a/.gitea/workflows/test-mysql.yml b/.gitea/workflows/test-mysql.yml index 03ee2725..e13354f0 100644 --- a/.gitea/workflows/test-mysql.yml +++ b/.gitea/workflows/test-mysql.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test mysql utf8mb4 env: TEST_MYSQL_HOST: mysql diff --git a/.gitea/workflows/test-mysql8.yml b/.gitea/workflows/test-mysql8.yml index 3fbd7c30..7362065a 100644 --- a/.gitea/workflows/test-mysql8.yml +++ b/.gitea/workflows/test-mysql8.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test mysql8 env: TEST_MYSQL_HOST: mysql8 diff --git a/.gitea/workflows/test-postgres.yml b/.gitea/workflows/test-postgres.yml index 89aa72c3..d4abb2ad 100644 --- a/.gitea/workflows/test-postgres.yml +++ b/.gitea/workflows/test-postgres.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test postgres env: TEST_PGSQL_HOST: pgsql diff --git a/.gitea/workflows/test-sqlite.yml b/.gitea/workflows/test-sqlite.yml index cca2e786..164acc10 100644 --- a/.gitea/workflows/test-sqlite.yml +++ b/.gitea/workflows/test-sqlite.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: vet run: make vet - name: format check diff --git a/.gitea/workflows/test-tidb.yml b/.gitea/workflows/test-tidb.yml index fa6e27ad..ce898dcb 100644 --- a/.gitea/workflows/test-tidb.yml +++ b/.gitea/workflows/test-tidb.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test tidb env: TEST_TIDB_HOST: "tidb:4000" diff --git a/.gitignore b/.gitignore index 4cd4252b..6dc08c05 100644 --- a/.gitignore +++ b/.gitignore @@ -35,7 +35,7 @@ test.db.sql *coverage.out test.db -integrations/*.sql -integrations/test_sqlite* +tests/*.sql +tests/test_sqlite* cover.out cover.html diff --git a/Makefile b/Makefile index ade90e6d..55183557 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ SED_INPLACE := sed -i GO_DIRS := caches contexts integrations core dialects internal log migrate names schemas tags GOFILES := $(wildcard *.go) GOFILES += $(shell find $(GO_DIRS) -name "*.go" -type f) -INTEGRATION_PACKAGES := xorm.io/xorm/integrations +INTEGRATION_PACKAGES := xorm.io/xorm/tests PACKAGES ?= $(filter-out $(INTEGRATION_PACKAGES),$(shell $(GO) list ./...)) TEST_COCKROACH_HOST ?= cockroach:26257 diff --git a/convert/conversion.go b/convert/conversion.go index b69e345c..5577e863 100644 --- a/convert/conversion.go +++ b/convert/conversion.go @@ -16,11 +16,21 @@ import ( "time" ) +// ConversionFrom is an inteface to allow retrieve data from database +type ConversionFrom interface { + FromDB([]byte) error +} + +// ConversionTo is an interface to allow store data to database +type ConversionTo interface { + ToDB() ([]byte, error) +} + // Conversion is an interface. A type implements Conversion will according // the custom method to fill into database and retrieve from database. type Conversion interface { - FromDB([]byte) error - ToDB() ([]byte, error) + ConversionFrom + ConversionTo } // ErrNilPtr represents an error diff --git a/convert/int.go b/convert/int.go index af8d4f75..03994773 100644 --- a/convert/int.go +++ b/convert/int.go @@ -35,6 +35,56 @@ func AsInt64(src interface{}) (int64, error) { return int64(v), nil case uint64: return int64(v), nil + case *int: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *int16: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *int32: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *int8: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *int64: + if v == nil { + return 0, nil + } + return *v, nil + case *uint: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *uint8: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *uint16: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *uint32: + if v == nil { + return 0, nil + } + return int64(*v), nil + case *uint64: + if v == nil { + return 0, nil + } + return int64(*v), nil case []byte: return strconv.ParseInt(string(v), 10, 64) case string: @@ -110,9 +160,7 @@ func AsUint64(src interface{}) (uint64, error) { return 0, fmt.Errorf("unsupported value %T as uint64", src) } -var ( - _ sql.Scanner = &NullUint64{} -) +var _ sql.Scanner = &NullUint64{} // NullUint64 represents an uint64 that may be null. // NullUint64 implements the Scanner interface so @@ -142,9 +190,7 @@ func (n NullUint64) Value() (driver.Value, error) { return n.Uint64, nil } -var ( - _ sql.Scanner = &NullUint32{} -) +var _ sql.Scanner = &NullUint32{} // NullUint32 represents an uint32 that may be null. // NullUint32 implements the Scanner interface so diff --git a/convert/time.go b/convert/time.go index dc36912b..d90dc428 100644 --- a/convert/time.go +++ b/convert/time.go @@ -15,6 +15,7 @@ import ( ) // String2Time converts a string to time with original location +// be aware for time strings (HH:mm:ss) returns zero year (LMT) for converted location func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) { if len(s) == 19 { if s == utils.ZeroTime0 || s == utils.ZeroTime1 { @@ -27,6 +28,9 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t dt = dt.In(convertedLocation) return &dt, nil } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { + if strings.HasPrefix(s, "0000-00-00T00:00:00") || strings.HasPrefix(s, "0001-01-01T00:00:00") { + return &time.Time{}, nil + } dt, err := time.ParseInLocation("2006-01-02T15:04:05", s[:19], originalLocation) if err != nil { return nil, err @@ -34,6 +38,9 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t dt = dt.In(convertedLocation) return &dt, nil } else if len(s) == 25 && s[10] == 'T' && s[19] == '+' && s[22] == ':' { + if strings.HasPrefix(s, "0000-00-00T00:00:00") || strings.HasPrefix(s, "0001-01-01T00:00:00") { + return &time.Time{}, nil + } dt, err := time.Parse(time.RFC3339, s) if err != nil { return nil, err @@ -41,6 +48,10 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t dt = dt.In(convertedLocation) return &dt, nil } else if len(s) >= 21 && s[10] == 'T' && s[19] == '.' { + if strings.HasPrefix(s, "0000-00-00T00:00:00."+strings.Repeat("0", len(s)-20)) || + strings.HasPrefix(s, "0001-01-01T00:00:00."+strings.Repeat("0", len(s)-20)) { + return &time.Time{}, nil + } dt, err := time.Parse(time.RFC3339Nano, s) if err != nil { return nil, err @@ -48,7 +59,11 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t dt = dt.In(convertedLocation) return &dt, nil } else if len(s) >= 21 && s[19] == '.' { - var layout = "2006-01-02 15:04:05." + strings.Repeat("0", len(s)-20) + if strings.HasPrefix(s, "0000-00-00T00:00:00."+strings.Repeat("0", len(s)-20)) || + strings.HasPrefix(s, "0001-01-01T00:00:00."+strings.Repeat("0", len(s)-20)) { + return &time.Time{}, nil + } + layout := "2006-01-02 15:04:05." + strings.Repeat("0", len(s)-20) dt, err := time.ParseInLocation(layout, s, originalLocation) if err != nil { return nil, err @@ -65,9 +80,24 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t } dt = dt.In(convertedLocation) return &dt, nil + } else if len(s) == 8 && s[2] == ':' && s[5] == ':' { + dt, err := time.ParseInLocation("15:04:05", s, originalLocation) + if err != nil { + return nil, err + } + currentDate := time.Now() + // add current date for correct time locations + dt = dt.AddDate(currentDate.Year(), int(currentDate.Month()), currentDate.Day()) + dt = dt.In(convertedLocation) + // back to zero year + dt = dt.AddDate(-currentDate.Year(), int(-currentDate.Month()), -currentDate.Day()) + return &dt, nil } else { i, err := strconv.ParseInt(s, 10, 64) if err == nil { + if i == 0 { + return &time.Time{}, nil + } tm := time.Unix(i, 0).In(convertedLocation) return &tm, nil } @@ -94,6 +124,9 @@ func AsTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time. if !t.Valid { return nil, nil } + if utils.IsTimeZero(t.Time) { + return &time.Time{}, nil + } z, _ := t.Time.Zone() if len(z) == 0 || t.Time.Year() == 0 || t.Time.Location().String() != dbLoc.String() { tm := time.Date(t.Time.Year(), t.Time.Month(), t.Time.Day(), t.Time.Hour(), @@ -103,6 +136,9 @@ func AsTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time. tm := t.Time.In(uiLoc) return &tm, nil case *time.Time: + if utils.IsTimeZero(*t) { + return &time.Time{}, nil + } z, _ := t.Zone() if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), @@ -112,6 +148,9 @@ func AsTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time. tm := t.In(uiLoc) return &tm, nil case time.Time: + if utils.IsTimeZero(t) { + return &time.Time{}, nil + } z, _ := t.Zone() if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), @@ -121,12 +160,21 @@ func AsTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time. tm := t.In(uiLoc) return &tm, nil case int: + if t == 0 { + return &time.Time{}, nil + } tm := time.Unix(int64(t), 0).In(uiLoc) return &tm, nil case int64: + if t == 0 { + return &time.Time{}, nil + } tm := time.Unix(t, 0).In(uiLoc) return &tm, nil case *sql.NullInt64: + if t.Int64 == 0 { + return &time.Time{}, nil + } tm := time.Unix(t.Int64, 0).In(uiLoc) return &tm, nil } diff --git a/convert/time_test.go b/convert/time_test.go index 4b1c2279..d7a9d5ad 100644 --- a/convert/time_test.go +++ b/convert/time_test.go @@ -15,7 +15,7 @@ func TestString2Time(t *testing.T) { expectedLoc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) - var kases = map[string]time.Time{ + cases := map[string]time.Time{ "2021-08-10": time.Date(2021, 8, 10, 8, 0, 0, 0, expectedLoc), "2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc), "2021-07-11 10:44:00.999": time.Date(2021, 7, 11, 18, 44, 0, 999000000, expectedLoc), @@ -25,12 +25,13 @@ func TestString2Time(t *testing.T) { "2021-06-06T22:58:20.999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999000000, expectedLoc), "2021-06-06T22:58:20.999999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999999000, expectedLoc), "2021-06-06T22:58:20.999999999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999999999, expectedLoc), - "2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 04, 0, expectedLoc), - "2021-08-10T10:33:04.999Z": time.Date(2021, 8, 10, 18, 33, 04, 999000000, expectedLoc), - "2021-08-10T10:33:04.999999Z": time.Date(2021, 8, 10, 18, 33, 04, 999999000, expectedLoc), - "2021-08-10T10:33:04.999999999Z": time.Date(2021, 8, 10, 18, 33, 04, 999999999, expectedLoc), + "2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 0o4, 0, expectedLoc), + "2021-08-10T10:33:04.999Z": time.Date(2021, 8, 10, 18, 33, 0o4, 999000000, expectedLoc), + "2021-08-10T10:33:04.999999Z": time.Date(2021, 8, 10, 18, 33, 0o4, 999999000, expectedLoc), + "2021-08-10T10:33:04.999999999Z": time.Date(2021, 8, 10, 18, 33, 0o4, 999999999, expectedLoc), + "10:22:33": time.Date(0, 1, 1, 18, 22, 33, 0, expectedLoc), } - for layout, tm := range kases { + for layout, tm := range cases { t.Run(layout, func(t *testing.T) { target, err := String2Time(layout, time.UTC, expectedLoc) assert.NoError(t, err) diff --git a/dialects/mssql.go b/dialects/mssql.go index dcac9c3f..2c64e637 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -320,11 +320,7 @@ func (db *mssql) SQLType(c *schemas.Column) string { res += "(MAX)" } case schemas.TimeStamp, schemas.DateTime: - if c.Length > 3 { - res = "DATETIME2" - } else { - return schemas.DateTime - } + return "DATETIME2" case schemas.TimeStampz: res = "DATETIMEOFFSET" c.Length = 7 diff --git a/dialects/postgres.go b/dialects/postgres.go index f1f6a2f2..53f66184 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -821,6 +821,7 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas } // Postgres: 9.5.22 on x86_64-pc-linux-gnu (Debian 9.5.22-1.pgdg90+1), compiled by gcc (Debian 6.3.0-18+deb9u1) 6.3.0 20170516, 64-bit + // Postgres: PostgreSQL 15.3, compiled by Visual C++ build 1914, 64-bit // CockroachDB CCL v19.2.4 (x86_64-unknown-linux-gnu, built if strings.HasPrefix(version, "CockroachDB") { versions := strings.Split(strings.TrimPrefix(version, "CockroachDB CCL "), " ") @@ -829,12 +830,22 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas Edition: "CockroachDB", }, nil } else if strings.HasPrefix(version, "PostgreSQL") { - versions := strings.Split(strings.TrimPrefix(version, "PostgreSQL "), " on ") - return &schemas.Version{ - Number: versions[0], - Level: versions[1], - Edition: "PostgreSQL", - }, nil + if strings.Contains(version, " on ") { + versions := strings.Split(strings.TrimPrefix(version, "PostgreSQL "), " on ") + return &schemas.Version{ + Number: versions[0], + Level: versions[1], + Edition: "PostgreSQL", + }, nil + } else { + versions := strings.Split(strings.TrimPrefix(version, "PostgreSQL "), ",") + return &schemas.Version{ + Number: versions[0], + Level: versions[1], + Edition: "PostgreSQL", + }, nil + } + } return nil, errors.New("unknow database version") diff --git a/dialects/time_test.go b/dialects/time_test.go new file mode 100644 index 00000000..670207c6 --- /dev/null +++ b/dialects/time_test.go @@ -0,0 +1,190 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "xorm.io/xorm/schemas" +) + +type dialect struct { + Dialect + dbType schemas.DBType +} + +func (d dialect) URI() *URI { + return &URI{ + DBType: d.dbType, + } +} + +func TestFormatColumnTime(t *testing.T) { + date := time.Date(2020, 10, 23, 10, 14, 15, 123456, time.Local) + tests := []struct { + name string + dialect Dialect + location *time.Location + column *schemas.Column + time time.Time + wantRes interface{} + wantErr error + }{ + { + name: "nullable", + dialect: nil, + location: nil, + column: &schemas.Column{Nullable: true}, + time: time.Time{}, + wantRes: nil, + wantErr: nil, + }, + { + name: "invalid sqltype", + dialect: nil, + location: nil, + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Bit}}, + time: time.Time{}, + wantRes: 0, + wantErr: nil, + }, + { + name: "return default", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Bit}}, + time: date, + wantRes: date, + wantErr: nil, + }, + { + name: "return default (set timezone)", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Bit}, TimeZone: time.UTC}, + time: date, + wantRes: date.In(time.UTC), + wantErr: nil, + }, + { + name: "format date", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Date}}, + time: date, + wantRes: date.Format("2006-01-02"), + wantErr: nil, + }, + { + name: "format time", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Time}}, + time: date, + wantRes: date.Format("15:04:05"), + wantErr: nil, + }, + { + name: "format time (set length)", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Time}, Length: 64}, + time: date, + wantRes: date.Format("15:04:05.999999999"), + wantErr: nil, + }, + { + name: "format datetime", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.DateTime}}, + time: date, + wantRes: date.Format("2006-01-02 15:04:05"), + wantErr: nil, + }, + { + name: "format datetime (set length)", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.DateTime}, Length: 64}, + time: date, + wantRes: date.Format("2006-01-02 15:04:05.999999999"), + wantErr: nil, + }, + { + name: "format timestamp", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.TimeStamp}}, + time: date, + wantRes: date.Format("2006-01-02 15:04:05"), + wantErr: nil, + }, + { + name: "format timestamp (set length)", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.TimeStamp}, Length: 64}, + time: date, + wantRes: date.Format("2006-01-02 15:04:05.999999999"), + wantErr: nil, + }, + { + name: "format varchar", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Varchar}}, + time: date, + wantRes: date.Format("2006-01-02 15:04:05"), + wantErr: nil, + }, + { + name: "format timestampz", + dialect: dialect{}, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.TimeStampz}}, + time: date, + wantRes: date.Format(time.RFC3339Nano), + wantErr: nil, + }, + { + name: "format timestampz (mssql)", + dialect: dialect{dbType: schemas.MSSQL}, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.TimeStampz}}, + time: date, + wantRes: date.Format("2006-01-02T15:04:05.9999999Z07:00"), + wantErr: nil, + }, + { + name: "format int", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.Int}}, + time: date, + wantRes: date.Unix(), + wantErr: nil, + }, + { + name: "format bigint", + dialect: nil, + location: date.Location(), + column: &schemas.Column{SQLType: schemas.SQLType{Name: schemas.BigInt}}, + time: date, + wantRes: date.Unix(), + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := FormatColumnTime(tt.dialect, tt.location, tt.column, tt.time) + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.wantRes, got) + }) + } +} diff --git a/go.mod b/go.mod index 5f2f18e0..b48b3bbc 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module xorm.io/xorm -go 1.13 +go 1.16 require ( gitee.com/travelliu/dm v1.8.11192 diff --git a/go.sum b/go.sum index 9f197fae..5c780ec6 100644 --- a/go.sum +++ b/go.sum @@ -24,7 +24,6 @@ github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDror github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= @@ -39,7 +38,6 @@ github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZ github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= -github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= @@ -231,7 +229,6 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/interface.go b/interface.go index d10abe9e..03dfd236 100644 --- a/interface.go +++ b/interface.go @@ -121,6 +121,7 @@ type EngineInterface interface { ShowSQL(show ...bool) Sync(...interface{}) error Sync2(...interface{}) error + SyncWithOptions(SyncOptions, ...interface{}) (*SyncResult, error) StoreEngine(storeEngine string) *Session TableInfo(bean interface{}) (*schemas.Table, error) TableName(interface{}, ...bool) string diff --git a/internal/statements/delete.go b/internal/statements/delete.go index a77cf862..6e859399 100644 --- a/internal/statements/delete.go +++ b/internal/statements/delete.go @@ -20,6 +20,9 @@ func (statement *Statement) writeDeleteOrder(w *builder.BytesWriter) error { } if statement.LimitN != nil && *statement.LimitN > 0 { + if statement.Start > 0 { + return fmt.Errorf("Delete with Limit start is unsupported") + } limitNValue := *statement.LimitN if _, err := fmt.Fprintf(w, " LIMIT %d", limitNValue); err != nil { return err diff --git a/internal/statements/order_by.go b/internal/statements/order_by.go index 595c0430..54a3c6e0 100644 --- a/internal/statements/order_by.go +++ b/internal/statements/order_by.go @@ -5,6 +5,7 @@ package statements import ( + "errors" "fmt" "xorm.io/builder" @@ -16,6 +17,26 @@ type orderBy struct { direction string // ASC, DESC or "", "" means raw orderStr } +func (ob orderBy) CheckValid() error { + if ob.orderStr == nil { + return fmt.Errorf("order by string is nil") + } + switch t := ob.orderStr.(type) { + case string: + if t == "" { + return fmt.Errorf("order by string is empty") + } + return nil + case *builder.Expression: + if t.Content() == "" { + return fmt.Errorf("order by string is empty") + } + return nil + default: + return fmt.Errorf("order by string is not string or builder.Expression") + } +} + func (statement *Statement) HasOrderBy() bool { return len(statement.orderBy) > 0 } @@ -25,6 +46,8 @@ func (statement *Statement) ResetOrderBy() { statement.orderBy = []orderBy{} } +var ErrNoColumnName = errors.New("no column name") + func (statement *Statement) writeOrderBy(w *builder.BytesWriter, orderBy orderBy) error { switch t := orderBy.orderStr.(type) { case (*builder.Expression): @@ -75,22 +98,45 @@ func (statement *Statement) writeOrderBys(w *builder.BytesWriter) error { // OrderBy generate "Order By order" statement func (statement *Statement) OrderBy(order interface{}, args ...interface{}) *Statement { - statement.orderBy = append(statement.orderBy, orderBy{order, args, ""}) + ob := orderBy{order, args, ""} + if err := ob.CheckValid(); err != nil { + statement.LastError = err + return statement + } + statement.orderBy = append(statement.orderBy, ob) return statement } // Desc generate `ORDER BY xx DESC` func (statement *Statement) Desc(colNames ...string) *Statement { + if len(colNames) == 0 { + statement.LastError = ErrNoColumnName + return statement + } for _, colName := range colNames { - statement.orderBy = append(statement.orderBy, orderBy{colName, nil, "DESC"}) + ob := orderBy{colName, nil, "DESC"} + statement.orderBy = append(statement.orderBy, ob) + if err := ob.CheckValid(); err != nil { + statement.LastError = err + return statement + } } return statement } // Asc provide asc order by query condition, the input parameters are columns. func (statement *Statement) Asc(colNames ...string) *Statement { + if len(colNames) == 0 { + statement.LastError = ErrNoColumnName + return statement + } for _, colName := range colNames { - statement.orderBy = append(statement.orderBy, orderBy{colName, nil, "ASC"}) + ob := orderBy{colName, nil, "ASC"} + statement.orderBy = append(statement.orderBy, ob) + if err := ob.CheckValid(); err != nil { + statement.LastError = err + return statement + } } return statement } diff --git a/internal/statements/query.go b/internal/statements/query.go index 211ba268..216a2028 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -35,7 +35,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int } buf := builder.NewWriter() - if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -66,7 +66,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri } buf := builder.NewWriter() - if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true); err != nil { + if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true, true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -122,7 +122,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, } buf := builder.NewWriter() - if err := statement.writeSelect(buf, columnStr, true); err != nil { + if err := statement.writeSelect(buf, columnStr, true, true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -153,12 +153,6 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa selectSQL = "count(*)" } } - var subQuerySelect string - if statement.GroupByStr != "" { - subQuerySelect = statement.GroupByStr - } else { - subQuerySelect = selectSQL - } buf := builder.NewWriter() if statement.GroupByStr != "" { @@ -167,7 +161,14 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa } } - if err := statement.writeSelect(buf, subQuerySelect, false); err != nil { + var subQuerySelect string + if statement.GroupByStr != "" { + subQuerySelect = statement.GroupByStr + } else { + subQuerySelect = selectSQL + } + + if err := statement.writeSelect(buf, subQuerySelect, false, false); err != nil { return "", nil, err } @@ -243,14 +244,19 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr return err } -func (statement *Statement) writeWhere(w *builder.BytesWriter) error { - if !statement.cond.IsValid() { +func (statement *Statement) writeWhereCond(w *builder.BytesWriter, cond builder.Cond) error { + if !cond.IsValid() { return nil } + if _, err := fmt.Fprint(w, " WHERE "); err != nil { return err } - return statement.cond.WriteTo(statement.QuoteReplacer(w)) + return cond.WriteTo(statement.QuoteReplacer(w)) +} + +func (statement *Statement) writeWhere(w *builder.BytesWriter) error { + return statement.writeWhereCond(w, statement.cond) } func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error { @@ -359,7 +365,7 @@ func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr s return err } -func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error { +func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit, needOrderBy bool) error { if err := statement.writeSelectColumns(buf, columnStr); err != nil { return err } @@ -375,8 +381,10 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri if err := statement.writeHaving(buf); err != nil { return err } - if err := statement.writeOrderBys(buf); err != nil { - return err + if needOrderBy { + if err := statement.writeOrderBys(buf); err != nil { + return err + } } dialect := statement.dialect @@ -514,7 +522,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa statement.cond = statement.cond.And(autoCond) buf := builder.NewWriter() - if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 017f40a5..c075ec54 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -644,6 +644,23 @@ func (statement *Statement) convertSQLOrArgs(sqlOrArgs ...interface{}) (string, newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) } else if v, ok := arg.(*time.Time); ok && v != nil { newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) + } else if v, ok := arg.(convert.ConversionTo); ok { + r, err := v.ToDB() + if err != nil { + return "", nil, err + } + if r != nil { + // for nvarchar column on mssql, bytes have to be converted as ucs-2 external of driver + // for binary column, a string will be converted as bytes directly. So we have to + // convert bytes as string + if statement.dialect.URI().DBType == schemas.MSSQL { + newArgs = append(newArgs, string(r)) + } else { + newArgs = append(newArgs, r) + } + } else { + newArgs = append(newArgs, nil) + } } else { newArgs = append(newArgs, arg) } @@ -690,10 +707,7 @@ func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { if col.SQLType.IsNumeric() { cond = builder.Eq{colName: 0} } else { - // FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value. - if statement.dialect.URI().DBType != schemas.MSSQL { - cond = builder.Eq{colName: utils.ZeroTime1} - } + cond = builder.Eq{colName: utils.ZeroTime1} } if col.Nullable { diff --git a/internal/statements/update.go b/internal/statements/update.go index f0914b0b..5d71f34d 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "reflect" - "strings" "time" "xorm.io/builder" @@ -311,84 +310,328 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, return colNames, args, nil } -func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string) error { - whereWriter := builder.NewWriter() - if cond.IsValid() { - fmt.Fprint(whereWriter, "WHERE ") +func (statement *Statement) writeUpdateTop(updateWriter *builder.BytesWriter) error { + if statement.dialect.URI().DBType != schemas.MSSQL || statement.LimitN == nil { + return nil } - if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { + + table := statement.RefTable + if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { + return nil + } + + _, err := fmt.Fprintf(updateWriter, " TOP (%d)", *statement.LimitN) + return err +} + +func (statement *Statement) writeUpdateTableName(updateWriter *builder.BytesWriter) error { + tableName := statement.quote(statement.TableName()) + if statement.TableAlias == "" { + _, err := fmt.Fprint(updateWriter, " ", tableName) return err } - if err := statement.writeOrderBys(whereWriter); err != nil { + + switch statement.dialect.URI().DBType { + case schemas.MSSQL: + _, err := fmt.Fprint(updateWriter, " ", statement.TableAlias) return err + default: + _, err := fmt.Fprint(updateWriter, " ", tableName, " AS ", statement.TableAlias) + return err + } +} + +func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) error { + if statement.dialect.URI().DBType != schemas.MSSQL || statement.TableAlias == "" { + return nil + } + + _, err := fmt.Fprint(updateWriter, " FROM ", statement.quote(statement.TableName()), " ", statement.TableAlias) + return err +} + +func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, cond builder.Cond) error { + if statement.LimitN == nil { + return nil } table := statement.RefTable tableName := statement.TableName() - // TODO: Oracle support needed - var top string - if statement.LimitN != nil { - limitValue := *statement.LimitN - switch statement.dialect.URI().DBType { - case schemas.MYSQL: - fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) - case schemas.SQLITE: - fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) - cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", - statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)) - - whereWriter = builder.NewWriter() - fmt.Fprint(whereWriter, "WHERE ") - if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { + limitValue := *statement.LimitN + switch statement.dialect.URI().DBType { + case schemas.MYSQL: + _, err := fmt.Fprintf(updateWriter, " LIMIT %d", limitValue) + return err + case schemas.SQLITE: + if cond.IsValid() { + if _, err := fmt.Fprint(updateWriter, " AND "); err != nil { return err } - case schemas.POSTGRES: - fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) - - cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", - statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)) - - whereWriter = builder.NewWriter() - fmt.Fprint(whereWriter, "WHERE ") - if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { + } else { + if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil { return err } - case schemas.MSSQL: - if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { - cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", - table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], - statement.quote(tableName), whereWriter.String()), whereWriter.Args()...) - - whereWriter = builder.NewWriter() - fmt.Fprint(whereWriter, "WHERE ") - if err := cond.WriteTo(whereWriter); err != nil { - return err - } - } else { - top = fmt.Sprintf("TOP (%d) ", limitValue) + } + if _, err := fmt.Fprint(updateWriter, "rowid IN (SELECT rowid FROM ", statement.quote(tableName)); err != nil { + return err + } + if err := statement.writeWhereCond(updateWriter, cond); err != nil { + return err + } + if err := statement.writeOrderBys(updateWriter); err != nil { + return err + } + _, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue) + return err + case schemas.POSTGRES: + if cond.IsValid() { + if _, err := fmt.Fprint(updateWriter, " AND "); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil { + return err } } - } - - tableAlias := statement.quote(tableName) - var fromSQL string - if statement.TableAlias != "" { - switch statement.dialect.URI().DBType { - case schemas.MSSQL: - fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, statement.TableAlias) - tableAlias = statement.TableAlias - default: - tableAlias = fmt.Sprintf("%s AS %s", tableAlias, statement.TableAlias) + if _, err := fmt.Fprint(updateWriter, "CTID IN (SELECT CTID FROM ", statement.quote(tableName)); err != nil { + return err } + if err := statement.writeWhereCond(updateWriter, cond); err != nil { + return err + } + if err := statement.writeOrderBys(updateWriter); err != nil { + return err + } + _, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue) + return err + case schemas.MSSQL: + if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { + if _, err := fmt.Fprintf(updateWriter, " WHERE %s IN (SELECT TOP (%d) %s FROM %v", + table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], + statement.quote(tableName)); err != nil { + return err + } + if err := statement.writeWhereCond(updateWriter, cond); err != nil { + return err + } + if err := statement.writeOrderBys(updateWriter); err != nil { + return err + } + _, err := fmt.Fprint(updateWriter, ")") + return err + } + return nil + default: // TODO: Oracle support needed + return fmt.Errorf("not implemented") + } +} + +func (statement *Statement) GenConditionsFromMap(m interface{}) ([]builder.Cond, error) { + switch t := m.(type) { + case map[string]interface{}: + conds := []builder.Cond{} + for k, v := range t { + conds = append(conds, builder.Eq{k: v}) + } + return conds, nil + case map[string]string: + conds := []builder.Cond{} + for k, v := range t { + conds = append(conds, builder.Eq{k: v}) + } + return conds, nil + default: + return nil, fmt.Errorf("unsupported condition map type %v", t) + } +} + +func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Value, hasPreviousSet bool) error { + if v.Type().Kind() != reflect.Struct { + return nil } - if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v", - top, - tableAlias, - strings.Join(colNames, ", "), - fromSQL); err != nil { + table := statement.RefTable + if !(statement.RefTable != nil && table.Version != "" && statement.CheckVersion) { + return nil + } + + verValue, err := table.VersionColumn().ValueOfV(&v) + if err != nil { return err } - return utils.WriteBuilder(updateWriter, whereWriter) + + if verValue == nil { + return nil + } + + if hasPreviousSet { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + + if _, err := fmt.Fprint(w, statement.quote(table.Version), " = ", statement.quote(table.Version), " + 1"); err != nil { + return err + } + return nil +} + +func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error { + for i, expr := range statement.IncrColumns { + if i > 0 || hasPreviousSet { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " + ?"); err != nil { + return err + } + w.Append(expr.Arg) + } + return nil +} + +func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error { + // for update action to like "column = column - ?" + for i, expr := range statement.DecrColumns { + if i > 0 || hasPreviousSet { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " - ?"); err != nil { + return err + } + w.Append(expr.Arg) + } + return nil +} + +func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error { + // for update action to like "column = expression" + for i, expr := range statement.ExprColumns { + if i > 0 || hasPreviousSet { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + switch tp := expr.Arg.(type) { + case string: + if len(tp) == 0 { + tp = "''" + } + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", tp); err != nil { + return err + } + case *builder.Builder: + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ("); err != nil { + return err + } + if err := tp.WriteTo(statement.QuoteReplacer(w)); err != nil { + return err + } + if _, err := fmt.Fprint(w, ")"); err != nil { + return err + } + default: + if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ?"); err != nil { + return err + } + w.Append(expr.Arg) + } + } + return nil +} + +func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error { + previousLen := w.Len() + for i, colName := range colNames { + if i > 0 { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + if _, err := fmt.Fprint(w, colName); err != nil { + return err + } + } + w.Append(args...) + + if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil { + return err + } + + if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil { + return err + } + + if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil { + return err + } + + if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil { + return err + } + return nil +} + +var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") + +func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error { + if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil { + return err + } + + if err := statement.writeUpdateTop(updateWriter); err != nil { + return err + } + + if err := statement.writeUpdateTableName(updateWriter); err != nil { + return err + } + + // write set + if _, err := fmt.Fprint(updateWriter, " SET "); err != nil { + return err + } + previousLen := updateWriter.Len() + + if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil { + return err + } + + // if no columns to be updated, return error + if previousLen == updateWriter.Len() { + return ErrNoColumnsTobeUpdated + } + + // write from + if err := statement.writeUpdateFrom(updateWriter); err != nil { + return err + } + + if statement.dialect.URI().DBType == schemas.MSSQL { + table := statement.RefTable + if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { + } else { + // write where + if err := statement.writeWhereCond(updateWriter, cond); err != nil { + return err + } + } + } else { + // write where + if err := statement.writeWhereCond(updateWriter, cond); err != nil { + return err + } + } + + if statement.dialect.URI().DBType == schemas.MYSQL { + if err := statement.writeOrderBys(updateWriter); err != nil { + return err + } + } + + return statement.writeUpdateLimit(updateWriter, cond) } diff --git a/rows.go b/rows.go index a42eedb9..c539410e 100644 --- a/rows.go +++ b/rows.go @@ -144,6 +144,8 @@ func (rows *Rows) Close() error { defer rows.session.Close() } + defer rows.session.resetStatement() + if rows.rows != nil { return rows.rows.Close() } diff --git a/session_insert.go b/session_insert.go index 7003e0f7..7cc15241 100644 --- a/session_insert.go +++ b/session_insert.go @@ -471,7 +471,8 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } if col.IsDeleted { - arg, err := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, time.Time{}) + zeroTime := time.Date(1, 1, 1, 0, 0, 0, 0, session.engine.DatabaseTZ) + arg, err := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, zeroTime) if err != nil { return nil, nil, err } diff --git a/session_update.go b/session_update.go index 9a6964f1..b3640ad2 100644 --- a/session_update.go +++ b/session_update.go @@ -5,17 +5,17 @@ package xorm import ( - "errors" "reflect" "xorm.io/builder" + "xorm.io/xorm/internal/statements" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) // enumerated all errors var ( - ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") + ErrNoColumnsTobeUpdated = statements.ErrNoColumnsTobeUpdated ) func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) { @@ -74,9 +74,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 v := utils.ReflectValue(bean) t := v.Type() - var colNames []string - var args []interface{} - // handle before update processors for _, closure := range session.beforeClosures { closure(bean) @@ -87,6 +84,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } // -- + var colNames []string + var args []interface{} var err error isMap := t.Kind() == reflect.Map isStruct := t.Kind() == reflect.Struct @@ -148,41 +147,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - // for update action to like "column = column + ?" - incColumns := session.statement.IncrColumns - for _, expr := range incColumns { - colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?") - args = append(args, expr.Arg) - } - // for update action to like "column = column - ?" - decColumns := session.statement.DecrColumns - for _, expr := range decColumns { - colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?") - args = append(args, expr.Arg) - } - // for update action to like "column = expression" - exprColumns := session.statement.ExprColumns - for _, expr := range exprColumns { - switch tp := expr.Arg.(type) { - case string: - if len(tp) == 0 { - tp = "''" - } - colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp) - case *builder.Builder: - subQuery, subArgs, err := builder.ToSQL(tp) - if err != nil { - return 0, err - } - subQuery = session.statement.ReplaceQuote(subQuery) - colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")") - args = append(args, subArgs...) - default: - colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?") - args = append(args, expr.Arg) - } - } - if err = session.statement.ProcessIDParam(); err != nil { return 0, err } @@ -211,30 +175,25 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 verValue *reflect.Value ) if doIncVer { - verValue, err = table.VersionColumn().ValueOf(bean) + verValue, err = table.VersionColumn().ValueOfV(&v) if err != nil { return 0, err } if verValue != nil { cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()}) - colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1") } } - if len(colNames) == 0 { - return 0, ErrNoColumnsTobeUpdated - } - updateWriter := builder.NewWriter() - if err := session.statement.WriteUpdate(updateWriter, cond, colNames); err != nil { + if err := session.statement.WriteUpdate(updateWriter, cond, v, colNames, args); err != nil { return 0, err } tableName := session.statement.TableName() // table name must been get before exec because statement will be reset useCache := session.statement.UseCache - res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...) + res, err := session.exec(updateWriter.String(), updateWriter.Args()...) if err != nil { return 0, err } else if doIncVer { diff --git a/sync.go b/sync.go index 635a8ba9..9e1cb8c1 100644 --- a/sync.go +++ b/sync.go @@ -13,6 +13,10 @@ import ( type SyncOptions struct { WarnIfDatabaseColumnMissed bool + // IgnoreConstrains will not add, delete or update unique constrains + IgnoreConstrains bool + // IgnoreIndices will not add or delete indices + IgnoreIndices bool } type SyncResult struct{} @@ -49,6 +53,8 @@ func (session *Session) Sync2(beans ...interface{}) error { func (session *Session) Sync(beans ...interface{}) error { _, err := session.SyncWithOptions(SyncOptions{ WarnIfDatabaseColumnMissed: false, + IgnoreConstrains: false, + IgnoreIndices: false, }, beans...) return err } @@ -103,15 +109,20 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) return nil, err } - err = session.createUniques(bean) - if err != nil { - return nil, err + if !opts.IgnoreConstrains { + err = session.createUniques(bean) + if err != nil { + return nil, err + } } - err = session.createIndexes(bean) - if err != nil { - return nil, err + if !opts.IgnoreIndices { + err = session.createIndexes(bean) + if err != nil { + return nil, err + } } + continue } @@ -208,9 +219,12 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) } } + // indices found in orig table foundIndexNames := make(map[string]bool) + // indices to be added addedNames := make(map[string]*schemas.Index) + // drop indices that exist in orig and new table schema but are not equal for name, index := range table.Indexes { var oriIndex *schemas.Index for name2, index2 := range oriTable.Indexes { @@ -221,15 +235,13 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) } } - if oriIndex != nil { - if oriIndex.Type != index.Type { - sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex) - _, err = session.exec(sql) - if err != nil { - return nil, err - } - oriIndex = nil + if oriIndex != nil && oriIndex.Type != index.Type { + sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex) + _, err = session.exec(sql) + if err != nil { + return nil, err } + oriIndex = nil } if oriIndex == nil { @@ -237,8 +249,17 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) } } + // drop all indices that do not exist in new schema or have changed for name2, index2 := range oriTable.Indexes { if _, ok := foundIndexNames[name2]; !ok { + // ignore based on there type + if (index2.Type == schemas.IndexType && opts.IgnoreIndices) || + (index2.Type == schemas.UniqueType && opts.IgnoreConstrains) { + // make sure we do not add a index with same name later + delete(addedNames, name2) + continue + } + sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2) _, err = session.exec(sql) if err != nil { @@ -247,12 +268,13 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) } } + // Add new indices because either they did not exist before or were dropped to update them for name, index := range addedNames { - if index.Type == schemas.UniqueType { + if index.Type == schemas.UniqueType && !opts.IgnoreConstrains { session.statement.RefTable = table session.statement.SetTableName(tbNameWithSchema) err = session.addUnique(tbNameWithSchema, name) - } else if index.Type == schemas.IndexType { + } else if index.Type == schemas.IndexType && !opts.IgnoreIndices { session.statement.RefTable = table session.statement.SetTableName(tbNameWithSchema) err = session.addIndex(tbNameWithSchema, name) diff --git a/integrations/cache_test.go b/tests/cache_test.go similarity index 97% rename from integrations/cache_test.go rename to tests/cache_test.go index 2caeaa34..c3f84c77 100644 --- a/integrations/cache_test.go +++ b/tests/cache_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "testing" @@ -28,7 +28,7 @@ func TestCacheFind(t *testing.T) { assert.NoError(t, testEngine.Sync(new(MailBox))) - var inserts = []*MailBox{ + inserts := []*MailBox{ { Id: 0, Username: "user1", @@ -105,7 +105,7 @@ func TestCacheFind2(t *testing.T) { assert.NoError(t, testEngine.Sync(new(MailBox2))) - var inserts = []*MailBox2{ + inserts := []*MailBox2{ { Id: 0, Username: "user1", @@ -156,7 +156,7 @@ func TestCacheGet(t *testing.T) { assert.NoError(t, testEngine.Sync(new(MailBox3))) - var inserts = []*MailBox3{ + inserts := []*MailBox3{ { Username: "user1", Password: "pass1", diff --git a/integrations/engine_dm_test.go b/tests/engine_dm_test.go similarity index 93% rename from integrations/engine_dm_test.go rename to tests/engine_dm_test.go index 3b195ef8..5b25af29 100644 --- a/integrations/engine_dm_test.go +++ b/tests/engine_dm_test.go @@ -5,7 +5,7 @@ //go:build dm // +build dm -package integrations +package tests import "xorm.io/xorm/schemas" diff --git a/integrations/engine_group_test.go b/tests/engine_group_test.go similarity index 97% rename from integrations/engine_group_test.go rename to tests/engine_group_test.go index 635f73a6..629e0aa4 100644 --- a/integrations/engine_group_test.go +++ b/tests/engine_group_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "testing" diff --git a/integrations/engine_test.go b/tests/engine_test.go similarity index 99% rename from integrations/engine_test.go rename to tests/engine_test.go index 86ed7344..79ca42f5 100644 --- a/integrations/engine_test.go +++ b/tests/engine_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "context" diff --git a/integrations/main_test.go b/tests/main_test.go similarity index 91% rename from integrations/main_test.go rename to tests/main_test.go index 225ae45a..ab0ee802 100644 --- a/integrations/main_test.go +++ b/tests/main_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "testing" diff --git a/integrations/performance_test.go b/tests/performance_test.go similarity index 92% rename from integrations/performance_test.go rename to tests/performance_test.go index 49183717..d1d12161 100644 --- a/integrations/performance_test.go +++ b/tests/performance_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "testing" @@ -23,7 +23,7 @@ func BenchmarkGetVars(b *testing.B) { assert.NoError(b, testEngine.Sync(new(BenchmarkGetVars))) - var v = BenchmarkGetVars{ + v := BenchmarkGetVars{ Name: "myname", } _, err := testEngine.Insert(&v) @@ -54,7 +54,7 @@ func BenchmarkGetStruct(b *testing.B) { assert.NoError(b, testEngine.Sync(new(BenchmarkGetStruct))) - var v = BenchmarkGetStruct{ + v := BenchmarkGetStruct{ Name: "myname", } _, err := testEngine.Insert(&v) @@ -86,13 +86,13 @@ func BenchmarkFindStruct(b *testing.B) { assert.NoError(b, testEngine.Sync(new(BenchmarkFindStruct))) - var v = BenchmarkFindStruct{ + v := BenchmarkFindStruct{ Name: "myname", } _, err := testEngine.Insert(&v) assert.NoError(b, err) - var mynames = make([]BenchmarkFindStruct, 0, 1) + mynames := make([]BenchmarkFindStruct, 0, 1) b.StartTimer() for i := 0; i < b.N; i++ { err := testEngine.Find(&mynames) diff --git a/integrations/processors_test.go b/tests/processors_test.go similarity index 99% rename from integrations/processors_test.go rename to tests/processors_test.go index 4c383437..af2866e8 100644 --- a/integrations/processors_test.go +++ b/tests/processors_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "errors" @@ -200,7 +200,7 @@ func TestProcessors(t *testing.T) { // -- // test find map processors - var p2FindMap = make(map[int64]*ProcessorsStruct) + p2FindMap := make(map[int64]*ProcessorsStruct) err = testEngine.Find(&p2FindMap) assert.NoError(t, err) @@ -848,13 +848,13 @@ func TestAfterLoadProcessor(t *testing.T) { assertSync(t, new(AfterLoadStructA), new(AfterLoadStructB)) - var a = AfterLoadStructA{ + a := AfterLoadStructA{ Content: "testa", } _, err := testEngine.Insert(&a) assert.NoError(t, err) - var b = AfterLoadStructB{ + b := AfterLoadStructB{ Content: "testb", AId: a.Id, } diff --git a/integrations/rows_test.go b/tests/rows_test.go similarity index 99% rename from integrations/rows_test.go rename to tests/rows_test.go index e354b75e..fe4f374c 100644 --- a/integrations/rows_test.go +++ b/tests/rows_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "testing" diff --git a/integrations/schema_test.go b/tests/schema_test.go similarity index 87% rename from integrations/schema_test.go rename to tests/schema_test.go index 149c6394..db9f9e8f 100644 --- a/integrations/schema_test.go +++ b/tests/schema_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "errors" @@ -12,6 +12,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "xorm.io/xorm" "xorm.io/xorm/schemas" ) @@ -645,3 +646,101 @@ func TestCollate(t *testing.T) { }) assert.NoError(t, err) } + +type SyncWithOpts1 struct { + Id int64 + Index int `xorm:"index"` + Unique int `xorm:"unique"` + Group1 int `xorm:"index(ttt)"` + Group2 int `xorm:"index(ttt)"` + UniGroup1 int `xorm:"unique(lll)"` + UniGroup2 int `xorm:"unique(lll)"` +} + +func (*SyncWithOpts1) TableName() string { + return "sync_with_opts" +} + +type SyncWithOpts2 struct { + Id int64 + Index int `xorm:"index"` + Unique int `xorm:""` + Group1 int `xorm:"index(ttt)"` + Group2 int `xorm:"index(ttt)"` + UniGroup1 int `xorm:""` + UniGroup2 int `xorm:"unique(lll)"` +} + +func (*SyncWithOpts2) TableName() string { + return "sync_with_opts" +} + +type SyncWithOpts3 struct { + Id int64 + Index int `xorm:""` + Unique int `xorm:"unique"` + Group1 int `xorm:""` + Group2 int `xorm:"index(ttt)"` + UniGroup1 int `xorm:"unique(lll)"` + UniGroup2 int `xorm:"unique(lll)"` +} + +func (*SyncWithOpts3) TableName() string { + return "sync_with_opts" +} + +func TestSyncWithOptions(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + // ignore indices and constrains + result, err := testEngine.SyncWithOptions(xorm.SyncOptions{IgnoreIndices: true, IgnoreConstrains: true}, &SyncWithOpts1{}) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Len(t, getIndicesOfBeanFromDB(t, &SyncWithOpts1{}), 0) + + // only ignore indices + result, err = testEngine.SyncWithOptions(xorm.SyncOptions{IgnoreConstrains: true}, &SyncWithOpts2{}) + assert.NoError(t, err) + assert.NotNil(t, result) + indices := getIndicesOfBeanFromDB(t, &SyncWithOpts1{}) + assert.Len(t, indices, 2) + assert.ElementsMatch(t, []string{"ttt", "index"}, getKeysFromMap(indices)) + + // only ignore constrains + result, err = testEngine.SyncWithOptions(xorm.SyncOptions{IgnoreIndices: true}, &SyncWithOpts3{}) + assert.NoError(t, err) + assert.NotNil(t, result) + indices = getIndicesOfBeanFromDB(t, &SyncWithOpts1{}) + assert.Len(t, indices, 4) + assert.ElementsMatch(t, []string{"ttt", "index", "unique", "lll"}, getKeysFromMap(indices)) + + tableInfoFromStruct, _ := testEngine.TableInfo(&SyncWithOpts1{}) + assert.ElementsMatch(t, getKeysFromMap(tableInfoFromStruct.Indexes), getKeysFromMap(getIndicesOfBeanFromDB(t, &SyncWithOpts1{}))) + +} + +func getIndicesOfBeanFromDB(t *testing.T, bean interface{}) map[string]*schemas.Index { + dbm, err := testEngine.DBMetas() + assert.NoError(t, err) + + tName := testEngine.TableName(bean) + var tSchema *schemas.Table + for _, t := range dbm { + if t.Name == tName { + tSchema = t + break + } + } + if !assert.NotNil(t, tSchema) { + return nil + } + return tSchema.Indexes +} + +func getKeysFromMap(m map[string]*schemas.Index) []string { + var ss []string + for k := range m { + ss = append(ss, k) + } + return ss +} diff --git a/integrations/session_cols_test.go b/tests/session_cols_test.go similarity index 97% rename from integrations/session_cols_test.go rename to tests/session_cols_test.go index 462ea7c7..4a6ef39f 100644 --- a/integrations/session_cols_test.go +++ b/tests/session_cols_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "testing" @@ -22,7 +22,7 @@ func TestSetExpr(t *testing.T) { assert.NoError(t, testEngine.Sync(new(UserExprIssue))) - var issue = UserExprIssue{ + issue := UserExprIssue{ Title: "my issue", } cnt, err := testEngine.Insert(&issue) @@ -44,7 +44,7 @@ func TestSetExpr(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var not = "NOT" + not := "NOT" if testEngine.Dialect().URI().DBType == schemas.MSSQL || testEngine.Dialect().URI().DBType == schemas.DAMENG { not = "~" } @@ -118,7 +118,7 @@ func TestMustCol(t *testing.T) { assertSync(t, new(CustomerUpdate)) - var customer = CustomerUpdate{ + customer := CustomerUpdate{ ParentId: 1, } cnt, err := testEngine.Insert(&customer) diff --git a/integrations/session_cond_test.go b/tests/session_cond_test.go similarity index 98% rename from integrations/session_cond_test.go rename to tests/session_cond_test.go index 0597d74e..cbcd0cb5 100644 --- a/integrations/session_cond_test.go +++ b/tests/session_cond_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "errors" @@ -37,7 +37,7 @@ func TestBuilder(t *testing.T) { assert.NoError(t, err) var cond Condition - var q = testEngine.Quote + q := testEngine.Quote has, err := testEngine.Where(builder.Eq{q("col_name"): "col1"}).Get(&cond) assert.NoError(t, err) assert.Equal(t, true, has, "records should exist") @@ -90,7 +90,7 @@ func TestBuilder(t *testing.T) { assert.EqualValues(t, 0, len(conds), "records should not exist") // complex condtions - var where = builder.NewCond() + where := builder.NewCond() if true { where = where.And(builder.Eq{q("col_name"): "col1"}) where = where.Or(builder.And(builder.In(q("col_name"), "col1", "col2"), builder.Expr(q("col_name")+" = ?", "col1"))) diff --git a/integrations/session_count_test.go b/tests/session_count_test.go similarity index 93% rename from integrations/session_count_test.go rename to tests/session_count_test.go index 079602c3..24c5d24a 100644 --- a/integrations/session_count_test.go +++ b/tests/session_count_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "testing" @@ -125,6 +125,11 @@ func TestWithTableName(t *testing.T) { total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{}) assert.NoError(t, err) assert.EqualValues(t, 2, total) + + // the orderby will be ignored by count because some databases will return errors if the orderby columns not in group by + total, err = testEngine.OrderBy("`name`").Count(CountWithTableName{}) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) } func TestCountWithSelectCols(t *testing.T) { diff --git a/integrations/session_delete_test.go b/tests/session_delete_test.go similarity index 83% rename from integrations/session_delete_test.go rename to tests/session_delete_test.go index 680c3215..44d4ad7d 100644 --- a/integrations/session_delete_test.go +++ b/tests/session_delete_test.go @@ -2,9 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( + "os" "testing" "time" @@ -70,6 +71,63 @@ func TestDelete(t *testing.T) { assert.False(t, has) } +func TestDeleteLimit(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + if testEngine.Dialect().URI().DBType == schemas.MSSQL || os.Getenv("IGNORE_TEST_DELETE_LIMIT") == "true" { + t.Skip() + return + } + + type UserinfoDeleteLimit struct { + Uid int64 `xorm:"id pk not null autoincr"` + IsMan bool + } + + assert.NoError(t, testEngine.Sync2(new(UserinfoDeleteLimit))) + + session := testEngine.NewSession() + defer session.Close() + + var err error + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT userinfo_delete_limit ON") + assert.NoError(t, err) + } + + user := UserinfoDeleteLimit{Uid: 1, IsMan: true} + cnt, err := session.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + user2 := UserinfoDeleteLimit{Uid: 2} + cnt, err = session.Insert(&user2) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Commit() + assert.NoError(t, err) + } + + cnt, err = testEngine.Limit(1, 1).Delete(&UserinfoDeleteLimit{}) + assert.Error(t, err) + assert.EqualValues(t, 0, cnt) + + cnt, err = testEngine.Limit(1).Desc("id").Delete(&UserinfoDeleteLimit{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var users []UserinfoDeleteLimit + err = testEngine.Find(&users) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) + assert.EqualValues(t, 1, users[0].Uid) + assert.EqualValues(t, true, users[0].IsMan) +} + func TestDeleted(t *testing.T) { assert.NoError(t, PrepareEngine()) diff --git a/integrations/session_exist_test.go b/tests/session_exist_test.go similarity index 98% rename from integrations/session_exist_test.go rename to tests/session_exist_test.go index ca1e66ad..a9e3a6a8 100644 --- a/integrations/session_exist_test.go +++ b/tests/session_exist_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "context" @@ -106,14 +106,14 @@ func TestExistStructForJoin(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var orderlist = OrderList{ + orderlist := OrderList{ Eid: ply.Id, } cnt, err = testEngine.Insert(&orderlist) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var um = Number{ + um := Number{ Lid: orderlist.Id, } cnt, err = testEngine.Insert(&um) diff --git a/integrations/session_find_test.go b/tests/session_find_test.go similarity index 98% rename from integrations/session_find_test.go rename to tests/session_find_test.go index 65df5aee..d991e6ba 100644 --- a/integrations/session_find_test.go +++ b/tests/session_find_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "testing" @@ -1237,3 +1237,20 @@ func TestBuilderDialect(t *testing.T) { err := testEngine.Table("test_builder_dialect").Where(builder.Eq{"age2": 2}).Join("INNER", inner, "test_builder_dialect_foo.dialect_id = test_builder_dialect.id").Find(&result) assert.NoError(t, err) } + +func TestFindInMaxID(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestFindInMaxId struct { + Id int64 + Name string `xorm:"index"` + Age2 int + } + + assertSync(t, new(TestFindInMaxId)) + + var res []TestFindInMaxId + tableName := testEngine.TableName("test_find_in_max_id", true) + err := testEngine.In("id", builder.Select("max(id)").From(testEngine.Quote(tableName))).Find(&res) + assert.NoError(t, err) +} diff --git a/integrations/session_get_test.go b/tests/session_get_test.go similarity index 99% rename from integrations/session_get_test.go rename to tests/session_get_test.go index 841ec709..2ff2f67d 100644 --- a/integrations/session_get_test.go +++ b/tests/session_get_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "database/sql" @@ -1012,4 +1012,12 @@ func TestGetBytesVars(t *testing.T) { assert.True(t, has) assert.EqualValues(t, []byte("bytes1-1"), gbv.Bytes1) assert.EqualValues(t, []byte("bytes2-2"), gbv.Bytes2) + + type MyID int64 + var myID MyID + + has, err = testEngine.Table("get_bytes_vars").Select("id").Desc("id").Get(&myID) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, gbv.Id, myID) } diff --git a/integrations/session_insert_test.go b/tests/session_insert_test.go similarity index 92% rename from integrations/session_insert_test.go rename to tests/session_insert_test.go index 084deb38..dd3e8405 100644 --- a/integrations/session_insert_test.go +++ b/tests/session_insert_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "fmt" @@ -142,8 +142,13 @@ func TestInsert(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) - user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(), - Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true} + user := Userinfo{ + 0, "xiaolunwen", "dev", "lunny", time.Now(), + Userdetail{Id: 1}, + 1.78, + []byte{1, 2, 3}, + true, + } cnt, err := testEngine.Insert(&user) assert.NoError(t, err) assert.EqualValues(t, 1, cnt, "insert not returned 1") @@ -161,8 +166,10 @@ func TestInsertAutoIncr(t *testing.T) { assertSync(t, new(Userinfo)) // auto increment insert - user := Userinfo{Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(), - Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true} + user := Userinfo{ + Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(), + Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true, + } cnt, err := testEngine.Insert(&user) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) @@ -184,7 +191,7 @@ func TestInsertDefault(t *testing.T) { err := testEngine.Sync(di) assert.NoError(t, err) - var di2 = DefaultInsert{Name: "test"} + di2 := DefaultInsert{Name: "test"} _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("Status")).Insert(&di2) assert.NoError(t, err) @@ -210,7 +217,7 @@ func TestInsertDefault2(t *testing.T) { err := testEngine.Sync(di) assert.NoError(t, err) - var di2 = DefaultInsert2{Name: "test"} + di2 := DefaultInsert2{Name: "test"} _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("CheckTime")).Insert(&di2) assert.NoError(t, err) @@ -438,7 +445,7 @@ func TestCreatedJsonTime(t *testing.T) { assert.True(t, has) assert.EqualValues(t, time.Time(ci5.Created).Unix(), time.Time(di5.Created).Unix()) - var dis = make([]MyJSONTime, 0) + dis := make([]MyJSONTime, 0) err = testEngine.Find(&dis) assert.NoError(t, err) } @@ -762,7 +769,7 @@ func TestInsertWhere(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(InsertWhere)) - var i = InsertWhere{ + i := InsertWhere{ RepoId: 1, Width: 10, Height: 20, @@ -872,7 +879,7 @@ func TestInsertExpr2(t *testing.T) { assertSync(t, new(InsertExprsRelease)) - var ie = InsertExprsRelease{ + ie := InsertExprsRelease{ RepoId: 1, IsTag: true, } @@ -1047,7 +1054,7 @@ func TestInsertIntSlice(t *testing.T) { assert.NoError(t, testEngine.Sync(new(InsertIntSlice))) - var v = InsertIntSlice{ + v := InsertIntSlice{ NameIDs: []int{1, 2}, } cnt, err := testEngine.Insert(&v) @@ -1064,7 +1071,7 @@ func TestInsertIntSlice(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var v3 = InsertIntSlice{ + v3 := InsertIntSlice{ NameIDs: nil, } cnt, err = testEngine.Insert(&v3) @@ -1202,3 +1209,80 @@ func TestInsertMultipleMap(t *testing.T) { Name: "xiaolunwen", }, res[1]) } + +func TestInsertNotDeleted(t *testing.T) { + assert.NoError(t, PrepareEngine()) + zeroTime := time.Date(1, 1, 1, 0, 0, 0, 0, testEngine.GetTZDatabase()) + type TestInsertNotDeletedStructNotRight struct { + ID uint64 `xorm:"'ID' pk autoincr"` + DeletedAt time.Time `xorm:"'DELETED_AT' deleted notnull"` + } + // notnull tag will be ignored + err := testEngine.Sync(new(TestInsertNotDeletedStructNotRight)) + assert.NoError(t, err) + + type TestInsertNotDeletedStruct struct { + ID uint64 `xorm:"'ID' pk autoincr"` + DeletedAt time.Time `xorm:"'DELETED_AT' deleted"` + } + + assert.NoError(t, testEngine.Sync(new(TestInsertNotDeletedStruct))) + + var v1 TestInsertNotDeletedStructNotRight + _, err = testEngine.Insert(&v1) + assert.NoError(t, err) + + var v2 TestInsertNotDeletedStructNotRight + has, err := testEngine.Get(&v2) + assert.NoError(t, err) + assert.True(t, has) + assert.Equal(t, v2.DeletedAt.In(testEngine.GetTZDatabase()).Format("2006-01-02 15:04:05"), zeroTime.Format("2006-01-02 15:04:05")) + + var v3 TestInsertNotDeletedStruct + _, err = testEngine.Insert(&v3) + assert.NoError(t, err) + + var v4 TestInsertNotDeletedStruct + has, err = testEngine.Get(&v4) + assert.NoError(t, err) + assert.True(t, has) + assert.Equal(t, v4.DeletedAt.In(testEngine.GetTZDatabase()).Format("2006-01-02 15:04:05"), zeroTime.Format("2006-01-02 15:04:05")) +} + +type MyAutoTimeFields1 struct { + Id int64 + Dt time.Time `xorm:"created DATETIME"` +} + +func (MyAutoTimeFields1) TableName() string { + return "my_auto_time_fields" +} + +type MyAutoTimeFields2 struct { + Id int64 + Dt time.Time `xorm:"created"` +} + +func (MyAutoTimeFields2) TableName() string { + return "my_auto_time_fields" +} + +func TestAutoTimeFields(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(MyAutoTimeFields1)) + + _, err := testEngine.Insert(&MyAutoTimeFields1{}) + assert.NoError(t, err) + + var res []MyAutoTimeFields2 + assert.NoError(t, testEngine.Find(&res)) + assert.EqualValues(t, 1, len(res)) + + _, err = testEngine.Insert(&MyAutoTimeFields2{}) + assert.NoError(t, err) + + res = []MyAutoTimeFields2{} + assert.NoError(t, testEngine.Find(&res)) + assert.EqualValues(t, 2, len(res)) +} diff --git a/integrations/session_iterate_test.go b/tests/session_iterate_test.go similarity index 98% rename from integrations/session_iterate_test.go rename to tests/session_iterate_test.go index c5ecc593..f2e36899 100644 --- a/integrations/session_iterate_test.go +++ b/tests/session_iterate_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "testing" @@ -59,7 +59,7 @@ func TestBufferIterate(t *testing.T) { assert.NoError(t, testEngine.Sync(new(UserBufferIterate))) - var size = 20 + size := 20 for i := 0; i < size; i++ { cnt, err := testEngine.Insert(&UserBufferIterate{ IsMan: true, @@ -68,7 +68,7 @@ func TestBufferIterate(t *testing.T) { assert.EqualValues(t, 1, cnt) } - var cnt = 0 + cnt := 0 err := testEngine.BufferSize(9).Iterate(new(UserBufferIterate), func(i int, bean interface{}) error { user := bean.(*UserBufferIterate) assert.EqualValues(t, cnt+1, user.Id) diff --git a/integrations/session_pk_test.go b/tests/session_pk_test.go similarity index 97% rename from integrations/session_pk_test.go rename to tests/session_pk_test.go index 0244937f..43de4eea 100644 --- a/integrations/session_pk_test.go +++ b/tests/session_pk_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "sort" @@ -53,17 +53,21 @@ type StringPK struct { Name string } -type ID int64 -type MyIntPK struct { - ID ID `xorm:"pk autoincr"` - Name string -} +type ( + ID int64 + MyIntPK struct { + ID ID `xorm:"pk autoincr"` + Name string + } +) -type StrID string -type MyStringPK struct { - ID StrID `xorm:"pk notnull"` - Name string -} +type ( + StrID string + MyStringPK struct { + ID StrID `xorm:"pk notnull"` + Name string + } +) func TestIntId(t *testing.T) { assert.NoError(t, PrepareEngine()) @@ -187,7 +191,7 @@ func TestUintId(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var inserts = []UintId{ + inserts := []UintId{ {Name: "test1"}, {Name: "test2"}, } @@ -390,7 +394,7 @@ func TestCompositeKey(t *testing.T) { assert.True(t, has) assert.EqualValues(t, compositeKeyVal, compositeKeyVal2) - var cps = make([]CompositeKey, 0) + cps := make([]CompositeKey, 0) err = testEngine.Find(&cps) assert.NoError(t, err) assert.EqualValues(t, 1, len(cps)) @@ -460,13 +464,15 @@ func TestCompositeKey2(t *testing.T) { assert.EqualValues(t, 1, cnt) } -type MyString string -type UserPK2 struct { - UserId MyString `xorm:"varchar(19) not null pk"` - NickName string `xorm:"varchar(19) not null"` - GameId uint32 `xorm:"integer pk"` - Score int32 `xorm:"integer"` -} +type ( + MyString string + UserPK2 struct { + UserId MyString `xorm:"varchar(19) not null pk"` + NickName string `xorm:"varchar(19) not null"` + GameId uint32 `xorm:"integer pk"` + Score int32 `xorm:"integer"` + } +) func TestCompositeKey3(t *testing.T) { assert.NoError(t, PrepareEngine()) diff --git a/integrations/session_query_test.go b/tests/session_query_test.go similarity index 90% rename from integrations/session_query_test.go rename to tests/session_query_test.go index 00b7d7a6..5a3a3631 100644 --- a/integrations/session_query_test.go +++ b/tests/session_query_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "bytes" @@ -30,7 +30,7 @@ func TestQueryString(t *testing.T) { assert.NoError(t, testEngine.Sync(new(GetVar2))) - var data = GetVar2{ + data := GetVar2{ Msg: "hi", Age: 28, Money: 1.5, @@ -58,7 +58,7 @@ func TestQueryString2(t *testing.T) { assert.NoError(t, testEngine.Sync(new(GetVar3))) - var data = GetVar3{ + data := GetVar3{ Msg: false, } _, err := testEngine.Insert(data) @@ -95,7 +95,7 @@ func TestQueryInterface(t *testing.T) { assert.NoError(t, testEngine.Sync(new(GetVarInterface))) - var data = GetVarInterface{ + data := GetVarInterface{ Msg: "hi", Age: 28, Money: 1.5, @@ -128,7 +128,7 @@ func TestQueryNoParams(t *testing.T) { assert.NoError(t, testEngine.Sync(new(QueryNoParams))) - var q = QueryNoParams{ + q := QueryNoParams{ Msg: "message", Age: 20, Money: 3000, @@ -172,7 +172,7 @@ func TestQueryStringNoParam(t *testing.T) { assert.NoError(t, testEngine.Sync(new(GetVar4))) - var data = GetVar4{ + data := GetVar4{ Msg: false, } _, err := testEngine.Insert(data) @@ -209,7 +209,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { assert.NoError(t, testEngine.Sync(new(GetVar6))) - var data = GetVar6{ + data := GetVar6{ Msg: false, } _, err := testEngine.Insert(data) @@ -246,7 +246,7 @@ func TestQueryInterfaceNoParam(t *testing.T) { assert.NoError(t, testEngine.Sync(new(GetVar5))) - var data = GetVar5{ + data := GetVar5{ Msg: false, } _, err := testEngine.Insert(data) @@ -280,7 +280,7 @@ func TestQueryWithBuilder(t *testing.T) { assert.NoError(t, testEngine.Sync(new(QueryWithBuilder))) - var q = QueryWithBuilder{ + q := QueryWithBuilder{ Msg: "message", Age: 20, Money: 3000, @@ -329,14 +329,14 @@ func TestJoinWithSubQuery(t *testing.T) { assert.NoError(t, testEngine.Sync(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart))) - var depart = JoinWithSubQueryDepart{ + depart := JoinWithSubQueryDepart{ Name: "depart1", } cnt, err := testEngine.Insert(&depart) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var q = JoinWithSubQuery1{ + q := JoinWithSubQuery1{ Msg: "message", DepartId: depart.Id, Money: 3000, @@ -401,7 +401,7 @@ func TestQueryBLOBInMySQL(t *testing.T) { } const N = 10 - var data = []Avatar{} + data := []Avatar{} for i := 0; i < N; i++ { // allocate a []byte that is as twice big as the last one // so that the underlying buffer will need to reallocate when querying @@ -448,3 +448,54 @@ func TestQueryBLOBInMySQL(t *testing.T) { } } } + +func TestRowsReset(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type RowsReset1 struct { + Id int64 + Name string + } + + type RowsReset2 struct { + Id int64 + Name string + } + + assert.NoError(t, testEngine.Sync(new(RowsReset1), new(RowsReset2))) + + data := []RowsReset1{ + {0, "1"}, + {0, "2"}, + {0, "3"}, + } + _, err := testEngine.Insert(data) + assert.NoError(t, err) + + data2 := []RowsReset2{ + {0, "4"}, + {0, "5"}, + {0, "6"}, + } + _, err = testEngine.Insert(data2) + assert.NoError(t, err) + + sess := testEngine.NewSession() + defer sess.Close() + + rows, err := sess.Rows(new(RowsReset1)) + assert.NoError(t, err) + for rows.Next() { + var data1 RowsReset1 + assert.NoError(t, rows.Scan(&data1)) + } + rows.Close() + + var rrs []RowsReset2 + assert.NoError(t, sess.Find(&rrs)) + + assert.Len(t, rrs, 3) + assert.EqualValues(t, "4", rrs[0].Name) + assert.EqualValues(t, "5", rrs[1].Name) + assert.EqualValues(t, "6", rrs[2].Name) +} diff --git a/integrations/session_raw_test.go b/tests/session_raw_test.go similarity index 72% rename from integrations/session_raw_test.go rename to tests/session_raw_test.go index e53cd009..9bdecf9b 100644 --- a/integrations/session_raw_test.go +++ b/tests/session_raw_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "database/sql/driver" @@ -10,6 +10,8 @@ import ( "testing" "time" + "xorm.io/xorm/convert" + "github.com/stretchr/testify/assert" ) @@ -104,3 +106,48 @@ func TestExecDriverValuer(t *testing.T) { assert.Equal(t, "user", results[0]["name"]) assert.EqualValues(t, "data", results[0]["data"]) } + +type ConversionData struct { + MyData string +} + +var _ convert.Conversion = new(ConversionData) + +func (c ConversionData) ToDB() ([]byte, error) { + return []byte(c.MyData), nil +} + +func (c *ConversionData) FromDB(bs []byte) error { + if bs != nil { + c.MyData = string(bs) + } + return nil +} + +func TestExecCustomTypes(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoExec struct { + Uid int + Name string + Data string + } + + assert.NoError(t, testEngine.Sync2(new(UserinfoExec))) + + res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_exec`", true)+" (uid, name,data) VALUES (?, ?, ?)", + 1, "user", ConversionData{"data"}) + assert.NoError(t, err) + cnt, err := res.RowsAffected() + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + results, err := testEngine.QueryString("select * from " + testEngine.TableName("userinfo_exec", true)) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + id, err := strconv.Atoi(results[0]["uid"]) + assert.NoError(t, err) + assert.EqualValues(t, 1, id) + assert.Equal(t, "user", results[0]["name"]) + assert.EqualValues(t, "data", results[0]["data"]) +} diff --git a/integrations/session_sum_test.go b/tests/session_sum_test.go similarity index 92% rename from integrations/session_sum_test.go rename to tests/session_sum_test.go index e000233b..926269ee 100644 --- a/integrations/session_sum_test.go +++ b/tests/session_sum_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "fmt" @@ -25,13 +25,11 @@ func TestSum(t *testing.T) { assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync(new(SumStruct))) - var ( - cases = []SumStruct{ - {1, 6.2}, - {2, 5.3}, - {92, -0.2}, - } - ) + cases := []SumStruct{ + {1, 6.2}, + {2, 5.3}, + {92, -0.2}, + } var i int var f float32 @@ -84,13 +82,11 @@ func TestSumWithTableName(t *testing.T) { assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync(new(SumStructWithTableName))) - var ( - cases = []SumStructWithTableName{ - {1, 6.2}, - {2, 5.3}, - {92, -0.2}, - } - ) + cases := []SumStructWithTableName{ + {1, 6.2}, + {2, 5.3}, + {92, -0.2}, + } var i int var f float32 @@ -138,13 +134,11 @@ func TestSumCustomColumn(t *testing.T) { Float float32 } - var ( - cases = []SumStruct2{ - {1, 6.2}, - {2, 5.3}, - {92, -0.2}, - } - ) + cases := []SumStruct2{ + {1, 6.2}, + {2, 5.3}, + {92, -0.2}, + } assert.NoError(t, testEngine.Sync(new(SumStruct2))) diff --git a/integrations/session_test.go b/tests/session_test.go similarity index 98% rename from integrations/session_test.go rename to tests/session_test.go index a36b81bf..261f2c48 100644 --- a/integrations/session_test.go +++ b/tests/session_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "database/sql" diff --git a/integrations/session_tx_test.go b/tests/session_tx_test.go similarity index 98% rename from integrations/session_tx_test.go rename to tests/session_tx_test.go index 890e755d..c9db40ba 100644 --- a/integrations/session_tx_test.go +++ b/tests/session_tx_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "fmt" @@ -24,7 +24,7 @@ func TestTransaction(t *testing.T) { } counter(t) - //defer counter() + // defer counter() session := testEngine.NewSession() defer session.Close() @@ -58,7 +58,7 @@ func TestCombineTransaction(t *testing.T) { } counter() - //defer counter() + // defer counter() session := testEngine.NewSession() defer session.Close() @@ -187,7 +187,6 @@ func TestMultipleTransaction(t *testing.T) { } func TestInsertMulti2InterfaceTransaction(t *testing.T) { - type Multi2InterfaceTransaction struct { ID uint64 `xorm:"id pk autoincr"` Name string diff --git a/integrations/session_update_test.go b/tests/session_update_test.go similarity index 96% rename from integrations/session_update_test.go rename to tests/session_update_test.go index 45338cad..c13468d9 100644 --- a/integrations/session_update_test.go +++ b/tests/session_update_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "fmt" @@ -28,7 +28,7 @@ func TestUpdateMap(t *testing.T) { } assert.NoError(t, testEngine.Sync(new(UpdateTable))) - var tb = UpdateTable{ + tb := UpdateTable{ Name: "test", Age: 35, } @@ -79,7 +79,7 @@ func TestUpdateLimit(t *testing.T) { } assert.NoError(t, testEngine.Sync(new(UpdateTable2))) - var tb = UpdateTable2{ + tb := UpdateTable2{ Name: "test1", Age: 35, } @@ -400,7 +400,7 @@ func TestUpdate1(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var s = "test" + s := "test" col1 := &UpdateAllCols{Ptr: &s} err = testEngine.Sync(col1) @@ -864,7 +864,7 @@ func TestCreatedUpdated2(t *testing.T) { assertSync(t, new(CreatedUpdatedStruct)) - var s = CreatedUpdatedStruct{ + s := CreatedUpdatedStruct{ Name: "test", } cnt, err := testEngine.Insert(&s) @@ -874,7 +874,7 @@ func TestCreatedUpdated2(t *testing.T) { time.Sleep(time.Second) - var s1 = CreatedUpdatedStruct{ + s1 := CreatedUpdatedStruct{ Name: "test1", CreateAt: s.CreateAt, UpdateAt: s.UpdateAt, @@ -907,7 +907,7 @@ func TestDeletedUpdate(t *testing.T) { assertSync(t, new(DeletedUpdatedStruct)) - var s = DeletedUpdatedStruct{ + s := DeletedUpdatedStruct{ Name: "test", } cnt, err := testEngine.Insert(&s) @@ -956,7 +956,7 @@ func TestUpdateMapCondition(t *testing.T) { assertSync(t, new(UpdateMapCondition)) - var c = UpdateMapCondition{ + c := UpdateMapCondition{ String: "string", } _, err := testEngine.Insert(&c) @@ -990,7 +990,7 @@ func TestUpdateMapContent(t *testing.T) { assertSync(t, new(UpdateMapContent)) - var c = UpdateMapContent{ + c := UpdateMapContent{ Name: "lunny", IsMan: true, Gender: 1, @@ -1126,7 +1126,7 @@ func TestUpdateDeleted(t *testing.T) { assertSync(t, new(UpdateDeletedStruct)) - var s = UpdateDeletedStruct{ + s := UpdateDeletedStruct{ Name: "test", } cnt, err := testEngine.Insert(&s) @@ -1232,7 +1232,7 @@ func TestUpdateExprs2(t *testing.T) { assertSync(t, new(UpdateExprsRelease)) - var uer = UpdateExprsRelease{ + uer := UpdateExprsRelease{ RepoId: 1, IsTag: false, IsDraft: false, @@ -1407,7 +1407,7 @@ func TestNilFromDB(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(TestTable1)) - var tt0 = TestTable1{ + tt0 := TestTable1{ Field1: &TestFieldType1{ cb: []byte("string"), }, @@ -1437,7 +1437,7 @@ func TestNilFromDB(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var tt = TestTable1{ + tt := TestTable1{ UpdateTime: time.Now(), Field1: &TestFieldType1{ cb: nil, @@ -1453,7 +1453,7 @@ func TestNilFromDB(t *testing.T) { assert.True(t, has) assert.Nil(t, tt2.Field1) - var tt3 = TestTable1{ + tt3 := TestTable1{ UpdateTime: time.Now(), Field1: &TestFieldType1{ cb: []byte{}, @@ -1470,3 +1470,34 @@ func TestNilFromDB(t *testing.T) { assert.NotNil(t, tt4.Field1) assert.NotNil(t, tt4.Field1.cb) } + +/* +func TestUpdateWithJoin(t *testing.T) { + type TestUpdateWithJoin struct { + Id int64 + ExtId int64 + Name string + } + + type TestUpdateWithJoin2 struct { + Id int64 + Name string + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestUpdateWithJoin), new(TestUpdateWithJoin2)) + + b := TestUpdateWithJoin2{Name: "test"} + _, err := testEngine.Insert(&b) + assert.NoError(t, err) + + _, err = testEngine.Insert(&TestUpdateWithJoin{ExtId: b.Id, Name: "test"}) + assert.NoError(t, err) + + _, err = testEngine.Table("test_update_with_join"). + Join("INNER", "test_update_with_join2", "test_update_with_join.ext_id = test_update_with_join2.id"). + Where("test_update_with_join2.name = ?", "test"). + Update(&TestUpdateWithJoin{Name: "test2"}) + assert.NoError(t, err) +} +*/ diff --git a/integrations/tags_test.go b/tests/tags_test.go similarity index 99% rename from integrations/tags_test.go rename to tests/tags_test.go index 4c33d56c..14803462 100644 --- a/integrations/tags_test.go +++ b/tests/tags_test.go @@ -2,16 +2,16 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "fmt" "sort" - "strings" "testing" "time" "github.com/stretchr/testify/assert" + "xorm.io/xorm/convert" "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" "xorm.io/xorm/schemas" @@ -1201,8 +1201,10 @@ func TestTagTime(t *testing.T) { has, err = testEngine.Table("tag_u_t_c_struct").Cols("created").Get(&tm) assert.NoError(t, err) assert.True(t, has) - assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), - strings.ReplaceAll(strings.ReplaceAll(tm, "T", " "), "Z", "")) + + tmTime, err := convert.String2Time(tm, time.UTC, time.UTC) + assert.NoError(t, err) + assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), tmTime.Format("2006-01-02 15:04:05")) } func TestTagAutoIncr(t *testing.T) { diff --git a/integrations/testdata/import1.sql b/tests/testdata/import1.sql similarity index 100% rename from integrations/testdata/import1.sql rename to tests/testdata/import1.sql diff --git a/integrations/testdata/import2.sql b/tests/testdata/import2.sql similarity index 100% rename from integrations/testdata/import2.sql rename to tests/testdata/import2.sql diff --git a/integrations/tests.go b/tests/tests.go similarity index 98% rename from integrations/tests.go rename to tests/tests.go index 59f4b29a..220e1c67 100644 --- a/integrations/tests.go +++ b/tests/tests.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "database/sql" @@ -162,7 +162,7 @@ func createEngine(dbType, connStr string) error { if err != nil { return err } - var tableNames = make([]interface{}, 0, len(tables)) + tableNames := make([]interface{}, 0, len(tables)) for _, table := range tables { tableNames = append(tableNames, table.Name) } diff --git a/integrations/time_test.go b/tests/time_test.go similarity index 96% rename from integrations/time_test.go rename to tests/time_test.go index 5a17417a..13b9ed15 100644 --- a/integrations/time_test.go +++ b/tests/time_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "fmt" @@ -10,6 +10,7 @@ import ( "strings" "testing" "time" + "xorm.io/xorm/convert" "xorm.io/xorm/internal/utils" @@ -18,7 +19,7 @@ import ( ) func formatTime(t time.Time, scales ...int) string { - var layout = "2006-01-02 15:04:05" + layout := "2006-01-02 15:04:05" if len(scales) > 0 && scales[0] > 0 { layout += "." + strings.Repeat("0", scales[0]) } @@ -35,7 +36,7 @@ func TestTimeUserTime(t *testing.T) { assertSync(t, new(TimeUser)) - var user = TimeUser{ + user := TimeUser{ Id: "lunny", OperTime: time.Now(), } @@ -80,7 +81,7 @@ func TestTimeUserTimeDiffLoc(t *testing.T) { assertSync(t, new(TimeUser2)) - var user = TimeUser2{ + user := TimeUser2{ Id: "lunny", OperTime: time.Now(), } @@ -110,7 +111,7 @@ func TestTimeUserCreated(t *testing.T) { assertSync(t, new(UserCreated)) - var user = UserCreated{ + user := UserCreated{ Id: "lunny", } @@ -154,7 +155,7 @@ func TestTimeUserCreatedDiffLoc(t *testing.T) { assertSync(t, new(UserCreated2)) - var user = UserCreated2{ + user := UserCreated2{ Id: "lunny", } @@ -184,7 +185,7 @@ func TestTimeUserUpdated(t *testing.T) { assertSync(t, new(UserUpdated)) - var user = UserUpdated{ + user := UserUpdated{ Id: "lunny", } @@ -204,7 +205,7 @@ func TestTimeUserUpdated(t *testing.T) { assert.EqualValues(t, formatTime(user.UpdatedAt), formatTime(user2.UpdatedAt)) fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt) - var user3 = UserUpdated{ + user3 := UserUpdated{ Id: "lunny2", } @@ -250,7 +251,7 @@ func TestTimeUserUpdatedDiffLoc(t *testing.T) { assertSync(t, new(UserUpdated2)) - var user = UserUpdated2{ + user := UserUpdated2{ Id: "lunny", } @@ -270,7 +271,7 @@ func TestTimeUserUpdatedDiffLoc(t *testing.T) { assert.EqualValues(t, formatTime(user.UpdatedAt), formatTime(user2.UpdatedAt)) fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt) - var user3 = UserUpdated2{ + user3 := UserUpdated2{ Id: "lunny2", } @@ -304,7 +305,7 @@ func TestTimeUserDeleted(t *testing.T) { assertSync(t, new(UserDeleted)) - var user = UserDeleted{ + user := UserDeleted{ Id: "lunny", } @@ -367,7 +368,7 @@ func TestTimeUserDeletedDiffLoc(t *testing.T) { assertSync(t, new(UserDeleted2)) - var user = UserDeleted2{ + user := UserDeleted2{ Id: "lunny", } @@ -412,7 +413,7 @@ func (j JSONDate) MarshalJSON() ([]byte, error) { } func (j *JSONDate) UnmarshalJSON(value []byte) error { - var v = strings.TrimSpace(strings.Trim(string(value), "\"")) + v := strings.TrimSpace(strings.Trim(string(value), "\"")) t, err := time.ParseInLocation("2006-01-02 15:04:05", v, time.Local) if err != nil { @@ -438,7 +439,7 @@ func TestCustomTimeUserDeleted(t *testing.T) { assertSync(t, new(UserDeleted3)) - var user = UserDeleted3{ + user := UserDeleted3{ Id: "lunny", } @@ -500,7 +501,7 @@ func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { assertSync(t, new(UserDeleted4)) - var user = UserDeleted4{ + user := UserDeleted4{ Id: "lunny", } @@ -583,7 +584,7 @@ func TestTimestamp(t *testing.T) { assertSync(t, new(TimestampStruct)) - var d1 = TimestampStruct{ + d1 := TimestampStruct{ InsertTime: time.Now(), } cnt, err := testEngine.Insert(&d1) @@ -625,10 +626,10 @@ func TestTimestamp(t *testing.T) { func TestString2Time(t *testing.T) { loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) - var timeTmp1 = time.Date(2023, 7, 14, 11, 30, 0, 0, loc) - var timeTmp2 = time.Date(2023, 7, 14, 0, 0, 0, 0, loc) - var time1StampStr = strconv.FormatInt(timeTmp1.Unix(), 10) - var timeStr = "0000-00-00 00:00:00" + timeTmp1 := time.Date(2023, 7, 14, 11, 30, 0, 0, loc) + timeTmp2 := time.Date(2023, 7, 14, 0, 0, 0, 0, loc) + time1StampStr := strconv.FormatInt(timeTmp1.Unix(), 10) + timeStr := "0000-00-00 00:00:00" dt, err := convert.String2Time(timeStr, time.Local, time.Local) assert.NoError(t, err) assert.True(t, dt.Nanosecond() == 0) diff --git a/integrations/types_null_test.go b/tests/types_null_test.go similarity index 99% rename from integrations/types_null_test.go rename to tests/types_null_test.go index 8d98b456..d4fa250e 100644 --- a/integrations/types_null_test.go +++ b/tests/types_null_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "database/sql" diff --git a/integrations/types_test.go b/tests/types_test.go similarity index 99% rename from integrations/types_test.go rename to tests/types_test.go index 1c815b7a..dfdb4766 100644 --- a/integrations/types_test.go +++ b/tests/types_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package integrations +package tests import ( "errors"