Fix question mark replacement on postgres

This commit is contained in:
Lunny Xiao 2022-12-12 18:23:34 +08:00
parent f1bfc5ce98
commit 412319af23
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
4 changed files with 91 additions and 14 deletions

View File

@ -14,13 +14,13 @@ type Filter interface {
Do(sql string) string Do(sql string) string
} }
// SeqFilter filter SQL replace ?, ? ... to $1, $2 ... // postgresSeqFilter filter SQL replace ?, ? ... to $1, $2 ...
type SeqFilter struct { type postgresSeqFilter struct {
Prefix string Prefix string
Start int Start int
} }
func convertQuestionMark(sql, prefix string, start int) string { func postgresSeqFilterConvertQuestionMark(sql, prefix string, start int) string {
var buf strings.Builder var buf strings.Builder
var beginSingleQuote bool var beginSingleQuote bool
var isLineComment bool var isLineComment bool
@ -28,7 +28,73 @@ func convertQuestionMark(sql, prefix string, start int) string {
var isMaybeLineComment bool var isMaybeLineComment bool
var isMaybeComment bool var isMaybeComment bool
var isMaybeCommentEnd bool var isMaybeCommentEnd bool
var index = start var isMaybeJsonbQuestion bool
index := start
for i, c := range sql {
if !beginSingleQuote && !isLineComment && !isComment && !isMaybeJsonbQuestion && c == '?' {
buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
index++
} else {
if isMaybeJsonbQuestion && c == '?' {
isMaybeJsonbQuestion = false
} else if isMaybeLineComment {
if c == '-' {
isLineComment = true
}
isMaybeLineComment = false
} else if isMaybeComment {
if c == '*' {
isComment = true
}
isMaybeComment = false
} else if isMaybeCommentEnd {
if c == '/' {
isComment = false
}
isMaybeCommentEnd = false
} else if isLineComment {
if c == '\n' {
isLineComment = false
}
} else if isComment {
if c == '*' {
isMaybeCommentEnd = true
}
} else if !beginSingleQuote && c == '-' {
isMaybeLineComment = true
} else if !beginSingleQuote && c == '/' {
isMaybeComment = true
} else if !beginSingleQuote && c == ' ' && i >= 7 && strings.TrimSpace(sql[i-7:i]) == "::jsonb" {
isMaybeJsonbQuestion = true
} else if c == '\'' {
beginSingleQuote = !beginSingleQuote
}
buf.WriteRune(c)
}
}
return buf.String()
}
// Do implements Filter
func (s *postgresSeqFilter) Do(sql string) string {
return postgresSeqFilterConvertQuestionMark(sql, s.Prefix, s.Start)
}
// oracleSeqFilter filter SQL replace ?, ? ... to :1, :2 ...
type oracleSeqFilter struct {
Prefix string
Start int
}
func oracleSeqFilterConvertQuestionMark(sql, prefix string, start int) string {
var buf strings.Builder
var beginSingleQuote bool
var isLineComment bool
var isComment bool
var isMaybeLineComment bool
var isMaybeComment bool
var isMaybeCommentEnd bool
index := start
for _, c := range sql { for _, c := range sql {
if !beginSingleQuote && !isLineComment && !isComment && c == '?' { if !beginSingleQuote && !isLineComment && !isComment && c == '?' {
buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
@ -71,6 +137,6 @@ func convertQuestionMark(sql, prefix string, start int) string {
} }
// Do implements Filter // Do implements Filter
func (s *SeqFilter) Do(sql string) string { func (s *oracleSeqFilter) Do(sql string) string {
return convertQuestionMark(sql, s.Prefix, s.Start) return oracleSeqFilterConvertQuestionMark(sql, s.Prefix, s.Start)
} }

View File

@ -7,7 +7,7 @@ import (
) )
func TestSeqFilter(t *testing.T) { func TestSeqFilter(t *testing.T) {
var kases = map[string]string{ kases := map[string]string{
"SELECT * FROM TABLE1 WHERE a=? AND b=?": "SELECT * FROM TABLE1 WHERE a=$1 AND b=$2", "SELECT * FROM TABLE1 WHERE a=? AND b=?": "SELECT * FROM TABLE1 WHERE a=$1 AND b=$2",
"SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=? AND b=?": "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=$1 AND b=$2", "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=? AND b=?": "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=$1 AND b=$2",
"select '1''?' from issue": "select '1''?' from issue", "select '1''?' from issue": "select '1''?' from issue",
@ -16,12 +16,12 @@ func TestSeqFilter(t *testing.T) {
"select '1\\''?',? from issue": "select '1\\''?',$1 from issue", "select '1\\''?',? from issue": "select '1\\''?',$1 from issue",
} }
for sql, result := range kases { for sql, result := range kases {
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) assert.EqualValues(t, result, postgresSeqFilterConvertQuestionMark(sql, "$", 1))
} }
} }
func TestSeqFilterLineComment(t *testing.T) { func TestSeqFilterLineComment(t *testing.T) {
var kases = map[string]string{ kases := map[string]string{
`SELECT * `SELECT *
FROM TABLE1 FROM TABLE1
WHERE foo='bar' WHERE foo='bar'
@ -49,12 +49,12 @@ func TestSeqFilterLineComment(t *testing.T) {
AND b=$2`, AND b=$2`,
} }
for sql, result := range kases { for sql, result := range kases {
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) assert.EqualValues(t, result, postgresSeqFilterConvertQuestionMark(sql, "$", 1))
} }
} }
func TestSeqFilterComment(t *testing.T) { func TestSeqFilterComment(t *testing.T) {
var kases = map[string]string{ kases := map[string]string{
`SELECT * `SELECT *
FROM TABLE1 FROM TABLE1
WHERE a=? /* it's a comment */ WHERE a=? /* it's a comment */
@ -73,6 +73,17 @@ func TestSeqFilterComment(t *testing.T) {
AND b=$2`, AND b=$2`,
} }
for sql, result := range kases { for sql, result := range kases {
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) assert.EqualValues(t, result, postgresSeqFilterConvertQuestionMark(sql, "$", 1))
}
}
func TestSeqFilterPostgresQuestions(t *testing.T) {
kases := map[string]string{
`SELECT '{"a":1, "b":2}'::jsonb ? 'b'
FROM table1 WHERE c = ?`: `SELECT '{"a":1, "b":2}'::jsonb ? 'b'
FROM table1 WHERE c = $1`,
}
for sql, result := range kases {
assert.EqualValues(t, result, postgresSeqFilterConvertQuestionMark(sql, "$", 1))
} }
} }

View File

@ -859,7 +859,7 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam
func (db *oracle) Filters() []Filter { func (db *oracle) Filters() []Filter {
return []Filter{ return []Filter{
&SeqFilter{Prefix: ":", Start: 1}, &oracleSeqFilter{Prefix: ":", Start: 1},
} }
} }

View File

@ -1358,7 +1358,7 @@ func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, ta
} }
func (db *postgres) Filters() []Filter { func (db *postgres) Filters() []Filter {
return []Filter{&SeqFilter{Prefix: "$", Start: 1}} return []Filter{&postgresSeqFilter{Prefix: "$", Start: 1}}
} }
type pqDriver struct { type pqDriver struct {