diff --git a/dialect_mssql.go b/dialect_mssql.go index 29070da2..430811ff 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -281,7 +281,7 @@ func (db *mssql) SupportInsertMany() bool { } func (db *mssql) IsReserved(name string) bool { - _, ok := mssqlReservedWords[name] + _, ok := mssqlReservedWords[strings.ToUpper(name)] return ok } diff --git a/dialect_mysql.go b/dialect_mysql.go index cf1dbb6f..2e393439 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -270,7 +270,7 @@ func (db *mysql) SupportInsertMany() bool { } func (db *mysql) IsReserved(name string) bool { - _, ok := mysqlReservedWords[name] + _, ok := mysqlReservedWords[strings.ToUpper(name)] return ok } diff --git a/dialect_oracle.go b/dialect_oracle.go index 15010ca5..537ee50b 100644 --- a/dialect_oracle.go +++ b/dialect_oracle.go @@ -547,7 +547,7 @@ func (db *oracle) SupportInsertMany() bool { } func (db *oracle) IsReserved(name string) bool { - _, ok := oracleReservedWords[name] + _, ok := oracleReservedWords[strings.ToUpper(name)] return ok } diff --git a/dialect_postgres.go b/dialect_postgres.go index ccef3086..c12fad40 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -854,7 +854,7 @@ func (db *postgres) SupportInsertMany() bool { } func (db *postgres) IsReserved(name string) bool { - _, ok := postgresReservedWords[name] + _, ok := postgresReservedWords[strings.ToUpper(name)] return ok } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 0a290f3c..181fad29 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -194,7 +194,7 @@ func (db *sqlite3) SupportInsertMany() bool { } func (db *sqlite3) IsReserved(name string) bool { - _, ok := sqlite3ReservedWords[name] + _, ok := sqlite3ReservedWords[strings.ToUpper(name)] return ok } diff --git a/engine_quote.go b/engine_quote.go index 1fe04395..30b15f4b 100644 --- a/engine_quote.go +++ b/engine_quote.go @@ -8,6 +8,7 @@ import ( "fmt" "strings" + "xorm.io/builder" "xorm.io/core" ) @@ -26,6 +27,7 @@ type Quoter interface { Quotes() (byte, byte) QuotePolicy() QuotePolicy IsReserved(string) bool + WriteTo(w *builder.BytesWriter, value string) error } type quoter struct { @@ -53,6 +55,29 @@ func (q *quoter) IsReserved(value string) bool { return q.dialect.IsReserved(value) } +func (q *quoter) needQuote(value string) bool { + return q.quotePolicy == QuoteAddAlways || (q.quotePolicy == QuoteAddReserved && q.IsReserved(value)) +} + +func (q *quoter) WriteTo(w *builder.BytesWriter, name string) error { + leftQuote, rightQuote := q.Quotes() + needQuote := q.needQuote(name) + if needQuote && name[0] != '`' { + if err := w.WriteByte(leftQuote); err != nil { + return err + } + } + if _, err := w.WriteString(name); err != nil { + return err + } + if needQuote && name[len(name)-1] != '`' { + if err := w.WriteByte(rightQuote); err != nil { + return err + } + } + return nil +} + func quoteColumns(quoter Quoter, columnStr string) string { columns := strings.Split(columnStr, ",") return quoteJoin(quoter, columns) @@ -96,13 +121,21 @@ func (engine *Engine) IsReserved(value string) bool { return engine.dialect.IsReserved(value) } +// SetTableQuotePolicy set table quote policy +func (engine *Engine) SetTableQuotePolicy(policy QuotePolicy) { + engine.tableQuoter = newQuoter(engine.dialect, policy) +} + +// SetColumnQuotePolicy set column quote policy +func (engine *Engine) SetColumnQuotePolicy(policy QuotePolicy) { + engine.colQuoter = newQuoter(engine.dialect, policy) +} + // quoteTo quotes string and writes into the buffer func quoteTo(quoter Quoter, buf *strings.Builder, value string) { left, right := quoter.Quotes() - if quoter.QuotePolicy() == QuoteAddAlways { - realQuoteTo(left, right, buf, value) - return - } else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) { + if (quoter.QuotePolicy() == QuoteAddAlways) || + (quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value)) { realQuoteTo(left, right, buf, value) return } diff --git a/engine_quote_test.go b/engine_quote_test.go index db38cb8b..85a498eb 100644 --- a/engine_quote_test.go +++ b/engine_quote_test.go @@ -18,3 +18,73 @@ func TestQuoteColumns(t *testing.T) { assert.EqualValues(t, "[f1], [f2], [f3]", quoteJoinFunc(cols, quoteFunc, ",")) } + +func TestChangeQuotePolicy(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type ChangeQuotePolicy struct { + Id int64 + Name string + } + + testEngine.SetColumnQuotePolicy(QuoteNoAdd) + assertSync(t, new(ChangeQuotePolicy)) + + var obj1 = ChangeQuotePolicy{ + Name: "obj1", + } + _, err := testEngine.Insert(&obj1) + assert.NoError(t, err) + + var obj2 ChangeQuotePolicy + _, err = testEngine.ID(obj1.Id).Get(&obj2) + assert.NoError(t, err) + + var objs []ChangeQuotePolicy + err = testEngine.Find(&objs) + assert.NoError(t, err) + + _, err = testEngine.ID(obj1.Id).Update(&ChangeQuotePolicy{ + Name: "obj2", + }) + assert.NoError(t, err) + + _, err = testEngine.ID(obj1.Id).Delete(new(ChangeQuotePolicy)) + assert.NoError(t, err) +} + +func TestChangeQuotePolicy2(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type ChangeQuotePolicy2 struct { + Id int64 + Name string + User string + Index int + } + + testEngine.SetColumnQuotePolicy(QuoteAddReserved) + assertSync(t, new(ChangeQuotePolicy2)) + + var obj1 = ChangeQuotePolicy2{ + Name: "obj1", + } + _, err := testEngine.Insert(&obj1) + assert.NoError(t, err) + + var obj2 ChangeQuotePolicy2 + _, err = testEngine.ID(obj1.Id).Get(&obj2) + assert.NoError(t, err) + + var objs []ChangeQuotePolicy2 + err = testEngine.Find(&objs) + assert.NoError(t, err) + + _, err = testEngine.ID(obj1.Id).Update(&ChangeQuotePolicy2{ + Name: "obj2", + }) + assert.NoError(t, err) + + _, err = testEngine.ID(obj1.Id).Delete(new(ChangeQuotePolicy2)) + assert.NoError(t, err) +} diff --git a/interface.go b/interface.go index 473f9d1a..3da9bd69 100644 --- a/interface.go +++ b/interface.go @@ -91,6 +91,7 @@ type EngineInterface interface { NoAutoTime() *Session Quote(string, bool) string SetCacher(string, core.Cacher) + SetColumnQuotePolicy(policy QuotePolicy) SetConnMaxLifetime(time.Duration) SetDefaultCacher(core.Cacher) SetLogger(logger core.ILogger) @@ -99,6 +100,7 @@ type EngineInterface interface { SetMaxOpenConns(int) SetMaxIdleConns(int) SetSchema(string) + SetTableQuotePolicy(policy QuotePolicy) SetTZDatabase(tz *time.Location) SetTZLocation(tz *time.Location) ShowExecTime(...bool) diff --git a/session_insert.go b/session_insert.go index f82b8fdd..1620db70 100644 --- a/session_insert.go +++ b/session_insert.go @@ -377,7 +377,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 0, err } - if err := writeStrings(buf, append(colNames, exprs.colNames...), "`", "`"); err != nil { + if err := writeStrings(buf, append(colNames, exprs.colNames...), session.engine.colQuoter); err != nil { return 0, err } @@ -735,7 +735,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err return 0, err } - if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { + if err := writeStrings(w, append(columns, exprs.colNames...), session.engine.colQuoter); err != nil { return 0, err } @@ -821,7 +821,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { return 0, err } - if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { + if err := writeStrings(w, append(columns, exprs.colNames...), session.engine.colQuoter); err != nil { return 0, err } diff --git a/statement_args.go b/statement_args.go index 310f24d6..e5c85c15 100644 --- a/statement_args.go +++ b/statement_args.go @@ -145,21 +145,11 @@ func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{} return nil } -func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error { +func writeStrings(w *builder.BytesWriter, cols []string, quoter Quoter) error { for i, colName := range cols { - if len(leftQuote) > 0 && colName[0] != '`' { - if _, err := w.WriteString(leftQuote); err != nil { - return err - } - } - if _, err := w.WriteString(colName); err != nil { + if err := quoter.WriteTo(w, colName); err != nil { return err } - if len(rightQuote) > 0 && colName[len(colName)-1] != '`' { - if _, err := w.WriteString(rightQuote); err != nil { - return err - } - } if i+1 != len(cols) { if _, err := w.WriteString(","); err != nil { return err