From c3311fae2b60f7421d7cb4e054b5133c3a25d151 Mon Sep 17 00:00:00 2001 From: hsfish Date: Wed, 19 Feb 2025 18:31:43 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dstatement=20GenExistSQ?= =?UTF-8?q?L=20=E5=8F=AA=E6=9C=89=E8=A1=A8=E5=90=8D=EF=BC=8C=E6=B2=A1?= =?UTF-8?q?=E6=9C=89RefTable=E6=97=B6=EF=BC=8C=E6=9C=AA=E6=8B=BC=E6=8E=A5?= =?UTF-8?q?=E5=88=AB=E5=90=8DAlias=E7=9A=84BUG?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/statements/query.go | 15 ++++++++++++--- internal/statements/query_test.go | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 internal/statements/query_test.go diff --git a/internal/statements/query.go b/internal/statements/query.go index e817403c..03a40bb0 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -318,7 +318,10 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac buf := builder.NewWriter() if statement.dialect.URI().DBType == schemas.MSSQL { - if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil { + if _, err := fmt.Fprintf(buf, "SELECT TOP 1 *"); err != nil { + return "", nil, err + } + if err := statement.writeFrom(buf); err != nil { return "", nil, err } if err := statement.writeJoins(buf); err != nil { @@ -328,7 +331,10 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac return "", nil, err } } else if statement.dialect.URI().DBType == schemas.ORACLE { - if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil { + if _, err := fmt.Fprintf(buf, "SELECT *"); err != nil { + return "", nil, err + } + if err := statement.writeFrom(buf); err != nil { return "", nil, err } if err := statement.writeJoins(buf); err != nil { @@ -349,7 +355,10 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac return "", nil, err } } else { - if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil { + if _, err := fmt.Fprintf(buf, "SELECT 1"); err != nil { + return "", nil, err + } + if err := statement.writeFrom(buf); err != nil { return "", nil, err } if err := statement.writeJoins(buf); err != nil { diff --git a/internal/statements/query_test.go b/internal/statements/query_test.go new file mode 100644 index 00000000..edd7e5a3 --- /dev/null +++ b/internal/statements/query_test.go @@ -0,0 +1,19 @@ +package statements + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestGenExistSQL(t *testing.T) { + statement, err := createTestStatement() + assert.NoError(t, err) + + statement.RefTable = nil + statement.SetTable("testDB") + statement.Alias("tdb") + statement.Where(`tdb.id=1`) + sql, _, err := statement.GenExistSQL() + assert.NoError(t, err) + assert.Equal(t, "SELECT 1 FROM `testDB` AS `tdb` WHERE (tdb.id=1) LIMIT 1", sql) +}