From eeacd22674314a0712f91033c91185a33c83cacb Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 13 Sep 2023 02:02:12 +0000 Subject: [PATCH 01/10] Fix ci (#2330) Reviewed-on: https://gitea.com/xorm/xorm/pulls/2330 --- .gitea/workflows/release-tag.yml | 4 ++-- .gitea/workflows/test-cockroach.yml | 2 +- .gitea/workflows/test-mariadb.yml | 2 +- .gitea/workflows/test-mssql.yml | 2 +- .gitea/workflows/test-mysql.yml | 2 +- .gitea/workflows/test-mysql8.yml | 2 +- .gitea/workflows/test-postgres.yml | 2 +- .gitea/workflows/test-sqlite.yml | 2 +- .gitea/workflows/test-tidb.yml | 2 +- 9 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.gitea/workflows/release-tag.yml b/.gitea/workflows/release-tag.yml index 10ed831e..3b788e76 100644 --- a/.gitea/workflows/release-tag.yml +++ b/.gitea/workflows/release-tag.yml @@ -13,11 +13,11 @@ jobs: with: fetch-depth: 0 - name: setup go - uses: https://github.com/actions/setup-go@v4 + uses: actions/setup-go@v4 with: go-version: '>=1.20.1' - name: Use Go Action id: use-go-action - uses: actions/release-action@main + uses: https://gitea.com/actions/release-action@main with: api_key: '${{secrets.RELEASE_TOKEN}}' \ No newline at end of file diff --git a/.gitea/workflows/test-cockroach.yml b/.gitea/workflows/test-cockroach.yml index ba966dc9..cfcda89d 100644 --- a/.gitea/workflows/test-cockroach.yml +++ b/.gitea/workflows/test-cockroach.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test cockroach env: TEST_COCKROACH_HOST: "cockroach:26257" diff --git a/.gitea/workflows/test-mariadb.yml b/.gitea/workflows/test-mariadb.yml index 466f3858..dbc819db 100644 --- a/.gitea/workflows/test-mariadb.yml +++ b/.gitea/workflows/test-mariadb.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test mariadb env: TEST_MYSQL_HOST: mariadb diff --git a/.gitea/workflows/test-mssql.yml b/.gitea/workflows/test-mssql.yml index d02e6956..04b8031a 100644 --- a/.gitea/workflows/test-mssql.yml +++ b/.gitea/workflows/test-mssql.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test mssql env: TEST_MSSQL_HOST: mssql diff --git a/.gitea/workflows/test-mysql.yml b/.gitea/workflows/test-mysql.yml index 03ee2725..e13354f0 100644 --- a/.gitea/workflows/test-mysql.yml +++ b/.gitea/workflows/test-mysql.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test mysql utf8mb4 env: TEST_MYSQL_HOST: mysql diff --git a/.gitea/workflows/test-mysql8.yml b/.gitea/workflows/test-mysql8.yml index 3fbd7c30..7362065a 100644 --- a/.gitea/workflows/test-mysql8.yml +++ b/.gitea/workflows/test-mysql8.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test mysql8 env: TEST_MYSQL_HOST: mysql8 diff --git a/.gitea/workflows/test-postgres.yml b/.gitea/workflows/test-postgres.yml index 89aa72c3..d4abb2ad 100644 --- a/.gitea/workflows/test-postgres.yml +++ b/.gitea/workflows/test-postgres.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test postgres env: TEST_PGSQL_HOST: pgsql diff --git a/.gitea/workflows/test-sqlite.yml b/.gitea/workflows/test-sqlite.yml index cca2e786..164acc10 100644 --- a/.gitea/workflows/test-sqlite.yml +++ b/.gitea/workflows/test-sqlite.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: vet run: make vet - name: format check diff --git a/.gitea/workflows/test-tidb.yml b/.gitea/workflows/test-tidb.yml index fa6e27ad..ce898dcb 100644 --- a/.gitea/workflows/test-tidb.yml +++ b/.gitea/workflows/test-tidb.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: 1.20 - - uses: https://github.com/actions/checkout@v3 + - uses: actions/checkout@v3 - name: test tidb env: TEST_TIDB_HOST: "tidb:4000" From 407375c9b466dc551868f95ab7feb25e07d3ffc1 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 16 Sep 2023 13:48:49 +0000 Subject: [PATCH 02/10] Add test for max ( id ) (#2316) Reviewed-on: https://gitea.com/xorm/xorm/pulls/2316 --- tests/session_find_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/session_find_test.go b/tests/session_find_test.go index 2a754e2a..d991e6ba 100644 --- a/tests/session_find_test.go +++ b/tests/session_find_test.go @@ -1237,3 +1237,20 @@ func TestBuilderDialect(t *testing.T) { err := testEngine.Table("test_builder_dialect").Where(builder.Eq{"age2": 2}).Join("INNER", inner, "test_builder_dialect_foo.dialect_id = test_builder_dialect.id").Find(&result) assert.NoError(t, err) } + +func TestFindInMaxID(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestFindInMaxId struct { + Id int64 + Name string `xorm:"index"` + Age2 int + } + + assertSync(t, new(TestFindInMaxId)) + + var res []TestFindInMaxId + tableName := testEngine.TableName("test_find_in_max_id", true) + err := testEngine.In("id", builder.Select("max(id)").From(testEngine.Quote(tableName))).Find(&res) + assert.NoError(t, err) +} From 2885c88b77c37369c4d8edd66e87b9eed5ae21fd Mon Sep 17 00:00:00 2001 From: zzdboy <28206697@qq.com> Date: Sat, 16 Sep 2023 13:49:19 +0000 Subject: [PATCH 03/10] fix PostgreSQL version (#2332) Reviewed-on: https://gitea.com/xorm/xorm/pulls/2332 Co-authored-by: zzdboy <28206697@qq.com> Co-committed-by: zzdboy <28206697@qq.com> --- dialects/postgres.go | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/dialects/postgres.go b/dialects/postgres.go index f1f6a2f2..53f66184 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -821,6 +821,7 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas } // Postgres: 9.5.22 on x86_64-pc-linux-gnu (Debian 9.5.22-1.pgdg90+1), compiled by gcc (Debian 6.3.0-18+deb9u1) 6.3.0 20170516, 64-bit + // Postgres: PostgreSQL 15.3, compiled by Visual C++ build 1914, 64-bit // CockroachDB CCL v19.2.4 (x86_64-unknown-linux-gnu, built if strings.HasPrefix(version, "CockroachDB") { versions := strings.Split(strings.TrimPrefix(version, "CockroachDB CCL "), " ") @@ -829,12 +830,22 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas Edition: "CockroachDB", }, nil } else if strings.HasPrefix(version, "PostgreSQL") { - versions := strings.Split(strings.TrimPrefix(version, "PostgreSQL "), " on ") - return &schemas.Version{ - Number: versions[0], - Level: versions[1], - Edition: "PostgreSQL", - }, nil + if strings.Contains(version, " on ") { + versions := strings.Split(strings.TrimPrefix(version, "PostgreSQL "), " on ") + return &schemas.Version{ + Number: versions[0], + Level: versions[1], + Edition: "PostgreSQL", + }, nil + } else { + versions := strings.Split(strings.TrimPrefix(version, "PostgreSQL "), ",") + return &schemas.Version{ + Number: versions[0], + Level: versions[1], + Edition: "PostgreSQL", + }, nil + } + } return nil, errors.New("unknow database version") From e5be0f4129217de60b5dcf0976a282fb1b781690 Mon Sep 17 00:00:00 2001 From: 6543 <6543@obermui.de> Date: Sat, 16 Sep 2023 14:41:02 +0000 Subject: [PATCH 04/10] Remove dead code from session.SyncWithOptions() (#2323) https://gitea.com/xorm/xorm/src/commit/db7c2640627d24539aa4607f50bcba7037ddd9e6/sync.go#L229-L231 as oriIndex only is **not** nil if index.Equal(index2) and index.Equal(index2) check if `oriIndex.Type == index.Type` ... so it always is false Co-authored-by: Lunny Xiao Reviewed-on: https://gitea.com/xorm/xorm/pulls/2323 Reviewed-by: Lunny Xiao Co-authored-by: 6543 <6543@obermui.de> Co-committed-by: 6543 <6543@obermui.de> --- sync.go | 9 --------- tests/sync_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) create mode 100644 tests/sync_test.go diff --git a/sync.go b/sync.go index 9e1cb8c1..adc2d859 100644 --- a/sync.go +++ b/sync.go @@ -235,15 +235,6 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) } } - if oriIndex != nil && oriIndex.Type != index.Type { - sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex) - _, err = session.exec(sql) - if err != nil { - return nil, err - } - oriIndex = nil - } - if oriIndex == nil { addedNames[name] = index } diff --git a/tests/sync_test.go b/tests/sync_test.go new file mode 100644 index 00000000..dedd3343 --- /dev/null +++ b/tests/sync_test.go @@ -0,0 +1,34 @@ +// Copyright 2023 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tests + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type TestSync1 struct { + Id int64 + ClassId int64 `xorm:"index"` +} + +func (TestSync1) TableName() string { + return "test_sync" +} + +type TestSync2 struct { + Id int64 + ClassId int64 `xorm:"unique"` +} + +func (TestSync2) TableName() string { + return "test_sync" +} + +func TestSync(t *testing.T) { + assert.NoError(t, testEngine.Sync(new(TestSync1))) + assert.NoError(t, testEngine.Sync(new(TestSync2))) +} From ac88a5705aafc0109221cf98c74b4304b66d3a76 Mon Sep 17 00:00:00 2001 From: Ryan Liu Date: Sat, 16 Sep 2023 15:43:12 +0000 Subject: [PATCH 05/10] fix the error in mysql: unknown colType UNSIGNED FLOAT (#2143) Error: unknown colType UNSIGNED FLOAT in mysql 5.6.27 or 5.7.32 Fix #2123 Co-authored-by: Lunny Xiao Reviewed-on: https://gitea.com/xorm/xorm/pulls/2143 Co-authored-by: Ryan Liu Co-committed-by: Ryan Liu --- dialects/mssql.go | 2 +- dialects/mysql.go | 8 +++++--- dialects/postgres.go | 2 +- schemas/type.go | 22 +++++++++++---------- tests/session_pk_test.go | 42 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 61 insertions(+), 15 deletions(-) diff --git a/dialects/mssql.go b/dialects/mssql.go index 2c64e637..1413b441 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -328,7 +328,7 @@ func (db *mssql) SQLType(c *schemas.Column) string { res = schemas.Int case schemas.Text, schemas.MediumText, schemas.TinyText, schemas.LongText, schemas.Json: res = db.defaultVarchar + "(MAX)" - case schemas.Double: + case schemas.Double, schemas.UnsignedFloat: res = schemas.Real case schemas.Uuid: res = schemas.Varchar diff --git a/dialects/mysql.go b/dialects/mysql.go index 6b92752b..2c061a14 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -319,6 +319,9 @@ func (db *mysql) SQLType(c *schemas.Column) string { case schemas.UnsignedTinyInt: res = schemas.TinyInt isUnsigned = true + case schemas.UnsignedFloat: + res = schemas.Float + isUnsigned = true default: res = t } @@ -510,11 +513,10 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName } col.Length = len1 col.Length2 = len2 - if _, ok := schemas.SqlTypes[colType]; ok { - col.SQLType = schemas.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2} - } else { + if _, ok := schemas.SqlTypes[colType]; !ok { return nil, nil, fmt.Errorf("unknown colType %v", colType) } + col.SQLType = schemas.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2} if colKey == "PRI" { col.IsPrimaryKey = true diff --git a/dialects/postgres.go b/dialects/postgres.go index 53f66184..03966f2d 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -928,7 +928,7 @@ func (db *postgres) SQLType(c *schemas.Column) string { return schemas.Uuid case schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob: return schemas.Bytea - case schemas.Double: + case schemas.Double, schemas.UnsignedFloat: return "DOUBLE PRECISION" default: if c.IsAutoIncrement { diff --git a/schemas/type.go b/schemas/type.go index b8b30851..3dbcee7e 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -139,9 +139,10 @@ var ( Money = "MONEY" SmallMoney = "SMALLMONEY" - Real = "REAL" - Float = "FLOAT" - Double = "DOUBLE" + Real = "REAL" + Float = "FLOAT" + UnsignedFloat = "UNSIGNED FLOAT" + Double = "DOUBLE" Binary = "BINARY" VarBinary = "VARBINARY" @@ -208,13 +209,14 @@ var ( SmallDateTime: TIME_TYPE, Year: TIME_TYPE, - Decimal: NUMERIC_TYPE, - Numeric: NUMERIC_TYPE, - Real: NUMERIC_TYPE, - Float: NUMERIC_TYPE, - Double: NUMERIC_TYPE, - Money: NUMERIC_TYPE, - SmallMoney: NUMERIC_TYPE, + Decimal: NUMERIC_TYPE, + Numeric: NUMERIC_TYPE, + Real: NUMERIC_TYPE, + Float: NUMERIC_TYPE, + UnsignedFloat: NUMERIC_TYPE, + Double: NUMERIC_TYPE, + Money: NUMERIC_TYPE, + SmallMoney: NUMERIC_TYPE, Binary: BLOB_TYPE, VarBinary: BLOB_TYPE, diff --git a/tests/session_pk_test.go b/tests/session_pk_test.go index 43de4eea..baf84547 100644 --- a/tests/session_pk_test.go +++ b/tests/session_pk_test.go @@ -325,6 +325,48 @@ func TestUint64Id(t *testing.T) { assert.EqualValues(t, 1, cnt) } +func TestUnsignedfloat(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UnsignedFloat struct { + Id int64 + UnsignedFloat float64 `xorm:"UNSIGNED FLOAT"` + } + + err := testEngine.DropTables(&UnsignedFloat{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&UnsignedFloat{}) + assert.NoError(t, err) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + assert.EqualValues(t, 1, len(tables)) + cols := tables[0].Columns() + assert.EqualValues(t, 2, len(cols)) + if testEngine.Dialect().URI().DBType == schemas.MYSQL { + assert.EqualValues(t, "UNSIGNED FLOAT", cols[1].SQLType.Name) + } + + idbean := &UnsignedFloat{UnsignedFloat: 12345678.90123456} + cnt, err := testEngine.Insert(idbean) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + bean := new(UnsignedFloat) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, bean.Id, idbean.Id) + + beans := make([]UnsignedFloat, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + assert.EqualValues(t, *bean, beans[0]) +} + func TestStringPK(t *testing.T) { assert.NoError(t, PrepareEngine()) From 551de3767c94c023acf5d2e915c495b376492f5b Mon Sep 17 00:00:00 2001 From: FlyingOnion <731677080@qq.com> Date: Wed, 20 Sep 2023 02:07:03 +0000 Subject: [PATCH 06/10] modify limit offset implement (#2188) Oracle and SQLServer specific: When `LIMIT OFFSET` function is needed, use `OFFSET ROWS FETCH NEXT ROWS ONLY` to replace legacy subquery. SQLServer specific: When `ORDER BY` is not set and `OFFSET FETCH` is set, set `statement.orderStr` to `1` (`ORDER BY 1`). See [here](https://learn.microsoft.com/zh-cn/sql/t-sql/queries/select-order-by-clause-transact-sql?view=sql-server-ver16). MySQL specific: When limit is 0 and offset > 0, use `LIMIT 9223372036854775807` ($2^{63}-1$). See comments [here](https://gitea.com/xorm/xorm/src/commit/15d171ea55a011eec76910353d5d8d17397d78e4/internal/statements/query.go#L314). Reviewed-on: https://gitea.com/xorm/xorm/pulls/2188 Reviewed-by: Lunny Xiao Co-authored-by: FlyingOnion <731677080@qq.com> Co-committed-by: FlyingOnion <731677080@qq.com> --- dialects/mssql.go | 10 +++ dialects/oracle.go | 12 +++ internal/statements/legacy_select.go | 59 +++++++++++++++ internal/statements/query.go | 109 ++++++++++++++++----------- tests/session_count_test.go | 9 +-- tests/session_query_test.go | 5 -- 6 files changed, 149 insertions(+), 55 deletions(-) create mode 100644 internal/statements/legacy_select.go diff --git a/dialects/mssql.go b/dialects/mssql.go index 1413b441..e4edc466 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -217,6 +217,7 @@ type mssql struct { Base defaultVarchar string defaultChar string + useLegacy bool } func (db *mssql) Init(uri *URI) error { @@ -226,6 +227,8 @@ func (db *mssql) Init(uri *URI) error { return db.Base.Init(db, uri) } +func (db *mssql) UseLegacyLimitOffset() bool { return db.useLegacy } + func (db *mssql) SetParams(params map[string]string) { defaultVarchar, ok := params["DEFAULT_VARCHAR"] if ok { @@ -252,6 +255,13 @@ func (db *mssql) SetParams(params map[string]string) { } else { db.defaultChar = "CHAR" } + + useLegacy, ok := params["USE_LEGACY_LIMIT_OFFSET"] + if ok { + if b, _ := strconv.ParseBool(useLegacy); b { + db.useLegacy = true + } + } } func (db *mssql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { diff --git a/dialects/oracle.go b/dialects/oracle.go index fbda9dda..ac0fb944 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -509,6 +509,7 @@ var ( type oracle struct { Base + useLegacy bool } func (db *oracle) Init(uri *URI) error { @@ -516,6 +517,17 @@ func (db *oracle) Init(uri *URI) error { return db.Base.Init(db, uri) } +func (db *oracle) UseLegacyLimitOffset() bool { return db.useLegacy } + +func (db *oracle) SetParams(params map[string]string) { + useLegacy, ok := params["USE_LEGACY_LIMIT_OFFSET"] + if ok { + if b, _ := strconv.ParseBool(useLegacy); b { + db.useLegacy = true + } + } +} + func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { rows, err := queryer.QueryContext(ctx, "select * from v$version where banner like 'Oracle%'") if err != nil { diff --git a/internal/statements/legacy_select.go b/internal/statements/legacy_select.go new file mode 100644 index 00000000..1015839e --- /dev/null +++ b/internal/statements/legacy_select.go @@ -0,0 +1,59 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + + "xorm.io/builder" +) + +// isUsingLegacy returns true if xorm uses legacy LIMIT OFFSET. +// It's only available in sqlserver and oracle, if param USE_LEGACY_LIMIT_OFFSET is set to "true" +func (statement *Statement) isUsingLegacyLimitOffset() bool { + u, ok := statement.dialect.(interface{ UseLegacyLimitOffset() bool }) + return ok && u.UseLegacyLimitOffset() +} + +func (statement *Statement) writeSelectWithFns(buf *builder.BytesWriter, writeFuncs ...func(*builder.BytesWriter) error) (err error) { + for _, fn := range writeFuncs { + if err = fn(buf); err != nil { + return + } + } + return +} + +// write mssql legacy query sql +func (statement *Statement) writeMssqlLegacySelect(buf *builder.BytesWriter, columnStr string) error { + writeFns := []func(*builder.BytesWriter) error{ + func(bw *builder.BytesWriter) (err error) { + _, err = fmt.Fprintf(bw, "SELECT") + return + }, + func(bw *builder.BytesWriter) error { return statement.writeDistinct(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeTop(bw) }, + statement.writeFrom, + statement.writeWhereWithMssqlPagination, + func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeOrderBys(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, + } + return statement.writeSelectWithFns(buf, writeFns...) +} + +func (statement *Statement) writeOracleLegacySelect(buf *builder.BytesWriter, columnStr string) error { + writeFns := []func(*builder.BytesWriter) error{ + func(bw *builder.BytesWriter) error { return statement.writeSelectColumns(bw, columnStr) }, + statement.writeFrom, + func(bw *builder.BytesWriter) error { return statement.writeOracleLimit(bw, columnStr) }, + func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeOrderBys(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, + } + return statement.writeSelectWithFns(buf, writeFns...) +} diff --git a/internal/statements/query.go b/internal/statements/query.go index 216a2028..c8384760 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -35,7 +35,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int } buf := builder.NewWriter() - if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil { + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -66,7 +66,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri } buf := builder.NewWriter() - if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true, true); err != nil { + if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -122,7 +122,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, } buf := builder.NewWriter() - if err := statement.writeSelect(buf, columnStr, true, true); err != nil { + if err := statement.writeSelect(buf, columnStr, true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -168,7 +168,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa subQuerySelect = selectSQL } - if err := statement.writeSelect(buf, subQuerySelect, false, false); err != nil { + if err := statement.writeSelect(buf, subQuerySelect, false); err != nil { return "", nil, err } @@ -200,7 +200,7 @@ func (statement *Statement) writeLimitOffset(w builder.Writer) error { _, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start) return err } - _, err := fmt.Fprintf(w, " LIMIT 0 OFFSET %v", statement.Start) + _, err := fmt.Fprintf(w, " OFFSET %v", statement.Start) return err } if statement.LimitN != nil { @@ -211,10 +211,20 @@ func (statement *Statement) writeLimitOffset(w builder.Writer) error { return nil } -func (statement *Statement) writeTop(w builder.Writer) error { - if statement.dialect.URI().DBType != schemas.MSSQL { - return nil +func (statement *Statement) writeOffsetFetch(w builder.Writer) error { + if statement.LimitN != nil { + _, err := fmt.Fprintf(w, " OFFSET %v ROWS FETCH NEXT %v ROWS ONLY", statement.Start, *statement.LimitN) + return err } + if statement.Start > 0 { + _, err := fmt.Fprintf(w, " OFFSET %v ROWS", statement.Start) + return err + } + return nil +} + +// write "TOP " (mssql only) +func (statement *Statement) writeTop(w builder.Writer) error { if statement.LimitN == nil { return nil } @@ -237,9 +247,6 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr if err := statement.writeDistinct(w); err != nil { return err } - if err := statement.writeTop(w); err != nil { - return err - } _, err := fmt.Fprint(w, " ", columnStr) return err } @@ -284,8 +291,10 @@ func (statement *Statement) writeForUpdate(w io.Writer) error { return err } +// write subquery to implement limit offset +// (mssql legacy only) func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) error { - if statement.dialect.URI().DBType != schemas.MSSQL || statement.Start <= 0 { + if statement.Start <= 0 { return nil } @@ -365,41 +374,55 @@ func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr s return err } -func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit, needOrderBy bool) error { - if err := statement.writeSelectColumns(buf, columnStr); err != nil { - return err - } - if err := statement.writeFrom(buf); err != nil { - return err - } - if err := statement.writeWhereWithMssqlPagination(buf); err != nil { - return err - } - if err := statement.writeGroupBy(buf); err != nil { - return err - } - if err := statement.writeHaving(buf); err != nil { - return err - } - if needOrderBy { - if err := statement.writeOrderBys(buf); err != nil { - return err +func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error { + dbType := statement.dialect.URI().DBType + if statement.isUsingLegacyLimitOffset() { + if dbType == "mssql" { + return statement.writeMssqlLegacySelect(buf, columnStr) + } + if dbType == "oracle" { + return statement.writeOracleLegacySelect(buf, columnStr) } } - - dialect := statement.dialect - if needLimit { - if dialect.URI().DBType == schemas.ORACLE { - if err := statement.writeOracleLimit(buf, columnStr); err != nil { - return err + // TODO: modify all functions to func(w builder.Writer) error + writeFns := []func(*builder.BytesWriter) error{ + func(bw *builder.BytesWriter) error { return statement.writeSelectColumns(bw, columnStr) }, + statement.writeFrom, + statement.writeWhere, + func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, + func(bw *builder.BytesWriter) (err error) { + if dbType == "mssql" && len(statement.orderBy) == 0 && needLimit { + // ORDER BY is mandatory to use OFFSET and FETCH clause (only in sqlserver) + if statement.LimitN == nil && statement.Start == 0 { + // no need to add + return + } + if statement.IsDistinct || len(statement.GroupByStr) > 0 { + // the order-by column should be one of distincts or group-bys + // order by the first column + _, err = bw.WriteString(" ORDER BY 1 ASC") + return + } + if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { + // no primary key, order by the first column + _, err = bw.WriteString(" ORDER BY 1 ASC") + return + } + // order by primary key + statement.orderBy = []orderBy{{orderStr: statement.colName(statement.RefTable.GetColumn(statement.RefTable.PrimaryKeys[0]), statement.TableName()), direction: "ASC"}} } - } else if dialect.URI().DBType != schemas.MSSQL { - if err := statement.writeLimitOffset(buf); err != nil { - return err + return statement.writeOrderBys(bw) + }, + func(bw *builder.BytesWriter) error { + if dbType == "mssql" || dbType == "oracle" { + return statement.writeOffsetFetch(bw) } - } + return statement.writeLimitOffset(bw) + }, + func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, } - return statement.writeForUpdate(buf) + return statement.writeSelectWithFns(buf, writeFns...) } // GenExistSQL generates Exist SQL @@ -522,7 +545,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa statement.cond = statement.cond.And(autoCond) buf := builder.NewWriter() - if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil { + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil diff --git a/tests/session_count_test.go b/tests/session_count_test.go index 24c5d24a..7b359812 100644 --- a/tests/session_count_test.go +++ b/tests/session_count_test.go @@ -118,16 +118,11 @@ func TestWithTableName(t *testing.T) { }) assert.NoError(t, err) - total, err := testEngine.OrderBy("count(`id`) desc").Count(new(CountWithTableName)) + total, err := testEngine.Count(new(CountWithTableName)) assert.NoError(t, err) assert.EqualValues(t, 2, total) - total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{}) - assert.NoError(t, err) - assert.EqualValues(t, 2, total) - - // the orderby will be ignored by count because some databases will return errors if the orderby columns not in group by - total, err = testEngine.OrderBy("`name`").Count(CountWithTableName{}) + total, err = testEngine.Count(CountWithTableName{}) assert.NoError(t, err) assert.EqualValues(t, 2, total) } diff --git a/tests/session_query_test.go b/tests/session_query_test.go index 5a3a3631..4df85f79 100644 --- a/tests/session_query_test.go +++ b/tests/session_query_test.go @@ -365,11 +365,6 @@ func TestJoinWithSubQuery(t *testing.T) { func TestQueryStringWithLimit(t *testing.T) { assert.NoError(t, PrepareEngine()) - if testEngine.Dialect().URI().DBType == schemas.MSSQL { - t.SkipNow() - return - } - type QueryWithLimit struct { Id int64 `xorm:"autoincr pk"` Msg string `xorm:"varchar(255)"` From 3eda0f7805020f4cfb449e20af0643ac84ed62b5 Mon Sep 17 00:00:00 2001 From: zzdboy <28206697@qq.com> Date: Sun, 24 Sep 2023 03:52:50 +0000 Subject: [PATCH 07/10] fix KingbaseES version (#2335) Reviewed-on: https://gitea.com/xorm/xorm/pulls/2335 Reviewed-by: Lunny Xiao Co-authored-by: zzdboy <28206697@qq.com> Co-committed-by: zzdboy <28206697@qq.com> --- dialects/postgres.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/dialects/postgres.go b/dialects/postgres.go index 03966f2d..99574459 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -822,6 +822,7 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas // Postgres: 9.5.22 on x86_64-pc-linux-gnu (Debian 9.5.22-1.pgdg90+1), compiled by gcc (Debian 6.3.0-18+deb9u1) 6.3.0 20170516, 64-bit // Postgres: PostgreSQL 15.3, compiled by Visual C++ build 1914, 64-bit + // KingbaseES V008R006C008B0014 on x64, compiled by Visual C++ build 1800, 64-bit // CockroachDB CCL v19.2.4 (x86_64-unknown-linux-gnu, built if strings.HasPrefix(version, "CockroachDB") { versions := strings.Split(strings.TrimPrefix(version, "CockroachDB CCL "), " ") @@ -845,7 +846,22 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas Edition: "PostgreSQL", }, nil } - + } else if strings.HasPrefix(version, "KingbaseES") { + if strings.Contains(version, " on ") { + versions := strings.Split(strings.TrimPrefix(version, "KingbaseES "), " on ") + return &schemas.Version{ + Number: versions[0], + Level: versions[1], + Edition: "KingbaseES", + }, nil + } else { + versions := strings.Split(strings.TrimPrefix(version, "KingbaseES "), ",") + return &schemas.Version{ + Number: versions[0], + Level: versions[1], + Edition: "KingbaseES", + }, nil + } } return nil, errors.New("unknow database version") From dbe499091a7eacb01d06f9a20ff00c1384df797d Mon Sep 17 00:00:00 2001 From: lng2020 Date: Tue, 17 Oct 2023 09:41:42 +0000 Subject: [PATCH 08/10] Revert "Fix deleted tag attribute zeroTime is not DatabaseTZ (#2299)" (#2341) Related #2339 Reviewed-on: https://gitea.com/xorm/xorm/pulls/2341 Co-authored-by: lng2020 Co-committed-by: lng2020 --- convert/time.go | 38 +--------------- dialects/mssql.go | 6 ++- internal/statements/statement.go | 5 ++- session_insert.go | 3 +- tests/session_insert_test.go | 77 -------------------------------- tests/tags_test.go | 8 ++-- 6 files changed, 15 insertions(+), 122 deletions(-) diff --git a/convert/time.go b/convert/time.go index d90dc428..c923e955 100644 --- a/convert/time.go +++ b/convert/time.go @@ -28,19 +28,14 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t dt = dt.In(convertedLocation) return &dt, nil } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { - if strings.HasPrefix(s, "0000-00-00T00:00:00") || strings.HasPrefix(s, "0001-01-01T00:00:00") { - return &time.Time{}, nil - } dt, err := time.ParseInLocation("2006-01-02T15:04:05", s[:19], originalLocation) if err != nil { return nil, err } dt = dt.In(convertedLocation) + dt.IsZero() return &dt, nil } else if len(s) == 25 && s[10] == 'T' && s[19] == '+' && s[22] == ':' { - if strings.HasPrefix(s, "0000-00-00T00:00:00") || strings.HasPrefix(s, "0001-01-01T00:00:00") { - return &time.Time{}, nil - } dt, err := time.Parse(time.RFC3339, s) if err != nil { return nil, err @@ -48,10 +43,6 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t dt = dt.In(convertedLocation) return &dt, nil } else if len(s) >= 21 && s[10] == 'T' && s[19] == '.' { - if strings.HasPrefix(s, "0000-00-00T00:00:00."+strings.Repeat("0", len(s)-20)) || - strings.HasPrefix(s, "0001-01-01T00:00:00."+strings.Repeat("0", len(s)-20)) { - return &time.Time{}, nil - } dt, err := time.Parse(time.RFC3339Nano, s) if err != nil { return nil, err @@ -59,10 +50,6 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t dt = dt.In(convertedLocation) return &dt, nil } else if len(s) >= 21 && s[19] == '.' { - if strings.HasPrefix(s, "0000-00-00T00:00:00."+strings.Repeat("0", len(s)-20)) || - strings.HasPrefix(s, "0001-01-01T00:00:00."+strings.Repeat("0", len(s)-20)) { - return &time.Time{}, nil - } layout := "2006-01-02 15:04:05." + strings.Repeat("0", len(s)-20) dt, err := time.ParseInLocation(layout, s, originalLocation) if err != nil { @@ -81,11 +68,11 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t dt = dt.In(convertedLocation) return &dt, nil } else if len(s) == 8 && s[2] == ':' && s[5] == ':' { + currentDate := time.Now() dt, err := time.ParseInLocation("15:04:05", s, originalLocation) if err != nil { return nil, err } - currentDate := time.Now() // add current date for correct time locations dt = dt.AddDate(currentDate.Year(), int(currentDate.Month()), currentDate.Day()) dt = dt.In(convertedLocation) @@ -95,9 +82,6 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t } else { i, err := strconv.ParseInt(s, 10, 64) if err == nil { - if i == 0 { - return &time.Time{}, nil - } tm := time.Unix(i, 0).In(convertedLocation) return &tm, nil } @@ -124,9 +108,6 @@ func AsTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time. if !t.Valid { return nil, nil } - if utils.IsTimeZero(t.Time) { - return &time.Time{}, nil - } z, _ := t.Time.Zone() if len(z) == 0 || t.Time.Year() == 0 || t.Time.Location().String() != dbLoc.String() { tm := time.Date(t.Time.Year(), t.Time.Month(), t.Time.Day(), t.Time.Hour(), @@ -136,9 +117,6 @@ func AsTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time. tm := t.Time.In(uiLoc) return &tm, nil case *time.Time: - if utils.IsTimeZero(*t) { - return &time.Time{}, nil - } z, _ := t.Zone() if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), @@ -148,9 +126,6 @@ func AsTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time. tm := t.In(uiLoc) return &tm, nil case time.Time: - if utils.IsTimeZero(t) { - return &time.Time{}, nil - } z, _ := t.Zone() if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), @@ -160,21 +135,12 @@ func AsTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time. tm := t.In(uiLoc) return &tm, nil case int: - if t == 0 { - return &time.Time{}, nil - } tm := time.Unix(int64(t), 0).In(uiLoc) return &tm, nil case int64: - if t == 0 { - return &time.Time{}, nil - } tm := time.Unix(t, 0).In(uiLoc) return &tm, nil case *sql.NullInt64: - if t.Int64 == 0 { - return &time.Time{}, nil - } tm := time.Unix(t.Int64, 0).In(uiLoc) return &tm, nil } diff --git a/dialects/mssql.go b/dialects/mssql.go index e4edc466..aaa40335 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -330,7 +330,11 @@ func (db *mssql) SQLType(c *schemas.Column) string { res += "(MAX)" } case schemas.TimeStamp, schemas.DateTime: - return "DATETIME2" + if c.Length > 3 { + res = "DATETIME2" + } else { + return schemas.DateTime + } case schemas.TimeStampz: res = "DATETIMEOFFSET" c.Length = 7 diff --git a/internal/statements/statement.go b/internal/statements/statement.go index c075ec54..68690bbe 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -707,7 +707,10 @@ func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { if col.SQLType.IsNumeric() { cond = builder.Eq{colName: 0} } else { - cond = builder.Eq{colName: utils.ZeroTime1} + // FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value. + if statement.dialect.URI().DBType != schemas.MSSQL { + cond = builder.Eq{colName: utils.ZeroTime1} + } } if col.Nullable { diff --git a/session_insert.go b/session_insert.go index 7cc15241..7003e0f7 100644 --- a/session_insert.go +++ b/session_insert.go @@ -471,8 +471,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } if col.IsDeleted { - zeroTime := time.Date(1, 1, 1, 0, 0, 0, 0, session.engine.DatabaseTZ) - arg, err := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, zeroTime) + arg, err := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, time.Time{}) if err != nil { return nil, nil, err } diff --git a/tests/session_insert_test.go b/tests/session_insert_test.go index dd3e8405..e45e6e54 100644 --- a/tests/session_insert_test.go +++ b/tests/session_insert_test.go @@ -1209,80 +1209,3 @@ func TestInsertMultipleMap(t *testing.T) { Name: "xiaolunwen", }, res[1]) } - -func TestInsertNotDeleted(t *testing.T) { - assert.NoError(t, PrepareEngine()) - zeroTime := time.Date(1, 1, 1, 0, 0, 0, 0, testEngine.GetTZDatabase()) - type TestInsertNotDeletedStructNotRight struct { - ID uint64 `xorm:"'ID' pk autoincr"` - DeletedAt time.Time `xorm:"'DELETED_AT' deleted notnull"` - } - // notnull tag will be ignored - err := testEngine.Sync(new(TestInsertNotDeletedStructNotRight)) - assert.NoError(t, err) - - type TestInsertNotDeletedStruct struct { - ID uint64 `xorm:"'ID' pk autoincr"` - DeletedAt time.Time `xorm:"'DELETED_AT' deleted"` - } - - assert.NoError(t, testEngine.Sync(new(TestInsertNotDeletedStruct))) - - var v1 TestInsertNotDeletedStructNotRight - _, err = testEngine.Insert(&v1) - assert.NoError(t, err) - - var v2 TestInsertNotDeletedStructNotRight - has, err := testEngine.Get(&v2) - assert.NoError(t, err) - assert.True(t, has) - assert.Equal(t, v2.DeletedAt.In(testEngine.GetTZDatabase()).Format("2006-01-02 15:04:05"), zeroTime.Format("2006-01-02 15:04:05")) - - var v3 TestInsertNotDeletedStruct - _, err = testEngine.Insert(&v3) - assert.NoError(t, err) - - var v4 TestInsertNotDeletedStruct - has, err = testEngine.Get(&v4) - assert.NoError(t, err) - assert.True(t, has) - assert.Equal(t, v4.DeletedAt.In(testEngine.GetTZDatabase()).Format("2006-01-02 15:04:05"), zeroTime.Format("2006-01-02 15:04:05")) -} - -type MyAutoTimeFields1 struct { - Id int64 - Dt time.Time `xorm:"created DATETIME"` -} - -func (MyAutoTimeFields1) TableName() string { - return "my_auto_time_fields" -} - -type MyAutoTimeFields2 struct { - Id int64 - Dt time.Time `xorm:"created"` -} - -func (MyAutoTimeFields2) TableName() string { - return "my_auto_time_fields" -} - -func TestAutoTimeFields(t *testing.T) { - assert.NoError(t, PrepareEngine()) - - assertSync(t, new(MyAutoTimeFields1)) - - _, err := testEngine.Insert(&MyAutoTimeFields1{}) - assert.NoError(t, err) - - var res []MyAutoTimeFields2 - assert.NoError(t, testEngine.Find(&res)) - assert.EqualValues(t, 1, len(res)) - - _, err = testEngine.Insert(&MyAutoTimeFields2{}) - assert.NoError(t, err) - - res = []MyAutoTimeFields2{} - assert.NoError(t, testEngine.Find(&res)) - assert.EqualValues(t, 2, len(res)) -} diff --git a/tests/tags_test.go b/tests/tags_test.go index 14803462..f8448b4a 100644 --- a/tests/tags_test.go +++ b/tests/tags_test.go @@ -7,11 +7,11 @@ package tests import ( "fmt" "sort" + "strings" "testing" "time" "github.com/stretchr/testify/assert" - "xorm.io/xorm/convert" "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" "xorm.io/xorm/schemas" @@ -1201,10 +1201,8 @@ func TestTagTime(t *testing.T) { has, err = testEngine.Table("tag_u_t_c_struct").Cols("created").Get(&tm) assert.NoError(t, err) assert.True(t, has) - - tmTime, err := convert.String2Time(tm, time.UTC, time.UTC) - assert.NoError(t, err) - assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), tmTime.Format("2006-01-02 15:04:05")) + assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), + strings.ReplaceAll(strings.ReplaceAll(tm, "T", " "), "Z", "")) } func TestTagAutoIncr(t *testing.T) { From 0f085408afd85707635eadb2294ab52be04f3c0f Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 25 Oct 2023 07:11:18 +0000 Subject: [PATCH 09/10] some refactors for write functions (#2342) Reviewed-on: https://gitea.com/xorm/xorm/pulls/2342 --- internal/statements/insert.go | 2 +- internal/statements/legacy_select.go | 50 +++---- internal/statements/pagination.go | 148 ++++++++++++++++++++ internal/statements/query.go | 200 ++++----------------------- internal/statements/statement.go | 4 +- internal/statements/table_name.go | 4 +- internal/statements/writer.go | 37 +++++ tests/session_count_test.go | 20 +++ 8 files changed, 257 insertions(+), 208 deletions(-) create mode 100644 internal/statements/pagination.go create mode 100644 internal/statements/writer.go diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 9370c984..aa396431 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -89,7 +89,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) } if statement.Conds().IsValid() { - if _, err := buf.WriteString(" SELECT "); err != nil { + if err := statement.writeStrings(" SELECT ")(buf); err != nil { return "", nil, err } diff --git a/internal/statements/legacy_select.go b/internal/statements/legacy_select.go index 1015839e..144ad96d 100644 --- a/internal/statements/legacy_select.go +++ b/internal/statements/legacy_select.go @@ -5,8 +5,6 @@ package statements import ( - "fmt" - "xorm.io/builder" ) @@ -17,43 +15,29 @@ func (statement *Statement) isUsingLegacyLimitOffset() bool { return ok && u.UseLegacyLimitOffset() } -func (statement *Statement) writeSelectWithFns(buf *builder.BytesWriter, writeFuncs ...func(*builder.BytesWriter) error) (err error) { - for _, fn := range writeFuncs { - if err = fn(buf); err != nil { - return - } - } - return -} - // write mssql legacy query sql func (statement *Statement) writeMssqlLegacySelect(buf *builder.BytesWriter, columnStr string) error { - writeFns := []func(*builder.BytesWriter) error{ - func(bw *builder.BytesWriter) (err error) { - _, err = fmt.Fprintf(bw, "SELECT") - return - }, - func(bw *builder.BytesWriter) error { return statement.writeDistinct(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeTop(bw) }, + return statement.writeMultiple(buf, + statement.writeStrings("SELECT"), + statement.writeDistinct, + statement.writeTop, statement.writeFrom, statement.writeWhereWithMssqlPagination, - func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeOrderBys(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, - } - return statement.writeSelectWithFns(buf, writeFns...) + statement.writeGroupBy, + statement.writeHaving, + statement.writeOrderBys, + statement.writeForUpdate, + ) } func (statement *Statement) writeOracleLegacySelect(buf *builder.BytesWriter, columnStr string) error { - writeFns := []func(*builder.BytesWriter) error{ - func(bw *builder.BytesWriter) error { return statement.writeSelectColumns(bw, columnStr) }, + return statement.writeMultiple(buf, + statement.writeSelectColumns(columnStr), statement.writeFrom, - func(bw *builder.BytesWriter) error { return statement.writeOracleLimit(bw, columnStr) }, - func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeOrderBys(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, - } - return statement.writeSelectWithFns(buf, writeFns...) + statement.writeOracleLimit(columnStr), + statement.writeGroupBy, + statement.writeHaving, + statement.writeOrderBys, + statement.writeForUpdate, + ) } diff --git a/internal/statements/pagination.go b/internal/statements/pagination.go new file mode 100644 index 00000000..3c7a3913 --- /dev/null +++ b/internal/statements/pagination.go @@ -0,0 +1,148 @@ +// Copyright 2023 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "errors" + "fmt" + + "xorm.io/builder" + "xorm.io/xorm/internal/utils" +) + +func (statement *Statement) writePagination(bw *builder.BytesWriter) error { + dbType := statement.dialect.URI().DBType + if dbType == "mssql" || dbType == "oracle" { + return statement.writeOffsetFetch(bw) + } + return statement.writeLimitOffset(bw) +} + +func (statement *Statement) writeLimitOffset(w builder.Writer) error { + if statement.Start > 0 { + if statement.LimitN != nil { + _, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start) + return err + } + _, err := fmt.Fprintf(w, " OFFSET %v", statement.Start) + return err + } + if statement.LimitN != nil { + _, err := fmt.Fprint(w, " LIMIT ", *statement.LimitN) + return err + } + // no limit statement + return nil +} + +func (statement *Statement) writeOffsetFetch(w builder.Writer) error { + if statement.LimitN != nil { + _, err := fmt.Fprintf(w, " OFFSET %v ROWS FETCH NEXT %v ROWS ONLY", statement.Start, *statement.LimitN) + return err + } + if statement.Start > 0 { + _, err := fmt.Fprintf(w, " OFFSET %v ROWS", statement.Start) + return err + } + return nil +} + +func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error { + if !statement.cond.IsValid() { + return statement.writeMssqlPaginationCond(w) + } + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil { + return err + } + return statement.writeMssqlPaginationCond(w) +} + +// write subquery to implement limit offset +// (mssql legacy only) +func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) error { + if statement.Start <= 0 { + return nil + } + + if statement.RefTable == nil { + return errors.New("unsupported query limit without reference table") + } + + var column string + if len(statement.RefTable.PKColumns()) == 0 { + for _, index := range statement.RefTable.Indexes { + if len(index.Cols) == 1 { + column = index.Cols[0] + break + } + } + if len(column) == 0 { + column = statement.RefTable.ColumnsSeq()[0] + } + } else { + column = statement.RefTable.PKColumns()[0].Name + } + if statement.NeedTableName() { + if len(statement.TableAlias) > 0 { + column = fmt.Sprintf("%s.%s", statement.TableAlias, column) + } else { + column = fmt.Sprintf("%s.%s", statement.TableName(), column) + } + } + + subWriter := builder.NewWriter() + if _, err := fmt.Fprintf(subWriter, "(%s NOT IN (SELECT TOP %d %s", + column, statement.Start, column); err != nil { + return err + } + if err := statement.writeFrom(subWriter); err != nil { + return err + } + if err := statement.writeWhere(subWriter); err != nil { + return err + } + if err := statement.writeOrderBys(subWriter); err != nil { + return err + } + if err := statement.writeGroupBy(subWriter); err != nil { + return err + } + if _, err := fmt.Fprint(subWriter, "))"); err != nil { + return err + } + + if statement.cond.IsValid() { + if _, err := fmt.Fprint(w, " AND "); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + } + + return utils.WriteBuilder(w, subWriter) +} + +func (statement *Statement) writeOracleLimit(columnStr string) func(w *builder.BytesWriter) error { + return func(w *builder.BytesWriter) error { + if statement.LimitN == nil { + return nil + } + + oldString := w.String() + w.Reset() + rawColStr := columnStr + if rawColStr == "*" { + rawColStr = "at.*" + } + _, err := fmt.Fprintf(w, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", + columnStr, rawColStr, oldString, statement.Start+*statement.LimitN, statement.Start) + return err + } +} diff --git a/internal/statements/query.go b/internal/statements/query.go index c8384760..8a9e59e4 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -7,12 +7,10 @@ package statements import ( "errors" "fmt" - "io" "reflect" "strings" "xorm.io/builder" - "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -35,7 +33,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int } buf := builder.NewWriter() - if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), false); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -122,7 +120,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, } buf := builder.NewWriter() - if err := statement.writeSelect(buf, columnStr, true); err != nil { + if err := statement.writeSelect(buf, columnStr, false); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -168,7 +166,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa subQuerySelect = selectSQL } - if err := statement.writeSelect(buf, subQuerySelect, false); err != nil { + if err := statement.writeSelect(buf, subQuerySelect, true); err != nil { return "", nil, err } @@ -182,49 +180,16 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa } func (statement *Statement) writeFrom(w *builder.BytesWriter) error { - if _, err := fmt.Fprint(w, " FROM "); err != nil { - return err - } - if err := statement.writeTableName(w); err != nil { - return err - } - if err := statement.writeAlias(w); err != nil { - return err - } - return statement.writeJoins(w) -} - -func (statement *Statement) writeLimitOffset(w builder.Writer) error { - if statement.Start > 0 { - if statement.LimitN != nil { - _, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start) - return err - } - _, err := fmt.Fprintf(w, " OFFSET %v", statement.Start) - return err - } - if statement.LimitN != nil { - _, err := fmt.Fprint(w, " LIMIT ", *statement.LimitN) - return err - } - // no limit statement - return nil -} - -func (statement *Statement) writeOffsetFetch(w builder.Writer) error { - if statement.LimitN != nil { - _, err := fmt.Fprintf(w, " OFFSET %v ROWS FETCH NEXT %v ROWS ONLY", statement.Start, *statement.LimitN) - return err - } - if statement.Start > 0 { - _, err := fmt.Fprintf(w, " OFFSET %v ROWS", statement.Start) - return err - } - return nil + return statement.writeMultiple(w, + statement.writeStrings(" FROM "), + statement.writeTableName, + statement.writeAlias, + statement.writeJoins, + ) } // write "TOP " (mssql only) -func (statement *Statement) writeTop(w builder.Writer) error { +func (statement *Statement) writeTop(w *builder.BytesWriter) error { if statement.LimitN == nil { return nil } @@ -232,7 +197,7 @@ func (statement *Statement) writeTop(w builder.Writer) error { return err } -func (statement *Statement) writeDistinct(w builder.Writer) error { +func (statement *Statement) writeDistinct(w *builder.BytesWriter) error { if statement.IsDistinct && !strings.HasPrefix(statement.SelectStr, "count(") { _, err := fmt.Fprint(w, " DISTINCT") return err @@ -240,15 +205,12 @@ func (statement *Statement) writeDistinct(w builder.Writer) error { return nil } -func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr string) error { - if _, err := fmt.Fprintf(w, "SELECT"); err != nil { - return err - } - if err := statement.writeDistinct(w); err != nil { - return err - } - _, err := fmt.Fprint(w, " ", columnStr) - return err +func (statement *Statement) writeSelectColumns(columnStr string) func(w *builder.BytesWriter) error { + return statement.groupWriteFns( + statement.writeStrings("SELECT"), + statement.writeDistinct, + statement.writeStrings(" ", columnStr), + ) } func (statement *Statement) writeWhereCond(w *builder.BytesWriter, cond builder.Cond) error { @@ -266,20 +228,7 @@ func (statement *Statement) writeWhere(w *builder.BytesWriter) error { return statement.writeWhereCond(w, statement.cond) } -func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error { - if !statement.cond.IsValid() { - return statement.writeMssqlPaginationCond(w) - } - if _, err := fmt.Fprint(w, " WHERE "); err != nil { - return err - } - if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil { - return err - } - return statement.writeMssqlPaginationCond(w) -} - -func (statement *Statement) writeForUpdate(w io.Writer) error { +func (statement *Statement) writeForUpdate(w *builder.BytesWriter) error { if !statement.IsForUpdate { return nil } @@ -291,90 +240,7 @@ func (statement *Statement) writeForUpdate(w io.Writer) error { return err } -// write subquery to implement limit offset -// (mssql legacy only) -func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) error { - if statement.Start <= 0 { - return nil - } - - if statement.RefTable == nil { - return errors.New("unsupported query limit without reference table") - } - - var column string - if len(statement.RefTable.PKColumns()) == 0 { - for _, index := range statement.RefTable.Indexes { - if len(index.Cols) == 1 { - column = index.Cols[0] - break - } - } - if len(column) == 0 { - column = statement.RefTable.ColumnsSeq()[0] - } - } else { - column = statement.RefTable.PKColumns()[0].Name - } - if statement.NeedTableName() { - if len(statement.TableAlias) > 0 { - column = fmt.Sprintf("%s.%s", statement.TableAlias, column) - } else { - column = fmt.Sprintf("%s.%s", statement.TableName(), column) - } - } - - subWriter := builder.NewWriter() - if _, err := fmt.Fprintf(subWriter, "(%s NOT IN (SELECT TOP %d %s", - column, statement.Start, column); err != nil { - return err - } - if err := statement.writeFrom(subWriter); err != nil { - return err - } - if err := statement.writeWhere(subWriter); err != nil { - return err - } - if err := statement.writeOrderBys(subWriter); err != nil { - return err - } - if err := statement.writeGroupBy(subWriter); err != nil { - return err - } - if _, err := fmt.Fprint(subWriter, "))"); err != nil { - return err - } - - if statement.cond.IsValid() { - if _, err := fmt.Fprint(w, " AND "); err != nil { - return err - } - } else { - if _, err := fmt.Fprint(w, " WHERE "); err != nil { - return err - } - } - - return utils.WriteBuilder(w, subWriter) -} - -func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr string) error { - if statement.LimitN == nil { - return nil - } - - oldString := w.String() - w.Reset() - rawColStr := columnStr - if rawColStr == "*" { - rawColStr = "at.*" - } - _, err := fmt.Fprintf(w, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", - columnStr, rawColStr, oldString, statement.Start+*statement.LimitN, statement.Start) - return err -} - -func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error { +func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, isCounting bool) error { dbType := statement.dialect.URI().DBType if statement.isUsingLegacyLimitOffset() { if dbType == "mssql" { @@ -384,21 +250,21 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri return statement.writeOracleLegacySelect(buf, columnStr) } } - // TODO: modify all functions to func(w builder.Writer) error - writeFns := []func(*builder.BytesWriter) error{ - func(bw *builder.BytesWriter) error { return statement.writeSelectColumns(bw, columnStr) }, + + return statement.writeMultiple(buf, + statement.writeSelectColumns(columnStr), statement.writeFrom, statement.writeWhere, - func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, + statement.writeGroupBy, + statement.writeHaving, func(bw *builder.BytesWriter) (err error) { - if dbType == "mssql" && len(statement.orderBy) == 0 && needLimit { + if dbType == "mssql" && len(statement.orderBy) == 0 { // ORDER BY is mandatory to use OFFSET and FETCH clause (only in sqlserver) if statement.LimitN == nil && statement.Start == 0 { // no need to add return } - if statement.IsDistinct || len(statement.GroupByStr) > 0 { + if statement.IsDistinct || len(statement.GroupByStr) > 0 || isCounting { // the order-by column should be one of distincts or group-bys // order by the first column _, err = bw.WriteString(" ORDER BY 1 ASC") @@ -414,15 +280,9 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri } return statement.writeOrderBys(bw) }, - func(bw *builder.BytesWriter) error { - if dbType == "mssql" || dbType == "oracle" { - return statement.writeOffsetFetch(bw) - } - return statement.writeLimitOffset(bw) - }, - func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, - } - return statement.writeSelectWithFns(buf, writeFns...) + statement.writePagination, + statement.writeForUpdate, + ) } // GenExistSQL generates Exist SQL @@ -545,7 +405,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa statement.cond = statement.cond.And(autoCond) buf := builder.NewWriter() - if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), false); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 68690bbe..55a3d89e 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -294,7 +294,7 @@ func (statement *Statement) GroupBy(keys string) *Statement { return statement } -func (statement *Statement) writeGroupBy(w builder.Writer) error { +func (statement *Statement) writeGroupBy(w *builder.BytesWriter) error { if statement.GroupByStr == "" { return nil } @@ -308,7 +308,7 @@ func (statement *Statement) Having(conditions string) *Statement { return statement } -func (statement *Statement) writeHaving(w builder.Writer) error { +func (statement *Statement) writeHaving(w *builder.BytesWriter) error { if statement.HavingStr == "" { return nil } diff --git a/internal/statements/table_name.go b/internal/statements/table_name.go index 8072a99d..1396b7df 100644 --- a/internal/statements/table_name.go +++ b/internal/statements/table_name.go @@ -27,7 +27,7 @@ func (statement *Statement) Alias(alias string) *Statement { return statement } -func (statement *Statement) writeAlias(w builder.Writer) error { +func (statement *Statement) writeAlias(w *builder.BytesWriter) error { if statement.TableAlias != "" { if statement.dialect.URI().DBType == schemas.ORACLE { if _, err := fmt.Fprint(w, " ", statement.quote(statement.TableAlias)); err != nil { @@ -42,7 +42,7 @@ func (statement *Statement) writeAlias(w builder.Writer) error { return nil } -func (statement *Statement) writeTableName(w builder.Writer) error { +func (statement *Statement) writeTableName(w *builder.BytesWriter) error { if statement.dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { if _, err := fmt.Fprint(w, statement.TableName()); err != nil { return err diff --git a/internal/statements/writer.go b/internal/statements/writer.go new file mode 100644 index 00000000..b4ca8047 --- /dev/null +++ b/internal/statements/writer.go @@ -0,0 +1,37 @@ +// Copyright 2023 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + + "xorm.io/builder" +) + +func (statement *Statement) writeStrings(strs ...string) func(w *builder.BytesWriter) error { + return func(w *builder.BytesWriter) error { + for _, str := range strs { + if _, err := fmt.Fprint(w, str); err != nil { + return err + } + } + return nil + } +} + +func (statement *Statement) groupWriteFns(writeFuncs ...func(*builder.BytesWriter) error) func(*builder.BytesWriter) error { + return func(bw *builder.BytesWriter) error { + return statement.writeMultiple(bw, writeFuncs...) + } +} + +func (statement *Statement) writeMultiple(buf *builder.BytesWriter, writeFuncs ...func(*builder.BytesWriter) error) (err error) { + for _, fn := range writeFuncs { + if err = fn(buf); err != nil { + return + } + } + return +} diff --git a/tests/session_count_test.go b/tests/session_count_test.go index 7b359812..d9540f9e 100644 --- a/tests/session_count_test.go +++ b/tests/session_count_test.go @@ -170,3 +170,23 @@ func TestCountWithGroupBy(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, cnt) } + +func TestCountWithLimit(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "1", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "2", + }) + assert.NoError(t, err) + + cnt, err := testEngine.Limit(100).Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) +} From 6ef0a7798fb51ae441f2609a9af5899baceefb5c Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 25 Oct 2023 11:01:46 +0000 Subject: [PATCH 10/10] Fix bug when join with alias start with `a` (#2343) Fix #2331 Reviewed-on: https://gitea.com/xorm/xorm/pulls/2343 --- schemas/quote.go | 14 ++++++++------ schemas/quote_test.go | 10 ++++++++-- tests/session_query_test.go | 2 +- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/schemas/quote.go b/schemas/quote.go index 6df7bf0b..63230bf4 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -111,7 +111,7 @@ func findStart(value string, start int) int { return start } - var k = -1 + k := -1 for j := start; j < len(value); j++ { if value[j] != ' ' { k = j @@ -122,7 +122,9 @@ func findStart(value string, start int) int { return len(value) } - if (value[k] == 'A' || value[k] == 'a') && (value[k+1] == 'S' || value[k+1] == 's') { + if k+1 < len(value) && + (value[k] == 'A' || value[k] == 'a') && + (value[k+1] == 'S' || value[k+1] == 's') { k += 2 } @@ -135,7 +137,7 @@ func findStart(value string, start int) int { } func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error { - var realWord = word + realWord := word if (word[0] == CommanQuoteMark && word[len(word)-1] == CommanQuoteMark) || (word[0] == q.Prefix && word[len(word)-1] == q.Suffix) { realWord = word[1 : len(word)-1] @@ -188,7 +190,7 @@ func (q Quoter) QuoteTo(buf *strings.Builder, value string) error { return nil } - var nextEnd = findWord(value, start) + nextEnd := findWord(value, start) if err := q.quoteWordTo(buf, value[start:nextEnd]); err != nil { return err } @@ -199,7 +201,7 @@ func (q Quoter) QuoteTo(buf *strings.Builder, value string) error { // Strings quotes a slice of string func (q Quoter) Strings(s []string) []string { - var res = make([]string, 0, len(s)) + res := make([]string, 0, len(s)) for _, a := range s { res = append(res, q.Quote(a)) } @@ -218,7 +220,7 @@ func (q Quoter) Replace(sql string) string { var beginSingleQuote bool for i := 0; i < len(sql); i++ { if !beginSingleQuote && sql[i] == CommanQuoteMark { - var j = i + 1 + j := i + 1 for ; j < len(sql); j++ { if sql[j] == CommanQuoteMark { break diff --git a/schemas/quote_test.go b/schemas/quote_test.go index f84dfb7d..8f39db0d 100644 --- a/schemas/quote_test.go +++ b/schemas/quote_test.go @@ -131,6 +131,8 @@ func TestJoin(t *testing.T) { assert.EqualValues(t, "[a].*,[b].[c]", quoter.Join([]string{"a.*", " b.c"}, ",")) + assert.EqualValues(t, "[b] [a]", quoter.Join([]string{"b a"}, ",")) + assert.EqualValues(t, "[f1], [f2], [f3]", quoter.Join(cols, ", ")) quoter.IsReserved = AlwaysNoReserve @@ -146,7 +148,7 @@ func TestStrings(t *testing.T) { } func TestTrim(t *testing.T) { - var kases = map[string]string{ + kases := map[string]string{ "[table_name]": "table_name", "[schema].[table_name]": "schema.table_name", } @@ -159,7 +161,7 @@ func TestTrim(t *testing.T) { func TestReplace(t *testing.T) { q := Quoter{'[', ']', AlwaysReserve} - var kases = []struct { + kases := []struct { source string expected string }{ @@ -171,6 +173,10 @@ func TestReplace(t *testing.T) { "SELECT 'abc```test```''', `a` FROM b", "SELECT 'abc```test```''', [a] FROM b", }, + { + "SELECT * FROM `a` INNER JOIN `b` `c` WHERE `a`.`id` = `c`.`a_id`", + "SELECT * FROM [a] INNER JOIN [b] [c] WHERE [a].[id] = [c].[a_id]", + }, { "UPDATE table SET `a` = ~ `a`, `b`='abc`'", "UPDATE table SET [a] = ~ [a], [b]='abc`'", diff --git a/tests/session_query_test.go b/tests/session_query_test.go index 4df85f79..726b19e2 100644 --- a/tests/session_query_test.go +++ b/tests/session_query_test.go @@ -493,4 +493,4 @@ func TestRowsReset(t *testing.T) { assert.EqualValues(t, "4", rrs[0].Name) assert.EqualValues(t, "5", rrs[1].Name) assert.EqualValues(t, "6", rrs[2].Name) -} +} \ No newline at end of file