From 65822bd505b2c408a3aeb20cc3495925b84e8627 Mon Sep 17 00:00:00 2001 From: zhanghelong Date: Tue, 16 Jul 2019 08:03:49 +0000 Subject: [PATCH] Remove func QuoteStr() in interface Dialect (#51) --- column.go | 4 ++-- dialect.go | 18 +++++++++++++----- filter.go | 20 ++++++++++++++++++-- filter_test.go | 25 +++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 filter_test.go diff --git a/column.go b/column.go index 40d8f926..8eaa5445 100644 --- a/column.go +++ b/column.go @@ -73,7 +73,7 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable // String generate column description string according dialect func (col *Column) String(d Dialect) string { - sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " + sql := d.Quote(col.Name) + " " sql += d.SqlType(col) + " " @@ -101,7 +101,7 @@ func (col *Column) String(d Dialect) string { // StringNoPk generate column description string according dialect without primary keys func (col *Column) StringNoPk(d Dialect) string { - sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " + sql := d.Quote(col.Name) sql += d.SqlType(col) + " " diff --git a/dialect.go b/dialect.go index 9d214f31..0a0c2d0d 100644 --- a/dialect.go +++ b/dialect.go @@ -40,9 +40,10 @@ type Dialect interface { DriverName() string DataSourceName() string - QuoteStr() string IsReserved(string) bool Quote(string) string + // Deprecated: use Quote(string) string instead + QuoteStr() string AndStr() string OrStr() string EqStr() string @@ -70,8 +71,8 @@ type Dialect interface { ForUpdateSql(query string) string - //CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error - //MustDropTable(tableName string) error + // CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error + // MustDropTable(tableName string) error GetColumns(tableName string) ([]string, map[string]*Column, error) GetTables() ([]*Table, error) @@ -173,8 +174,15 @@ func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) { } func (db *Base) IsColumnExist(tableName, colName string) (bool, error) { - query := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" - query = strings.Replace(query, "`", db.dialect.QuoteStr(), -1) + query := fmt.Sprintf( + "SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?", + db.dialect.Quote("COLUMN_NAME"), + db.dialect.Quote("INFORMATION_SCHEMA"), + db.dialect.Quote("COLUMNS"), + db.dialect.Quote("TABLE_SCHEMA"), + db.dialect.Quote("TABLE_NAME"), + db.dialect.Quote("COLUMN_NAME"), + ) return db.HasRecords(query, db.DbName, tableName, colName) } diff --git a/filter.go b/filter.go index 6aeed424..aeea1223 100644 --- a/filter.go +++ b/filter.go @@ -19,7 +19,23 @@ type QuoteFilter struct { } func (s *QuoteFilter) Do(sql string, dialect Dialect, table *Table) string { - return strings.Replace(sql, "`", dialect.QuoteStr(), -1) + dummy := dialect.Quote("") + if len(dummy) != 2 { + return sql + } + prefix, suffix := dummy[0], dummy[1] + raw := []byte(sql) + for i, cnt := 0, 0; i < len(raw); i = i + 1 { + if raw[i] == '`' { + if cnt%2 == 0 { + raw[i] = prefix + } else { + raw[i] = suffix + } + cnt++ + } + } + return string(raw) } // IdFilter filter SQL replace (id) to primary key column name @@ -35,7 +51,7 @@ func NewQuoter(dialect Dialect) *Quoter { } func (q *Quoter) Quote(content string) string { - return q.dialect.QuoteStr() + content + q.dialect.QuoteStr() + return q.dialect.Quote(content) } func (i *IdFilter) Do(sql string, dialect Dialect, table *Table) string { diff --git a/filter_test.go b/filter_test.go new file mode 100644 index 00000000..c9d836b9 --- /dev/null +++ b/filter_test.go @@ -0,0 +1,25 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type quoterOnly struct { + Dialect +} + +func (q *quoterOnly) Quote(item string) string { + return "[" + item + "]" +} + +func TestQuoteFilter_Do(t *testing.T) { + f := QuoteFilter{} + sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" + res := f.Do(sql, new(quoterOnly), nil) + assert.EqualValues(t, + "SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", + res, + ) +}