From 30667e9d3806313f976ac902c8d1dfbeffcd89c3 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 19 Oct 2023 14:39:58 +0800 Subject: [PATCH] some refactors for write functions --- internal/statements/insert.go | 2 +- internal/statements/legacy_select.go | 50 ++++++------------ internal/statements/query.go | 78 ++++++++++++++-------------- internal/statements/statement.go | 4 +- internal/statements/writer.go | 39 ++++++++++++++ tests/session_count_test.go | 20 +++++++ 6 files changed, 118 insertions(+), 75 deletions(-) create mode 100644 internal/statements/writer.go diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 9370c984..0bb656ca 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -89,7 +89,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) } if statement.Conds().IsValid() { - if _, err := buf.WriteString(" SELECT "); err != nil { + if err := statement.writeString(" SELECT ")(buf); err != nil { return "", nil, err } diff --git a/internal/statements/legacy_select.go b/internal/statements/legacy_select.go index 1015839e..7f55c822 100644 --- a/internal/statements/legacy_select.go +++ b/internal/statements/legacy_select.go @@ -5,8 +5,6 @@ package statements import ( - "fmt" - "xorm.io/builder" ) @@ -17,43 +15,29 @@ func (statement *Statement) isUsingLegacyLimitOffset() bool { return ok && u.UseLegacyLimitOffset() } -func (statement *Statement) writeSelectWithFns(buf *builder.BytesWriter, writeFuncs ...func(*builder.BytesWriter) error) (err error) { - for _, fn := range writeFuncs { - if err = fn(buf); err != nil { - return - } - } - return -} - // write mssql legacy query sql func (statement *Statement) writeMssqlLegacySelect(buf *builder.BytesWriter, columnStr string) error { - writeFns := []func(*builder.BytesWriter) error{ - func(bw *builder.BytesWriter) (err error) { - _, err = fmt.Fprintf(bw, "SELECT") - return - }, - func(bw *builder.BytesWriter) error { return statement.writeDistinct(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeTop(bw) }, + return statement.writeMultiple(buf, + statement.writeString("SELECT"), + statement.writeDistinct, + statement.writeTop, statement.writeFrom, statement.writeWhereWithMssqlPagination, - func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeOrderBys(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, - } - return statement.writeSelectWithFns(buf, writeFns...) + statement.writeGroupBy, + statement.writeHaving, + statement.writeOrderBys, + statement.writeForUpdate, + ) } func (statement *Statement) writeOracleLegacySelect(buf *builder.BytesWriter, columnStr string) error { - writeFns := []func(*builder.BytesWriter) error{ - func(bw *builder.BytesWriter) error { return statement.writeSelectColumns(bw, columnStr) }, + return statement.writeMultiple(buf, + statement.writeSelectColumns(columnStr), statement.writeFrom, - func(bw *builder.BytesWriter) error { return statement.writeOracleLimit(bw, columnStr) }, - func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeOrderBys(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, - } - return statement.writeSelectWithFns(buf, writeFns...) + statement.writeOracleLimit(columnStr), + statement.writeGroupBy, + statement.writeHaving, + statement.writeOrderBys, + statement.writeForUpdate, + ) } diff --git a/internal/statements/query.go b/internal/statements/query.go index c8384760..492fe915 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -7,7 +7,6 @@ package statements import ( "errors" "fmt" - "io" "reflect" "strings" @@ -194,6 +193,14 @@ func (statement *Statement) writeFrom(w *builder.BytesWriter) error { return statement.writeJoins(w) } +func (statement *Statement) writePagination(bw *builder.BytesWriter) error { + dbType := statement.dialect.URI().DBType + if dbType == "mssql" || dbType == "oracle" { + return statement.writeOffsetFetch(bw) + } + return statement.writeLimitOffset(bw) +} + func (statement *Statement) writeLimitOffset(w builder.Writer) error { if statement.Start > 0 { if statement.LimitN != nil { @@ -224,7 +231,7 @@ func (statement *Statement) writeOffsetFetch(w builder.Writer) error { } // write "TOP " (mssql only) -func (statement *Statement) writeTop(w builder.Writer) error { +func (statement *Statement) writeTop(w *builder.BytesWriter) error { if statement.LimitN == nil { return nil } @@ -232,7 +239,7 @@ func (statement *Statement) writeTop(w builder.Writer) error { return err } -func (statement *Statement) writeDistinct(w builder.Writer) error { +func (statement *Statement) writeDistinct(w *builder.BytesWriter) error { if statement.IsDistinct && !strings.HasPrefix(statement.SelectStr, "count(") { _, err := fmt.Fprint(w, " DISTINCT") return err @@ -240,15 +247,12 @@ func (statement *Statement) writeDistinct(w builder.Writer) error { return nil } -func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr string) error { - if _, err := fmt.Fprintf(w, "SELECT"); err != nil { - return err - } - if err := statement.writeDistinct(w); err != nil { - return err - } - _, err := fmt.Fprint(w, " ", columnStr) - return err +func (statement *Statement) writeSelectColumns(columnStr string) func(w *builder.BytesWriter) error { + return statement.groupWriteFns( + statement.writeString("SELECT"), + statement.writeDistinct, + statement.writeString(columnStr), + ) } func (statement *Statement) writeWhereCond(w *builder.BytesWriter, cond builder.Cond) error { @@ -279,7 +283,7 @@ func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter return statement.writeMssqlPaginationCond(w) } -func (statement *Statement) writeForUpdate(w io.Writer) error { +func (statement *Statement) writeForUpdate(w *builder.BytesWriter) error { if !statement.IsForUpdate { return nil } @@ -358,20 +362,22 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err return utils.WriteBuilder(w, subWriter) } -func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr string) error { - if statement.LimitN == nil { - return nil - } +func (statement *Statement) writeOracleLimit(columnStr string) func(w *builder.BytesWriter) error { + return func(w *builder.BytesWriter) error { + if statement.LimitN == nil { + return nil + } - oldString := w.String() - w.Reset() - rawColStr := columnStr - if rawColStr == "*" { - rawColStr = "at.*" + oldString := w.String() + w.Reset() + rawColStr := columnStr + if rawColStr == "*" { + rawColStr = "at.*" + } + _, err := fmt.Fprintf(w, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", + columnStr, rawColStr, oldString, statement.Start+*statement.LimitN, statement.Start) + return err } - _, err := fmt.Fprintf(w, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", - columnStr, rawColStr, oldString, statement.Start+*statement.LimitN, statement.Start) - return err } func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error { @@ -384,13 +390,13 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri return statement.writeOracleLegacySelect(buf, columnStr) } } - // TODO: modify all functions to func(w builder.Writer) error - writeFns := []func(*builder.BytesWriter) error{ - func(bw *builder.BytesWriter) error { return statement.writeSelectColumns(bw, columnStr) }, + + return statement.writeMultiple(buf, + statement.writeSelectColumns(columnStr), statement.writeFrom, statement.writeWhere, - func(bw *builder.BytesWriter) error { return statement.writeGroupBy(bw) }, - func(bw *builder.BytesWriter) error { return statement.writeHaving(bw) }, + statement.writeGroupBy, + statement.writeHaving, func(bw *builder.BytesWriter) (err error) { if dbType == "mssql" && len(statement.orderBy) == 0 && needLimit { // ORDER BY is mandatory to use OFFSET and FETCH clause (only in sqlserver) @@ -414,15 +420,9 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri } return statement.writeOrderBys(bw) }, - func(bw *builder.BytesWriter) error { - if dbType == "mssql" || dbType == "oracle" { - return statement.writeOffsetFetch(bw) - } - return statement.writeLimitOffset(bw) - }, - func(bw *builder.BytesWriter) error { return statement.writeForUpdate(bw) }, - } - return statement.writeSelectWithFns(buf, writeFns...) + statement.writePagination, + statement.writeForUpdate, + ) } // GenExistSQL generates Exist SQL diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 68690bbe..55a3d89e 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -294,7 +294,7 @@ func (statement *Statement) GroupBy(keys string) *Statement { return statement } -func (statement *Statement) writeGroupBy(w builder.Writer) error { +func (statement *Statement) writeGroupBy(w *builder.BytesWriter) error { if statement.GroupByStr == "" { return nil } @@ -308,7 +308,7 @@ func (statement *Statement) Having(conditions string) *Statement { return statement } -func (statement *Statement) writeHaving(w builder.Writer) error { +func (statement *Statement) writeHaving(w *builder.BytesWriter) error { if statement.HavingStr == "" { return nil } diff --git a/internal/statements/writer.go b/internal/statements/writer.go new file mode 100644 index 00000000..b0a8b5c7 --- /dev/null +++ b/internal/statements/writer.go @@ -0,0 +1,39 @@ +// Copyright 2023 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + + "xorm.io/builder" +) + +func (statement *Statement) writeString(str string) func(w *builder.BytesWriter) error { + return func(w *builder.BytesWriter) error { + if _, err := fmt.Fprint(w, str); err != nil { + return err + } + return nil + } +} + +func (statement *Statement) writeSpace(w *builder.BytesWriter) error { + return statement.writeString(" ")(w) +} + +func (statement *Statement) groupWriteFns(writeFuncs ...func(*builder.BytesWriter) error) func(*builder.BytesWriter) error { + return func(bw *builder.BytesWriter) error { + return statement.writeMultiple(bw, writeFuncs...) + } +} + +func (statement *Statement) writeMultiple(buf *builder.BytesWriter, writeFuncs ...func(*builder.BytesWriter) error) (err error) { + for _, fn := range writeFuncs { + if err = fn(buf); err != nil { + return + } + } + return +} diff --git a/tests/session_count_test.go b/tests/session_count_test.go index 7b359812..d9540f9e 100644 --- a/tests/session_count_test.go +++ b/tests/session_count_test.go @@ -170,3 +170,23 @@ func TestCountWithGroupBy(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, cnt) } + +func TestCountWithLimit(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "1", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "2", + }) + assert.NoError(t, err) + + cnt, err := testEngine.Limit(100).Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) +}