diff --git a/dialects/filter.go b/dialects/filter.go index bfe2e93e..c5f7daa3 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -14,13 +14,13 @@ type Filter interface { Do(sql string) string } -// SeqFilter filter SQL replace ?, ? ... to $1, $2 ... -type SeqFilter struct { +// postgresSeqFilter filter SQL replace ?, ? ... to $1, $2 ... +type postgresSeqFilter struct { Prefix string Start int } -func convertQuestionMark(sql, prefix string, start int) string { +func postgresSeqFilterConvertQuestionMark(sql, prefix string, start int) string { var buf strings.Builder var beginSingleQuote bool var isLineComment bool @@ -28,7 +28,73 @@ func convertQuestionMark(sql, prefix string, start int) string { var isMaybeLineComment bool var isMaybeComment bool var isMaybeCommentEnd bool - var index = start + var isMaybeJsonbQuestion bool + index := start + for i, c := range sql { + if !beginSingleQuote && !isLineComment && !isComment && !isMaybeJsonbQuestion && c == '?' { + buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) + index++ + } else { + if isMaybeJsonbQuestion && c == '?' { + isMaybeJsonbQuestion = false + } else if isMaybeLineComment { + if c == '-' { + isLineComment = true + } + isMaybeLineComment = false + } else if isMaybeComment { + if c == '*' { + isComment = true + } + isMaybeComment = false + } else if isMaybeCommentEnd { + if c == '/' { + isComment = false + } + isMaybeCommentEnd = false + } else if isLineComment { + if c == '\n' { + isLineComment = false + } + } else if isComment { + if c == '*' { + isMaybeCommentEnd = true + } + } else if !beginSingleQuote && c == '-' { + isMaybeLineComment = true + } else if !beginSingleQuote && c == '/' { + isMaybeComment = true + } else if !beginSingleQuote && c == ' ' && i >= 7 && strings.TrimSpace(sql[i-7:i]) == "::jsonb" { + isMaybeJsonbQuestion = true + } else if c == '\'' { + beginSingleQuote = !beginSingleQuote + } + buf.WriteRune(c) + } + } + return buf.String() +} + +// Do implements Filter +func (s *postgresSeqFilter) Do(sql string) string { + return postgresSeqFilterConvertQuestionMark(sql, s.Prefix, s.Start) +} + +// oracleSeqFilter filter SQL replace ?, ? ... to :1, :2 ... +type oracleSeqFilter struct { + Prefix string + Start int +} + +func oracleSeqFilterConvertQuestionMark(sql, prefix string, start int) string { + var buf strings.Builder + var beginSingleQuote bool + var isLineComment bool + var isComment bool + var isMaybeLineComment bool + var isMaybeComment bool + var isMaybeCommentEnd bool + index := start for _, c := range sql { if !beginSingleQuote && !isLineComment && !isComment && c == '?' { buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) @@ -71,6 +137,6 @@ func convertQuestionMark(sql, prefix string, start int) string { } // Do implements Filter -func (s *SeqFilter) Do(sql string) string { - return convertQuestionMark(sql, s.Prefix, s.Start) +func (s *oracleSeqFilter) Do(sql string) string { + return oracleSeqFilterConvertQuestionMark(sql, s.Prefix, s.Start) } diff --git a/dialects/filter_test.go b/dialects/filter_test.go index 15050656..e0da11df 100644 --- a/dialects/filter_test.go +++ b/dialects/filter_test.go @@ -7,7 +7,7 @@ import ( ) func TestSeqFilter(t *testing.T) { - var kases = map[string]string{ + kases := map[string]string{ "SELECT * FROM TABLE1 WHERE a=? AND b=?": "SELECT * FROM TABLE1 WHERE a=$1 AND b=$2", "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=? AND b=?": "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=$1 AND b=$2", "select '1''?' from issue": "select '1''?' from issue", @@ -16,12 +16,12 @@ func TestSeqFilter(t *testing.T) { "select '1\\''?',? from issue": "select '1\\''?',$1 from issue", } for sql, result := range kases { - assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + assert.EqualValues(t, result, postgresSeqFilterConvertQuestionMark(sql, "$", 1)) } } func TestSeqFilterLineComment(t *testing.T) { - var kases = map[string]string{ + kases := map[string]string{ `SELECT * FROM TABLE1 WHERE foo='bar' @@ -49,12 +49,12 @@ func TestSeqFilterLineComment(t *testing.T) { AND b=$2`, } for sql, result := range kases { - assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + assert.EqualValues(t, result, postgresSeqFilterConvertQuestionMark(sql, "$", 1)) } } func TestSeqFilterComment(t *testing.T) { - var kases = map[string]string{ + kases := map[string]string{ `SELECT * FROM TABLE1 WHERE a=? /* it's a comment */ @@ -73,6 +73,17 @@ func TestSeqFilterComment(t *testing.T) { AND b=$2`, } for sql, result := range kases { - assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + assert.EqualValues(t, result, postgresSeqFilterConvertQuestionMark(sql, "$", 1)) + } +} + +func TestSeqFilterPostgresQuestions(t *testing.T) { + kases := map[string]string{ + `SELECT '{"a":1, "b":2}'::jsonb ? 'b' + FROM table1 WHERE c = ?`: `SELECT '{"a":1, "b":2}'::jsonb ? 'b' + FROM table1 WHERE c = $1`, + } + for sql, result := range kases { + assert.EqualValues(t, result, postgresSeqFilterConvertQuestionMark(sql, "$", 1)) } } diff --git a/dialects/oracle.go b/dialects/oracle.go index ce91cd5d..0ff52f84 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -859,7 +859,7 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam func (db *oracle) Filters() []Filter { return []Filter{ - &SeqFilter{Prefix: ":", Start: 1}, + &oracleSeqFilter{Prefix: ":", Start: 1}, } } diff --git a/dialects/postgres.go b/dialects/postgres.go index f9de5859..0c2e212d 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1358,7 +1358,7 @@ func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, ta } func (db *postgres) Filters() []Filter { - return []Filter{&SeqFilter{Prefix: "$", Start: 1}} + return []Filter{&postgresSeqFilter{Prefix: "$", Start: 1}} } type pqDriver struct {