diff --git a/dialects/dameng.go b/dialects/dameng.go index f4a075d5..70a22666 100644 --- a/dialects/dameng.go +++ b/dialects/dameng.go @@ -509,13 +509,17 @@ var ( "ZONE": true, } - damengQuoter = schemas.Quoter{ - Prefix: '"', - Suffix: '"', - IsReserved: schemas.AlwaysReserve, - } + damengQuoter schemas.Quoter ) +func init() { + var err error + damengQuoter, err = schemas.NewQuoter('"', '"', schemas.AlwaysReserve) + if err != nil { + panic(err) + } +} + type dameng struct { Base } @@ -729,13 +733,11 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl func (db *dameng) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: - var q = damengQuoter - q.IsReserved = schemas.AlwaysNoReserve - db.quoter = q + db.quoter = damengQuoter + db.quoter.SetIsReserved(schemas.AlwaysNoReserve) case QuotePolicyReserved: - var q = damengQuoter - q.IsReserved = db.IsReserved - db.quoter = q + db.quoter = damengQuoter + db.quoter.SetIsReserved(db.IsReserved) case QuotePolicyAlways: fallthrough default: diff --git a/dialects/mssql.go b/dialects/mssql.go index 706a754a..5724aec2 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -206,13 +206,17 @@ var ( "PROC": true, } - mssqlQuoter = schemas.Quoter{ - Prefix: '[', - Suffix: ']', - IsReserved: schemas.AlwaysReserve, - } + mssqlQuoter schemas.Quoter ) +func init() { + var err error + mssqlQuoter, err = schemas.NewQuoter('[', ']', schemas.AlwaysReserve) + if err != nil { + panic(err) + } +} + type mssql struct { Base defaultVarchar string @@ -403,13 +407,11 @@ func (db *mssql) IsReserved(name string) bool { func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: - var q = mssqlQuoter - q.IsReserved = schemas.AlwaysNoReserve - db.quoter = q + db.quoter = mssqlQuoter + db.quoter.SetIsReserved(schemas.AlwaysNoReserve) case QuotePolicyReserved: - var q = mssqlQuoter - q.IsReserved = db.IsReserved - db.quoter = q + db.quoter = mssqlQuoter + db.quoter.SetIsReserved(db.IsReserved) case QuotePolicyAlways: fallthrough default: diff --git a/dialects/mysql.go b/dialects/mysql.go index 31e7b788..eb1c25c6 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -162,13 +162,17 @@ var ( "ZEROFILL": true, } - mysqlQuoter = schemas.Quoter{ - Prefix: '`', - Suffix: '`', - IsReserved: schemas.AlwaysReserve, - } + mysqlQuoter schemas.Quoter ) +func init() { + var err error + mysqlQuoter, err = schemas.NewQuoter('`', '`', schemas.AlwaysReserve) + if err != nil { + panic(err) + } +} + type mysql struct { Base rowFormat string @@ -560,13 +564,11 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: - q := mysqlQuoter - q.IsReserved = schemas.AlwaysNoReserve - db.quoter = q + db.quoter = mysqlQuoter + db.quoter.SetIsReserved(schemas.AlwaysNoReserve) case QuotePolicyReserved: - q := mysqlQuoter - q.IsReserved = db.IsReserved - db.quoter = q + db.quoter = mysqlQuoter + db.quoter.SetIsReserved(db.IsReserved) case QuotePolicyAlways: fallthrough default: diff --git a/dialects/oracle.go b/dialects/oracle.go index 04652bd6..4213b967 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -500,13 +500,17 @@ var ( "ZONE": true, } - oracleQuoter = schemas.Quoter{ - Prefix: '"', - Suffix: '"', - IsReserved: schemas.AlwaysReserve, - } + oracleQuoter schemas.Quoter ) +func init() { + var err error + oracleQuoter, err = schemas.NewQuoter('"', '"', schemas.AlwaysReserve) + if err != nil { + panic(err) + } +} + type oracle struct { Base } @@ -641,13 +645,11 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: - var q = oracleQuoter - q.IsReserved = schemas.AlwaysNoReserve - db.quoter = q + db.quoter = oracleQuoter + db.quoter.SetIsReserved(schemas.AlwaysNoReserve) case QuotePolicyReserved: - var q = oracleQuoter - q.IsReserved = db.IsReserved - db.quoter = q + db.quoter = oracleQuoter + db.quoter.SetIsReserved(db.IsReserved) case QuotePolicyAlways: fallthrough default: diff --git a/dialects/postgres.go b/dialects/postgres.go index ba73aad7..bdbf2ec6 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -768,13 +768,17 @@ var ( "ZONE": true, } - postgresQuoter = schemas.Quoter{ - Prefix: '"', - Suffix: '"', - IsReserved: schemas.AlwaysReserve, - } + postgresQuoter schemas.Quoter ) +func init() { + var err error + postgresQuoter, err = schemas.NewQuoter('"', '"', schemas.AlwaysReserve) + if err != nil { + panic(err) + } +} + var ( // DefaultPostgresSchema default postgres schema DefaultPostgresSchema = "public" @@ -862,13 +866,11 @@ func (db *postgres) needQuote(name string) bool { func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: - q := postgresQuoter - q.IsReserved = schemas.AlwaysNoReserve - db.quoter = q + db.quoter = postgresQuoter + db.quoter.SetIsReserved(schemas.AlwaysNoReserve) case QuotePolicyReserved: - q := postgresQuoter - q.IsReserved = db.needQuote - db.quoter = q + db.quoter = postgresQuoter + db.quoter.SetIsReserved(db.IsReserved) case QuotePolicyAlways: fallthrough default: diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 4ff9a39e..031a7882 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -144,13 +144,17 @@ var ( "WITHOUT": true, } - sqlite3Quoter = schemas.Quoter{ - Prefix: '`', - Suffix: '`', - IsReserved: schemas.AlwaysReserve, - } + sqlite3Quoter schemas.Quoter ) +func init() { + var err error + sqlite3Quoter, err = schemas.NewQuoter('`', '`', schemas.AlwaysReserve) + if err != nil { + panic(err) + } +} + type sqlite3 struct { Base } @@ -193,13 +197,11 @@ func (db *sqlite3) Features() *DialectFeatures { func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: - var q = sqlite3Quoter - q.IsReserved = schemas.AlwaysNoReserve - db.quoter = q + db.quoter = sqlite3Quoter + db.quoter.SetIsReserved(schemas.AlwaysNoReserve) case QuotePolicyReserved: - var q = sqlite3Quoter - q.IsReserved = db.IsReserved - db.quoter = q + db.quoter = sqlite3Quoter + db.quoter.SetIsReserved(db.IsReserved) case QuotePolicyAlways: fallthrough default: diff --git a/schemas/quote.go b/schemas/quote.go index 4cab30fe..587b39e5 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -5,14 +5,52 @@ package schemas import ( + "fmt" + "regexp" "strings" ) // Quoter represents a quoter to the SQL table name and column name type Quoter struct { - Prefix byte - Suffix byte - IsReserved func(string) bool + prefix byte + suffix byte + isReserved func(string) bool + re *regexp.Regexp +} + +var regexCharsToEscape = map[byte]struct{}{ + '[': {}, ']': {}, '(': {}, ')': {}, '{': {}, '}': {}, '*': {}, '+': {}, '?': {}, '|': {}, '^': {}, '$': {}, '.': {}, '\\': {}, '`': {}, +} + +func escapedRegexBytes(in byte) []byte { + if _, ok := regexCharsToEscape[in]; !ok { + return []byte{in} + } + return []byte{'\\', in} +} + +func NewQuoter(prefix byte, suffix byte, isReserved func(string) bool) (Quoter, error) { + regexPrefix := escapedRegexBytes(prefix) + regexSuffix := escapedRegexBytes(suffix) + + regex := fmt.Sprintf(`(?i)^\s*([^.\s]+|\x60[^.\s]+\x60|%s[^.\s]+%s)(?:\s*\.\s*([^.\s]+|\x60[^.\s]+\x60|%s[^.\s]+%s))?\s*?(?:\s+as\s+([^.\s]+|\x60[^.\s]+\x60|%s[^.\s]+%s))?(?:\s+(use|force)\s+index\s+\(([^.\s]+|\x60[^.\s]+\x60|%s[^.\s]+%s)\))?\s*$`, + regexPrefix, regexSuffix, regexPrefix, regexSuffix, regexPrefix, regexSuffix, regexPrefix, regexSuffix, + ) + re, err := regexp.Compile(regex) + if err != nil { + return Quoter{}, err + } + + return Quoter{ + prefix: prefix, + suffix: suffix, + isReserved: isReserved, + re: re, + }, nil +} + +func (q *Quoter) SetIsReserved(isReserved func(string) bool) { + q.isReserved = isReserved } var ( @@ -22,16 +60,24 @@ var ( // AlwaysReserve always reverse the word AlwaysReserve = func(string) bool { return true } - // CommanQuoteMark represnets the common quote mark - CommanQuoteMark byte = '`' + // CommonQuoteMark represents the common quote mark + CommonQuoteMark byte = '`' - // CommonQuoter represetns a common quoter - CommonQuoter = Quoter{CommanQuoteMark, CommanQuoteMark, AlwaysReserve} + // CommonQuoter represents a common quoter + CommonQuoter Quoter ) +func init() { + var err error + CommonQuoter, err = NewQuoter(CommonQuoteMark, CommonQuoteMark, AlwaysReserve) + if err != nil { + panic(err) + } +} + // IsEmpty return true if no prefix and suffix func (q Quoter) IsEmpty() bool { - return q.Prefix == 0 && q.Suffix == 0 + return q.prefix == 0 && q.suffix == 0 } // Quote quote a string @@ -50,10 +96,10 @@ func (q Quoter) Trim(s string) string { var buf strings.Builder for i := 0; i < len(s); i++ { switch { - case i == 0 && s[i] == q.Prefix: - case i == len(s)-1 && s[i] == q.Suffix: - case s[i] == q.Suffix && s[i+1] == '.': - case s[i] == q.Prefix && s[i-1] == '.': + case i == 0 && s[i] == q.prefix: + case i == len(s)-1 && s[i] == q.suffix: + case s[i] == q.suffix && s[i+1] == '.': + case s[i] == q.prefix && s[i-1] == '.': default: buf.WriteByte(s[i]) } @@ -93,51 +139,10 @@ func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error { return nil } -func findWord(v string, start int) int { - for j := start; j < len(v); j++ { - switch v[j] { - case '.', ' ': - return j - } - } - return len(v) -} - -func findStart(value string, start int) int { - if value[start] == '.' { - return start + 1 - } - if value[start] != ' ' { - return start - } - - var k = -1 - for j := start; j < len(value); j++ { - if value[j] != ' ' { - k = j - break - } - } - if k == -1 { - return len(value) - } - - if (value[k] == 'A' || value[k] == 'a') && (value[k+1] == 'S' || value[k+1] == 's') { - k += 2 - } - - for j := k; j < len(value); j++ { - if value[j] != ' ' { - return j - } - } - return len(value) -} - func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error { var realWord = word - if (word[0] == CommanQuoteMark && word[len(word)-1] == CommanQuoteMark) || - (word[0] == q.Prefix && word[len(word)-1] == q.Suffix) { + if (word[0] == CommonQuoteMark && word[len(word)-1] == CommonQuoteMark) || + (word[0] == q.prefix && word[len(word)-1] == q.suffix) { realWord = word[1 : len(word)-1] } @@ -146,9 +151,9 @@ func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error { return err } - isReserved := q.IsReserved(realWord) + isReserved := q.isReserved(realWord) if isReserved && realWord != "*" { - if err := buf.WriteByte(q.Prefix); err != nil { + if err := buf.WriteByte(q.prefix); err != nil { return err } } @@ -156,7 +161,7 @@ func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error { return err } if isReserved && realWord != "*" { - return buf.WriteByte(q.Suffix) + return buf.WriteByte(q.suffix) } return nil @@ -174,25 +179,48 @@ func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error { // schema.[name] -> [schema].[name] // name AS a -> [name] AS a // schema.name AS a -> [schema].[name] AS a -func (q Quoter) QuoteTo(buf *strings.Builder, value string) error { - var i int - for i < len(value) { - start := findStart(value, i) - if start > i { - if _, err := buf.WriteString(value[i:start]); err != nil { - return err - } - } - if start == len(value) { - return nil - } +func (q Quoter) QuoteTo(buf *strings.Builder, value string) (err error) { + matches := q.re.FindStringSubmatch(value) + if len(matches) != 6 { + return fmt.Errorf("unable to determine quoting for %q", value) + } - var nextEnd = findWord(value, start) - if err := q.quoteWordTo(buf, value[start:nextEnd]); err != nil { + schema := matches[1] + table := matches[2] + alias := matches[3] + indexMethod := matches[4] + index := matches[5] + if table == "" { + table = schema + schema = "" + } + + if schema != "" { + if err = q.quoteWordTo(buf, schema); err != nil { return err } - i = nextEnd + buf.WriteByte('.') } + if err = q.quoteWordTo(buf, table); err != nil { + return err + } + if alias != "" { + buf.WriteString(" AS ") + if err = q.quoteWordTo(buf, alias); err != nil { + return err + } + } + if index != "" { + _, err = fmt.Fprintf(buf, " %s index (", indexMethod) + if err != nil { + return err + } + if err = q.quoteWordTo(buf, index); err != nil { + return err + } + buf.WriteByte(')') + } + return nil } @@ -216,21 +244,21 @@ func (q Quoter) Replace(sql string) string { var beginSingleQuote bool for i := 0; i < len(sql); i++ { - if !beginSingleQuote && sql[i] == CommanQuoteMark { + if !beginSingleQuote && sql[i] == CommonQuoteMark { var j = i + 1 for ; j < len(sql); j++ { - if sql[j] == CommanQuoteMark { + if sql[j] == CommonQuoteMark { break } } word := sql[i+1 : j] - isReserved := q.IsReserved(word) + isReserved := q.isReserved(word) if isReserved { - buf.WriteByte(q.Prefix) + buf.WriteByte(q.prefix) } buf.WriteString(word) if isReserved { - buf.WriteByte(q.Suffix) + buf.WriteByte(q.suffix) } i = j } else { diff --git a/schemas/quote_test.go b/schemas/quote_test.go index f84dfb7d..2cc6235a 100644 --- a/schemas/quote_test.go +++ b/schemas/quote_test.go @@ -12,9 +12,10 @@ import ( ) func TestAlwaysQuoteTo(t *testing.T) { + quoter, err := NewQuoter('[', ']', AlwaysReserve) + assert.NoError(t, err) var ( - quoter = Quoter{'[', ']', AlwaysReserve} - kases = []struct { + kases = []struct { expected string value string }{ @@ -33,12 +34,15 @@ func TestAlwaysQuoteTo(t *testing.T) { {`["myschema].[mytable"]`, `"myschema.mytable"`}, {"[message_user] AS [sender]", "`message_user` AS `sender`"}, {"[myschema].[mytable] AS [table]", "myschema.mytable AS table"}, - {" [mytable]", " mytable"}, - {" [mytable]", " mytable"}, - {"[mytable] ", "mytable "}, - {"[mytable] ", "mytable "}, - {" [mytable] ", " mytable "}, - {" [mytable] ", " mytable "}, + {"[mytable]", " mytable"}, + {"[mytable]", " mytable"}, + {"[mytable]", "mytable "}, + {"[mytable]", "mytable "}, + {"[mytable]", " mytable "}, + {"[mytable]", " mytable "}, + {"[table] AS [t] use index ([myindex])", "`table` AS `t` use index (`myindex`)"}, + {"[table] AS [t] use index ([myindex])", "`table` AS `t` use index (`myindex`) "}, + {"[table] AS [t] force index ([myindex])", "table AS t force index (myindex) "}, } ) @@ -53,10 +57,11 @@ func TestAlwaysQuoteTo(t *testing.T) { } func TestReversedQuoteTo(t *testing.T) { + quoter, err := NewQuoter('[', ']', func(s string) bool { + return s == "mytable" + }) + assert.NoError(t, err) var ( - quoter = Quoter{'[', ']', func(s string) bool { - return s == "mytable" - }} kases = []struct { expected string value string @@ -82,16 +87,18 @@ func TestReversedQuoteTo(t *testing.T) { for _, v := range kases { t.Run(v.value, func(t *testing.T) { buf := &strings.Builder{} - quoter.QuoteTo(buf, v.value) + err := quoter.QuoteTo(buf, v.value) + assert.NoError(t, err) assert.EqualValues(t, v.expected, buf.String()) }) } } func TestNoQuoteTo(t *testing.T) { + quoter, err := NewQuoter('[', ']', AlwaysNoReserve) + assert.NoError(t, err) var ( - quoter = Quoter{'[', ']', AlwaysNoReserve} - kases = []struct { + kases = []struct { expected string value string }{ @@ -125,7 +132,8 @@ func TestNoQuoteTo(t *testing.T) { func TestJoin(t *testing.T) { cols := []string{"f1", "f2", "f3"} - quoter := Quoter{'[', ']', AlwaysReserve} + quoter, err := NewQuoter('[', ']', AlwaysReserve) + assert.NoError(t, err) assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ",")) @@ -133,13 +141,14 @@ func TestJoin(t *testing.T) { assert.EqualValues(t, "[f1], [f2], [f3]", quoter.Join(cols, ", ")) - quoter.IsReserved = AlwaysNoReserve + quoter.SetIsReserved(AlwaysNoReserve) assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", ")) } func TestStrings(t *testing.T) { cols := []string{"f1", "f2", "t3.f3", "t4.*"} - quoter := Quoter{'[', ']', AlwaysReserve} + quoter, err := NewQuoter('[', ']', AlwaysReserve) + assert.NoError(t, err) quotedCols := quoter.Strings(cols) assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]", "[t4].*"}, quotedCols) @@ -151,14 +160,18 @@ func TestTrim(t *testing.T) { "[schema].[table_name]": "schema.table_name", } + quoter, err := NewQuoter('[', ']', AlwaysReserve) + assert.NoError(t, err) + for src, dst := range kases { assert.EqualValues(t, src, CommonQuoter.Trim(src)) - assert.EqualValues(t, dst, Quoter{'[', ']', AlwaysReserve}.Trim(src)) + assert.EqualValues(t, dst, quoter.Trim(src)) } } func TestReplace(t *testing.T) { - q := Quoter{'[', ']', AlwaysReserve} + q, err := NewQuoter('[', ']', AlwaysReserve) + assert.NoError(t, err) var kases = []struct { source string expected string