From 0f085408afd85707635eadb2294ab52be04f3c0f Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 25 Oct 2023 07:11:18 +0000 Subject: [PATCH] some refactors for write functions (#2342) Reviewed-on: https://gitea.com/xorm/xorm/pulls/2342 --- internal/statements/insert.go | 2 +- internal/statements/legacy_select.go | 50 +++---- internal/statements/pagination.go | 148 ++++++++++++++++++++ internal/statements/query.go | 200 ++++----------------------- internal/statements/statement.go | 4 +- internal/statements/table_name.go | 4 +- internal/statements/writer.go | 37 +++++ tests/session_count_test.go | 20 +++ 8 files changed, 257 insertions(+), 208 deletions(-) create mode 100644 internal/statements/pagination.go create mode 100644 internal/statements/writer.go diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 9370c984..aa396431 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -89,7 +89,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) } if statement.Conds().IsValid() { - if _, err := buf.WriteString(" SELECT "); err != nil { + if err := statement.writeStrings(" SELECT ")(buf); err != nil { return "", nil, err } diff --git a/internal/statements/legacy_select.go b/internal/statements/legacy_select.go index 1015839e..144ad96d 100644 --- a/internal/statements/legacy_select.go +++ b/internal/statements/legacy_select.go @@ -5,8 +5,6 @@ package statements import ( - "fmt" - "xorm.io/builder" ) @@ -17,43 +15,29 @@ func (statement *Statement) isUsingLegacyLimitOffset() bool { return ok && u.UseLegacyLimitOffset() } -func (statement *Statement) writeSelectWithFns(buf *builder.BytesWriter, writeFuncs ...func(*builder.BytesWriter) error) (err error) { - for _, fn := range writeFuncs { - if err = fn(buf); err != nil { - return - } - } - return -} - // write mssql legacy query sql func (statement *Statement) writeMssqlLegacySelect(buf *builder.BytesWriter, columnStr string) error { - writeFns := []func(*builder.BytesWriter) error{ - func(bw *builder.BytesWriter) (err error) { - _, err = fmt.Fprintf(bw, "SELECT") - return - }, - func(bw *builder.BytesWriter) error { return statement.writeDistinct(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeTop(bw) }, + return statement.writeMultiple(buf, + statement.writeStrings("SELECT"), + statement.writeDistinct, + statement.writeTop, statement.writeFrom, statement.writeWhereWithMssqlPagination, - func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeOrderBys(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, - } - return statement.writeSelectWithFns(buf, writeFns...) + statement.writeGroupBy, + statement.writeHaving, + statement.writeOrderBys, + statement.writeForUpdate, + ) } func (statement *Statement) writeOracleLegacySelect(buf *builder.BytesWriter, columnStr string) error { - writeFns := []func(*builder.BytesWriter) error{ - func(bw *builder.BytesWriter) error { return statement.writeSelectColumns(bw, columnStr) }, + return statement.writeMultiple(buf, + statement.writeSelectColumns(columnStr), statement.writeFrom, - func(bw *builder.BytesWriter) error { return statement.writeOracleLimit(bw, columnStr) }, - func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeOrderBys(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, - } - return statement.writeSelectWithFns(buf, writeFns...) + statement.writeOracleLimit(columnStr), + statement.writeGroupBy, + statement.writeHaving, + statement.writeOrderBys, + statement.writeForUpdate, + ) } diff --git a/internal/statements/pagination.go b/internal/statements/pagination.go new file mode 100644 index 00000000..3c7a3913 --- /dev/null +++ b/internal/statements/pagination.go @@ -0,0 +1,148 @@ +// Copyright 2023 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "errors" + "fmt" + + "xorm.io/builder" + "xorm.io/xorm/internal/utils" +) + +func (statement *Statement) writePagination(bw *builder.BytesWriter) error { + dbType := statement.dialect.URI().DBType + if dbType == "mssql" || dbType == "oracle" { + return statement.writeOffsetFetch(bw) + } + return statement.writeLimitOffset(bw) +} + +func (statement *Statement) writeLimitOffset(w builder.Writer) error { + if statement.Start > 0 { + if statement.LimitN != nil { + _, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start) + return err + } + _, err := fmt.Fprintf(w, " OFFSET %v", statement.Start) + return err + } + if statement.LimitN != nil { + _, err := fmt.Fprint(w, " LIMIT ", *statement.LimitN) + return err + } + // no limit statement + return nil +} + +func (statement *Statement) writeOffsetFetch(w builder.Writer) error { + if statement.LimitN != nil { + _, err := fmt.Fprintf(w, " OFFSET %v ROWS FETCH NEXT %v ROWS ONLY", statement.Start, *statement.LimitN) + return err + } + if statement.Start > 0 { + _, err := fmt.Fprintf(w, " OFFSET %v ROWS", statement.Start) + return err + } + return nil +} + +func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error { + if !statement.cond.IsValid() { + return statement.writeMssqlPaginationCond(w) + } + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil { + return err + } + return statement.writeMssqlPaginationCond(w) +} + +// write subquery to implement limit offset +// (mssql legacy only) +func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) error { + if statement.Start <= 0 { + return nil + } + + if statement.RefTable == nil { + return errors.New("unsupported query limit without reference table") + } + + var column string + if len(statement.RefTable.PKColumns()) == 0 { + for _, index := range statement.RefTable.Indexes { + if len(index.Cols) == 1 { + column = index.Cols[0] + break + } + } + if len(column) == 0 { + column = statement.RefTable.ColumnsSeq()[0] + } + } else { + column = statement.RefTable.PKColumns()[0].Name + } + if statement.NeedTableName() { + if len(statement.TableAlias) > 0 { + column = fmt.Sprintf("%s.%s", statement.TableAlias, column) + } else { + column = fmt.Sprintf("%s.%s", statement.TableName(), column) + } + } + + subWriter := builder.NewWriter() + if _, err := fmt.Fprintf(subWriter, "(%s NOT IN (SELECT TOP %d %s", + column, statement.Start, column); err != nil { + return err + } + if err := statement.writeFrom(subWriter); err != nil { + return err + } + if err := statement.writeWhere(subWriter); err != nil { + return err + } + if err := statement.writeOrderBys(subWriter); err != nil { + return err + } + if err := statement.writeGroupBy(subWriter); err != nil { + return err + } + if _, err := fmt.Fprint(subWriter, "))"); err != nil { + return err + } + + if statement.cond.IsValid() { + if _, err := fmt.Fprint(w, " AND "); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + } + + return utils.WriteBuilder(w, subWriter) +} + +func (statement *Statement) writeOracleLimit(columnStr string) func(w *builder.BytesWriter) error { + return func(w *builder.BytesWriter) error { + if statement.LimitN == nil { + return nil + } + + oldString := w.String() + w.Reset() + rawColStr := columnStr + if rawColStr == "*" { + rawColStr = "at.*" + } + _, err := fmt.Fprintf(w, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", + columnStr, rawColStr, oldString, statement.Start+*statement.LimitN, statement.Start) + return err + } +} diff --git a/internal/statements/query.go b/internal/statements/query.go index c8384760..8a9e59e4 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -7,12 +7,10 @@ package statements import ( "errors" "fmt" - "io" "reflect" "strings" "xorm.io/builder" - "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -35,7 +33,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int } buf := builder.NewWriter() - if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), false); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -122,7 +120,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, } buf := builder.NewWriter() - if err := statement.writeSelect(buf, columnStr, true); err != nil { + if err := statement.writeSelect(buf, columnStr, false); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil @@ -168,7 +166,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa subQuerySelect = selectSQL } - if err := statement.writeSelect(buf, subQuerySelect, false); err != nil { + if err := statement.writeSelect(buf, subQuerySelect, true); err != nil { return "", nil, err } @@ -182,49 +180,16 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa } func (statement *Statement) writeFrom(w *builder.BytesWriter) error { - if _, err := fmt.Fprint(w, " FROM "); err != nil { - return err - } - if err := statement.writeTableName(w); err != nil { - return err - } - if err := statement.writeAlias(w); err != nil { - return err - } - return statement.writeJoins(w) -} - -func (statement *Statement) writeLimitOffset(w builder.Writer) error { - if statement.Start > 0 { - if statement.LimitN != nil { - _, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start) - return err - } - _, err := fmt.Fprintf(w, " OFFSET %v", statement.Start) - return err - } - if statement.LimitN != nil { - _, err := fmt.Fprint(w, " LIMIT ", *statement.LimitN) - return err - } - // no limit statement - return nil -} - -func (statement *Statement) writeOffsetFetch(w builder.Writer) error { - if statement.LimitN != nil { - _, err := fmt.Fprintf(w, " OFFSET %v ROWS FETCH NEXT %v ROWS ONLY", statement.Start, *statement.LimitN) - return err - } - if statement.Start > 0 { - _, err := fmt.Fprintf(w, " OFFSET %v ROWS", statement.Start) - return err - } - return nil + return statement.writeMultiple(w, + statement.writeStrings(" FROM "), + statement.writeTableName, + statement.writeAlias, + statement.writeJoins, + ) } // write "TOP " (mssql only) -func (statement *Statement) writeTop(w builder.Writer) error { +func (statement *Statement) writeTop(w *builder.BytesWriter) error { if statement.LimitN == nil { return nil } @@ -232,7 +197,7 @@ func (statement *Statement) writeTop(w builder.Writer) error { return err } -func (statement *Statement) writeDistinct(w builder.Writer) error { +func (statement *Statement) writeDistinct(w *builder.BytesWriter) error { if statement.IsDistinct && !strings.HasPrefix(statement.SelectStr, "count(") { _, err := fmt.Fprint(w, " DISTINCT") return err @@ -240,15 +205,12 @@ func (statement *Statement) writeDistinct(w builder.Writer) error { return nil } -func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr string) error { - if _, err := fmt.Fprintf(w, "SELECT"); err != nil { - return err - } - if err := statement.writeDistinct(w); err != nil { - return err - } - _, err := fmt.Fprint(w, " ", columnStr) - return err +func (statement *Statement) writeSelectColumns(columnStr string) func(w *builder.BytesWriter) error { + return statement.groupWriteFns( + statement.writeStrings("SELECT"), + statement.writeDistinct, + statement.writeStrings(" ", columnStr), + ) } func (statement *Statement) writeWhereCond(w *builder.BytesWriter, cond builder.Cond) error { @@ -266,20 +228,7 @@ func (statement *Statement) writeWhere(w *builder.BytesWriter) error { return statement.writeWhereCond(w, statement.cond) } -func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error { - if !statement.cond.IsValid() { - return statement.writeMssqlPaginationCond(w) - } - if _, err := fmt.Fprint(w, " WHERE "); err != nil { - return err - } - if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil { - return err - } - return statement.writeMssqlPaginationCond(w) -} - -func (statement *Statement) writeForUpdate(w io.Writer) error { +func (statement *Statement) writeForUpdate(w *builder.BytesWriter) error { if !statement.IsForUpdate { return nil } @@ -291,90 +240,7 @@ func (statement *Statement) writeForUpdate(w io.Writer) error { return err } -// write subquery to implement limit offset -// (mssql legacy only) -func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) error { - if statement.Start <= 0 { - return nil - } - - if statement.RefTable == nil { - return errors.New("unsupported query limit without reference table") - } - - var column string - if len(statement.RefTable.PKColumns()) == 0 { - for _, index := range statement.RefTable.Indexes { - if len(index.Cols) == 1 { - column = index.Cols[0] - break - } - } - if len(column) == 0 { - column = statement.RefTable.ColumnsSeq()[0] - } - } else { - column = statement.RefTable.PKColumns()[0].Name - } - if statement.NeedTableName() { - if len(statement.TableAlias) > 0 { - column = fmt.Sprintf("%s.%s", statement.TableAlias, column) - } else { - column = fmt.Sprintf("%s.%s", statement.TableName(), column) - } - } - - subWriter := builder.NewWriter() - if _, err := fmt.Fprintf(subWriter, "(%s NOT IN (SELECT TOP %d %s", - column, statement.Start, column); err != nil { - return err - } - if err := statement.writeFrom(subWriter); err != nil { - return err - } - if err := statement.writeWhere(subWriter); err != nil { - return err - } - if err := statement.writeOrderBys(subWriter); err != nil { - return err - } - if err := statement.writeGroupBy(subWriter); err != nil { - return err - } - if _, err := fmt.Fprint(subWriter, "))"); err != nil { - return err - } - - if statement.cond.IsValid() { - if _, err := fmt.Fprint(w, " AND "); err != nil { - return err - } - } else { - if _, err := fmt.Fprint(w, " WHERE "); err != nil { - return err - } - } - - return utils.WriteBuilder(w, subWriter) -} - -func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr string) error { - if statement.LimitN == nil { - return nil - } - - oldString := w.String() - w.Reset() - rawColStr := columnStr - if rawColStr == "*" { - rawColStr = "at.*" - } - _, err := fmt.Fprintf(w, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", - columnStr, rawColStr, oldString, statement.Start+*statement.LimitN, statement.Start) - return err -} - -func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error { +func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, isCounting bool) error { dbType := statement.dialect.URI().DBType if statement.isUsingLegacyLimitOffset() { if dbType == "mssql" { @@ -384,21 +250,21 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri return statement.writeOracleLegacySelect(buf, columnStr) } } - // TODO: modify all functions to func(w builder.Writer) error - writeFns := []func(*builder.BytesWriter) error{ - func(bw *builder.BytesWriter) error { return statement.writeSelectColumns(bw, columnStr) }, + + return statement.writeMultiple(buf, + statement.writeSelectColumns(columnStr), statement.writeFrom, statement.writeWhere, - func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, + statement.writeGroupBy, + statement.writeHaving, func(bw *builder.BytesWriter) (err error) { - if dbType == "mssql" && len(statement.orderBy) == 0 && needLimit { + if dbType == "mssql" && len(statement.orderBy) == 0 { // ORDER BY is mandatory to use OFFSET and FETCH clause (only in sqlserver) if statement.LimitN == nil && statement.Start == 0 { // no need to add return } - if statement.IsDistinct || len(statement.GroupByStr) > 0 { + if statement.IsDistinct || len(statement.GroupByStr) > 0 || isCounting { // the order-by column should be one of distincts or group-bys // order by the first column _, err = bw.WriteString(" ORDER BY 1 ASC") @@ -414,15 +280,9 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri } return statement.writeOrderBys(bw) }, - func(bw *builder.BytesWriter) error { - if dbType == "mssql" || dbType == "oracle" { - return statement.writeOffsetFetch(bw) - } - return statement.writeLimitOffset(bw) - }, - func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, - } - return statement.writeSelectWithFns(buf, writeFns...) + statement.writePagination, + statement.writeForUpdate, + ) } // GenExistSQL generates Exist SQL @@ -545,7 +405,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa statement.cond = statement.cond.And(autoCond) buf := builder.NewWriter() - if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), false); err != nil { return "", nil, err } return buf.String(), buf.Args(), nil diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 68690bbe..55a3d89e 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -294,7 +294,7 @@ func (statement *Statement) GroupBy(keys string) *Statement { return statement } -func (statement *Statement) writeGroupBy(w builder.Writer) error { +func (statement *Statement) writeGroupBy(w *builder.BytesWriter) error { if statement.GroupByStr == "" { return nil } @@ -308,7 +308,7 @@ func (statement *Statement) Having(conditions string) *Statement { return statement } -func (statement *Statement) writeHaving(w builder.Writer) error { +func (statement *Statement) writeHaving(w *builder.BytesWriter) error { if statement.HavingStr == "" { return nil } diff --git a/internal/statements/table_name.go b/internal/statements/table_name.go index 8072a99d..1396b7df 100644 --- a/internal/statements/table_name.go +++ b/internal/statements/table_name.go @@ -27,7 +27,7 @@ func (statement *Statement) Alias(alias string) *Statement { return statement } -func (statement *Statement) writeAlias(w builder.Writer) error { +func (statement *Statement) writeAlias(w *builder.BytesWriter) error { if statement.TableAlias != "" { if statement.dialect.URI().DBType == schemas.ORACLE { if _, err := fmt.Fprint(w, " ", statement.quote(statement.TableAlias)); err != nil { @@ -42,7 +42,7 @@ func (statement *Statement) writeAlias(w builder.Writer) error { return nil } -func (statement *Statement) writeTableName(w builder.Writer) error { +func (statement *Statement) writeTableName(w *builder.BytesWriter) error { if statement.dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { if _, err := fmt.Fprint(w, statement.TableName()); err != nil { return err diff --git a/internal/statements/writer.go b/internal/statements/writer.go new file mode 100644 index 00000000..b4ca8047 --- /dev/null +++ b/internal/statements/writer.go @@ -0,0 +1,37 @@ +// Copyright 2023 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + + "xorm.io/builder" +) + +func (statement *Statement) writeStrings(strs ...string) func(w *builder.BytesWriter) error { + return func(w *builder.BytesWriter) error { + for _, str := range strs { + if _, err := fmt.Fprint(w, str); err != nil { + return err + } + } + return nil + } +} + +func (statement *Statement) groupWriteFns(writeFuncs ...func(*builder.BytesWriter) error) func(*builder.BytesWriter) error { + return func(bw *builder.BytesWriter) error { + return statement.writeMultiple(bw, writeFuncs...) + } +} + +func (statement *Statement) writeMultiple(buf *builder.BytesWriter, writeFuncs ...func(*builder.BytesWriter) error) (err error) { + for _, fn := range writeFuncs { + if err = fn(buf); err != nil { + return + } + } + return +} diff --git a/tests/session_count_test.go b/tests/session_count_test.go index 7b359812..d9540f9e 100644 --- a/tests/session_count_test.go +++ b/tests/session_count_test.go @@ -170,3 +170,23 @@ func TestCountWithGroupBy(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, cnt) } + +func TestCountWithLimit(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "1", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "2", + }) + assert.NoError(t, err) + + cnt, err := testEngine.Limit(100).Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) +}