fix quote policy

This commit is contained in:
Lunny Xiao 2019-09-30 16:29:45 +08:00
parent 4289572f28
commit ea1825c2dd
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
9 changed files with 44 additions and 83 deletions

View File

@ -56,8 +56,8 @@ type Engine struct {
defaultContext context.Context
quotePolicy QuotePolicy
quoteMode QuoteMode
colQuoter Quoter
tableQuoter Quoter
}
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
@ -419,7 +419,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
return err
}
quoter := newQuoter(dialect, engine.quoteMode, engine.quotePolicy)
colQuoter := newQuoter(dialect, engine.colQuoter.QuotePolicy())
for i, table := range tables {
if i > 0 {
@ -440,8 +440,8 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
}
cols := table.ColumnsSeq()
colNames := quoteJoin(engine, cols)
destColNames := quoteJoin(quoter, cols)
colNames := quoteJoin(engine.colQuoter, cols)
destColNames := quoteJoin(colQuoter, cols)
rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.quote(table.Name, false))
if err != nil {

View File

@ -21,34 +21,21 @@ const (
QuoteAddReserved
)
// QuoteMode quote on which types
type QuoteMode int
// All QuoteModes
const (
QuoteTableAndColumns QuoteMode = iota
QuoteTableOnly
QuoteColumnsOnly
)
// Quoter represents an object has Quote method
type Quoter interface {
Quotes() (byte, byte)
QuotePolicy() QuotePolicy
QuoteMode() QuoteMode
IsReserved(string) bool
}
type quoter struct {
dialect core.Dialect
quoteMode QuoteMode
quotePolicy QuotePolicy
}
func newQuoter(dialect core.Dialect, quoteMode QuoteMode, quotePolicy QuotePolicy) Quoter {
func newQuoter(dialect core.Dialect, quotePolicy QuotePolicy) Quoter {
return &quoter{
dialect: dialect,
quoteMode: quoteMode,
quotePolicy: quotePolicy,
}
}
@ -62,10 +49,6 @@ func (q *quoter) QuotePolicy() QuotePolicy {
return q.quotePolicy
}
func (q *quoter) QuoteMode() QuoteMode {
return q.quoteMode
}
func (q *quoter) IsReserved(value string) bool {
return q.dialect.IsReserved(value)
}
@ -77,21 +60,24 @@ func quoteColumns(quoter Quoter, columnStr string) string {
func quoteJoin(quoter Quoter, columns []string) string {
for i := 0; i < len(columns); i++ {
columns[i] = quote(quoter, columns[i], true)
columns[i] = quote(quoter, columns[i])
}
return strings.Join(columns, ",")
}
// quote Use QuoteStr quote the string sql
func quote(quoter Quoter, value string, isColumn bool) string {
func quote(quoter Quoter, value string) string {
buf := strings.Builder{}
quoteTo(quoter, &buf, value, isColumn)
quoteTo(quoter, &buf, value)
return buf.String()
}
// Quote add quotes to the value
func (engine *Engine) quote(value string, isColumn bool) string {
return quote(engine, value, isColumn)
if isColumn {
return quote(engine.colQuoter, value)
}
return quote(engine.tableQuoter, value)
}
// Quote add quotes to the value
@ -105,53 +91,25 @@ func (engine *Engine) Quotes() (byte, byte) {
return quotes[0], quotes[1]
}
// QuoteMode returns quote mode
func (engine *Engine) QuoteMode() QuoteMode {
return engine.quoteMode
}
// QuotePolicy returns quote policy
func (engine *Engine) QuotePolicy() QuotePolicy {
return engine.quotePolicy
}
// IsReserved return true if the value is a reserved word of the database
func (engine *Engine) IsReserved(value string) bool {
return engine.dialect.IsReserved(value)
}
// quoteTo quotes string and writes into the buffer
func quoteTo(quoter Quoter, buf *strings.Builder, value string, isColumn bool) {
if isColumn {
if quoter.QuoteMode() == QuoteTableAndColumns ||
quoter.QuoteMode() == QuoteColumnsOnly {
if quoter.QuotePolicy() == QuoteAddAlways {
realQuoteTo(quoter, buf, value)
return
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
realQuoteTo(quoter, buf, value)
return
}
}
buf.WriteString(value)
func quoteTo(quoter Quoter, buf *strings.Builder, value string) {
left, right := quoter.Quotes()
if quoter.QuotePolicy() == QuoteAddAlways {
realQuoteTo(left, right, buf, value)
return
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
realQuoteTo(left, right, buf, value)
return
}
if quoter.QuoteMode() == QuoteTableAndColumns ||
quoter.QuoteMode() == QuoteTableOnly {
if quoter.QuotePolicy() == QuoteAddAlways {
realQuoteTo(quoter, buf, value)
return
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
realQuoteTo(quoter, buf, value)
return
}
}
buf.WriteString(value)
return
}
func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) {
func realQuoteTo(prefix, suffix byte, buf *strings.Builder, value string) {
if buf == nil {
return
}
@ -164,8 +122,6 @@ func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) {
return
}
prefix, suffix := quoter.Quotes()
i := 0
for i < len(value) {
// start of a token; might be already quoted

View File

@ -141,7 +141,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
if session.statement.JoinStr == "" {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
} else {
columnStr = session.statement.genColumnStr()
}
@ -149,7 +149,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} else {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
} else {
columnStr = "*"
}

View File

@ -249,15 +249,15 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if session.engine.dialect.DBType() == core.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
session.engine.quote(tableName, false),
quoteJoin(session.engine, colNames))
quoteJoin(session.engine.colQuoter, colNames))
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
session.engine.quote(tableName, false),
quoteJoin(session.engine, colNames),
quoteJoin(session.engine.colQuoter, colNames),
strings.Join(colMultiPlaces, temp))
} else {
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
session.engine.quote(tableName, false),
quoteJoin(session.engine, colNames),
quoteJoin(session.engine.colQuoter, colNames),
strings.Join(colMultiPlaces, "),("))
}
res, err := session.exec(sql, args...)
@ -855,7 +855,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
session.engine.quote(tableName, false),
quoteJoin(session.engine, columns), qm)); err != nil {
quoteJoin(session.engine.colQuoter, columns), qm)); err != nil {
return 0, err
}
w.Append(args...)

View File

@ -35,7 +35,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa
if session.statement.JoinStr == "" {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
} else {
columnStr = session.statement.genColumnStr()
}
@ -43,7 +43,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa
} else {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
} else {
columnStr = "*"
}

View File

@ -100,7 +100,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
for idx, kv := range kvs {
sps := strings.SplitN(kv, "=", 2)
sps2 := strings.Split(sps[0], ".")
colName := unQuote(session.engine, sps2[len(sps2)-1])
colName := unQuote(session.engine.colQuoter, sps2[len(sps2)-1])
if col := table.GetColumn(colName); col != nil {
fieldValue, err := col.ValueOf(bean)

View File

@ -615,7 +615,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
newColumns := statement.colmap2NewColsWithQuote()
statement.ColumnStr = quoteJoin(statement.Engine, newColumns)
statement.ColumnStr = quoteJoin(statement.Engine.colQuoter, newColumns)
return statement
}
@ -650,7 +650,7 @@ func (statement *Statement) Omit(columns ...string) {
for _, nc := range newColumns {
statement.omitColumnMap = append(statement.omitColumnMap, nc)
}
statement.OmitStr = quoteJoin(statement.Engine, newColumns)
statement.OmitStr = quoteJoin(statement.Engine.colQuoter, newColumns)
}
// Nullable Update use only: update columns to null when value is nullable and zero-value
@ -744,7 +744,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
}
tbs := strings.Split(tp.TableName(), ".")
var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1])
var aliasName = unQuote(statement.Engine.tableQuoter, tbs[len(tbs)-1])
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder:
@ -755,7 +755,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
}
tbs := strings.Split(tp.TableName(), ".")
var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1])
var aliasName = unQuote(statement.Engine.tableQuoter, tbs[len(tbs)-1])
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
@ -821,7 +821,7 @@ func (statement *Statement) genColumnStr() string {
buf.WriteString(".")
}
quoteTo(statement.Engine, &buf, col.Name, true)
quoteTo(statement.Engine.colQuoter, &buf, col.Name)
}
return buf.String()
@ -940,7 +940,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
if len(statement.JoinStr) == 0 {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = quoteColumns(statement.Engine, statement.GroupByStr)
columnStr = quoteColumns(statement.Engine.colQuoter, statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
@ -948,7 +948,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
} else {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = quoteColumns(statement.Engine, statement.GroupByStr)
columnStr = quoteColumns(statement.Engine.colQuoter, statement.GroupByStr)
}
}
}

View File

@ -243,6 +243,9 @@ func TestCol2NewColsWithQuote(t *testing.T) {
statement := createTestStatement()
quotedCols := quoteJoin(statement.Engine, cols)
assert.EqualValues(t, []string{statement.Engine.Quote("f1", true), statement.Engine.Quote("f2", true), statement.Engine.Quote("t3.f3", true)}, quotedCols)
quotedCols := quoteJoin(statement.Engine.colQuoter, cols)
assert.EqualValues(t, statement.Engine.Quote("f1", true)+","+
statement.Engine.Quote("f2", true)+","+
statement.Engine.Quote("t3.f3", true),
quotedCols)
}

View File

@ -95,6 +95,8 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
tagHandlers: defaultTagHandlers,
cachers: make(map[string]core.Cacher),
defaultContext: context.Background(),
colQuoter: newQuoter(dialect, QuoteAddAlways),
tableQuoter: newQuoter(dialect, QuoteAddAlways),
}
if uri.DbType == core.SQLITE {