From 9b0829fb302d43c904e578251dd8ae2a80148c28 Mon Sep 17 00:00:00 2001 From: FlyingOnion <731677080@qq.com> Date: Wed, 13 Sep 2023 17:30:17 +0800 Subject: [PATCH] implement offset fetch for oracle and sqlserver --- dialects/mssql.go | 10 +++ dialects/oracle.go | 12 +++ internal/statements/legacy_select.go | 55 ++++++++++++++ internal/statements/query.go | 109 ++++++++++++++++----------- tests/session_query_test.go | 5 -- 5 files changed, 143 insertions(+), 48 deletions(-) create mode 100644 internal/statements/legacy_select.go diff --git a/dialects/mssql.go b/dialects/mssql.go index 2c64e637..b321fd6f 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -217,6 +217,7 @@ type mssql struct { Base defaultVarchar string defaultChar string + useLegacy bool } func (db *mssql) Init(uri *URI) error { @@ -226,6 +227,8 @@ func (db *mssql) Init(uri *URI) error { return db.Base.Init(db, uri) } +func (db *mssql) UseLegacyLimitOffset() bool { return db.useLegacy } + func (db *mssql) SetParams(params map[string]string) { defaultVarchar, ok := params["DEFAULT_VARCHAR"] if ok { @@ -252,6 +255,13 @@ func (db *mssql) SetParams(params map[string]string) { } else { db.defaultChar = "CHAR" } + + useLegacy, ok := params["USE_LEGACY_LIMIT_OFFSET"] + if ok { + if b, _ := strconv.ParseBool(useLegacy); b { + db.useLegacy = true + } + } } func (db *mssql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { diff --git a/dialects/oracle.go b/dialects/oracle.go index fbda9dda..ac0fb944 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -509,6 +509,7 @@ var ( type oracle struct { Base + useLegacy bool } func (db *oracle) Init(uri *URI) error { @@ -516,6 +517,17 @@ func (db *oracle) Init(uri *URI) error { return db.Base.Init(db, uri) } +func (db *oracle) UseLegacyLimitOffset() bool { return db.useLegacy } + +func (db *oracle) SetParams(params map[string]string) { + useLegacy, ok := params["USE_LEGACY_LIMIT_OFFSET"] + if ok { + if b, _ := strconv.ParseBool(useLegacy); b { + db.useLegacy = true + } + } +} + func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { rows, err := queryer.QueryContext(ctx, "select * from v$version where banner like 'Oracle%'") if err != nil { diff --git a/internal/statements/legacy_select.go b/internal/statements/legacy_select.go new file mode 100644 index 00000000..f3aee0d0 --- /dev/null +++ b/internal/statements/legacy_select.go @@ -0,0 +1,55 @@ +package statements + +import ( + "fmt" + + "xorm.io/builder" +) + +// isUsingLegacy returns true if xorm uses legacy LIMIT OFFSET. +// It's only available in sqlserver and oracle, if param USE_LEGACY_LIMIT_OFFSET is set to "true" +func (statement *Statement) isUsingLegacyLimitOffset() bool { + u, ok := statement.dialect.(interface{ UseLegacyLimitOffset() bool }) + return ok && u.UseLegacyLimitOffset() +} + +func (statement *Statement) writeSelectWithFns(buf *builder.BytesWriter, writeFuncs ...func(*builder.BytesWriter) error) (err error) { + for _, fn := range writeFuncs { + if err = fn(buf); err != nil { + return + } + } + return +} + +// write mssql legacy query sql +func (statement *Statement) writeMssqlLegacySelect(buf *builder.BytesWriter, columnStr string) error { + writeFns := []func(*builder.BytesWriter) error{ + func(bw *builder.BytesWriter) (err error) { + _, err = fmt.Fprintf(bw, "SELECT") + return + }, + func(bw *builder.BytesWriter) error { return statement.writeDistinct(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeTop(bw) }, + statement.writeFrom, + statement.writeWhereWithMssqlPagination, + func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeOrderBys(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, + } + return statement.writeSelectWithFns(buf, writeFns...) +} + +func (statement *Statement) writeOracleLegacySelect(buf *builder.BytesWriter, columnStr string) error { + writeFns := []func(*builder.BytesWriter) error{ + func(bw *builder.BytesWriter) error { return statement.writeSelectColumns(bw, columnStr) }, + statement.writeFrom, + func(bw *builder.BytesWriter) error { return statement.writeOracleLimit(bw, columnStr) }, + func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeOrderBys(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, + } + return statement.writeSelectWithFns(buf, writeFns...) +} diff --git a/internal/statements/query.go b/internal/statements/query.go index 216a2028..c8384760 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -35,7 +35,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int } buf := builder.NewWriter() - if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil { + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -66,7 +66,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri } buf := builder.NewWriter() - if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true, true); err != nil { + if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -122,7 +122,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, } buf := builder.NewWriter() - if err := statement.writeSelect(buf, columnStr, true, true); err != nil { + if err := statement.writeSelect(buf, columnStr, true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -168,7 +168,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa subQuerySelect = selectSQL } - if err := statement.writeSelect(buf, subQuerySelect, false, false); err != nil { + if err := statement.writeSelect(buf, subQuerySelect, false); err != nil { return "", nil, err } @@ -200,7 +200,7 @@ func (statement *Statement) writeLimitOffset(w builder.Writer) error { _, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start) return err } - _, err := fmt.Fprintf(w, " LIMIT 0 OFFSET %v", statement.Start) + _, err := fmt.Fprintf(w, " OFFSET %v", statement.Start) return err } if statement.LimitN != nil { @@ -211,10 +211,20 @@ func (statement *Statement) writeLimitOffset(w builder.Writer) error { return nil } -func (statement *Statement) writeTop(w builder.Writer) error { - if statement.dialect.URI().DBType != schemas.MSSQL { - return nil +func (statement *Statement) writeOffsetFetch(w builder.Writer) error { + if statement.LimitN != nil { + _, err := fmt.Fprintf(w, " OFFSET %v ROWS FETCH NEXT %v ROWS ONLY", statement.Start, *statement.LimitN) + return err } + if statement.Start > 0 { + _, err := fmt.Fprintf(w, " OFFSET %v ROWS", statement.Start) + return err + } + return nil +} + +// write "TOP " (mssql only) +func (statement *Statement) writeTop(w builder.Writer) error { if statement.LimitN == nil { return nil } @@ -237,9 +247,6 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr if err := statement.writeDistinct(w); err != nil { return err } - if err := statement.writeTop(w); err != nil { - return err - } _, err := fmt.Fprint(w, " ", columnStr) return err } @@ -284,8 +291,10 @@ func (statement *Statement) writeForUpdate(w io.Writer) error { return err } +// write subquery to implement limit offset +// (mssql legacy only) func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) error { - if statement.dialect.URI().DBType != schemas.MSSQL || statement.Start <= 0 { + if statement.Start <= 0 { return nil } @@ -365,41 +374,55 @@ func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr s return err } -func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit, needOrderBy bool) error { - if err := statement.writeSelectColumns(buf, columnStr); err != nil { - return err - } - if err := statement.writeFrom(buf); err != nil { - return err - } - if err := statement.writeWhereWithMssqlPagination(buf); err != nil { - return err - } - if err := statement.writeGroupBy(buf); err != nil { - return err - } - if err := statement.writeHaving(buf); err != nil { - return err - } - if needOrderBy { - if err := statement.writeOrderBys(buf); err != nil { - return err +func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error { + dbType := statement.dialect.URI().DBType + if statement.isUsingLegacyLimitOffset() { + if dbType == "mssql" { + return statement.writeMssqlLegacySelect(buf, columnStr) + } + if dbType == "oracle" { + return statement.writeOracleLegacySelect(buf, columnStr) } } - - dialect := statement.dialect - if needLimit { - if dialect.URI().DBType == schemas.ORACLE { - if err := statement.writeOracleLimit(buf, columnStr); err != nil { - return err + // TODO: modify all functions to func(w builder.Writer) error + writeFns := []func(*builder.BytesWriter) error{ + func(bw *builder.BytesWriter) error { return statement.writeSelectColumns(bw, columnStr) }, + statement.writeFrom, + statement.writeWhere, + func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, + func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, + func(bw *builder.BytesWriter) (err error) { + if dbType == "mssql" && len(statement.orderBy) == 0 && needLimit { + // ORDER BY is mandatory to use OFFSET and FETCH clause (only in sqlserver) + if statement.LimitN == nil && statement.Start == 0 { + // no need to add + return + } + if statement.IsDistinct || len(statement.GroupByStr) > 0 { + // the order-by column should be one of distincts or group-bys + // order by the first column + _, err = bw.WriteString(" ORDER BY 1 ASC") + return + } + if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { + // no primary key, order by the first column + _, err = bw.WriteString(" ORDER BY 1 ASC") + return + } + // order by primary key + statement.orderBy = []orderBy{{orderStr: statement.colName(statement.RefTable.GetColumn(statement.RefTable.PrimaryKeys[0]), statement.TableName()), direction: "ASC"}} } - } else if dialect.URI().DBType != schemas.MSSQL { - if err := statement.writeLimitOffset(buf); err != nil { - return err + return statement.writeOrderBys(bw) + }, + func(bw *builder.BytesWriter) error { + if dbType == "mssql" || dbType == "oracle" { + return statement.writeOffsetFetch(bw) } - } + return statement.writeLimitOffset(bw) + }, + func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, } - return statement.writeForUpdate(buf) + return statement.writeSelectWithFns(buf, writeFns...) } // GenExistSQL generates Exist SQL @@ -522,7 +545,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa statement.cond = statement.cond.And(autoCond) buf := builder.NewWriter() - if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil { + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil diff --git a/tests/session_query_test.go b/tests/session_query_test.go index 5a3a3631..4df85f79 100644 --- a/tests/session_query_test.go +++ b/tests/session_query_test.go @@ -365,11 +365,6 @@ func TestJoinWithSubQuery(t *testing.T) { func TestQueryStringWithLimit(t *testing.T) { assert.NoError(t, PrepareEngine()) - if testEngine.Dialect().URI().DBType == schemas.MSSQL { - t.SkipNow() - return - } - type QueryWithLimit struct { Id int64 `xorm:"autoincr pk"` Msg string `xorm:"varchar(255)"`