diff --git a/internal/statements/query.go b/internal/statements/query.go index 137ad02f..cc559c45 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -382,10 +382,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac return statement.GenRawSQL(), statement.RawParams, nil } - var sqlStr string - var args []interface{} var joinStr string - var err error var b interface{} if len(bean) > 0 { b = bean[0] @@ -404,45 +401,61 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if len(tableName) <= 0 { return "", nil, ErrTableNotFound } - if statement.RefTable == nil { - tableName = statement.quote(tableName) - if len(statement.JoinStr) > 0 { - joinStr = statement.JoinStr - } + if statement.RefTable != nil { + return statement.Limit(1).GenGetSQL(b) + } + tableName = statement.quote(tableName) + if len(statement.JoinStr) > 0 { + joinStr = " " + statement.JoinStr + " " + } + + buf := builder.NewWriter() + if statement.dialect.URI().DBType == schemas.MSSQL { + if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s%s", tableName, joinStr); err != nil { + return "", nil, err + } if statement.Conds().IsValid() { - condSQL, condArgs, err := statement.GenCondSQL(statement.Conds()) - if err != nil { + if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { return "", nil, err } - - if statement.dialect.URI().DBType == schemas.MSSQL { - sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL) - } else if statement.dialect.URI().DBType == schemas.ORACLE { - sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL) - } else { - sqlStr = fmt.Sprintf("SELECT 1 FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL) + if err := statement.Conds().WriteTo(buf); err != nil { + return "", nil, err } - args = condArgs - } else { - if statement.dialect.URI().DBType == schemas.MSSQL { - sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) - } else if statement.dialect.URI().DBType == schemas.ORACLE { - sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) - } else { - sqlStr = fmt.Sprintf("SELECT 1 FROM %s %s LIMIT 1", tableName, joinStr) + } + } else if statement.dialect.URI().DBType == schemas.ORACLE { + if _, err := fmt.Fprintf(buf, "SELECT * FROM %s%s WHERE ", tableName, joinStr); err != nil { + return "", nil, err + } + if statement.Conds().IsValid() { + if err := statement.Conds().WriteTo(buf); err != nil { + return "", nil, err } - args = []interface{}{} + if _, err := fmt.Fprintf(buf, " AND "); err != nil { + return "", nil, err + } + } + if _, err := fmt.Fprintf(buf, "ROWNUM=1"); err != nil { + return "", nil, err } } else { - statement.Limit(1) - sqlStr, args, err = statement.GenGetSQL(b) - if err != nil { + if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s%s", tableName, joinStr); err != nil { + return "", nil, err + } + if statement.Conds().IsValid() { + if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { + return "", nil, err + } + if err := statement.Conds().WriteTo(buf); err != nil { + return "", nil, err + } + } + if _, err := fmt.Fprintf(buf, "LIMIT 1"); err != nil { return "", nil, err } } - return sqlStr, args, nil + return buf.String(), buf.Args(), nil } // GenFindSQL generates Find SQL