diff --git a/internal/statements/query.go b/internal/statements/query.go index 83c9cfd5..384b8a62 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -269,7 +269,11 @@ func (statement *Statement) writeForUpdate(w io.Writer) error { } func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) error { - if statement.dialect.URI().DBType != schemas.MSSQL || statement.LimitN == nil { + if statement.RefTable == nil { + return errors.New("unsupported query limit without reference table") + } + + if statement.dialect.URI().DBType != schemas.MSSQL || statement.Start <= 0 { return nil } @@ -277,51 +281,48 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err if err := statement.writeTop(mssqlCondi); err != nil { return err } - if statement.Start > 0 { - if statement.RefTable == nil { - return errors.New("unsupported query limit without reference table") + + var column string + if len(statement.RefTable.PKColumns()) == 0 { + for _, index := range statement.RefTable.Indexes { + if len(index.Cols) == 1 { + column = index.Cols[0] + break + } } - var column string - if len(statement.RefTable.PKColumns()) == 0 { - for _, index := range statement.RefTable.Indexes { - if len(index.Cols) == 1 { - column = index.Cols[0] - break - } - } - if len(column) == 0 { - column = statement.RefTable.ColumnsSeq()[0] - } + if len(column) == 0 { + column = statement.RefTable.ColumnsSeq()[0] + } + } else { + column = statement.RefTable.PKColumns()[0].Name + } + if statement.NeedTableName() { + if len(statement.TableAlias) > 0 { + column = fmt.Sprintf("%s.%s", statement.TableAlias, column) } else { - column = statement.RefTable.PKColumns()[0].Name - } - if statement.NeedTableName() { - if len(statement.TableAlias) > 0 { - column = fmt.Sprintf("%s.%s", statement.TableAlias, column) - } else { - column = fmt.Sprintf("%s.%s", statement.TableName(), column) - } - } - if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s", - column, statement.Start, column); err != nil { - return err - } - if err := statement.writeFrom(mssqlCondi); err != nil { - return err - } - if err := statement.writeWhere(mssqlCondi); err != nil { - return err - } - if err := statement.WriteOrderBy(mssqlCondi); err != nil { - return err - } - if err := statement.writeGroupBy(mssqlCondi); err != nil { - return err - } - if _, err := fmt.Fprint(mssqlCondi, "))"); err != nil { - return err + column = fmt.Sprintf("%s.%s", statement.TableName(), column) } } + if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s", + column, statement.Start, column); err != nil { + return err + } + if err := statement.writeFrom(mssqlCondi); err != nil { + return err + } + if err := statement.writeWhere(mssqlCondi); err != nil { + return err + } + if err := statement.WriteOrderBy(mssqlCondi); err != nil { + return err + } + if err := statement.writeGroupBy(mssqlCondi); err != nil { + return err + } + if _, err := fmt.Fprint(mssqlCondi, "))"); err != nil { + return err + } + if statement.cond.IsValid() { if _, err := fmt.Fprint(w, " AND "); err != nil { return err