From 04e6027f525a91b7ad30382a2c5ca8979a595ce6 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 1 Mar 2020 21:19:59 +0800 Subject: [PATCH] Fix tests --- dialects/filter.go | 23 +++++++++++------------ dialects/filter_test.go | 29 +++++++++++++++++++++++------ 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/dialects/filter.go b/dialects/filter.go index 3f7550a8..0722363f 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -29,13 +29,10 @@ func (s *QuoteFilter) Do(sql string) string { var buf strings.Builder buf.Grow(len(sql)) - var inSingleQuote bool + var beginSingleQuote bool var prefix = true for i := 0; i < len(sql); i++ { - if sql[i] == '\'' && (i == 0 || sql[i-1] != '\\') { - inSingleQuote = !inSingleQuote - } - if !inSingleQuote && sql[i] == '`' { + if !beginSingleQuote && sql[i] == '`' { if prefix { buf.WriteByte(s.quoter.Prefix) } else { @@ -43,6 +40,9 @@ func (s *QuoteFilter) Do(sql string) string { } prefix = !prefix } else { + if sql[i] == '\'' { + beginSingleQuote = !beginSingleQuote + } buf.WriteByte(sql[i]) } } @@ -57,23 +57,22 @@ type SeqFilter struct { } func convertQuestionMark(sql, prefix string, start int) string { - var ( - buf strings.Builder - beginSingleQuote bool - index = start - ) - for i, c := range sql { + var buf strings.Builder + var beginSingleQuote bool + var index = start + for _, c := range sql { if !beginSingleQuote && c == '?' { buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) index++ } else { - if c == '\'' && (i > 0 && sql[i-1] != '\\') { + if c == '\'' { beginSingleQuote = !beginSingleQuote } buf.WriteRune(c) } } return buf.String() + } func (s *SeqFilter) Do(sql string) string { diff --git a/dialects/filter_test.go b/dialects/filter_test.go index 6b5101c8..f75fb4d9 100644 --- a/dialects/filter_test.go +++ b/dialects/filter_test.go @@ -10,12 +10,29 @@ import ( func TestQuoteFilter_Do(t *testing.T) { f := QuoteFilter{schemas.Quoter{'[', ']', schemas.AlwaysReverse}} - sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" - res := f.Do(sql) - assert.EqualValues(t, - "SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", - res, - ) + var kases = []struct { + source string + expected string + }{ + { + "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?", + "SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", + }, + { + "SELECT 'abc```test```''', `a` FROM b", + "SELECT 'abc```test```''', [a] FROM b", + }, + { + "UPDATE table SET `a` = ~ `a`, `b`='abc`'", + "UPDATE table SET [a] = ~ [a], [b]='abc`'", + }, + } + + for _, kase := range kases { + t.Run(kase.source, func(t *testing.T) { + assert.EqualValues(t, kase.expected, f.Do(kase.source)) + }) + } } func TestSeqFilter(t *testing.T) {