From ea1825c2ddf766f148d8c5f2df70048ad897d889 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 30 Sep 2019 16:29:45 +0800 Subject: [PATCH] fix quote policy --- engine.go | 10 +++---- engine_quote.go | 76 ++++++++++------------------------------------- session_find.go | 4 +-- session_insert.go | 8 ++--- session_query.go | 4 +-- session_update.go | 2 +- statement.go | 14 ++++----- statement_test.go | 7 +++-- xorm.go | 2 ++ 9 files changed, 44 insertions(+), 83 deletions(-) diff --git a/engine.go b/engine.go index f0649697..4ab69a7b 100644 --- a/engine.go +++ b/engine.go @@ -56,8 +56,8 @@ type Engine struct { defaultContext context.Context - quotePolicy QuotePolicy - quoteMode QuoteMode + colQuoter Quoter + tableQuoter Quoter } func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { @@ -419,7 +419,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D return err } - quoter := newQuoter(dialect, engine.quoteMode, engine.quotePolicy) + colQuoter := newQuoter(dialect, engine.colQuoter.QuotePolicy()) for i, table := range tables { if i > 0 { @@ -440,8 +440,8 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D } cols := table.ColumnsSeq() - colNames := quoteJoin(engine, cols) - destColNames := quoteJoin(quoter, cols) + colNames := quoteJoin(engine.colQuoter, cols) + destColNames := quoteJoin(colQuoter, cols) rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.quote(table.Name, false)) if err != nil { diff --git a/engine_quote.go b/engine_quote.go index 1d0063fe..1fe04395 100644 --- a/engine_quote.go +++ b/engine_quote.go @@ -21,34 +21,21 @@ const ( QuoteAddReserved ) -// QuoteMode quote on which types -type QuoteMode int - -// All QuoteModes -const ( - QuoteTableAndColumns QuoteMode = iota - QuoteTableOnly - QuoteColumnsOnly -) - // Quoter represents an object has Quote method type Quoter interface { Quotes() (byte, byte) QuotePolicy() QuotePolicy - QuoteMode() QuoteMode IsReserved(string) bool } type quoter struct { dialect core.Dialect - quoteMode QuoteMode quotePolicy QuotePolicy } -func newQuoter(dialect core.Dialect, quoteMode QuoteMode, quotePolicy QuotePolicy) Quoter { +func newQuoter(dialect core.Dialect, quotePolicy QuotePolicy) Quoter { return "er{ dialect: dialect, - quoteMode: quoteMode, quotePolicy: quotePolicy, } } @@ -62,10 +49,6 @@ func (q *quoter) QuotePolicy() QuotePolicy { return q.quotePolicy } -func (q *quoter) QuoteMode() QuoteMode { - return q.quoteMode -} - func (q *quoter) IsReserved(value string) bool { return q.dialect.IsReserved(value) } @@ -77,21 +60,24 @@ func quoteColumns(quoter Quoter, columnStr string) string { func quoteJoin(quoter Quoter, columns []string) string { for i := 0; i < len(columns); i++ { - columns[i] = quote(quoter, columns[i], true) + columns[i] = quote(quoter, columns[i]) } return strings.Join(columns, ",") } // quote Use QuoteStr quote the string sql -func quote(quoter Quoter, value string, isColumn bool) string { +func quote(quoter Quoter, value string) string { buf := strings.Builder{} - quoteTo(quoter, &buf, value, isColumn) + quoteTo(quoter, &buf, value) return buf.String() } // Quote add quotes to the value func (engine *Engine) quote(value string, isColumn bool) string { - return quote(engine, value, isColumn) + if isColumn { + return quote(engine.colQuoter, value) + } + return quote(engine.tableQuoter, value) } // Quote add quotes to the value @@ -105,53 +91,25 @@ func (engine *Engine) Quotes() (byte, byte) { return quotes[0], quotes[1] } -// QuoteMode returns quote mode -func (engine *Engine) QuoteMode() QuoteMode { - return engine.quoteMode -} - -// QuotePolicy returns quote policy -func (engine *Engine) QuotePolicy() QuotePolicy { - return engine.quotePolicy -} - // IsReserved return true if the value is a reserved word of the database func (engine *Engine) IsReserved(value string) bool { return engine.dialect.IsReserved(value) } // quoteTo quotes string and writes into the buffer -func quoteTo(quoter Quoter, buf *strings.Builder, value string, isColumn bool) { - if isColumn { - if quoter.QuoteMode() == QuoteTableAndColumns || - quoter.QuoteMode() == QuoteColumnsOnly { - if quoter.QuotePolicy() == QuoteAddAlways { - realQuoteTo(quoter, buf, value) - return - } else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) { - realQuoteTo(quoter, buf, value) - return - } - } - buf.WriteString(value) +func quoteTo(quoter Quoter, buf *strings.Builder, value string) { + left, right := quoter.Quotes() + if quoter.QuotePolicy() == QuoteAddAlways { + realQuoteTo(left, right, buf, value) + return + } else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) { + realQuoteTo(left, right, buf, value) return } - - if quoter.QuoteMode() == QuoteTableAndColumns || - quoter.QuoteMode() == QuoteTableOnly { - if quoter.QuotePolicy() == QuoteAddAlways { - realQuoteTo(quoter, buf, value) - return - } else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) { - realQuoteTo(quoter, buf, value) - return - } - } buf.WriteString(value) - return } -func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) { +func realQuoteTo(prefix, suffix byte, buf *strings.Builder, value string) { if buf == nil { return } @@ -164,8 +122,6 @@ func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) { return } - prefix, suffix := quoter.Quotes() - i := 0 for i < len(value) { // start of a token; might be already quoted diff --git a/session_find.go b/session_find.go index 0d8985a4..e4bc3a55 100644 --- a/session_find.go +++ b/session_find.go @@ -141,7 +141,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) if session.statement.JoinStr == "" { if columnStr == "" { if session.statement.GroupByStr != "" { - columnStr = quoteColumns(session.engine, session.statement.GroupByStr) + columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr) } else { columnStr = session.statement.genColumnStr() } @@ -149,7 +149,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } else { if columnStr == "" { if session.statement.GroupByStr != "" { - columnStr = quoteColumns(session.engine, session.statement.GroupByStr) + columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr) } else { columnStr = "*" } diff --git a/session_insert.go b/session_insert.go index 4ccf08fe..f82b8fdd 100644 --- a/session_insert.go +++ b/session_insert.go @@ -249,15 +249,15 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if session.engine.dialect.DBType() == core.ORACLE { temp := fmt.Sprintf(") INTO %s (%v) VALUES (", session.engine.quote(tableName, false), - quoteJoin(session.engine, colNames)) + quoteJoin(session.engine.colQuoter, colNames)) sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL", session.engine.quote(tableName, false), - quoteJoin(session.engine, colNames), + quoteJoin(session.engine.colQuoter, colNames), strings.Join(colMultiPlaces, temp)) } else { sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)", session.engine.quote(tableName, false), - quoteJoin(session.engine, colNames), + quoteJoin(session.engine.colQuoter, colNames), strings.Join(colMultiPlaces, "),(")) } res, err := session.exec(sql, args...) @@ -855,7 +855,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", session.engine.quote(tableName, false), - quoteJoin(session.engine, columns), qm)); err != nil { + quoteJoin(session.engine.colQuoter, columns), qm)); err != nil { return 0, err } w.Append(args...) diff --git a/session_query.go b/session_query.go index 435349d7..8ea86df3 100644 --- a/session_query.go +++ b/session_query.go @@ -35,7 +35,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa if session.statement.JoinStr == "" { if columnStr == "" { if session.statement.GroupByStr != "" { - columnStr = quoteColumns(session.engine, session.statement.GroupByStr) + columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr) } else { columnStr = session.statement.genColumnStr() } @@ -43,7 +43,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa } else { if columnStr == "" { if session.statement.GroupByStr != "" { - columnStr = quoteColumns(session.engine, session.statement.GroupByStr) + columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr) } else { columnStr = "*" } diff --git a/session_update.go b/session_update.go index e899e7b6..322c9f91 100644 --- a/session_update.go +++ b/session_update.go @@ -100,7 +100,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, for idx, kv := range kvs { sps := strings.SplitN(kv, "=", 2) sps2 := strings.Split(sps[0], ".") - colName := unQuote(session.engine, sps2[len(sps2)-1]) + colName := unQuote(session.engine.colQuoter, sps2[len(sps2)-1]) if col := table.GetColumn(colName); col != nil { fieldValue, err := col.ValueOf(bean) diff --git a/statement.go b/statement.go index 16b35d31..589ebb01 100644 --- a/statement.go +++ b/statement.go @@ -615,7 +615,7 @@ func (statement *Statement) Cols(columns ...string) *Statement { newColumns := statement.colmap2NewColsWithQuote() - statement.ColumnStr = quoteJoin(statement.Engine, newColumns) + statement.ColumnStr = quoteJoin(statement.Engine.colQuoter, newColumns) return statement } @@ -650,7 +650,7 @@ func (statement *Statement) Omit(columns ...string) { for _, nc := range newColumns { statement.omitColumnMap = append(statement.omitColumnMap, nc) } - statement.OmitStr = quoteJoin(statement.Engine, newColumns) + statement.OmitStr = quoteJoin(statement.Engine.colQuoter, newColumns) } // Nullable Update use only: update columns to null when value is nullable and zero-value @@ -744,7 +744,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition } tbs := strings.Split(tp.TableName(), ".") - var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1]) + var aliasName = unQuote(statement.Engine.tableQuoter, tbs[len(tbs)-1]) fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) case *builder.Builder: @@ -755,7 +755,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition } tbs := strings.Split(tp.TableName(), ".") - var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1]) + var aliasName = unQuote(statement.Engine.tableQuoter, tbs[len(tbs)-1]) fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) default: @@ -821,7 +821,7 @@ func (statement *Statement) genColumnStr() string { buf.WriteString(".") } - quoteTo(statement.Engine, &buf, col.Name, true) + quoteTo(statement.Engine.colQuoter, &buf, col.Name) } return buf.String() @@ -940,7 +940,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, if len(statement.JoinStr) == 0 { if len(columnStr) == 0 { if len(statement.GroupByStr) > 0 { - columnStr = quoteColumns(statement.Engine, statement.GroupByStr) + columnStr = quoteColumns(statement.Engine.colQuoter, statement.GroupByStr) } else { columnStr = statement.genColumnStr() } @@ -948,7 +948,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, } else { if len(columnStr) == 0 { if len(statement.GroupByStr) > 0 { - columnStr = quoteColumns(statement.Engine, statement.GroupByStr) + columnStr = quoteColumns(statement.Engine.colQuoter, statement.GroupByStr) } } } diff --git a/statement_test.go b/statement_test.go index 84cad131..03d73ff7 100644 --- a/statement_test.go +++ b/statement_test.go @@ -243,6 +243,9 @@ func TestCol2NewColsWithQuote(t *testing.T) { statement := createTestStatement() - quotedCols := quoteJoin(statement.Engine, cols) - assert.EqualValues(t, []string{statement.Engine.Quote("f1", true), statement.Engine.Quote("f2", true), statement.Engine.Quote("t3.f3", true)}, quotedCols) + quotedCols := quoteJoin(statement.Engine.colQuoter, cols) + assert.EqualValues(t, statement.Engine.Quote("f1", true)+","+ + statement.Engine.Quote("f2", true)+","+ + statement.Engine.Quote("t3.f3", true), + quotedCols) } diff --git a/xorm.go b/xorm.go index e1c83b56..f6cfcc2b 100644 --- a/xorm.go +++ b/xorm.go @@ -95,6 +95,8 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { tagHandlers: defaultTagHandlers, cachers: make(map[string]core.Cacher), defaultContext: context.Background(), + colQuoter: newQuoter(dialect, QuoteAddAlways), + tableQuoter: newQuoter(dialect, QuoteAddAlways), } if uri.DbType == core.SQLITE {