diff --git a/dialects/filter.go b/dialects/filter.go index add8cc7d..6968b6ce 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -7,8 +7,6 @@ package dialects import ( "fmt" "strings" - - "xorm.io/xorm/schemas" ) // Filter is an interface to filter SQL @@ -16,48 +14,6 @@ type Filter interface { Do(sql string) string } -// QuoteFilter filter SQL replace ` to database's own quote character -type QuoteFilter struct { - quoter schemas.Quoter -} - -func (s *QuoteFilter) Do(sql string) string { - if s.quoter.IsEmpty() { - return sql - } - - var buf strings.Builder - buf.Grow(len(sql)) - - var beginSingleQuote bool - for i := 0; i < len(sql); i++ { - if !beginSingleQuote && sql[i] == '`' { - var j = i + 1 - for ; j < len(sql); j++ { - if sql[j] == '`' { - break - } - } - word := sql[i+1 : j] - isReserved := s.quoter.IsReserved(word) - if isReserved { - buf.WriteByte(s.quoter.Prefix) - } - buf.WriteString(word) - if isReserved { - buf.WriteByte(s.quoter.Suffix) - } - i = j - } else { - if sql[i] == '\'' { - beginSingleQuote = !beginSingleQuote - } - buf.WriteByte(sql[i]) - } - } - return buf.String() -} - // SeqFilter filter SQL replace ?, ? ... to $1, $2 ... type SeqFilter struct { Prefix string diff --git a/dialects/filter_test.go b/dialects/filter_test.go index e8395156..7e2ef0a2 100644 --- a/dialects/filter_test.go +++ b/dialects/filter_test.go @@ -3,38 +3,9 @@ package dialects import ( "testing" - "xorm.io/xorm/schemas" - "github.com/stretchr/testify/assert" ) -func TestQuoteFilter_Do(t *testing.T) { - f := QuoteFilter{schemas.Quoter{'[', ']', schemas.AlwaysReserve}} - var kases = []struct { - source string - expected string - }{ - { - "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?", - "SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", - }, - { - "SELECT 'abc```test```''', `a` FROM b", - "SELECT 'abc```test```''', [a] FROM b", - }, - { - "UPDATE table SET `a` = ~ `a`, `b`='abc`'", - "UPDATE table SET [a] = ~ [a], [b]='abc`'", - }, - } - - for _, kase := range kases { - t.Run(kase.source, func(t *testing.T) { - assert.EqualValues(t, kase.expected, f.Do(kase.source)) - }) - } -} - func TestSeqFilter(t *testing.T) { var kases = map[string]string{ "SELECT * FROM TABLE1 WHERE a=? AND b=?": "SELECT * FROM TABLE1 WHERE a=$1 AND b=$2", diff --git a/dialects/mssql.go b/dialects/mssql.go index 6ba2cd97..cad18a29 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -525,7 +525,7 @@ func (db *mssql) ForUpdateSQL(query string) string { } func (db *mssql) Filters() []Filter { - return []Filter{&QuoteFilter{db.Quoter()}} + return []Filter{} } type odbcDriver struct { diff --git a/dialects/oracle.go b/dialects/oracle.go index 045ad99b..c48d32b9 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -793,7 +793,6 @@ func (db *oracle) GetIndexes(ctx context.Context, tableName string) (map[string] func (db *oracle) Filters() []Filter { return []Filter{ - &QuoteFilter{db.Quoter()}, &SeqFilter{Prefix: ":", Start: 1}, } } diff --git a/dialects/postgres.go b/dialects/postgres.go index e393452f..8412ad40 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1231,7 +1231,7 @@ func (db *postgres) GetIndexes(ctx context.Context, tableName string) (map[strin } func (db *postgres) Filters() []Filter { - return []Filter{&QuoteFilter{db.Quoter()}, &SeqFilter{Prefix: "$", Start: 1}} + return []Filter{&SeqFilter{Prefix: "$", Start: 1}} } type pqDriver struct { diff --git a/internal/statements/query.go b/internal/statements/query.go index 8f7aeebb..1568259e 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -16,11 +16,11 @@ import ( func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) { if len(sqlOrArgs) > 0 { - return ConvertSQLOrArgs(sqlOrArgs...) + return statement.ConvertSQLOrArgs(sqlOrArgs...) } if statement.RawSQL != "" { - return statement.RawSQL, statement.RawParams, nil + return statement.GenRawSQL(), statement.RawParams, nil } if len(statement.TableName()) <= 0 { @@ -74,7 +74,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { if statement.RawSQL != "" { - return statement.RawSQL, statement.RawParams, nil + return statement.GenRawSQL(), statement.RawParams, nil } statement.SetRefBean(bean) @@ -83,6 +83,8 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri for _, colName := range columns { if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { colName = statement.quote(colName) + } else { + colName = statement.ReplaceQuote(colName) } sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) } @@ -153,7 +155,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) { if statement.RawSQL != "" { - return statement.RawSQL, statement.RawParams, nil + return statement.GenRawSQL(), statement.RawParams, nil } var condArgs []interface{} @@ -193,7 +195,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB distinct = "DISTINCT " } - condSQL, condArgs, err := builder.ToSQL(statement.cond) + condSQL, condArgs, err := statement.GenCondSQL(statement.cond) if err != nil { return "", nil, err } @@ -313,7 +315,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) { if statement.RawSQL != "" { - return statement.RawSQL, statement.RawParams, nil + return statement.GenRawSQL(), statement.RawParams, nil } var sqlStr string @@ -332,7 +334,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac } if statement.Conds().IsValid() { - condSQL, condArgs, err := builder.ToSQL(statement.Conds()) + condSQL, condArgs, err := statement.GenCondSQL(statement.Conds()) if err != nil { return "", nil, err } @@ -382,7 +384,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) { if statement.RawSQL != "" { - return statement.RawSQL, statement.RawParams, nil + return statement.GenRawSQL(), statement.RawParams, nil } var sqlStr string diff --git a/internal/statements/statement.go b/internal/statements/statement.go index f8cb6f2d..4beb9a7e 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -98,6 +98,27 @@ func (statement *Statement) omitStr() string { return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,") } +// GenRawSQL generates correct raw sql +func (statement *Statement) GenRawSQL() string { + return statement.ReplaceQuote(statement.RawSQL) +} + +func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) { + condSQL, condArgs, err := builder.ToSQL(condOrBuilder) + if err != nil { + return "", nil, err + } + return statement.ReplaceQuote(condSQL), condArgs, nil +} + +func (statement *Statement) ReplaceQuote(sql string) string { + if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || + statement.dialect.URI().DBType == schemas.SQLITE { + return sql + } + return statement.dialect.Quoter().Replace(sql) +} + func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) { statement.Context = ctxCache } @@ -348,7 +369,11 @@ func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { // SetExpr Generate "Update ... Set column = {expression}" statement func (statement *Statement) SetExpr(column string, expression interface{}) *Statement { - statement.ExprColumns.addParam(column, expression) + if e, ok := expression.(string); ok { + statement.ExprColumns.addParam(column, statement.dialect.Quoter().Replace(e)) + } else { + statement.ExprColumns.addParam(column, expression) + } return statement } @@ -367,7 +392,7 @@ func (statement *Statement) ForUpdate() *Statement { // Select replace select func (statement *Statement) Select(str string) *Statement { - statement.SelectStr = str + statement.SelectStr = statement.ReplaceQuote(str) return statement } @@ -458,7 +483,7 @@ func (statement *Statement) OrderBy(order string) *Statement { if len(statement.OrderStr) > 0 { statement.OrderStr += ", " } - statement.OrderStr += order + statement.OrderStr += statement.ReplaceQuote(order) return statement } @@ -537,7 +562,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) aliasName = schemas.CommonQuoter.Trim(aliasName) - fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), aliasName, statement.ReplaceQuote(condition)) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) case *builder.Builder: subSQL, subQueryArgs, err := tp.ToSQL() @@ -550,7 +575,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) aliasName = schemas.CommonQuoter.Trim(aliasName) - fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), aliasName, statement.ReplaceQuote(condition)) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) default: tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) @@ -559,7 +584,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition statement.dialect.Quoter().QuoteTo(&buf, tbName) tbName = buf.String() } - fmt.Fprintf(&buf, "%s ON %v", tbName, condition) + fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition)) } statement.JoinStr = buf.String() @@ -578,13 +603,13 @@ func (statement *Statement) tbNameNoSchema(table *schemas.Table) string { // GroupBy generate "Group By keys" statement func (statement *Statement) GroupBy(keys string) *Statement { - statement.GroupByStr = keys + statement.GroupByStr = statement.ReplaceQuote(keys) return statement } // Having generate "Having conditions" statement func (statement *Statement) Having(conditions string) *Statement { - statement.HavingStr = fmt.Sprintf("HAVING %v", conditions) + statement.HavingStr = fmt.Sprintf("HAVING %v", statement.ReplaceQuote(conditions)) return statement } @@ -926,7 +951,7 @@ func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, e return "", nil, err } - return builder.ToSQL(statement.cond) + return statement.GenCondSQL(statement.cond) } func (statement *Statement) quoteColumnStr(columnStr string) string { @@ -934,7 +959,15 @@ func (statement *Statement) quoteColumnStr(columnStr string) string { return statement.dialect.Quoter().Join(columns, ",") } -func ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { +func (statement *Statement) ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { + sql, args, err := convertSQLOrArgs(sqlOrArgs...) + if err != nil { + return "", nil, err + } + return statement.ReplaceQuote(sql), args, nil +} + +func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { switch sqlOrArgs[0].(type) { case string: return sqlOrArgs[0].(string), sqlOrArgs[1:], nil diff --git a/rows.go b/rows.go index aa5e66e3..a56ea1c9 100644 --- a/rows.go +++ b/rows.go @@ -80,7 +80,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { return nil, err } } else { - sqlStr = rows.session.statement.RawSQL + sqlStr = rows.session.statement.GenRawSQL() args = rows.session.statement.RawParams } diff --git a/schemas/quote.go b/schemas/quote.go index 10436270..2a03152e 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -196,3 +196,41 @@ func (q Quoter) Strings(s []string) []string { } return res } + +// Replace replaces common quote(`) as the quotes on the sql +func (q Quoter) Replace(sql string) string { + if q.IsEmpty() { + return sql + } + + var buf strings.Builder + buf.Grow(len(sql)) + + var beginSingleQuote bool + for i := 0; i < len(sql); i++ { + if !beginSingleQuote && sql[i] == CommanQuoteMark { + var j = i + 1 + for ; j < len(sql); j++ { + if sql[j] == CommanQuoteMark { + break + } + } + word := sql[i+1 : j] + isReserved := q.IsReserved(word) + if isReserved { + buf.WriteByte(q.Prefix) + } + buf.WriteString(word) + if isReserved { + buf.WriteByte(q.Suffix) + } + i = j + } else { + if sql[i] == '\'' { + beginSingleQuote = !beginSingleQuote + } + buf.WriteByte(sql[i]) + } + } + return buf.String() +} diff --git a/schemas/quote_test.go b/schemas/quote_test.go index c7990f92..7a43bd24 100644 --- a/schemas/quote_test.go +++ b/schemas/quote_test.go @@ -146,3 +146,30 @@ func TestTrim(t *testing.T) { assert.EqualValues(t, dst, Quoter{'[', ']', AlwaysReserve}.Trim(src)) } } + +func TestReplace(t *testing.T) { + q := Quoter{'[', ']', AlwaysReserve} + var kases = []struct { + source string + expected string + }{ + { + "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?", + "SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", + }, + { + "SELECT 'abc```test```''', `a` FROM b", + "SELECT 'abc```test```''', [a] FROM b", + }, + { + "UPDATE table SET `a` = ~ `a`, `b`='abc`'", + "UPDATE table SET [a] = ~ [a], [b]='abc`'", + }, + } + + for _, kase := range kases { + t.Run(kase.source, func(t *testing.T) { + assert.EqualValues(t, kase.expected, q.Replace(kase.source)) + }) + } +} diff --git a/session_get.go b/session_get.go index 76918194..e56ef2d7 100644 --- a/session_get.go +++ b/session_get.go @@ -59,7 +59,7 @@ func (session *Session) get(bean interface{}) (bool, error) { return false, err } } else { - sqlStr = session.statement.RawSQL + sqlStr = session.statement.GenRawSQL() args = session.statement.RawParams } diff --git a/session_raw.go b/session_raw.go index 0cea60b7..4cfe297a 100644 --- a/session_raw.go +++ b/session_raw.go @@ -9,7 +9,6 @@ import ( "reflect" "xorm.io/xorm/core" - "xorm.io/xorm/internal/statements" ) func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { @@ -172,7 +171,7 @@ func (session *Session) Exec(sqlOrArgs ...interface{}) (sql.Result, error) { return nil, ErrUnSupportedType } - sqlStr, args, err := statements.ConvertSQLOrArgs(sqlOrArgs...) + sqlStr, args, err := session.statement.ConvertSQLOrArgs(sqlOrArgs...) if err != nil { return nil, err } diff --git a/session_update.go b/session_update.go index aa4718b6..f60f48e3 100644 --- a/session_update.go +++ b/session_update.go @@ -240,7 +240,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } colNames = append(colNames, session.engine.Quote(colName)+"="+tp) case *builder.Builder: - subQuery, subArgs, err := builder.ToSQL(tp) + subQuery, subArgs, err := session.statement.GenCondSQL(tp) if err != nil { return 0, err } @@ -317,7 +317,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - condSQL, condArgs, err = builder.ToSQL(cond) + if len(colNames) <= 0 { + return 0, errors.New("No content found to be updated") + } + + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -335,24 +339,25 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var top string if st.LimitN != nil { limitValue := *st.LimitN - if session.engine.dialect.URI().DBType == schemas.MYSQL { + switch session.engine.dialect.URI().DBType { + case schemas.MYSQL: condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) - } else if session.engine.dialect.URI().DBType == schemas.SQLITE { + case schemas.SQLITE: tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if session.engine.dialect.URI().DBType == schemas.POSTGRES { + case schemas.POSTGRES: tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -360,14 +365,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if session.engine.dialect.URI().DBType == schemas.MSSQL { - if st.OrderStr != "" && session.engine.dialect.URI().DBType == schemas.MSSQL && - table != nil && len(table.PrimaryKeys) == 1 { + case schemas.MSSQL: + if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 { cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], session.engine.Quote(tableName), condSQL), condArgs...) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -380,10 +384,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - if len(colNames) <= 0 { - return 0, errors.New("No content found to be updated") - } - var tableAlias = session.engine.Quote(tableName) var fromSQL string if session.statement.TableAlias != "" {