add quote policy methods for engine interface

This commit is contained in:
Lunny Xiao 2019-09-30 22:38:16 +08:00
parent ea1825c2dd
commit 2fc4ecd998
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
10 changed files with 119 additions and 24 deletions

View File

@ -281,7 +281,7 @@ func (db *mssql) SupportInsertMany() bool {
}
func (db *mssql) IsReserved(name string) bool {
_, ok := mssqlReservedWords[name]
_, ok := mssqlReservedWords[strings.ToUpper(name)]
return ok
}

View File

@ -270,7 +270,7 @@ func (db *mysql) SupportInsertMany() bool {
}
func (db *mysql) IsReserved(name string) bool {
_, ok := mysqlReservedWords[name]
_, ok := mysqlReservedWords[strings.ToUpper(name)]
return ok
}

View File

@ -547,7 +547,7 @@ func (db *oracle) SupportInsertMany() bool {
}
func (db *oracle) IsReserved(name string) bool {
_, ok := oracleReservedWords[name]
_, ok := oracleReservedWords[strings.ToUpper(name)]
return ok
}

View File

@ -854,7 +854,7 @@ func (db *postgres) SupportInsertMany() bool {
}
func (db *postgres) IsReserved(name string) bool {
_, ok := postgresReservedWords[name]
_, ok := postgresReservedWords[strings.ToUpper(name)]
return ok
}

View File

@ -194,7 +194,7 @@ func (db *sqlite3) SupportInsertMany() bool {
}
func (db *sqlite3) IsReserved(name string) bool {
_, ok := sqlite3ReservedWords[name]
_, ok := sqlite3ReservedWords[strings.ToUpper(name)]
return ok
}

View File

@ -8,6 +8,7 @@ import (
"fmt"
"strings"
"xorm.io/builder"
"xorm.io/core"
)
@ -26,6 +27,7 @@ type Quoter interface {
Quotes() (byte, byte)
QuotePolicy() QuotePolicy
IsReserved(string) bool
WriteTo(w *builder.BytesWriter, value string) error
}
type quoter struct {
@ -53,6 +55,29 @@ func (q *quoter) IsReserved(value string) bool {
return q.dialect.IsReserved(value)
}
func (q *quoter) needQuote(value string) bool {
return q.quotePolicy == QuoteAddAlways || (q.quotePolicy == QuoteAddReserved && q.IsReserved(value))
}
func (q *quoter) WriteTo(w *builder.BytesWriter, name string) error {
leftQuote, rightQuote := q.Quotes()
needQuote := q.needQuote(name)
if needQuote && name[0] != '`' {
if err := w.WriteByte(leftQuote); err != nil {
return err
}
}
if _, err := w.WriteString(name); err != nil {
return err
}
if needQuote && name[len(name)-1] != '`' {
if err := w.WriteByte(rightQuote); err != nil {
return err
}
}
return nil
}
func quoteColumns(quoter Quoter, columnStr string) string {
columns := strings.Split(columnStr, ",")
return quoteJoin(quoter, columns)
@ -96,13 +121,21 @@ func (engine *Engine) IsReserved(value string) bool {
return engine.dialect.IsReserved(value)
}
// SetTableQuotePolicy set table quote policy
func (engine *Engine) SetTableQuotePolicy(policy QuotePolicy) {
engine.tableQuoter = newQuoter(engine.dialect, policy)
}
// SetColumnQuotePolicy set column quote policy
func (engine *Engine) SetColumnQuotePolicy(policy QuotePolicy) {
engine.colQuoter = newQuoter(engine.dialect, policy)
}
// quoteTo quotes string and writes into the buffer
func quoteTo(quoter Quoter, buf *strings.Builder, value string) {
left, right := quoter.Quotes()
if quoter.QuotePolicy() == QuoteAddAlways {
realQuoteTo(left, right, buf, value)
return
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
if (quoter.QuotePolicy() == QuoteAddAlways) ||
(quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value)) {
realQuoteTo(left, right, buf, value)
return
}

View File

@ -18,3 +18,73 @@ func TestQuoteColumns(t *testing.T) {
assert.EqualValues(t, "[f1], [f2], [f3]", quoteJoinFunc(cols, quoteFunc, ","))
}
func TestChangeQuotePolicy(t *testing.T) {
assert.NoError(t, prepareEngine())
type ChangeQuotePolicy struct {
Id int64
Name string
}
testEngine.SetColumnQuotePolicy(QuoteNoAdd)
assertSync(t, new(ChangeQuotePolicy))
var obj1 = ChangeQuotePolicy{
Name: "obj1",
}
_, err := testEngine.Insert(&obj1)
assert.NoError(t, err)
var obj2 ChangeQuotePolicy
_, err = testEngine.ID(obj1.Id).Get(&obj2)
assert.NoError(t, err)
var objs []ChangeQuotePolicy
err = testEngine.Find(&objs)
assert.NoError(t, err)
_, err = testEngine.ID(obj1.Id).Update(&ChangeQuotePolicy{
Name: "obj2",
})
assert.NoError(t, err)
_, err = testEngine.ID(obj1.Id).Delete(new(ChangeQuotePolicy))
assert.NoError(t, err)
}
func TestChangeQuotePolicy2(t *testing.T) {
assert.NoError(t, prepareEngine())
type ChangeQuotePolicy2 struct {
Id int64
Name string
User string
Index int
}
testEngine.SetColumnQuotePolicy(QuoteAddReserved)
assertSync(t, new(ChangeQuotePolicy2))
var obj1 = ChangeQuotePolicy2{
Name: "obj1",
}
_, err := testEngine.Insert(&obj1)
assert.NoError(t, err)
var obj2 ChangeQuotePolicy2
_, err = testEngine.ID(obj1.Id).Get(&obj2)
assert.NoError(t, err)
var objs []ChangeQuotePolicy2
err = testEngine.Find(&objs)
assert.NoError(t, err)
_, err = testEngine.ID(obj1.Id).Update(&ChangeQuotePolicy2{
Name: "obj2",
})
assert.NoError(t, err)
_, err = testEngine.ID(obj1.Id).Delete(new(ChangeQuotePolicy2))
assert.NoError(t, err)
}

View File

@ -91,6 +91,7 @@ type EngineInterface interface {
NoAutoTime() *Session
Quote(string, bool) string
SetCacher(string, core.Cacher)
SetColumnQuotePolicy(policy QuotePolicy)
SetConnMaxLifetime(time.Duration)
SetDefaultCacher(core.Cacher)
SetLogger(logger core.ILogger)
@ -99,6 +100,7 @@ type EngineInterface interface {
SetMaxOpenConns(int)
SetMaxIdleConns(int)
SetSchema(string)
SetTableQuotePolicy(policy QuotePolicy)
SetTZDatabase(tz *time.Location)
SetTZLocation(tz *time.Location)
ShowExecTime(...bool)

View File

@ -377,7 +377,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 0, err
}
if err := writeStrings(buf, append(colNames, exprs.colNames...), "`", "`"); err != nil {
if err := writeStrings(buf, append(colNames, exprs.colNames...), session.engine.colQuoter); err != nil {
return 0, err
}
@ -735,7 +735,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
return 0, err
}
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
if err := writeStrings(w, append(columns, exprs.colNames...), session.engine.colQuoter); err != nil {
return 0, err
}
@ -821,7 +821,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
return 0, err
}
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
if err := writeStrings(w, append(columns, exprs.colNames...), session.engine.colQuoter); err != nil {
return 0, err
}

View File

@ -145,21 +145,11 @@ func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{}
return nil
}
func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error {
func writeStrings(w *builder.BytesWriter, cols []string, quoter Quoter) error {
for i, colName := range cols {
if len(leftQuote) > 0 && colName[0] != '`' {
if _, err := w.WriteString(leftQuote); err != nil {
return err
}
}
if _, err := w.WriteString(colName); err != nil {
if err := quoter.WriteTo(w, colName); err != nil {
return err
}
if len(rightQuote) > 0 && colName[len(colName)-1] != '`' {
if _, err := w.WriteString(rightQuote); err != nil {
return err
}
}
if i+1 != len(cols) {
if _, err := w.WriteString(","); err != nil {
return err