From 7c240fc21ee6d2a73a9ba884833ca4c95f073c93 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 1 Mar 2020 16:16:44 +0800 Subject: [PATCH] Improve quote policy --- dialects/dialect.go | 10 +- dialects/filter.go | 5 +- dialects/mssql.go | 23 +++- dialects/mysql.go | 28 ++++- dialects/oracle.go | 28 ++++- dialects/postgres.go | 26 +++- dialects/quote.go | 15 +++ dialects/sqlite3.go | 26 +++- engine.go | 4 + schemas/quote.go | 269 ++++++++++++++++++++---------------------- schemas/quote_test.go | 133 ++++++++++++++++----- 11 files changed, 364 insertions(+), 203 deletions(-) create mode 100644 dialects/quote.go diff --git a/dialects/dialect.go b/dialects/dialect.go index c591cc7b..d89f1ebe 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -41,6 +41,7 @@ type Dialect interface { IsReserved(string) bool Quoter() schemas.Quoter + SetQuotePolicy(quotePolicy QuotePolicy) AutoIncrStr() string @@ -79,6 +80,11 @@ type Base struct { db *core.DB dialect Dialect uri *URI + quoter schemas.Quoter +} + +func (b *Base) Quoter() schemas.Quoter { + return b.quoter } func (b *Base) DB() *core.DB { @@ -210,7 +216,7 @@ func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { idxName = index.XName(tableName) return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, quoter.Quote(idxName), quoter.Quote(tableName), - quoter.Quote(strings.Join(index.Cols, quoter.ReverseQuote(",")))) + quoter.Join(index.Cols, ",")) } func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { @@ -258,7 +264,7 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char if len(pkList) > 1 { sql += "PRIMARY KEY ( " - sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(","))) + sql += quoter.Join(pkList, ",") sql += " ), " } diff --git a/dialects/filter.go b/dialects/filter.go index 0f9b4107..0095fdb5 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -26,14 +26,13 @@ func (s *QuoteFilter) Do(sql string) string { return sql } - prefix, suffix := s.quoter[0][0], s.quoter[1][0] raw := []byte(sql) for i, cnt := 0, 0; i < len(raw); i = i + 1 { if raw[i] == '`' { if cnt%2 == 0 { - raw[i] = prefix + raw[i] = s.quoter.Prefix } else { - raw[i] = suffix + raw[i] = s.quoter.Suffix } cnt++ } diff --git a/dialects/mssql.go b/dialects/mssql.go index 558abdfc..8d092886 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -204,13 +204,17 @@ var ( "EXIT": true, "PROC": true, } + + mssqlQuoter = schemas.Quoter{'[', ']', schemas.AlwaysReverse} ) type mssql struct { Base + quoter schemas.Quoter } func (db *mssql) Init(d *core.DB, uri *URI) error { + db.quoter = mssqlQuoter return db.Base.Init(d, db, uri) } @@ -283,12 +287,25 @@ func (db *mssql) SupportInsertMany() bool { } func (db *mssql) IsReserved(name string) bool { - _, ok := mssqlReservedWords[name] + _, ok := mssqlReservedWords[strings.ToUpper(name)] return ok } -func (db *mssql) Quoter() schemas.Quoter { - return schemas.Quoter{"[", "]"} +func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = mssqlQuoter + q.IsReverse = schemas.AlwaysNoReverse + db.quoter = q + case QuotePolicyReserved: + var q = mssqlQuoter + q.IsReverse = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = mssqlQuoter + } } func (db *mssql) SupportEngine() bool { diff --git a/dialects/mysql.go b/dialects/mysql.go index 939a7cf1..cf768fd6 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -161,6 +161,8 @@ var ( "YEAR_MONTH": true, "ZEROFILL": true, } + + mysqlQuoter = schemas.Quoter{'`', '`', schemas.AlwaysReverse} ) type mysql struct { @@ -178,6 +180,7 @@ type mysql struct { } func (db *mysql) Init(d *core.DB, uri *URI) error { + db.quoter = mysqlQuoter return db.Base.Init(d, db, uri) } @@ -272,14 +275,10 @@ func (db *mysql) SupportInsertMany() bool { } func (db *mysql) IsReserved(name string) bool { - _, ok := mysqlReservedWords[name] + _, ok := mysqlReservedWords[strings.ToUpper(name)] return ok } -func (db *mysql) Quoter() schemas.Quoter { - return schemas.Quoter{"`", "`"} -} - func (db *mysql) SupportEngine() bool { return true } @@ -458,6 +457,23 @@ func (db *mysql) GetTables(ctx context.Context) ([]*schemas.Table, error) { return tables, nil } +func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = mysqlQuoter + q.IsReverse = schemas.AlwaysNoReverse + db.quoter = q + case QuotePolicyReserved: + var q = mysqlQuoter + q.IsReverse = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = mysqlQuoter + } +} + func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{db.uri.DBName, tableName} s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" @@ -538,7 +554,7 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch if len(pkList) > 1 { sql += "PRIMARY KEY ( " - sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(","))) + sql += quoter.Join(pkList, ",") sql += " ), " } diff --git a/dialects/oracle.go b/dialects/oracle.go index 4a8162ac..204623b7 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -498,6 +498,8 @@ var ( "YEAR": true, "ZONE": true, } + + oracleQuoter = schemas.Quoter{'[', ']', schemas.AlwaysReverse} ) type oracle struct { @@ -505,6 +507,7 @@ type oracle struct { } func (db *oracle) Init(d *core.DB, uri *URI) error { + db.quoter = oracleQuoter return db.Base.Init(d, db, uri) } @@ -549,14 +552,10 @@ func (db *oracle) SupportInsertMany() bool { } func (db *oracle) IsReserved(name string) bool { - _, ok := oracleReservedWords[name] + _, ok := oracleReservedWords[strings.ToUpper(name)] return ok } -func (db *oracle) Quoter() schemas.Quoter { - return schemas.Quoter{"\"", "\""} -} - func (db *oracle) SupportEngine() bool { return false } @@ -601,7 +600,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c if len(pkList) > 0 { sql += "PRIMARY KEY ( " - sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(","))) + sql += quoter.Join(pkList, ",") sql += " ), " } @@ -620,6 +619,23 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c return sql } +func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = oracleQuoter + q.IsReverse = schemas.AlwaysNoReverse + db.quoter = q + case QuotePolicyReserved: + var q = oracleQuoter + q.IsReverse = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = oracleQuoter + } +} + func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { args := []interface{}{tableName, idxName} return `SELECT INDEX_NAME FROM USER_INDEXES ` + diff --git a/dialects/postgres.go b/dialects/postgres.go index f92202cd..4d5c29bb 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -766,6 +766,8 @@ var ( "YES": true, "ZONE": true, } + + postgresQuoter = schemas.Quoter{'"', '"', schemas.AlwaysReverse} ) const postgresPublicSchema = "public" @@ -775,6 +777,7 @@ type postgres struct { } func (db *postgres) Init(d *core.DB, uri *URI) error { + db.quoter = postgresQuoter err := db.Base.Init(d, db, uri) if err != nil { return err @@ -785,6 +788,23 @@ func (db *postgres) Init(d *core.DB, uri *URI) error { return nil } +func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = postgresQuoter + q.IsReverse = schemas.AlwaysNoReverse + db.quoter = q + case QuotePolicyReserved: + var q = postgresQuoter + q.IsReverse = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = postgresQuoter + } +} + func (db *postgres) DefaultSchema() string { return postgresPublicSchema } @@ -857,14 +877,10 @@ func (db *postgres) SupportInsertMany() bool { } func (db *postgres) IsReserved(name string) bool { - _, ok := postgresReservedWords[name] + _, ok := postgresReservedWords[strings.ToUpper(name)] return ok } -func (db *postgres) Quoter() schemas.Quoter { - return schemas.Quoter{`"`, `"`} -} - func (db *postgres) AutoIncrStr() string { return "" } diff --git a/dialects/quote.go b/dialects/quote.go new file mode 100644 index 00000000..da4e0dd6 --- /dev/null +++ b/dialects/quote.go @@ -0,0 +1,15 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +// QuotePolicy describes quote handle policy +type QuotePolicy int + +// All QuotePolicies +const ( + QuotePolicyAlways QuotePolicy = iota + QuotePolicyNone + QuotePolicyReserved +) diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 39138b13..bee99942 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -143,6 +143,8 @@ var ( "WITH": true, "WITHOUT": true, } + + sqlite3Quoter = schemas.Quoter{'`', '`', schemas.AlwaysReverse} ) type sqlite3 struct { @@ -150,9 +152,27 @@ type sqlite3 struct { } func (db *sqlite3) Init(d *core.DB, uri *URI) error { + db.quoter = sqlite3Quoter return db.Base.Init(d, db, uri) } +func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = sqlite3Quoter + q.IsReverse = schemas.AlwaysNoReverse + db.quoter = q + case QuotePolicyReserved: + var q = sqlite3Quoter + q.IsReverse = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = sqlite3Quoter + } +} + func (db *sqlite3) SQLType(c *schemas.Column) string { switch t := c.SQLType.Name; t { case schemas.Bool: @@ -196,14 +216,10 @@ func (db *sqlite3) SupportInsertMany() bool { } func (db *sqlite3) IsReserved(name string) bool { - _, ok := sqlite3ReservedWords[name] + _, ok := sqlite3ReservedWords[strings.ToUpper(name)] return ok } -func (db *sqlite3) Quoter() schemas.Quoter { - return schemas.Quoter{"`", "`"} -} - func (db *sqlite3) AutoIncrStr() string { return "AUTOINCREMENT" } diff --git a/engine.go b/engine.go index cc8a74a0..c657cd1f 100644 --- a/engine.go +++ b/engine.go @@ -54,6 +54,10 @@ func (engine *Engine) GetCacher(tableName string) caches.Cacher { return engine.cacherMgr.GetCacher(tableName) } +func (engine *Engine) SetQuotePolicy(quotePolicy dialects.QuotePolicy) { + engine.dialect.SetQuotePolicy(quotePolicy) +} + // BufferSize sets buffer size for iterate func (engine *Engine) BufferSize(size int) *Session { session := engine.NewSession() diff --git a/schemas/quote.go b/schemas/quote.go index 736b774a..6c521d0f 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -8,14 +8,29 @@ import ( "strings" ) -// Quoter represents two quote characters -type Quoter [2]string +// Quoter represents a quoter to the SQL table name and column name +type Quoter struct { + Prefix byte + Suffix byte + IsReverse func(string) bool +} -// CommonQuoter represetns a common quoter -var CommonQuoter = Quoter{"`", "`"} +var ( + // AlwaysFalseReverse always think it's not a reverse word + AlwaysNoReverse = func(string) bool { return false } + + // AlwaysReverse always reverse the word + AlwaysReverse = func(string) bool { return true } + + // CommanQuoteMark represnets the common quote mark + CommanQuoteMark byte = '`' + + // CommonQuoter represetns a common quoter + CommonQuoter = Quoter{CommanQuoteMark, CommanQuoteMark, AlwaysReverse} +) func (q Quoter) IsEmpty() bool { - return q[0] == "" && q[1] == "" + return q.Prefix == 0 && q.Suffix == 0 } func (q Quoter) Quote(s string) string { @@ -24,42 +39,6 @@ func (q Quoter) Quote(s string) string { return buf.String() } -func (q Quoter) Replace(sql string, newQuoter Quoter) string { - if q.IsEmpty() { - return sql - } - - if newQuoter.IsEmpty() { - var buf strings.Builder - for i := 0; i < len(sql); i = i + 1 { - if sql[i] != q[0][0] && sql[i] != q[1][0] { - _ = buf.WriteByte(sql[i]) - } - } - return buf.String() - } - - prefix, suffix := newQuoter[0][0], newQuoter[1][0] - var buf strings.Builder - for i, cnt := 0, 0; i < len(sql); i = i + 1 { - if cnt == 0 && sql[i] == q[0][0] { - _ = buf.WriteByte(prefix) - cnt = 1 - } else if cnt == 1 && sql[i] == q[1][0] { - _ = buf.WriteByte(suffix) - cnt = 0 - } else { - _ = buf.WriteByte(sql[i]) - } - } - return buf.String() -} - -func (q Quoter) ReverseQuote(s string) string { - reverseQuoter := Quoter{q[1], q[0]} - return reverseQuoter.Quote(s) -} - // Trim removes quotes from s func (q Quoter) Trim(s string) string { if len(s) < 2 { @@ -69,10 +48,10 @@ func (q Quoter) Trim(s string) string { var buf strings.Builder for i := 0; i < len(s); i++ { switch { - case i == 0 && s[i:i+1] == q[0]: - case i == len(s)-1 && s[i:i+1] == q[1]: - case s[i:i+1] == q[1] && s[i+1] == '.': - case s[i:i+1] == q[0] && s[i-1] == '.': + case i == 0 && s[i] == q.Prefix: + case i == len(s)-1 && s[i] == q.Suffix: + case s[i] == q.Suffix && s[i+1] == '.': + case s[i] == q.Prefix && s[i-1] == '.': default: buf.WriteByte(s[i]) } @@ -81,31 +60,8 @@ func (q Quoter) Trim(s string) string { } func (q Quoter) Join(a []string, sep string) string { - switch len(a) { - case 0: - return "" - case 1: - return a[0] - } - n := len(sep) * (len(a) - 1) - for i := 0; i < len(a); i++ { - n += len(a[i]) - } - var b strings.Builder - b.Grow(n) - for i, s := range a { - if i > 0 { - b.WriteString(sep) - } - if q[0] != "" && s != "*" { - b.WriteString(q[0]) - } - b.WriteString(strings.TrimSpace(s)) - if q[1] != "" && s != "*" { - b.WriteString(q[1]) - } - } + q.JoinWrite(&b, a, sep) return b.String() } @@ -126,23 +82,113 @@ func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error { return err } } - if q[0] != "" && s != "*" && s[0] != '`' { - if _, err := b.WriteString(q[0]); err != nil { - return err - } - } - if _, err := b.WriteString(strings.TrimSpace(s)); err != nil { - return err - } - if q[1] != "" && s != "*" && s[0] != '`' { - if _, err := b.WriteString(q[1]); err != nil { - return err - } + if s != "*" { + q.QuoteTo(b, strings.TrimSpace(s)) } } return nil } +func findWord(v string, start int) int { + for j := start; j < len(v); j++ { + switch v[j] { + case '.', ' ': + return j + } + } + return len(v) +} + +func findStart(value string, start int) int { + if value[start] == '.' { + return start + 1 + } + if value[start] != ' ' { + return start + } + + var k int + for j := start; j < len(value); j++ { + if value[j] != ' ' { + k = j + break + } + } + if k-1 == len(value) { + return len(value) + } + if (value[k] == 'A' || value[k] == 'a') && (value[k+1] == 'S' || value[k+1] == 's') { + k = k + 2 + } + + for j := k; j < len(value); j++ { + if value[j] != ' ' { + return j + } + } + return len(value) +} + +func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error { + var realWord = word + if (word[0] == CommanQuoteMark && word[len(word)-1] == CommanQuoteMark) || + (word[0] == q.Prefix && word[len(word)-1] == q.Suffix) { + realWord = word[1 : len(word)-1] + } + + if q.IsEmpty() { + _, err := buf.WriteString(realWord) + return err + } + + isReverse := q.IsReverse(realWord) + if isReverse { + if err := buf.WriteByte(q.Prefix); err != nil { + return err + } + } + if _, err := buf.WriteString(realWord); err != nil { + return err + } + if isReverse { + return buf.WriteByte(q.Suffix) + } + + return nil +} + +// QuoteTo quotes the table or column names. i.e. if the quotes are [ and ] +// name -> [name] +// `name` -> [name] +// [name] -> [name] +// schema.name -> [schema].[name] +// `schema`.`name` -> [schema].[name] +// `schema`.name -> [schema].[name] +// schema.`name` -> [schema].[name] +// [schema].name -> [schema].[name] +// schema.[name] -> [schema].[name] +// name AS a -> [name] AS a +// schema.name AS a -> [schema].[name] AS a +func (q Quoter) QuoteTo(buf *strings.Builder, value string) error { + var i int + for i < len(value) { + start := findStart(value, i) + if start > i { + if _, err := buf.WriteString(value[i:start]); err != nil { + return err + } + } + var nextEnd = findWord(value, start) + + if err := q.quoteWordTo(buf, value[start:nextEnd]); err != nil { + return err + } + i = nextEnd + } + return nil +} + +// Strings quotes a slice of string func (q Quoter) Strings(s []string) []string { var res = make([]string, 0, len(s)) for _, a := range s { @@ -150,64 +196,3 @@ func (q Quoter) Strings(s []string) []string { } return res } - -func (q Quoter) QuoteTo(buf *strings.Builder, value string) { - if q.IsEmpty() { - buf.WriteString(value) - return - } - - prefix, suffix := q[0][0], q[1][0] - lastCh := 0 // 0 prefix, 1 char, 2 suffix - i := 0 - for i < len(value) { - // start of a token; might be already quoted - if value[i] == '.' { - _ = buf.WriteByte('.') - lastCh = 1 - i++ - } else if value[i] == prefix || value[i] == '`' { - // Has quotes; skip/normalize `name` to prefix+name+sufix - var ch byte - if value[i] == prefix { - ch = suffix - } else { - ch = '`' - } - _ = buf.WriteByte(prefix) - i++ - lastCh = 0 - for ; i < len(value) && value[i] != ch && value[i] != ' '; i++ { - _ = buf.WriteByte(value[i]) - lastCh = 1 - } - _ = buf.WriteByte(suffix) - lastCh = 2 - i++ - } else if value[i] == ' ' { - if lastCh != 2 { - _ = buf.WriteByte(suffix) - lastCh = 2 - } - - // a AS b or a b - for ; i < len(value); i++ { - if value[i] != ' ' && value[i-1] == ' ' && (len(value) > i+1 && !strings.EqualFold(value[i:i+2], "AS")) { - break - } - - _ = buf.WriteByte(value[i]) - lastCh = 1 - } - } else { - // Requires quotes - _ = buf.WriteByte(prefix) - for ; i < len(value) && value[i] != '.' && value[i] != ' '; i++ { - _ = buf.WriteByte(value[i]) - lastCh = 1 - } - _ = buf.WriteByte(suffix) - lastCh = 2 - } - } -} diff --git a/schemas/quote_test.go b/schemas/quote_test.go index 24739377..730fa5f4 100644 --- a/schemas/quote_test.go +++ b/schemas/quote_test.go @@ -11,54 +11,125 @@ import ( "github.com/stretchr/testify/assert" ) -func TestQuoteTo(t *testing.T) { - var quoter = Quoter{"[", "]"} +func TestAlwaysQuoteTo(t *testing.T) { + var ( + quoter = Quoter{'[', ']', AlwaysReverse} + kases = []struct { + expected string + value string + }{ + {"[mytable]", "mytable"}, + {"[mytable]", "`mytable`"}, + {"[mytable]", `[mytable]`}, + {`["mytable"]`, `"mytable"`}, + {"[myschema].[mytable]", "myschema.mytable"}, + {"[myschema].[mytable]", "`myschema`.mytable"}, + {"[myschema].[mytable]", "myschema.`mytable`"}, + {"[myschema].[mytable]", "`myschema`.`mytable`"}, + {"[myschema].[mytable]", `[myschema].mytable`}, + {"[myschema].[mytable]", `myschema.[mytable]`}, + {"[myschema].[mytable]", `[myschema].[mytable]`}, + {`["myschema].[mytable"]`, `"myschema.mytable"`}, + {"[message_user] AS [sender]", "`message_user` AS `sender`"}, + {"[myschema].[mytable] AS [table]", "myschema.mytable AS table"}, + } + ) - test := func(t *testing.T, expected string, value string) { - buf := &strings.Builder{} - quoter.QuoteTo(buf, value) - assert.EqualValues(t, expected, buf.String()) + for _, v := range kases { + t.Run(v.value, func(t *testing.T) { + buf := &strings.Builder{} + quoter.QuoteTo(buf, v.value) + assert.EqualValues(t, v.expected, buf.String()) + }) } +} - test(t, "[mytable]", "mytable") - test(t, "[mytable]", "`mytable`") - test(t, "[mytable]", `[mytable]`) +func TestReversedQuoteTo(t *testing.T) { + var ( + quoter = Quoter{'[', ']', func(s string) bool { + if s == "mytable" { + return true + } + return false + }} + kases = []struct { + expected string + value string + }{ + {"[mytable]", "mytable"}, + {"[mytable]", "`mytable`"}, + {"[mytable]", `[mytable]`}, + {`"mytable"`, `"mytable"`}, + {"myschema.[mytable]", "myschema.mytable"}, + {"myschema.[mytable]", "`myschema`.mytable"}, + {"myschema.[mytable]", "myschema.`mytable`"}, + {"myschema.[mytable]", "`myschema`.`mytable`"}, + {"myschema.[mytable]", `[myschema].mytable`}, + {"myschema.[mytable]", `myschema.[mytable]`}, + {"myschema.[mytable]", `[myschema].[mytable]`}, + {`"myschema.mytable"`, `"myschema.mytable"`}, + {"message_user AS sender", "`message_user` AS `sender`"}, + {"myschema.[mytable] AS table", "myschema.mytable AS table"}, + } + ) - test(t, `["mytable"]`, `"mytable"`) + for _, v := range kases { + t.Run(v.value, func(t *testing.T) { + buf := &strings.Builder{} + quoter.QuoteTo(buf, v.value) + assert.EqualValues(t, v.expected, buf.String()) + }) + } +} - test(t, "[myschema].[mytable]", "myschema.mytable") - test(t, "[myschema].[mytable]", "`myschema`.mytable") - test(t, "[myschema].[mytable]", "myschema.`mytable`") - test(t, "[myschema].[mytable]", "`myschema`.`mytable`") - test(t, "[myschema].[mytable]", `[myschema].mytable`) - test(t, "[myschema].[mytable]", `myschema.[mytable]`) - test(t, "[myschema].[mytable]", `[myschema].[mytable]`) +func TestNoQuoteTo(t *testing.T) { + var ( + quoter = Quoter{'[', ']', AlwaysNoReverse} + kases = []struct { + expected string + value string + }{ + {"mytable", "mytable"}, + {"mytable", "`mytable`"}, + {"mytable", `[mytable]`}, + {`"mytable"`, `"mytable"`}, + {"myschema.mytable", "myschema.mytable"}, + {"myschema.mytable", "`myschema`.mytable"}, + {"myschema.mytable", "myschema.`mytable`"}, + {"myschema.mytable", "`myschema`.`mytable`"}, + {"myschema.mytable", `[myschema].mytable`}, + {"myschema.mytable", `myschema.[mytable]`}, + {"myschema.mytable", `[myschema].[mytable]`}, + {`"myschema.mytable"`, `"myschema.mytable"`}, + {"message_user AS sender", "`message_user` AS `sender`"}, + {"myschema.mytable AS table", "myschema.mytable AS table"}, + } + ) - test(t, `["myschema].[mytable"]`, `"myschema.mytable"`) - - test(t, "[message_user] AS [sender]", "`message_user` AS `sender`") - - assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ",")) - - buf := &strings.Builder{} - quoter = Quoter{"", ""} - quoter.QuoteTo(buf, "noquote") - assert.EqualValues(t, "noquote", buf.String()) + for _, v := range kases { + t.Run(v.value, func(t *testing.T) { + buf := &strings.Builder{} + quoter.QuoteTo(buf, v.value) + assert.EqualValues(t, v.expected, buf.String()) + }) + } } func TestJoin(t *testing.T) { cols := []string{"f1", "f2", "f3"} - quoter := Quoter{"[", "]"} + quoter := Quoter{'[', ']', AlwaysReverse} + + assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ",")) assert.EqualValues(t, "[f1], [f2], [f3]", quoter.Join(cols, ", ")) - quoter = Quoter{"", ""} + quoter.IsReverse = AlwaysNoReverse assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", ")) } func TestStrings(t *testing.T) { cols := []string{"f1", "f2", "t3.f3"} - quoter := Quoter{"[", "]"} + quoter := Quoter{'[', ']', AlwaysReverse} quotedCols := quoter.Strings(cols) assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols) @@ -72,6 +143,6 @@ func TestTrim(t *testing.T) { for src, dst := range kases { assert.EqualValues(t, src, CommonQuoter.Trim(src)) - assert.EqualValues(t, dst, Quoter{"[", "]"}.Trim(src)) + assert.EqualValues(t, dst, Quoter{'[', ']', AlwaysReverse}.Trim(src)) } }