Fix question mark replacement on postgres

This commit is contained in:
Lunny Xiao 2022-12-12 18:23:34 +08:00
parent f1bfc5ce98
commit 412319af23
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
4 changed files with 91 additions and 14 deletions

View File

@ -14,13 +14,13 @@ type Filter interface {
Do(sql string) string
}
// SeqFilter filter SQL replace ?, ? ... to $1, $2 ...
type SeqFilter struct {
// postgresSeqFilter filter SQL replace ?, ? ... to $1, $2 ...
type postgresSeqFilter struct {
Prefix string
Start int
}
func convertQuestionMark(sql, prefix string, start int) string {
func postgresSeqFilterConvertQuestionMark(sql, prefix string, start int) string {
var buf strings.Builder
var beginSingleQuote bool
var isLineComment bool
@ -28,7 +28,73 @@ func convertQuestionMark(sql, prefix string, start int) string {
var isMaybeLineComment bool
var isMaybeComment bool
var isMaybeCommentEnd bool
var index = start
var isMaybeJsonbQuestion bool
index := start
for i, c := range sql {
if !beginSingleQuote && !isLineComment && !isComment && !isMaybeJsonbQuestion && c == '?' {
buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
index++
} else {
if isMaybeJsonbQuestion && c == '?' {
isMaybeJsonbQuestion = false
} else if isMaybeLineComment {
if c == '-' {
isLineComment = true
}
isMaybeLineComment = false
} else if isMaybeComment {
if c == '*' {
isComment = true
}
isMaybeComment = false
} else if isMaybeCommentEnd {
if c == '/' {
isComment = false
}
isMaybeCommentEnd = false
} else if isLineComment {
if c == '\n' {
isLineComment = false
}
} else if isComment {
if c == '*' {
isMaybeCommentEnd = true
}
} else if !beginSingleQuote && c == '-' {
isMaybeLineComment = true
} else if !beginSingleQuote && c == '/' {
isMaybeComment = true
} else if !beginSingleQuote && c == ' ' && i >= 7 && strings.TrimSpace(sql[i-7:i]) == "::jsonb" {
isMaybeJsonbQuestion = true
} else if c == '\'' {
beginSingleQuote = !beginSingleQuote
}
buf.WriteRune(c)
}
}
return buf.String()
}
// Do implements Filter
func (s *postgresSeqFilter) Do(sql string) string {
return postgresSeqFilterConvertQuestionMark(sql, s.Prefix, s.Start)
}
// oracleSeqFilter filter SQL replace ?, ? ... to :1, :2 ...
type oracleSeqFilter struct {
Prefix string
Start int
}
func oracleSeqFilterConvertQuestionMark(sql, prefix string, start int) string {
var buf strings.Builder
var beginSingleQuote bool
var isLineComment bool
var isComment bool
var isMaybeLineComment bool
var isMaybeComment bool
var isMaybeCommentEnd bool
index := start
for _, c := range sql {
if !beginSingleQuote && !isLineComment && !isComment && c == '?' {
buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
@ -71,6 +137,6 @@ func convertQuestionMark(sql, prefix string, start int) string {
}
// Do implements Filter
func (s *SeqFilter) Do(sql string) string {
return convertQuestionMark(sql, s.Prefix, s.Start)
func (s *oracleSeqFilter) Do(sql string) string {
return oracleSeqFilterConvertQuestionMark(sql, s.Prefix, s.Start)
}

View File

@ -7,7 +7,7 @@ import (
)
func TestSeqFilter(t *testing.T) {
var kases = map[string]string{
kases := map[string]string{
"SELECT * FROM TABLE1 WHERE a=? AND b=?": "SELECT * FROM TABLE1 WHERE a=$1 AND b=$2",
"SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=? AND b=?": "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=$1 AND b=$2",
"select '1''?' from issue": "select '1''?' from issue",
@ -16,12 +16,12 @@ func TestSeqFilter(t *testing.T) {
"select '1\\''?',? from issue": "select '1\\''?',$1 from issue",
}
for sql, result := range kases {
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
assert.EqualValues(t, result, postgresSeqFilterConvertQuestionMark(sql, "$", 1))
}
}
func TestSeqFilterLineComment(t *testing.T) {
var kases = map[string]string{
kases := map[string]string{
`SELECT *
FROM TABLE1
WHERE foo='bar'
@ -49,12 +49,12 @@ func TestSeqFilterLineComment(t *testing.T) {
AND b=$2`,
}
for sql, result := range kases {
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
assert.EqualValues(t, result, postgresSeqFilterConvertQuestionMark(sql, "$", 1))
}
}
func TestSeqFilterComment(t *testing.T) {
var kases = map[string]string{
kases := map[string]string{
`SELECT *
FROM TABLE1
WHERE a=? /* it's a comment */
@ -73,6 +73,17 @@ func TestSeqFilterComment(t *testing.T) {
AND b=$2`,
}
for sql, result := range kases {
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
assert.EqualValues(t, result, postgresSeqFilterConvertQuestionMark(sql, "$", 1))
}
}
func TestSeqFilterPostgresQuestions(t *testing.T) {
kases := map[string]string{
`SELECT '{"a":1, "b":2}'::jsonb ? 'b'
FROM table1 WHERE c = ?`: `SELECT '{"a":1, "b":2}'::jsonb ? 'b'
FROM table1 WHERE c = $1`,
}
for sql, result := range kases {
assert.EqualValues(t, result, postgresSeqFilterConvertQuestionMark(sql, "$", 1))
}
}

View File

@ -859,7 +859,7 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam
func (db *oracle) Filters() []Filter {
return []Filter{
&SeqFilter{Prefix: ":", Start: 1},
&oracleSeqFilter{Prefix: ":", Start: 1},
}
}

View File

@ -1358,7 +1358,7 @@ func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, ta
}
func (db *postgres) Filters() []Filter {
return []Filter{&SeqFilter{Prefix: "$", Start: 1}}
return []Filter{&postgresSeqFilter{Prefix: "$", Start: 1}}
}
type pqDriver struct {