Fix writes

This commit is contained in:
Lunny Xiao 2023-10-19 17:29:52 +08:00
parent 9126dc31c1
commit f58fb48eba
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
3 changed files with 156 additions and 148 deletions

View File

@ -0,0 +1,148 @@
// Copyright 2023 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"errors"
"fmt"
"xorm.io/builder"
"xorm.io/xorm/internal/utils"
)
func (statement *Statement) writePagination(bw *builder.BytesWriter) error {
dbType := statement.dialect.URI().DBType
if dbType == "mssql" || dbType == "oracle" {
return statement.writeOffsetFetch(bw)
}
return statement.writeLimitOffset(bw)
}
func (statement *Statement) writeLimitOffset(w builder.Writer) error {
if statement.Start > 0 {
if statement.LimitN != nil {
_, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start)
return err
}
_, err := fmt.Fprintf(w, " OFFSET %v", statement.Start)
return err
}
if statement.LimitN != nil {
_, err := fmt.Fprint(w, " LIMIT ", *statement.LimitN)
return err
}
// no limit statement
return nil
}
func (statement *Statement) writeOffsetFetch(w builder.Writer) error {
if statement.LimitN != nil {
_, err := fmt.Fprintf(w, " OFFSET %v ROWS FETCH NEXT %v ROWS ONLY", statement.Start, *statement.LimitN)
return err
}
if statement.Start > 0 {
_, err := fmt.Fprintf(w, " OFFSET %v ROWS", statement.Start)
return err
}
return nil
}
func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error {
if !statement.cond.IsValid() {
return statement.writeMssqlPaginationCond(w)
}
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil {
return err
}
return statement.writeMssqlPaginationCond(w)
}
// write subquery to implement limit offset
// (mssql legacy only)
func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) error {
if statement.Start <= 0 {
return nil
}
if statement.RefTable == nil {
return errors.New("unsupported query limit without reference table")
}
var column string
if len(statement.RefTable.PKColumns()) == 0 {
for _, index := range statement.RefTable.Indexes {
if len(index.Cols) == 1 {
column = index.Cols[0]
break
}
}
if len(column) == 0 {
column = statement.RefTable.ColumnsSeq()[0]
}
} else {
column = statement.RefTable.PKColumns()[0].Name
}
if statement.NeedTableName() {
if len(statement.TableAlias) > 0 {
column = fmt.Sprintf("%s.%s", statement.TableAlias, column)
} else {
column = fmt.Sprintf("%s.%s", statement.TableName(), column)
}
}
subWriter := builder.NewWriter()
if _, err := fmt.Fprintf(subWriter, "(%s NOT IN (SELECT TOP %d %s",
column, statement.Start, column); err != nil {
return err
}
if err := statement.writeFrom(subWriter); err != nil {
return err
}
if err := statement.writeWhere(subWriter); err != nil {
return err
}
if err := statement.writeOrderBys(subWriter); err != nil {
return err
}
if err := statement.writeGroupBy(subWriter); err != nil {
return err
}
if _, err := fmt.Fprint(subWriter, "))"); err != nil {
return err
}
if statement.cond.IsValid() {
if _, err := fmt.Fprint(w, " AND "); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
}
return utils.WriteBuilder(w, subWriter)
}
func (statement *Statement) writeOracleLimit(columnStr string) func(w *builder.BytesWriter) error {
return func(w *builder.BytesWriter) error {
if statement.LimitN == nil {
return nil
}
oldString := w.String()
w.Reset()
rawColStr := columnStr
if rawColStr == "*" {
rawColStr = "at.*"
}
_, err := fmt.Fprintf(w, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, rawColStr, oldString, statement.Start+*statement.LimitN, statement.Start)
return err
}
}

View File

@ -11,7 +11,6 @@ import (
"strings" "strings"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -181,53 +180,12 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
} }
func (statement *Statement) writeFrom(w *builder.BytesWriter) error { func (statement *Statement) writeFrom(w *builder.BytesWriter) error {
if _, err := fmt.Fprint(w, " FROM "); err != nil { return statement.writeMultiple(w,
return err statement.writeStrings(" FROM "),
} statement.writeTableName,
if err := statement.writeTableName(w); err != nil { statement.writeAlias,
return err statement.writeJoins,
} )
if err := statement.writeAlias(w); err != nil {
return err
}
return statement.writeJoins(w)
}
func (statement *Statement) writePagination(bw *builder.BytesWriter) error {
dbType := statement.dialect.URI().DBType
if dbType == "mssql" || dbType == "oracle" {
return statement.writeOffsetFetch(bw)
}
return statement.writeLimitOffset(bw)
}
func (statement *Statement) writeLimitOffset(w builder.Writer) error {
if statement.Start > 0 {
if statement.LimitN != nil {
_, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start)
return err
}
_, err := fmt.Fprintf(w, " OFFSET %v", statement.Start)
return err
}
if statement.LimitN != nil {
_, err := fmt.Fprint(w, " LIMIT ", *statement.LimitN)
return err
}
// no limit statement
return nil
}
func (statement *Statement) writeOffsetFetch(w builder.Writer) error {
if statement.LimitN != nil {
_, err := fmt.Fprintf(w, " OFFSET %v ROWS FETCH NEXT %v ROWS ONLY", statement.Start, *statement.LimitN)
return err
}
if statement.Start > 0 {
_, err := fmt.Fprintf(w, " OFFSET %v ROWS", statement.Start)
return err
}
return nil
} }
// write "TOP <n>" (mssql only) // write "TOP <n>" (mssql only)
@ -270,19 +228,6 @@ func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
return statement.writeWhereCond(w, statement.cond) return statement.writeWhereCond(w, statement.cond)
} }
func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error {
if !statement.cond.IsValid() {
return statement.writeMssqlPaginationCond(w)
}
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil {
return err
}
return statement.writeMssqlPaginationCond(w)
}
func (statement *Statement) writeForUpdate(w *builder.BytesWriter) error { func (statement *Statement) writeForUpdate(w *builder.BytesWriter) error {
if !statement.IsForUpdate { if !statement.IsForUpdate {
return nil return nil
@ -295,91 +240,6 @@ func (statement *Statement) writeForUpdate(w *builder.BytesWriter) error {
return err return err
} }
// write subquery to implement limit offset
// (mssql legacy only)
func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) error {
if statement.Start <= 0 {
return nil
}
if statement.RefTable == nil {
return errors.New("unsupported query limit without reference table")
}
var column string
if len(statement.RefTable.PKColumns()) == 0 {
for _, index := range statement.RefTable.Indexes {
if len(index.Cols) == 1 {
column = index.Cols[0]
break
}
}
if len(column) == 0 {
column = statement.RefTable.ColumnsSeq()[0]
}
} else {
column = statement.RefTable.PKColumns()[0].Name
}
if statement.NeedTableName() {
if len(statement.TableAlias) > 0 {
column = fmt.Sprintf("%s.%s", statement.TableAlias, column)
} else {
column = fmt.Sprintf("%s.%s", statement.TableName(), column)
}
}
subWriter := builder.NewWriter()
if _, err := fmt.Fprintf(subWriter, "(%s NOT IN (SELECT TOP %d %s",
column, statement.Start, column); err != nil {
return err
}
if err := statement.writeFrom(subWriter); err != nil {
return err
}
if err := statement.writeWhere(subWriter); err != nil {
return err
}
if err := statement.writeOrderBys(subWriter); err != nil {
return err
}
if err := statement.writeGroupBy(subWriter); err != nil {
return err
}
if _, err := fmt.Fprint(subWriter, "))"); err != nil {
return err
}
if statement.cond.IsValid() {
if _, err := fmt.Fprint(w, " AND "); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
}
return utils.WriteBuilder(w, subWriter)
}
func (statement *Statement) writeOracleLimit(columnStr string) func(w *builder.BytesWriter) error {
return func(w *builder.BytesWriter) error {
if statement.LimitN == nil {
return nil
}
oldString := w.String()
w.Reset()
rawColStr := columnStr
if rawColStr == "*" {
rawColStr = "at.*"
}
_, err := fmt.Fprintf(w, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, rawColStr, oldString, statement.Start+*statement.LimitN, statement.Start)
return err
}
}
func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error { func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error {
dbType := statement.dialect.URI().DBType dbType := statement.dialect.URI().DBType
if statement.isUsingLegacyLimitOffset() { if statement.isUsingLegacyLimitOffset() {

View File

@ -27,7 +27,7 @@ func (statement *Statement) Alias(alias string) *Statement {
return statement return statement
} }
func (statement *Statement) writeAlias(w builder.Writer) error { func (statement *Statement) writeAlias(w *builder.BytesWriter) error {
if statement.TableAlias != "" { if statement.TableAlias != "" {
if statement.dialect.URI().DBType == schemas.ORACLE { if statement.dialect.URI().DBType == schemas.ORACLE {
if _, err := fmt.Fprint(w, " ", statement.quote(statement.TableAlias)); err != nil { if _, err := fmt.Fprint(w, " ", statement.quote(statement.TableAlias)); err != nil {
@ -42,7 +42,7 @@ func (statement *Statement) writeAlias(w builder.Writer) error {
return nil return nil
} }
func (statement *Statement) writeTableName(w builder.Writer) error { func (statement *Statement) writeTableName(w *builder.BytesWriter) error {
if statement.dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { if statement.dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
if _, err := fmt.Fprint(w, statement.TableName()); err != nil { if _, err := fmt.Fprint(w, statement.TableName()); err != nil {
return err return err