diff --git a/internal/statements/query.go b/internal/statements/query.go index 31fb0f96..2cb8458a 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -193,7 +193,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB distinct = "DISTINCT " } - condSQL, condArgs, err := builder.ToSQL(statement.cond) + condSQL, condArgs, err := statement.GenCondSQL(statement.cond) if err != nil { return "", nil, err } @@ -332,7 +332,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac } if statement.Conds().IsValid() { - condSQL, condArgs, err := builder.ToSQL(statement.Conds()) + condSQL, condArgs, err := statement.GenCondSQL(statement.Conds()) if err != nil { return "", nil, err } diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 99a99c58..a9bd6ade 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -100,11 +100,23 @@ func (statement *Statement) omitStr() string { // GenRawSQL generates correct raw sql func (statement *Statement) GenRawSQL() string { - if statement.RawSQL == "" || statement.dialect.URI().DBType == schemas.MYSQL || - statement.dialect.URI().DBType == schemas.SQLITE { - return statement.RawSQL + return statement.ReplaceQuote(statement.RawSQL) +} + +func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) { + condSQL, condArgs, err := builder.ToSQL(condOrBuilder) + if err != nil { + return "", nil, err } - return statement.dialect.Quoter().Replace(statement.RawSQL) + return statement.ReplaceQuote(condSQL), condArgs, nil +} + +func (statement *Statement) ReplaceQuote(sql string) string { + if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || + statement.dialect.URI().DBType == schemas.SQLITE { + return sql + } + return statement.dialect.Quoter().Replace(sql) } func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) { @@ -357,7 +369,11 @@ func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { // SetExpr Generate "Update ... Set column = {expression}" statement func (statement *Statement) SetExpr(column string, expression interface{}) *Statement { - statement.ExprColumns.addParam(column, expression) + if e, ok := expression.(string); ok { + statement.ExprColumns.addParam(column, statement.dialect.Quoter().Replace(e)) + } else { + statement.ExprColumns.addParam(column, expression) + } return statement } @@ -935,7 +951,7 @@ func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, e return "", nil, err } - return builder.ToSQL(statement.cond) + return statement.GenCondSQL(statement.cond) } func (statement *Statement) quoteColumnStr(columnStr string) string { diff --git a/session_update.go b/session_update.go index aa4718b6..f60f48e3 100644 --- a/session_update.go +++ b/session_update.go @@ -240,7 +240,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } colNames = append(colNames, session.engine.Quote(colName)+"="+tp) case *builder.Builder: - subQuery, subArgs, err := builder.ToSQL(tp) + subQuery, subArgs, err := session.statement.GenCondSQL(tp) if err != nil { return 0, err } @@ -317,7 +317,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - condSQL, condArgs, err = builder.ToSQL(cond) + if len(colNames) <= 0 { + return 0, errors.New("No content found to be updated") + } + + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -335,24 +339,25 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var top string if st.LimitN != nil { limitValue := *st.LimitN - if session.engine.dialect.URI().DBType == schemas.MYSQL { + switch session.engine.dialect.URI().DBType { + case schemas.MYSQL: condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) - } else if session.engine.dialect.URI().DBType == schemas.SQLITE { + case schemas.SQLITE: tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if session.engine.dialect.URI().DBType == schemas.POSTGRES { + case schemas.POSTGRES: tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -360,14 +365,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if session.engine.dialect.URI().DBType == schemas.MSSQL { - if st.OrderStr != "" && session.engine.dialect.URI().DBType == schemas.MSSQL && - table != nil && len(table.PrimaryKeys) == 1 { + case schemas.MSSQL: + if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 { cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], session.engine.Quote(tableName), condSQL), condArgs...) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -380,10 +384,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - if len(colNames) <= 0 { - return 0, errors.New("No content found to be updated") - } - var tableAlias = session.engine.Quote(tableName) var fromSQL string if session.statement.TableAlias != "" {