fix quote policy
This commit is contained in:
parent
4289572f28
commit
ea1825c2dd
10
engine.go
10
engine.go
|
@ -56,8 +56,8 @@ type Engine struct {
|
||||||
|
|
||||||
defaultContext context.Context
|
defaultContext context.Context
|
||||||
|
|
||||||
quotePolicy QuotePolicy
|
colQuoter Quoter
|
||||||
quoteMode QuoteMode
|
tableQuoter Quoter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
|
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
|
||||||
|
@ -419,7 +419,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
quoter := newQuoter(dialect, engine.quoteMode, engine.quotePolicy)
|
colQuoter := newQuoter(dialect, engine.colQuoter.QuotePolicy())
|
||||||
|
|
||||||
for i, table := range tables {
|
for i, table := range tables {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
|
@ -440,8 +440,8 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
|
||||||
}
|
}
|
||||||
|
|
||||||
cols := table.ColumnsSeq()
|
cols := table.ColumnsSeq()
|
||||||
colNames := quoteJoin(engine, cols)
|
colNames := quoteJoin(engine.colQuoter, cols)
|
||||||
destColNames := quoteJoin(quoter, cols)
|
destColNames := quoteJoin(colQuoter, cols)
|
||||||
|
|
||||||
rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.quote(table.Name, false))
|
rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.quote(table.Name, false))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -21,34 +21,21 @@ const (
|
||||||
QuoteAddReserved
|
QuoteAddReserved
|
||||||
)
|
)
|
||||||
|
|
||||||
// QuoteMode quote on which types
|
|
||||||
type QuoteMode int
|
|
||||||
|
|
||||||
// All QuoteModes
|
|
||||||
const (
|
|
||||||
QuoteTableAndColumns QuoteMode = iota
|
|
||||||
QuoteTableOnly
|
|
||||||
QuoteColumnsOnly
|
|
||||||
)
|
|
||||||
|
|
||||||
// Quoter represents an object has Quote method
|
// Quoter represents an object has Quote method
|
||||||
type Quoter interface {
|
type Quoter interface {
|
||||||
Quotes() (byte, byte)
|
Quotes() (byte, byte)
|
||||||
QuotePolicy() QuotePolicy
|
QuotePolicy() QuotePolicy
|
||||||
QuoteMode() QuoteMode
|
|
||||||
IsReserved(string) bool
|
IsReserved(string) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type quoter struct {
|
type quoter struct {
|
||||||
dialect core.Dialect
|
dialect core.Dialect
|
||||||
quoteMode QuoteMode
|
|
||||||
quotePolicy QuotePolicy
|
quotePolicy QuotePolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
func newQuoter(dialect core.Dialect, quoteMode QuoteMode, quotePolicy QuotePolicy) Quoter {
|
func newQuoter(dialect core.Dialect, quotePolicy QuotePolicy) Quoter {
|
||||||
return "er{
|
return "er{
|
||||||
dialect: dialect,
|
dialect: dialect,
|
||||||
quoteMode: quoteMode,
|
|
||||||
quotePolicy: quotePolicy,
|
quotePolicy: quotePolicy,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -62,10 +49,6 @@ func (q *quoter) QuotePolicy() QuotePolicy {
|
||||||
return q.quotePolicy
|
return q.quotePolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *quoter) QuoteMode() QuoteMode {
|
|
||||||
return q.quoteMode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *quoter) IsReserved(value string) bool {
|
func (q *quoter) IsReserved(value string) bool {
|
||||||
return q.dialect.IsReserved(value)
|
return q.dialect.IsReserved(value)
|
||||||
}
|
}
|
||||||
|
@ -77,21 +60,24 @@ func quoteColumns(quoter Quoter, columnStr string) string {
|
||||||
|
|
||||||
func quoteJoin(quoter Quoter, columns []string) string {
|
func quoteJoin(quoter Quoter, columns []string) string {
|
||||||
for i := 0; i < len(columns); i++ {
|
for i := 0; i < len(columns); i++ {
|
||||||
columns[i] = quote(quoter, columns[i], true)
|
columns[i] = quote(quoter, columns[i])
|
||||||
}
|
}
|
||||||
return strings.Join(columns, ",")
|
return strings.Join(columns, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
// quote Use QuoteStr quote the string sql
|
// quote Use QuoteStr quote the string sql
|
||||||
func quote(quoter Quoter, value string, isColumn bool) string {
|
func quote(quoter Quoter, value string) string {
|
||||||
buf := strings.Builder{}
|
buf := strings.Builder{}
|
||||||
quoteTo(quoter, &buf, value, isColumn)
|
quoteTo(quoter, &buf, value)
|
||||||
return buf.String()
|
return buf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quote add quotes to the value
|
// Quote add quotes to the value
|
||||||
func (engine *Engine) quote(value string, isColumn bool) string {
|
func (engine *Engine) quote(value string, isColumn bool) string {
|
||||||
return quote(engine, value, isColumn)
|
if isColumn {
|
||||||
|
return quote(engine.colQuoter, value)
|
||||||
|
}
|
||||||
|
return quote(engine.tableQuoter, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quote add quotes to the value
|
// Quote add quotes to the value
|
||||||
|
@ -105,53 +91,25 @@ func (engine *Engine) Quotes() (byte, byte) {
|
||||||
return quotes[0], quotes[1]
|
return quotes[0], quotes[1]
|
||||||
}
|
}
|
||||||
|
|
||||||
// QuoteMode returns quote mode
|
|
||||||
func (engine *Engine) QuoteMode() QuoteMode {
|
|
||||||
return engine.quoteMode
|
|
||||||
}
|
|
||||||
|
|
||||||
// QuotePolicy returns quote policy
|
|
||||||
func (engine *Engine) QuotePolicy() QuotePolicy {
|
|
||||||
return engine.quotePolicy
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsReserved return true if the value is a reserved word of the database
|
// IsReserved return true if the value is a reserved word of the database
|
||||||
func (engine *Engine) IsReserved(value string) bool {
|
func (engine *Engine) IsReserved(value string) bool {
|
||||||
return engine.dialect.IsReserved(value)
|
return engine.dialect.IsReserved(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// quoteTo quotes string and writes into the buffer
|
// quoteTo quotes string and writes into the buffer
|
||||||
func quoteTo(quoter Quoter, buf *strings.Builder, value string, isColumn bool) {
|
func quoteTo(quoter Quoter, buf *strings.Builder, value string) {
|
||||||
if isColumn {
|
left, right := quoter.Quotes()
|
||||||
if quoter.QuoteMode() == QuoteTableAndColumns ||
|
if quoter.QuotePolicy() == QuoteAddAlways {
|
||||||
quoter.QuoteMode() == QuoteColumnsOnly {
|
realQuoteTo(left, right, buf, value)
|
||||||
if quoter.QuotePolicy() == QuoteAddAlways {
|
return
|
||||||
realQuoteTo(quoter, buf, value)
|
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
|
||||||
return
|
realQuoteTo(left, right, buf, value)
|
||||||
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
|
|
||||||
realQuoteTo(quoter, buf, value)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
buf.WriteString(value)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if quoter.QuoteMode() == QuoteTableAndColumns ||
|
|
||||||
quoter.QuoteMode() == QuoteTableOnly {
|
|
||||||
if quoter.QuotePolicy() == QuoteAddAlways {
|
|
||||||
realQuoteTo(quoter, buf, value)
|
|
||||||
return
|
|
||||||
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
|
|
||||||
realQuoteTo(quoter, buf, value)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
buf.WriteString(value)
|
buf.WriteString(value)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) {
|
func realQuoteTo(prefix, suffix byte, buf *strings.Builder, value string) {
|
||||||
if buf == nil {
|
if buf == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -164,8 +122,6 @@ func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix, suffix := quoter.Quotes()
|
|
||||||
|
|
||||||
i := 0
|
i := 0
|
||||||
for i < len(value) {
|
for i < len(value) {
|
||||||
// start of a token; might be already quoted
|
// start of a token; might be already quoted
|
||||||
|
|
|
@ -141,7 +141,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
|
||||||
if session.statement.JoinStr == "" {
|
if session.statement.JoinStr == "" {
|
||||||
if columnStr == "" {
|
if columnStr == "" {
|
||||||
if session.statement.GroupByStr != "" {
|
if session.statement.GroupByStr != "" {
|
||||||
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
|
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
|
||||||
} else {
|
} else {
|
||||||
columnStr = session.statement.genColumnStr()
|
columnStr = session.statement.genColumnStr()
|
||||||
}
|
}
|
||||||
|
@ -149,7 +149,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
|
||||||
} else {
|
} else {
|
||||||
if columnStr == "" {
|
if columnStr == "" {
|
||||||
if session.statement.GroupByStr != "" {
|
if session.statement.GroupByStr != "" {
|
||||||
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
|
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
|
||||||
} else {
|
} else {
|
||||||
columnStr = "*"
|
columnStr = "*"
|
||||||
}
|
}
|
||||||
|
|
|
@ -249,15 +249,15 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
|
||||||
if session.engine.dialect.DBType() == core.ORACLE {
|
if session.engine.dialect.DBType() == core.ORACLE {
|
||||||
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
|
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
|
||||||
session.engine.quote(tableName, false),
|
session.engine.quote(tableName, false),
|
||||||
quoteJoin(session.engine, colNames))
|
quoteJoin(session.engine.colQuoter, colNames))
|
||||||
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
|
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
|
||||||
session.engine.quote(tableName, false),
|
session.engine.quote(tableName, false),
|
||||||
quoteJoin(session.engine, colNames),
|
quoteJoin(session.engine.colQuoter, colNames),
|
||||||
strings.Join(colMultiPlaces, temp))
|
strings.Join(colMultiPlaces, temp))
|
||||||
} else {
|
} else {
|
||||||
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
|
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
|
||||||
session.engine.quote(tableName, false),
|
session.engine.quote(tableName, false),
|
||||||
quoteJoin(session.engine, colNames),
|
quoteJoin(session.engine.colQuoter, colNames),
|
||||||
strings.Join(colMultiPlaces, "),("))
|
strings.Join(colMultiPlaces, "),("))
|
||||||
}
|
}
|
||||||
res, err := session.exec(sql, args...)
|
res, err := session.exec(sql, args...)
|
||||||
|
@ -855,7 +855,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
|
||||||
|
|
||||||
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||||
session.engine.quote(tableName, false),
|
session.engine.quote(tableName, false),
|
||||||
quoteJoin(session.engine, columns), qm)); err != nil {
|
quoteJoin(session.engine.colQuoter, columns), qm)); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
w.Append(args...)
|
w.Append(args...)
|
||||||
|
|
|
@ -35,7 +35,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa
|
||||||
if session.statement.JoinStr == "" {
|
if session.statement.JoinStr == "" {
|
||||||
if columnStr == "" {
|
if columnStr == "" {
|
||||||
if session.statement.GroupByStr != "" {
|
if session.statement.GroupByStr != "" {
|
||||||
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
|
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
|
||||||
} else {
|
} else {
|
||||||
columnStr = session.statement.genColumnStr()
|
columnStr = session.statement.genColumnStr()
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa
|
||||||
} else {
|
} else {
|
||||||
if columnStr == "" {
|
if columnStr == "" {
|
||||||
if session.statement.GroupByStr != "" {
|
if session.statement.GroupByStr != "" {
|
||||||
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
|
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
|
||||||
} else {
|
} else {
|
||||||
columnStr = "*"
|
columnStr = "*"
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,7 +100,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
|
||||||
for idx, kv := range kvs {
|
for idx, kv := range kvs {
|
||||||
sps := strings.SplitN(kv, "=", 2)
|
sps := strings.SplitN(kv, "=", 2)
|
||||||
sps2 := strings.Split(sps[0], ".")
|
sps2 := strings.Split(sps[0], ".")
|
||||||
colName := unQuote(session.engine, sps2[len(sps2)-1])
|
colName := unQuote(session.engine.colQuoter, sps2[len(sps2)-1])
|
||||||
|
|
||||||
if col := table.GetColumn(colName); col != nil {
|
if col := table.GetColumn(colName); col != nil {
|
||||||
fieldValue, err := col.ValueOf(bean)
|
fieldValue, err := col.ValueOf(bean)
|
||||||
|
|
14
statement.go
14
statement.go
|
@ -615,7 +615,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
|
||||||
|
|
||||||
newColumns := statement.colmap2NewColsWithQuote()
|
newColumns := statement.colmap2NewColsWithQuote()
|
||||||
|
|
||||||
statement.ColumnStr = quoteJoin(statement.Engine, newColumns)
|
statement.ColumnStr = quoteJoin(statement.Engine.colQuoter, newColumns)
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -650,7 +650,7 @@ func (statement *Statement) Omit(columns ...string) {
|
||||||
for _, nc := range newColumns {
|
for _, nc := range newColumns {
|
||||||
statement.omitColumnMap = append(statement.omitColumnMap, nc)
|
statement.omitColumnMap = append(statement.omitColumnMap, nc)
|
||||||
}
|
}
|
||||||
statement.OmitStr = quoteJoin(statement.Engine, newColumns)
|
statement.OmitStr = quoteJoin(statement.Engine.colQuoter, newColumns)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nullable Update use only: update columns to null when value is nullable and zero-value
|
// Nullable Update use only: update columns to null when value is nullable and zero-value
|
||||||
|
@ -744,7 +744,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
||||||
}
|
}
|
||||||
tbs := strings.Split(tp.TableName(), ".")
|
tbs := strings.Split(tp.TableName(), ".")
|
||||||
|
|
||||||
var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1])
|
var aliasName = unQuote(statement.Engine.tableQuoter, tbs[len(tbs)-1])
|
||||||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
||||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||||
case *builder.Builder:
|
case *builder.Builder:
|
||||||
|
@ -755,7 +755,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
|
||||||
}
|
}
|
||||||
tbs := strings.Split(tp.TableName(), ".")
|
tbs := strings.Split(tp.TableName(), ".")
|
||||||
|
|
||||||
var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1])
|
var aliasName = unQuote(statement.Engine.tableQuoter, tbs[len(tbs)-1])
|
||||||
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
|
||||||
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
|
||||||
default:
|
default:
|
||||||
|
@ -821,7 +821,7 @@ func (statement *Statement) genColumnStr() string {
|
||||||
buf.WriteString(".")
|
buf.WriteString(".")
|
||||||
}
|
}
|
||||||
|
|
||||||
quoteTo(statement.Engine, &buf, col.Name, true)
|
quoteTo(statement.Engine.colQuoter, &buf, col.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return buf.String()
|
return buf.String()
|
||||||
|
@ -940,7 +940,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
|
||||||
if len(statement.JoinStr) == 0 {
|
if len(statement.JoinStr) == 0 {
|
||||||
if len(columnStr) == 0 {
|
if len(columnStr) == 0 {
|
||||||
if len(statement.GroupByStr) > 0 {
|
if len(statement.GroupByStr) > 0 {
|
||||||
columnStr = quoteColumns(statement.Engine, statement.GroupByStr)
|
columnStr = quoteColumns(statement.Engine.colQuoter, statement.GroupByStr)
|
||||||
} else {
|
} else {
|
||||||
columnStr = statement.genColumnStr()
|
columnStr = statement.genColumnStr()
|
||||||
}
|
}
|
||||||
|
@ -948,7 +948,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
|
||||||
} else {
|
} else {
|
||||||
if len(columnStr) == 0 {
|
if len(columnStr) == 0 {
|
||||||
if len(statement.GroupByStr) > 0 {
|
if len(statement.GroupByStr) > 0 {
|
||||||
columnStr = quoteColumns(statement.Engine, statement.GroupByStr)
|
columnStr = quoteColumns(statement.Engine.colQuoter, statement.GroupByStr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -243,6 +243,9 @@ func TestCol2NewColsWithQuote(t *testing.T) {
|
||||||
|
|
||||||
statement := createTestStatement()
|
statement := createTestStatement()
|
||||||
|
|
||||||
quotedCols := quoteJoin(statement.Engine, cols)
|
quotedCols := quoteJoin(statement.Engine.colQuoter, cols)
|
||||||
assert.EqualValues(t, []string{statement.Engine.Quote("f1", true), statement.Engine.Quote("f2", true), statement.Engine.Quote("t3.f3", true)}, quotedCols)
|
assert.EqualValues(t, statement.Engine.Quote("f1", true)+","+
|
||||||
|
statement.Engine.Quote("f2", true)+","+
|
||||||
|
statement.Engine.Quote("t3.f3", true),
|
||||||
|
quotedCols)
|
||||||
}
|
}
|
||||||
|
|
2
xorm.go
2
xorm.go
|
@ -95,6 +95,8 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
|
||||||
tagHandlers: defaultTagHandlers,
|
tagHandlers: defaultTagHandlers,
|
||||||
cachers: make(map[string]core.Cacher),
|
cachers: make(map[string]core.Cacher),
|
||||||
defaultContext: context.Background(),
|
defaultContext: context.Background(),
|
||||||
|
colQuoter: newQuoter(dialect, QuoteAddAlways),
|
||||||
|
tableQuoter: newQuoter(dialect, QuoteAddAlways),
|
||||||
}
|
}
|
||||||
|
|
||||||
if uri.DbType == core.SQLITE {
|
if uri.DbType == core.SQLITE {
|
||||||
|
|
Loading…
Reference in New Issue