fix quote policy
This commit is contained in:
parent
4289572f28
commit
ea1825c2dd
10
engine.go
10
engine.go
|
@ -56,8 +56,8 @@ type Engine struct {
|
|||
|
||||
defaultContext context.Context
|
||||
|
||||
quotePolicy QuotePolicy
|
||||
quoteMode QuoteMode
|
||||
colQuoter Quoter
|
||||
tableQuoter Quoter
|
||||
}
|
||||
|
||||
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
|
||||
|
@ -419,7 +419,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
|
|||
return err
|
||||
}
|
||||
|
||||
quoter := newQuoter(dialect, engine.quoteMode, engine.quotePolicy)
|
||||
colQuoter := newQuoter(dialect, engine.colQuoter.QuotePolicy())
|
||||
|
||||
for i, table := range tables {
|
||||
if i > 0 {
|
||||
|
@ -440,8 +440,8 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
|
|||
}
|
||||
|
||||
cols := table.ColumnsSeq()
|
||||
colNames := quoteJoin(engine, cols)
|
||||
destColNames := quoteJoin(quoter, cols)
|
||||
colNames := quoteJoin(engine.colQuoter, cols)
|
||||
destColNames := quoteJoin(colQuoter, cols)
|
||||
|
||||
rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.quote(table.Name, false))
|
||||
if err != nil {
|
||||
|
|
|
@ -21,34 +21,21 @@ const (
|
|||
QuoteAddReserved
|
||||
)
|
||||
|
||||
// QuoteMode quote on which types
|
||||
type QuoteMode int
|
||||
|
||||
// All QuoteModes
|
||||
const (
|
||||
QuoteTableAndColumns QuoteMode = iota
|
||||
QuoteTableOnly
|
||||
QuoteColumnsOnly
|
||||
)
|
||||
|
||||
// Quoter represents an object has Quote method
|
||||
type Quoter interface {
|
||||
Quotes() (byte, byte)
|
||||
QuotePolicy() QuotePolicy
|
||||
QuoteMode() QuoteMode
|
||||
IsReserved(string) bool
|
||||
}
|
||||
|
||||
type quoter struct {
|
||||
dialect core.Dialect
|
||||
quoteMode QuoteMode
|
||||
quotePolicy QuotePolicy
|
||||
}
|
||||
|
||||
func newQuoter(dialect core.Dialect, quoteMode QuoteMode, quotePolicy QuotePolicy) Quoter {
|
||||
func newQuoter(dialect core.Dialect, quotePolicy QuotePolicy) Quoter {
|
||||
return "er{
|
||||
dialect: dialect,
|
||||
quoteMode: quoteMode,
|
||||
quotePolicy: quotePolicy,
|
||||
}
|
||||
}
|
||||
|
@ -62,10 +49,6 @@ func (q *quoter) QuotePolicy() QuotePolicy {
|
|||
return q.quotePolicy
|
||||
}
|
||||
|
||||
func (q *quoter) QuoteMode() QuoteMode {
|
||||
return q.quoteMode
|
||||
}
|
||||
|
||||
func (q *quoter) IsReserved(value string) bool {
|
||||
return q.dialect.IsReserved(value)
|
||||
}
|
||||
|
@ -77,21 +60,24 @@ func quoteColumns(quoter Quoter, columnStr string) string {
|
|||
|
||||
func quoteJoin(quoter Quoter, columns []string) string {
|
||||
for i := 0; i < len(columns); i++ {
|
||||
columns[i] = quote(quoter, columns[i], true)
|
||||
columns[i] = quote(quoter, columns[i])
|
||||
}
|
||||
return strings.Join(columns, ",")
|
||||
}
|
||||
|
||||
// quote Use QuoteStr quote the string sql
|
||||
func quote(quoter Quoter, value string, isColumn bool) string {
|
||||
func quote(quoter Quoter, value string) string {
|
||||
buf := strings.Builder{}
|
||||
quoteTo(quoter, &buf, value, isColumn)
|
||||
quoteTo(quoter, &buf, value)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// Quote add quotes to the value
|
||||
func (engine *Engine) quote(value string, isColumn bool) string {
|
||||
return quote(engine, value, isColumn)
|
||||
if isColumn {
|
||||
return quote(engine.colQuoter, value)
|
||||
}
|
||||
return quote(engine.tableQuoter, value)
|
||||
}
|
||||
|
||||
// Quote add quotes to the value
|
||||
|
@ -105,53 +91,25 @@ func (engine *Engine) Quotes() (byte, byte) {
|
|||
return quotes[0], quotes[1]
|
||||
}
|
||||
|
||||
// QuoteMode returns quote mode
|
||||
func (engine *Engine) QuoteMode() QuoteMode {
|
||||
return engine.quoteMode
|
||||
}
|
||||
|
||||
// QuotePolicy returns quote policy
|
||||
func (engine *Engine) QuotePolicy() QuotePolicy {
|
||||
return engine.quotePolicy
|
||||
}
|
||||
|
||||
// IsReserved return true if the value is a reserved word of the database
|
||||
func (engine *Engine) IsReserved(value string) bool {
|
||||
return engine.dialect.IsReserved(value)
|
||||
}
|
||||
|
||||
// quoteTo quotes string and writes into the buffer
|
||||
func quoteTo(quoter Quoter, buf *strings.Builder, value string, isColumn bool) {
|
||||
if isColumn {
|
||||
if quoter.QuoteMode() == QuoteTableAndColumns ||
|
||||
quoter.QuoteMode() == QuoteColumnsOnly {
|
||||
if quoter.QuotePolicy() == QuoteAddAlways {
|
||||
realQuoteTo(quoter, buf, value)
|
||||
return
|
||||
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
|
||||
realQuoteTo(quoter, buf, value)
|
||||
return
|
||||
}
|
||||
}
|
||||
buf.WriteString(value)
|
||||
func quoteTo(quoter Quoter, buf *strings.Builder, value string) {
|
||||
left, right := quoter.Quotes()
|
||||
if quoter.QuotePolicy() == QuoteAddAlways {
|
||||
realQuoteTo(left, right, buf, value)
|
||||
return
|
||||
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
|
||||
realQuoteTo(left, right, buf, value)
|
||||
return
|
||||
}
|
||||
|
||||
if quoter.QuoteMode() == QuoteTableAndColumns ||
|
||||
quoter.QuoteMode() == QuoteTableOnly {
|
||||
if quoter.QuotePolicy() == QuoteAddAlways {
|
||||
realQuoteTo(quoter, buf, value)
|
||||
return
|
||||
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
|
||||
realQuoteTo(quoter, buf, value)
|
||||
return
|
||||
}
|
||||
}
|
||||
buf.WriteString(value)
|
||||
return
|
||||
}
|
||||
|
||||
func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) {
|
||||
func realQuoteTo(prefix, suffix byte, buf *strings.Builder, value string) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
|
@ -164,8 +122,6 @@ func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) {
|
|||
return
|
||||
}
|
||||
|
||||
prefix, suffix := quoter.Quotes()
|
||||
|
||||
i := 0
|
||||
for i < len(value) {
|
||||
// start of a token; might be already quoted
|
||||
|
|
|
@ -141,7 +141,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
|
|||
if session.statement.JoinStr == "" {
|
||||
if columnStr == "" {
|
||||
if session.statement.GroupByStr != "" {
|
||||
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
|
||||
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
|
||||
} else {
|
||||
columnStr = session.statement.genColumnStr()
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
|
|||
} else {
|
||||
if columnStr == "" {
|
||||
if session.statement.GroupByStr != "" {
|
||||
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
|
||||
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
|
||||
} else {
|
||||
columnStr = "*"
|
||||
}
|
||||
|
|
|
@ -249,15 +249,15 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
|
|||
if session.engine.dialect.DBType() == core.ORACLE {
|
||||
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
|
||||
session.engine.quote(tableName, false),
|
||||
quoteJoin(session.engine, colNames))
|
||||
quoteJoin(session.engine.colQuoter, colNames))
|
||||
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
|
||||
session.engine.quote(tableName, false),
|
||||
quoteJoin(session.engine, colNames),
|
||||
quoteJoin(session.engine.colQuoter, colNames),
|
||||
strings.Join(colMultiPlaces, temp))
|
||||
} else {
|
||||
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
|
||||
session.engine.quote(tableName, false),
|
||||
quoteJoin(session.engine, colNames),
|
||||
quoteJoin(session.engine.colQuoter, colNames),
|
||||
strings.Join(colMultiPlaces, "),("))
|
||||
}
|
||||
res, err := session.exec(sql, args...)
|
||||
|
@ -855,7 +855,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
|
|||
|
||||
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||
session.engine.quote(tableName, false),
|
||||
quoteJoin(session.engine, columns), qm)); err != nil {
|
||||
quoteJoin(session.engine.colQuoter, columns), qm)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.Append(args...)
|
||||
|
|
|
@ -35,7 +35,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa
|
|||
if session.statement.JoinStr == "" {
|
||||
if columnStr == "" {
|
||||
if session.statement.GroupByStr != "" {
|
||||
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
|
||||
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
|
||||
} else {
|
||||
columnStr = session.statement.genColumnStr()
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa
|
|||
} else {
|
||||
if columnStr == "" {
|
||||
if session.statement.GroupByStr != "" {
|
||||
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
|
||||
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
|
||||
} else {
|
||||
columnStr = "*"
|
||||
}
|
||||
|
|
|
@ -100,7 +100,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
|
|||
for idx, kv := range kvs {
|
||||
sps := strings.SplitN(kv, "=", 2)
|
||||
sps2 := strings.Split(sps[0], ".")
|
||||
colName := unQuote(session.engine, sps2[len(sps2)-1])
|
||||
colName := unQuote(session.engine.colQuoter, sps2[len(sps2)-1])
|
||||
|
||||
if col := table.GetColumn(colName); col != nil {
|
||||
fieldValue, err := col.ValueOf(bean)
|
||||
|
|
14
statement.go
14
statement.go
|
@ -615,7 +615,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
|
|||
|
||||
newColumns := statement.colmap2NewColsWithQuote()
|
||||
|
||||
statement.ColumnStr = quoteJoin(statement.Engine, newColumns)
|
||||
statement.ColumnStr = quoteJoin(statement.Engine.colQuoter, newColumns)
|
||||
return statement
|
||||
}
|
||||
|
||||
|
@ -650,7 +650,7 @@ func (statement *Statement) Omit(columns ...string) {
|
|||
for _, nc := range newColumns {
|
||||
statement.omitColumnMap = append(statement.omitColumnMap, nc)
|
||||
}
|
||||
statement.OmitStr = quoteJoin(statement.Engine, newColumns)
|
||||
statement.OmitStr = quoteJoin(statement.Engine.colQuoter, newColumns)
|
||||
}
|
||||
|
||||
// Nullable Update use only: update columns to null when value is nullable and zero-value
|
||||
|
@ -744,7 +744,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
|||
}
|
||||
tbs := strings.Split(tp.TableName(), ".")
|
||||
|
||||
var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1])
|
||||
var aliasName = unQuote(statement.Engine.tableQuoter, tbs[len(tbs)-1])
|
||||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||
case *builder.Builder:
|
||||
|
@ -755,7 +755,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
|||
}
|
||||
tbs := strings.Split(tp.TableName(), ".")
|
||||
|
||||
var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1])
|
||||
var aliasName = unQuote(statement.Engine.tableQuoter, tbs[len(tbs)-1])
|
||||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||
default:
|
||||
|
@ -821,7 +821,7 @@ func (statement *Statement) genColumnStr() string {
|
|||
buf.WriteString(".")
|
||||
}
|
||||
|
||||
quoteTo(statement.Engine, &buf, col.Name, true)
|
||||
quoteTo(statement.Engine.colQuoter, &buf, col.Name)
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
|
@ -940,7 +940,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
|
|||
if len(statement.JoinStr) == 0 {
|
||||
if len(columnStr) == 0 {
|
||||
if len(statement.GroupByStr) > 0 {
|
||||
columnStr = quoteColumns(statement.Engine, statement.GroupByStr)
|
||||
columnStr = quoteColumns(statement.Engine.colQuoter, statement.GroupByStr)
|
||||
} else {
|
||||
columnStr = statement.genColumnStr()
|
||||
}
|
||||
|
@ -948,7 +948,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
|
|||
} else {
|
||||
if len(columnStr) == 0 {
|
||||
if len(statement.GroupByStr) > 0 {
|
||||
columnStr = quoteColumns(statement.Engine, statement.GroupByStr)
|
||||
columnStr = quoteColumns(statement.Engine.colQuoter, statement.GroupByStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -243,6 +243,9 @@ func TestCol2NewColsWithQuote(t *testing.T) {
|
|||
|
||||
statement := createTestStatement()
|
||||
|
||||
quotedCols := quoteJoin(statement.Engine, cols)
|
||||
assert.EqualValues(t, []string{statement.Engine.Quote("f1", true), statement.Engine.Quote("f2", true), statement.Engine.Quote("t3.f3", true)}, quotedCols)
|
||||
quotedCols := quoteJoin(statement.Engine.colQuoter, cols)
|
||||
assert.EqualValues(t, statement.Engine.Quote("f1", true)+","+
|
||||
statement.Engine.Quote("f2", true)+","+
|
||||
statement.Engine.Quote("t3.f3", true),
|
||||
quotedCols)
|
||||
}
|
||||
|
|
2
xorm.go
2
xorm.go
|
@ -95,6 +95,8 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
|
|||
tagHandlers: defaultTagHandlers,
|
||||
cachers: make(map[string]core.Cacher),
|
||||
defaultContext: context.Background(),
|
||||
colQuoter: newQuoter(dialect, QuoteAddAlways),
|
||||
tableQuoter: newQuoter(dialect, QuoteAddAlways),
|
||||
}
|
||||
|
||||
if uri.DbType == core.SQLITE {
|
||||
|
|
Loading…
Reference in New Issue