diff --git a/engine.go b/engine.go index a7e52ea4..f0649697 100644 --- a/engine.go +++ b/engine.go @@ -55,6 +55,9 @@ type Engine struct { cacherLock sync.RWMutex defaultContext context.Context + + quotePolicy QuotePolicy + quoteMode QuoteMode } func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { @@ -175,85 +178,6 @@ func (engine *Engine) SupportInsertMany() bool { return engine.dialect.SupportInsertMany() } -func (engine *Engine) quoteColumns(columnStr string) string { - columns := strings.Split(columnStr, ",") - for i := 0; i < len(columns); i++ { - columns[i] = engine.Quote(strings.TrimSpace(columns[i])) - } - return strings.Join(columns, ",") -} - -// Quote Use QuoteStr quote the string sql -func (engine *Engine) Quote(value string) string { - value = strings.TrimSpace(value) - if len(value) == 0 { - return value - } - - buf := strings.Builder{} - engine.QuoteTo(&buf, value) - - return buf.String() -} - -// QuoteTo quotes string and writes into the buffer -func (engine *Engine) QuoteTo(buf *strings.Builder, value string) { - if buf == nil { - return - } - - value = strings.TrimSpace(value) - if value == "" { - return - } - - quoteTo(buf, engine.dialect.Quote(""), value) -} - -func quoteTo(buf *strings.Builder, quotePair string, value string) { - if len(quotePair) < 2 { // no quote - _, _ = buf.WriteString(value) - return - } - - prefix, suffix := quotePair[0], quotePair[1] - - i := 0 - for i < len(value) { - // start of a token; might be already quoted - if value[i] == '.' { - _ = buf.WriteByte('.') - i++ - } else if value[i] == prefix || value[i] == '`' { - // Has quotes; skip/normalize `name` to prefix+name+sufix - var ch byte - if value[i] == prefix { - ch = suffix - } else { - ch = '`' - } - i++ - _ = buf.WriteByte(prefix) - for ; i < len(value) && value[i] != ch; i++ { - _ = buf.WriteByte(value[i]) - } - _ = buf.WriteByte(suffix) - i++ - } else { - // Requires quotes - _ = buf.WriteByte(prefix) - for ; i < len(value) && value[i] != '.'; i++ { - _ = buf.WriteByte(value[i]) - } - _ = buf.WriteByte(suffix) - } - } -} - -func (engine *Engine) quote(sql string) string { - return engine.dialect.Quote(sql) -} - // SqlType will be deprecated, please use SQLType instead // // Deprecated: use SQLType instead @@ -495,6 +419,8 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D return err } + quoter := newQuoter(dialect, engine.quoteMode, engine.quotePolicy) + for i, table := range tables { if i > 0 { _, err = io.WriteString(w, "\n") @@ -514,10 +440,10 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D } cols := table.ColumnsSeq() - colNames := engine.dialect.Quote(strings.Join(cols, engine.dialect.Quote(", "))) - destColNames := dialect.Quote(strings.Join(cols, dialect.Quote(", "))) + colNames := quoteJoin(engine, cols) + destColNames := quoteJoin(quoter, cols) - rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.Quote(table.Name)) + rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.quote(table.Name, false)) if err != nil { return err } diff --git a/engine_cond.go b/engine_cond.go index 702ac804..a323f3ee 100644 --- a/engine_cond.go +++ b/engine_cond.go @@ -44,9 +44,9 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{}, if len(aliasName) > 0 { nm = aliasName } - colName = engine.Quote(nm) + "." + engine.Quote(col.Name) + colName = engine.quote(nm, false) + "." + engine.quote(col.Name, true) } else { - colName = engine.Quote(col.Name) + colName = engine.quote(col.Name, true) } fieldValuePtr, err := col.ValueOf(bean) diff --git a/engine_quote.go b/engine_quote.go new file mode 100644 index 00000000..1d0063fe --- /dev/null +++ b/engine_quote.go @@ -0,0 +1,211 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "strings" + + "xorm.io/core" +) + +// QuotePolicy describes quote handle policy +type QuotePolicy int + +// All QuotePolicies +const ( + QuoteAddAlways QuotePolicy = iota + QuoteNoAdd + QuoteAddReserved +) + +// QuoteMode quote on which types +type QuoteMode int + +// All QuoteModes +const ( + QuoteTableAndColumns QuoteMode = iota + QuoteTableOnly + QuoteColumnsOnly +) + +// Quoter represents an object has Quote method +type Quoter interface { + Quotes() (byte, byte) + QuotePolicy() QuotePolicy + QuoteMode() QuoteMode + IsReserved(string) bool +} + +type quoter struct { + dialect core.Dialect + quoteMode QuoteMode + quotePolicy QuotePolicy +} + +func newQuoter(dialect core.Dialect, quoteMode QuoteMode, quotePolicy QuotePolicy) Quoter { + return "er{ + dialect: dialect, + quoteMode: quoteMode, + quotePolicy: quotePolicy, + } +} + +func (q *quoter) Quotes() (byte, byte) { + quotes := q.dialect.Quote("") + return quotes[0], quotes[1] +} + +func (q *quoter) QuotePolicy() QuotePolicy { + return q.quotePolicy +} + +func (q *quoter) QuoteMode() QuoteMode { + return q.quoteMode +} + +func (q *quoter) IsReserved(value string) bool { + return q.dialect.IsReserved(value) +} + +func quoteColumns(quoter Quoter, columnStr string) string { + columns := strings.Split(columnStr, ",") + return quoteJoin(quoter, columns) +} + +func quoteJoin(quoter Quoter, columns []string) string { + for i := 0; i < len(columns); i++ { + columns[i] = quote(quoter, columns[i], true) + } + return strings.Join(columns, ",") +} + +// quote Use QuoteStr quote the string sql +func quote(quoter Quoter, value string, isColumn bool) string { + buf := strings.Builder{} + quoteTo(quoter, &buf, value, isColumn) + return buf.String() +} + +// Quote add quotes to the value +func (engine *Engine) quote(value string, isColumn bool) string { + return quote(engine, value, isColumn) +} + +// Quote add quotes to the value +func (engine *Engine) Quote(value string, isColumn bool) string { + return engine.quote(value, isColumn) +} + +// Quotes return the left quote and right quote +func (engine *Engine) Quotes() (byte, byte) { + quotes := engine.dialect.Quote("") + return quotes[0], quotes[1] +} + +// QuoteMode returns quote mode +func (engine *Engine) QuoteMode() QuoteMode { + return engine.quoteMode +} + +// QuotePolicy returns quote policy +func (engine *Engine) QuotePolicy() QuotePolicy { + return engine.quotePolicy +} + +// IsReserved return true if the value is a reserved word of the database +func (engine *Engine) IsReserved(value string) bool { + return engine.dialect.IsReserved(value) +} + +// quoteTo quotes string and writes into the buffer +func quoteTo(quoter Quoter, buf *strings.Builder, value string, isColumn bool) { + if isColumn { + if quoter.QuoteMode() == QuoteTableAndColumns || + quoter.QuoteMode() == QuoteColumnsOnly { + if quoter.QuotePolicy() == QuoteAddAlways { + realQuoteTo(quoter, buf, value) + return + } else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) { + realQuoteTo(quoter, buf, value) + return + } + } + buf.WriteString(value) + return + } + + if quoter.QuoteMode() == QuoteTableAndColumns || + quoter.QuoteMode() == QuoteTableOnly { + if quoter.QuotePolicy() == QuoteAddAlways { + realQuoteTo(quoter, buf, value) + return + } else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) { + realQuoteTo(quoter, buf, value) + return + } + } + buf.WriteString(value) + return +} + +func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) { + if buf == nil { + return + } + + value = strings.TrimSpace(value) + if value == "" { + return + } else if value == "*" { + buf.WriteString("*") + return + } + + prefix, suffix := quoter.Quotes() + + i := 0 + for i < len(value) { + // start of a token; might be already quoted + if value[i] == '.' { + _ = buf.WriteByte('.') + i++ + } else if value[i] == prefix || value[i] == '`' { + // Has quotes; skip/normalize `name` to prefix+name+sufix + var ch byte + if value[i] == prefix { + ch = suffix + } else { + ch = '`' + } + i++ + _ = buf.WriteByte(prefix) + for ; i < len(value) && value[i] != ch; i++ { + _ = buf.WriteByte(value[i]) + } + _ = buf.WriteByte(suffix) + i++ + } else { + // Requires quotes + _ = buf.WriteByte(prefix) + for ; i < len(value) && value[i] != '.'; i++ { + _ = buf.WriteByte(value[i]) + } + _ = buf.WriteByte(suffix) + } + } +} + +func unQuote(quoter Quoter, value string) string { + left, right := quoter.Quotes() + return strings.Trim(value, fmt.Sprintf("%v%v`", left, right)) +} + +func quoteJoinFunc(cols []string, quoteFunc func(string) string, sep string) string { + for i := range cols { + cols[i] = quoteFunc(cols[i]) + } + return strings.Join(cols, sep+" ") +} diff --git a/engine_quote_test.go b/engine_quote_test.go new file mode 100644 index 00000000..db38cb8b --- /dev/null +++ b/engine_quote_test.go @@ -0,0 +1,20 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestQuoteColumns(t *testing.T) { + cols := []string{"f1", "f2", "f3"} + quoteFunc := func(value string) string { + return "[" + value + "]" + } + + assert.EqualValues(t, "[f1], [f2], [f3]", quoteJoinFunc(cols, quoteFunc, ",")) +} diff --git a/engine_table.go b/engine_table.go index eb5aa850..c7018d75 100644 --- a/engine_table.go +++ b/engine_table.go @@ -63,9 +63,9 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string { case []string: t := tablename.([]string) if len(t) > 1 { - return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) + return fmt.Sprintf("%v AS %v", engine.quote(t[0], false), engine.quote(t[1], false)) } else if len(t) == 1 { - return engine.Quote(t[0]) + return engine.quote(t[0], false) } case []interface{}: t := tablename.([]interface{}) @@ -84,15 +84,15 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string { if t.Kind() == reflect.Struct { table = engine.tbNameForMap(v) } else { - table = engine.Quote(fmt.Sprintf("%v", f)) + table = engine.quote(fmt.Sprintf("%v", f), false) } } } if l > 1 { - return fmt.Sprintf("%v AS %v", engine.Quote(table), - engine.Quote(fmt.Sprintf("%v", t[1]))) + return fmt.Sprintf("%v AS %v", engine.quote(table, false), + engine.quote(fmt.Sprintf("%v", t[1]), false)) } else if l == 1 { - return engine.Quote(table) + return engine.quote(table, false) } case TableName: return tablename.(TableName).TableName() @@ -107,7 +107,7 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string { if t.Kind() == reflect.Struct { return engine.tbNameForMap(v) } - return engine.Quote(fmt.Sprintf("%v", tablename)) + return engine.quote(fmt.Sprintf("%v", tablename), false) } return "" } diff --git a/helpers.go b/helpers.go index a31e922c..8ad99277 100644 --- a/helpers.go +++ b/helpers.go @@ -323,10 +323,3 @@ func eraseAny(value string, strToErase ...string) string { return replacer.Replace(value) } - -func quoteColumns(cols []string, quoteFunc func(string) string, sep string) string { - for i := range cols { - cols[i] = quoteFunc(cols[i]) - } - return strings.Join(cols, sep+" ") -} diff --git a/helpers_test.go b/helpers_test.go index caf7b9f0..fc9ece27 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -16,12 +16,3 @@ func TestEraseAny(t *testing.T) { assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`")) assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]")) } - -func TestQuoteColumns(t *testing.T) { - cols := []string{"f1", "f2", "f3"} - quoteFunc := func(value string) string { - return "[" + value + "]" - } - - assert.EqualValues(t, "[f1], [f2], [f3]", quoteColumns(cols, quoteFunc, ",")) -} diff --git a/interface.go b/interface.go index a564db12..473f9d1a 100644 --- a/interface.go +++ b/interface.go @@ -89,7 +89,7 @@ type EngineInterface interface { MapCacher(interface{}, core.Cacher) error NewSession() *Session NoAutoTime() *Session - Quote(string) string + Quote(string, bool) string SetCacher(string, core.Cacher) SetConnMaxLifetime(time.Duration) SetDefaultCacher(core.Cacher) diff --git a/rows_test.go b/rows_test.go index af333861..f022a888 100644 --- a/rows_test.go +++ b/rows_test.go @@ -70,7 +70,7 @@ func TestRows(t *testing.T) { } assert.EqualValues(t, 1, cnt) - var tbName = testEngine.Quote(testEngine.TableName(user, true)) + var tbName = testEngine.Quote(testEngine.TableName(user, true), false) rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows)) assert.NoError(t, err) defer rows2.Close() diff --git a/session_delete.go b/session_delete.go index 675d4d8c..8b7c9a07 100644 --- a/session_delete.go +++ b/session_delete.go @@ -106,7 +106,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { } var tableNameNoQuote = session.statement.TableName() - var tableName = session.engine.Quote(tableNameNoQuote) + var tableName = session.engine.quote(tableNameNoQuote, false) var table = session.statement.RefTable var deleteSQL string if len(condSQL) > 0 { @@ -160,8 +160,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) { deletedColumn := table.DeletedColumn() realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", - session.engine.Quote(session.statement.TableName()), - session.engine.Quote(deletedColumn.Name), + session.engine.quote(session.statement.TableName(), false), + session.engine.quote(deletedColumn.Name, true), condSQL) if len(orderSQL) > 0 { diff --git a/session_exist.go b/session_exist.go index 660cc47e..c76aa7f5 100644 --- a/session_exist.go +++ b/session_exist.go @@ -34,7 +34,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { return false, ErrTableNotFound } - tableName = session.statement.Engine.Quote(tableName) + tableName = session.statement.Engine.quote(tableName, false) if session.statement.cond.IsValid() { condSQL, condArgs, err := builder.ToSQL(session.statement.cond) diff --git a/session_find.go b/session_find.go index e16ae54c..0d8985a4 100644 --- a/session_find.go +++ b/session_find.go @@ -112,13 +112,13 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. // See https://gitea.com/xorm/xorm/issues/179 if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled - var colName = session.engine.Quote(col.Name) + var colName = session.engine.quote(col.Name, true) if addedTableName { var nm = session.statement.TableName() if len(session.statement.TableAlias) > 0 { nm = session.statement.TableAlias } - colName = session.engine.Quote(nm) + "." + colName + colName = session.engine.quote(nm, false) + "." + colName } autoCond = session.engine.CondDeleted(colName) @@ -141,7 +141,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) if session.statement.JoinStr == "" { if columnStr == "" { if session.statement.GroupByStr != "" { - columnStr = session.engine.quoteColumns(session.statement.GroupByStr) + columnStr = quoteColumns(session.engine, session.statement.GroupByStr) } else { columnStr = session.statement.genColumnStr() } @@ -149,7 +149,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } else { if columnStr == "" { if session.statement.GroupByStr != "" { - columnStr = session.engine.quoteColumns(session.statement.GroupByStr) + columnStr = quoteColumns(session.engine, session.statement.GroupByStr) } else { columnStr = "*" } diff --git a/session_find_test.go b/session_find_test.go index f805f06e..0f8ebd14 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/core" ) func TestJoinLimit(t *testing.T) { @@ -103,7 +103,7 @@ func TestFind(t *testing.T) { } users2 := make([]Userinfo, 0) - var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true)) + var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true), false) err = testEngine.SQL("select * from " + tbName).Find(&users2) assert.NoError(t, err) } diff --git a/session_insert.go b/session_insert.go index 1e19ce7a..4ccf08fe 100644 --- a/session_insert.go +++ b/session_insert.go @@ -248,16 +248,16 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error var sql string if session.engine.dialect.DBType() == core.ORACLE { temp := fmt.Sprintf(") INTO %s (%v) VALUES (", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ",")) + session.engine.quote(tableName, false), + quoteJoin(session.engine, colNames)) sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), + session.engine.quote(tableName, false), + quoteJoin(session.engine, colNames), strings.Join(colMultiPlaces, temp)) } else { sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), + session.engine.quote(tableName, false), + quoteJoin(session.engine, colNames), strings.Join(colMultiPlaces, "),(")) } res, err := session.exec(sql, args...) @@ -358,7 +358,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } var buf = builder.NewWriter() - if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s", session.engine.Quote(tableName))); err != nil { + if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s", session.engine.quote(tableName, false))); err != nil { return 0, err } @@ -399,7 +399,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 0, err } - if _, err := buf.WriteString(fmt.Sprintf(" FROM %v WHERE ", session.engine.Quote(tableName))); err != nil { + if _, err := buf.WriteString(fmt.Sprintf(" FROM %v WHERE ", session.engine.quote(tableName, false))); err != nil { return 0, err } @@ -426,7 +426,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES { - if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil { + if _, err := buf.WriteString(" RETURNING " + session.engine.quote(table.AutoIncrement, true)); err != nil { return 0, err } } @@ -731,7 +731,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err w := builder.NewWriter() if session.statement.cond.IsValid() { - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { + if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.quote(tableName, false))); err != nil { return 0, err } @@ -756,7 +756,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err } } - if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil { + if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.quote(tableName, false))); err != nil { return 0, err } @@ -767,7 +767,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err qm := strings.Repeat("?,", len(columns)) qm = qm[:len(qm)-1] - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil { + if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.quote(tableName, false), strings.Join(columns, "`,`"), qm)); err != nil { return 0, err } w.Append(args...) @@ -817,7 +817,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { w := builder.NewWriter() if session.statement.cond.IsValid() { - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { + if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.quote(tableName, false))); err != nil { return 0, err } @@ -842,7 +842,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { } } - if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil { + if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.quote(tableName, false))); err != nil { return 0, err } @@ -853,7 +853,9 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { qm := strings.Repeat("?,", len(columns)) qm = qm[:len(qm)-1] - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil { + if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", + session.engine.quote(tableName, false), + quoteJoin(session.engine, columns), qm)); err != nil { return 0, err } w.Append(args...) diff --git a/session_query.go b/session_query.go index 21c00b8d..435349d7 100644 --- a/session_query.go +++ b/session_query.go @@ -35,7 +35,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa if session.statement.JoinStr == "" { if columnStr == "" { if session.statement.GroupByStr != "" { - columnStr = session.engine.quoteColumns(session.statement.GroupByStr) + columnStr = quoteColumns(session.engine, session.statement.GroupByStr) } else { columnStr = session.statement.genColumnStr() } @@ -43,7 +43,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa } else { if columnStr == "" { if session.statement.GroupByStr != "" { - columnStr = session.engine.quoteColumns(session.statement.GroupByStr) + columnStr = quoteColumns(session.engine, session.statement.GroupByStr) } else { columnStr = "*" } diff --git a/session_query_test.go b/session_query_test.go index 772206a8..b831754c 100644 --- a/session_query_test.go +++ b/session_query_test.go @@ -372,7 +372,7 @@ func TestJoinWithSubQuery(t *testing.T) { assert.EqualValues(t, 1, cnt) var querys []JoinWithSubQuery1 - err = testEngine.Join("INNER", builder.Select("id").From(testEngine.Quote(testEngine.TableName("join_with_sub_query_depart", true))), + err = testEngine.Join("INNER", builder.Select("id").From(testEngine.Quote(testEngine.TableName("join_with_sub_query_depart", true), false)), "join_with_sub_query_depart.id = join_with_sub_query1.depart_id").Find(&querys) assert.NoError(t, err) assert.EqualValues(t, 1, len(querys)) diff --git a/session_schema.go b/session_schema.go index 5e576c29..722e01ec 100644 --- a/session_schema.go +++ b/session_schema.go @@ -168,7 +168,7 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { func (session *Session) isTableEmpty(tableName string) (bool, error) { var total int64 - sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true))) + sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.quote(session.engine.TableName(tableName, true), false)) err := session.queryRow(sqlStr).Scan(&total) if err != nil { if err == sql.ErrNoRows { diff --git a/session_update.go b/session_update.go index 231163e0..3993eb15 100644 --- a/session_update.go +++ b/session_update.go @@ -102,7 +102,8 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, sps2 := strings.Split(sps[0], ".") colName := sps2[len(sps2)-1] // treat quote prefix, suffix and '`' as quotes - quotes := append(strings.Split(session.engine.Quote(""), ""), "`") + left, right := session.engine.Quotes() + quotes := []string{string(left), string(right)} if strings.ContainsAny(colName, strings.Join(quotes, "")) { colName = strings.TrimSpace(eraseAny(colName, quotes...)) } else { @@ -195,7 +196,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 bValue := reflect.Indirect(reflect.ValueOf(bean)) for _, v := range bValue.MapKeys() { - colNames = append(colNames, session.engine.Quote(v.String())+" = ?") + colNames = append(colNames, session.engine.quote(v.String(), true)+" = ?") args = append(args, bValue.MapIndex(v).Interface()) } } else { @@ -207,7 +208,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.statement.UseAutoTime && table != nil && table.Updated != "" { if !session.statement.columnMap.contain(table.Updated) && !session.statement.omitColumnMap.contain(table.Updated) { - colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") + colNames = append(colNames, session.engine.quote(table.Updated, true)+" = ?") col := table.UpdatedColumn() val, t := session.engine.nowTime(col) args = append(args, val) @@ -225,13 +226,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 // for update action to like "column = column + ?" incColumns := session.statement.incrColumns for i, colName := range incColumns.colNames { - colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?") + colNames = append(colNames, session.engine.quote(colName, true)+" = "+session.engine.quote(colName, true)+" + ?") args = append(args, incColumns.args[i]) } // for update action to like "column = column - ?" decColumns := session.statement.decrColumns for i, colName := range decColumns.colNames { - colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?") + colNames = append(colNames, session.engine.quote(colName, true)+" = "+session.engine.quote(colName, true)+" - ?") args = append(args, decColumns.args[i]) } // for update action to like "column = expression" @@ -239,13 +240,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 for i, colName := range exprColumns.colNames { switch tp := exprColumns.args[i].(type) { case string: - colNames = append(colNames, session.engine.Quote(colName)+" = "+tp) + colNames = append(colNames, session.engine.quote(colName, true)+" = "+tp) case *builder.Builder: subQuery, subArgs, err := builder.ToSQL(tp) if err != nil { return 0, err } - colNames = append(colNames, session.engine.Quote(colName)+" = ("+subQuery+")") + colNames = append(colNames, session.engine.quote(colName, true)+" = ("+subQuery+")") args = append(args, subArgs...) } } @@ -281,7 +282,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if !condBeanIsStruct && table != nil { if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled - autoCond1 := session.engine.CondDeleted(session.engine.Quote(col.Name)) + autoCond1 := session.engine.CondDeleted(session.engine.quote(col.Name, true)) if autoCond == nil { autoCond = autoCond1 @@ -307,8 +308,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, err } - cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()}) - colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1") + cond = cond.And(builder.Eq{session.engine.quote(table.Version, true): verValue.Interface()}) + colNames = append(colNames, session.engine.quote(table.Version, true)+" = "+session.engine.quote(table.Version, true)+" + 1") } condSQL, condArgs, err = builder.ToSQL(cond) @@ -333,7 +334,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } else if st.Engine.dialect.DBType() == core.SQLITE { tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", - session.engine.Quote(tableName), tempCondSQL), condArgs...)) + session.engine.quote(tableName, false), tempCondSQL), condArgs...)) condSQL, condArgs, err = builder.ToSQL(cond) if err != nil { return 0, err @@ -344,7 +345,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } else if st.Engine.dialect.DBType() == core.POSTGRES { tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", - session.engine.Quote(tableName), tempCondSQL), condArgs...)) + session.engine.quote(tableName, false), tempCondSQL), condArgs...)) condSQL, condArgs, err = builder.ToSQL(cond) if err != nil { return 0, err @@ -358,7 +359,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 table != nil && len(table.PrimaryKeys) == 1 { cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0], - session.engine.Quote(tableName), condSQL), condArgs...) + session.engine.quote(tableName, false), condSQL), condArgs...) condSQL, condArgs, err = builder.ToSQL(cond) if err != nil { @@ -377,7 +378,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, errors.New("No content found to be updated") } - var tableAlias = session.engine.Quote(tableName) + var tableAlias = session.engine.Quote(tableName, false) var fromSQL string if session.statement.TableAlias != "" { switch session.engine.dialect.DBType() { @@ -532,7 +533,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac args = append(args, arg) } - colNames = append(colNames, session.engine.Quote(col.Name)+" = ?") + colNames = append(colNames, session.engine.quote(col.Name, true)+" = ?") } return colNames, args, nil } diff --git a/statement.go b/statement.go index 67e35213..1b51fe54 100644 --- a/statement.go +++ b/statement.go @@ -152,7 +152,7 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme queryMap := query.(map[string]interface{}) newMap := make(map[string]interface{}) for k, v := range queryMap { - newMap[statement.Engine.Quote(k)] = v + newMap[statement.Engine.Quote(k, true)] = v } statement.cond = statement.cond.And(builder.Eq(newMap)) case builder.Cond: @@ -195,14 +195,14 @@ func (statement *Statement) Or(query interface{}, args ...interface{}) *Statemen // In generate "Where column IN (?) " statement func (statement *Statement) In(column string, args ...interface{}) *Statement { - in := builder.In(statement.Engine.Quote(column), args...) + in := builder.In(statement.Engine.quote(column, true), args...) statement.cond = statement.cond.And(in) return statement } // NotIn generate "Where column NOT IN (?) " statement func (statement *Statement) NotIn(column string, args ...interface{}) *Statement { - notIn := builder.NotIn(statement.Engine.Quote(column), args...) + notIn := builder.NotIn(statement.Engine.quote(column, true), args...) statement.cond = statement.cond.And(notIn) return statement } @@ -339,7 +339,7 @@ func (statement *Statement) buildUpdates(bean interface{}, if fieldValue.IsNil() { if includeNil { args = append(args, nil) - colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) + colNames = append(colNames, fmt.Sprintf("%v=?", engine.quote(col.Name, true))) } continue } else if !fieldValue.IsValid() { @@ -486,7 +486,7 @@ func (statement *Statement) buildUpdates(bean interface{}, if col.IsPrimaryKey && engine.dialect.DBType() == "ql" { continue } - colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) + colNames = append(colNames, fmt.Sprintf("%v = ?", engine.quote(col.Name, true))) } return colNames, args @@ -502,9 +502,9 @@ func (statement *Statement) colName(col *core.Column, tableName string) string { if len(statement.TableAlias) > 0 { nm = statement.TableAlias } - return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name) + return statement.Engine.quote(nm, false) + "." + statement.Engine.quote(col.Name, true) } - return statement.Engine.Quote(col.Name) + return statement.Engine.quote(col.Name, true) } // TableName return current tableName @@ -572,9 +572,10 @@ func (statement *Statement) SetExpr(column string, expression interface{}) *Stat func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { newColumns := make([]string, 0) - quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") + left, right := statement.Engine.Quotes() + quotes := []string{string(left), string(right)} for _, col := range columns { - newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...))) + newColumns = append(newColumns, statement.Engine.quote(eraseAny(col, quotes...), true)) } return newColumns } @@ -583,7 +584,7 @@ func (statement *Statement) colmap2NewColsWithQuote() []string { newColumns := make([]string, len(statement.columnMap), len(statement.columnMap)) copy(newColumns, statement.columnMap) for i := 0; i < len(statement.columnMap); i++ { - newColumns[i] = statement.Engine.Quote(newColumns[i]) + newColumns[i] = statement.Engine.quote(newColumns[i], true) } return newColumns } @@ -616,8 +617,7 @@ func (statement *Statement) Cols(columns ...string) *Statement { newColumns := statement.colmap2NewColsWithQuote() - statement.ColumnStr = strings.Join(newColumns, ", ") - statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) + statement.ColumnStr = quoteJoin(statement.Engine, newColumns) return statement } @@ -652,7 +652,7 @@ func (statement *Statement) Omit(columns ...string) { for _, nc := range newColumns { statement.omitColumnMap = append(statement.omitColumnMap, nc) } - statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) + statement.OmitStr = quoteJoin(statement.Engine, newColumns) } // Nullable Update use only: update columns to null when value is nullable and zero-value @@ -745,9 +745,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition return statement } tbs := strings.Split(tp.TableName(), ".") - quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") - var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, "")) + var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1]) fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) case *builder.Builder: @@ -757,9 +756,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition return statement } tbs := strings.Split(tp.TableName(), ".") - quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") - var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, "")) + var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1]) fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) default: @@ -825,7 +823,7 @@ func (statement *Statement) genColumnStr() string { buf.WriteString(".") } - statement.Engine.QuoteTo(&buf, col.Name) + quoteTo(statement.Engine, &buf, col.Name, true) } return buf.String() @@ -880,9 +878,9 @@ func (statement *Statement) genDelIndexSQL() []string { } else if index.Type == core.IndexType { rIdxName = indexName(idxPrefixName, idxName) } - sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true))) + sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.quote(statement.Engine.TableName(rIdxName, true), false)) if statement.Engine.dialect.IndexOnTable() { - sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) + sql += fmt.Sprintf(" ON %v", statement.Engine.quote(tbName, false)) } sqls = append(sqls, sql) } @@ -890,8 +888,8 @@ func (statement *Statement) genDelIndexSQL() []string { } func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) { - quote := statement.Engine.Quote - sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()), + quote := statement.Engine.quote + sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName(), false), col.String(statement.Engine.dialect)) if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 { sql += " COMMENT '" + col.Comment + "'" @@ -944,7 +942,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, if len(statement.JoinStr) == 0 { if len(columnStr) == 0 { if len(statement.GroupByStr) > 0 { - columnStr = statement.Engine.quoteColumns(statement.GroupByStr) + columnStr = quoteColumns(statement.Engine, statement.GroupByStr) } else { columnStr = statement.genColumnStr() } @@ -952,7 +950,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, } else { if len(columnStr) == 0 { if len(statement.GroupByStr) > 0 { - columnStr = statement.Engine.quoteColumns(statement.GroupByStr) + columnStr = quoteColumns(statement.Engine, statement.GroupByStr) } } } @@ -1020,7 +1018,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri var sumStrs = make([]string, 0, len(columns)) for _, colName := range columns { if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { - colName = statement.Engine.Quote(colName) + colName = statement.Engine.quote(colName, true) } sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) } @@ -1043,7 +1041,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n var ( distinct string dialect = statement.Engine.Dialect() - quote = statement.Engine.Quote + quote = statement.Engine.quote fromStr = " FROM " top, mssqlCondi, whereStr string ) @@ -1057,14 +1055,14 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") { fromStr += statement.TableName() } else { - fromStr += quote(statement.TableName()) + fromStr += quote(statement.TableName(), false) } if statement.TableAlias != "" { if dialect.DBType() == core.ORACLE { - fromStr += " " + quote(statement.TableAlias) + fromStr += " " + quote(statement.TableAlias, false) } else { - fromStr += " AS " + quote(statement.TableAlias) + fromStr += " AS " + quote(statement.TableAlias, false) } } if statement.JoinStr != "" { @@ -1181,10 +1179,10 @@ func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bo var colnames = make([]string, len(cols)) for i, col := range cols { if includeTableName { - colnames[i] = statement.Engine.Quote(statement.TableName()) + - "." + statement.Engine.Quote(col.Name) + colnames[i] = statement.Engine.quote(statement.TableName(), false) + + "." + statement.Engine.quote(col.Name, true) } else { - colnames[i] = statement.Engine.Quote(col.Name) + colnames[i] = statement.Engine.quote(col.Name, true) } } return strings.Join(colnames, ", ") @@ -1224,7 +1222,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { if len(sqls) != 2 { if len(sqls) == 1 { return sqls[0], fmt.Sprintf("SELECT %v FROM %v", - colstrs, statement.Engine.Quote(statement.TableName())) + colstrs, statement.Engine.quote(statement.TableName(), false)) } return "", "" } @@ -1251,6 +1249,6 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { } return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", - colstrs, statement.Engine.Quote(statement.TableName()), + colstrs, statement.Engine.quote(statement.TableName(), false), whereStr) } diff --git a/statement_test.go b/statement_test.go index acc542ab..832deca4 100644 --- a/statement_test.go +++ b/statement_test.go @@ -244,5 +244,5 @@ func TestCol2NewColsWithQuote(t *testing.T) { statement := createTestStatement() quotedCols := statement.col2NewColsWithQuote(cols...) - assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols) + assert.EqualValues(t, []string{statement.Engine.Quote("f1", true), statement.Engine.Quote("f2", true), statement.Engine.Quote("t3.f3", true)}, quotedCols) } diff --git a/tag_extends_test.go b/tag_extends_test.go index 5a8031f0..78f3fbe4 100644 --- a/tag_extends_test.go +++ b/tag_extends_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/core" ) type tempUser struct { @@ -157,7 +157,7 @@ func TestExtends(t *testing.T) { uiid := testEngine.GetColumnMapper().Obj2Table("Id") udid := "detail_id" sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", - qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) + qt(ui, false), qt(ud, false), qt(ui, false), qt(udid, true), qt(ud, false), qt(uiid, true)) b, err := testEngine.SQL(sql).NoCascade().Get(&info) assert.NoError(t, err) if !b { @@ -175,7 +175,7 @@ func TestExtends(t *testing.T) { fmt.Println("----join--info2") var info2 UserAndDetail b, err = testEngine.Table(&Userinfo{}). - Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). + Join("LEFT", qt(ud, false), qt(ui, false)+"."+qt("detail_id", true)+" = "+qt(ud, false)+"."+qt(uiid, true)). NoCascade().Get(&info2) if err != nil { t.Error(err) @@ -196,7 +196,7 @@ func TestExtends(t *testing.T) { fmt.Println("----join--infos2") var infos2 = make([]UserAndDetail, 0) err = testEngine.Table(&Userinfo{}). - Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). + Join("LEFT", qt(ud, false), qt(ui, false)+"."+qt("detail_id", true)+" = "+qt(ud, false)+"."+qt(uiid, true)). NoCascade(). Find(&infos2) assert.NoError(t, err) @@ -286,9 +286,9 @@ func TestExtends2(t *testing.T) { var mapper = testEngine.GetTableMapper().Obj2Table var quote = testEngine.Quote - userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) - typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) - msgTableName := quote(testEngine.TableName(mapper("Message"), true)) + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true), false) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true), false) + msgTableName := quote(testEngine.TableName(mapper("Message"), true), false) list := make([]Message, 0) err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). @@ -355,9 +355,9 @@ func TestExtends3(t *testing.T) { var mapper = testEngine.GetTableMapper().Obj2Table var quote = testEngine.Quote - userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) - typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) - msgTableName := quote(testEngine.TableName(mapper("Message"), true)) + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true), false) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true), false) + msgTableName := quote(testEngine.TableName(mapper("Message"), true), false) list := make([]MessageExtend3, 0) err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). @@ -449,9 +449,9 @@ func TestExtends4(t *testing.T) { var mapper = testEngine.GetTableMapper().Obj2Table var quote = testEngine.Quote - userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) - typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) - msgTableName := quote(testEngine.TableName(mapper("Message"), true)) + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true), false) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true), false) + msgTableName := quote(testEngine.TableName(mapper("Message"), true), false) list := make([]MessageExtend4, 0) err = session.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). @@ -549,21 +549,21 @@ func TestExtends5(t *testing.T) { var mapper = testEngine.GetTableMapper().Obj2Table var quote = testEngine.Quote - bookTableName := quote(testEngine.TableName(mapper("Book"), true)) - sizeTableName := quote(testEngine.TableName(mapper("Size"), true)) + bookTableName := quote(testEngine.TableName(mapper("Book"), true), false) + sizeTableName := quote(testEngine.TableName(mapper("Size"), true), false) list := make([]Book, 0) err = session. Select(fmt.Sprintf( "%s.%s, sc.%s AS %s, sc.%s AS %s, s.%s, s.%s", - quote(bookTableName), - quote("id"), - quote("Width"), - quote("ClosedWidth"), - quote("Height"), - quote("ClosedHeight"), - quote("Width"), - quote("Height"), + quote(bookTableName, false), + quote("id", true), + quote("Width", true), + quote("ClosedWidth", true), + quote("Height", true), + quote("ClosedHeight", true), + quote("Width", true), + quote("Height", true), )). Table(bookTableName). Join( diff --git a/types_test.go b/types_test.go index 274609b2..780035e8 100644 --- a/types_test.go +++ b/types_test.go @@ -9,8 +9,8 @@ import ( "fmt" "testing" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/core" ) func TestArrayField(t *testing.T) { @@ -305,7 +305,7 @@ func TestCustomType2(t *testing.T) { assert.NoError(t, err) tableName := testEngine.TableName(&uc, true) - _, err = testEngine.Exec("delete from " + testEngine.Quote(tableName)) + _, err = testEngine.Exec("delete from " + testEngine.Quote(tableName, false)) assert.NoError(t, err) session := testEngine.NewSession()