diff --git a/internal/statements/query.go b/internal/statements/query.go index e817403c..03a40bb0 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -318,7 +318,10 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac buf := builder.NewWriter() if statement.dialect.URI().DBType == schemas.MSSQL { - if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil { + if _, err := fmt.Fprintf(buf, "SELECT TOP 1 *"); err != nil { + return "", nil, err + } + if err := statement.writeFrom(buf); err != nil { return "", nil, err } if err := statement.writeJoins(buf); err != nil { @@ -328,7 +331,10 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac return "", nil, err } } else if statement.dialect.URI().DBType == schemas.ORACLE { - if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil { + if _, err := fmt.Fprintf(buf, "SELECT *"); err != nil { + return "", nil, err + } + if err := statement.writeFrom(buf); err != nil { return "", nil, err } if err := statement.writeJoins(buf); err != nil { @@ -349,7 +355,10 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac return "", nil, err } } else { - if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil { + if _, err := fmt.Fprintf(buf, "SELECT 1"); err != nil { + return "", nil, err + } + if err := statement.writeFrom(buf); err != nil { return "", nil, err } if err := statement.writeJoins(buf); err != nil { diff --git a/internal/statements/query_test.go b/internal/statements/query_test.go new file mode 100644 index 00000000..edd7e5a3 --- /dev/null +++ b/internal/statements/query_test.go @@ -0,0 +1,19 @@ +package statements + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestGenExistSQL(t *testing.T) { + statement, err := createTestStatement() + assert.NoError(t, err) + + statement.RefTable = nil + statement.SetTable("testDB") + statement.Alias("tdb") + statement.Where(`tdb.id=1`) + sql, _, err := statement.GenExistSQL() + assert.NoError(t, err) + assert.Equal(t, "SELECT 1 FROM `testDB` AS `tdb` WHERE (tdb.id=1) LIMIT 1", sql) +}