This commit is contained in:
Lunny Xiao 2022-05-30 15:01:02 +08:00
parent fd823ae5bd
commit d459641f96
8 changed files with 423 additions and 383 deletions

View File

@ -0,0 +1,91 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import "xorm.io/builder"
// Where add Where statement
func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
return statement.And(query, args...)
}
// And add Where & and statement
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
switch qr := query.(type) {
case string:
cond := builder.Expr(qr, args...)
statement.cond = statement.cond.And(cond)
case map[string]interface{}:
cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.And(cond)
case builder.Cond:
statement.cond = statement.cond.And(qr)
for _, v := range args {
if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.And(vv)
}
}
default:
statement.LastError = ErrConditionType
}
return statement
}
// Or add Where & Or statement
func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
switch qr := query.(type) {
case string:
cond := builder.Expr(qr, args...)
statement.cond = statement.cond.Or(cond)
case map[string]interface{}:
cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.Or(cond)
case builder.Cond:
statement.cond = statement.cond.Or(qr)
for _, v := range args {
if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.Or(vv)
}
}
default:
statement.LastError = ErrConditionType
}
return statement
}
// In generate "Where column IN (?) " statement
func (statement *Statement) In(column string, args ...interface{}) *Statement {
in := builder.In(statement.quote(column), args...)
statement.cond = statement.cond.And(in)
return statement
}
// NotIn generate "Where column NOT IN (?) " statement
func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
notIn := builder.NotIn(statement.quote(column), args...)
statement.cond = statement.cond.And(notIn)
return statement
}
// SetNoAutoCondition if you do not want convert bean's field as query condition, then use this function
func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement {
statement.NoAutoCondition = true
if len(no) > 0 {
statement.NoAutoCondition = no[0]
}
return statement
}
// Conds returns condtions
func (statement *Statement) Conds() builder.Cond {
return statement.cond
}

View File

@ -0,0 +1,78 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/builder"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf strings.Builder
if len(statement.JoinStr) > 0 {
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
} else {
fmt.Fprintf(&buf, "%v JOIN ", joinOP)
}
switch tp := tablename.(type) {
case builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
if !utils.IsSubQuery(tbName) {
var buf strings.Builder
_ = statement.dialect.Quoter().QuoteTo(&buf, tbName)
tbName = buf.String()
} else {
tbName = statement.ReplaceQuote(tbName)
}
fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition))
}
statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...)
return statement
}
func (statement *Statement) writeJoin(w builder.Writer) error {
if statement.JoinStr != "" {
if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil {
return err
}
w.Append(statement.joinArgs...)
}
return nil
}

View File

@ -0,0 +1,79 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/builder"
)
func (statement *Statement) HasOrderBy() bool {
return statement.OrderStr != ""
}
// ResetOrderBy reset ordery conditions
func (statement *Statement) ResetOrderBy() {
statement.OrderStr = ""
statement.orderArgs = nil
}
// WriteOrderBy write order by to writer
func (statement *Statement) WriteOrderBy(w builder.Writer) error {
if len(statement.OrderStr) > 0 {
if _, err := fmt.Fprintf(w, " ORDER BY %s", statement.OrderStr); err != nil {
return err
}
w.Append(statement.orderArgs...)
}
return nil
}
// OrderBy generate "Order By order" statement
func (statement *Statement) OrderBy(order string, args ...interface{}) *Statement {
if len(statement.OrderStr) > 0 {
statement.OrderStr += ", "
}
statement.OrderStr += statement.ReplaceQuote(order)
if len(args) > 0 {
statement.orderArgs = append(statement.orderArgs, args...)
}
return statement
}
// Desc generate `ORDER BY xx DESC`
func (statement *Statement) Desc(colNames ...string) *Statement {
var buf strings.Builder
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, statement.OrderStr, ", ")
}
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
_ = statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " DESC")
}
statement.OrderStr = buf.String()
return statement
}
// Asc provide asc order by query condition, the input parameters are columns.
func (statement *Statement) Asc(colNames ...string) *Statement {
var buf strings.Builder
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, statement.OrderStr, ", ")
}
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
_ = statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " ASC")
}
statement.OrderStr = buf.String()
return statement
}

View File

@ -59,18 +59,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
return "", nil, err return "", nil, err
} }
sqlStr, args, err := statement.genSelectSQL(columnStr, true, true) return statement.genSelectSQL(columnStr, true, true)
if err != nil {
return "", nil, err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
} }
// GenSumSQL generates sum SQL // GenSumSQL generates sum SQL
@ -199,16 +188,6 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
return sqlStr, condArgs, nil return sqlStr, condArgs, nil
} }
func (statement *Statement) writeJoin(w builder.Writer) error {
if statement.JoinStr != "" {
if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil {
return err
}
w.Append(statement.joinArgs...)
}
return nil
}
func (statement *Statement) writeAlias(w builder.Writer) error { func (statement *Statement) writeAlias(w builder.Writer) error {
if statement.TableAlias != "" { if statement.TableAlias != "" {
if statement.dialect.URI().DBType == schemas.ORACLE { if statement.dialect.URI().DBType == schemas.ORACLE {
@ -422,7 +401,6 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
return statement.GenRawSQL(), statement.RawParams, nil return statement.GenRawSQL(), statement.RawParams, nil
} }
var joinStr string
var b interface{} var b interface{}
if len(bean) > 0 { if len(bean) > 0 {
b = bean[0] b = bean[0]
@ -446,13 +424,13 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
} }
tableName = statement.quote(tableName) tableName = statement.quote(tableName)
if len(statement.JoinStr) > 0 {
joinStr = " " + statement.JoinStr + " "
}
buf := builder.NewWriter() buf := builder.NewWriter()
if statement.dialect.URI().DBType == schemas.MSSQL { if statement.dialect.URI().DBType == schemas.MSSQL {
if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s%s", tableName, joinStr); err != nil { if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil {
return "", nil, err
}
if err := statement.writeJoin(buf); err != nil {
return "", nil, err return "", nil, err
} }
if statement.Conds().IsValid() { if statement.Conds().IsValid() {
@ -464,7 +442,13 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
} }
} }
} else if statement.dialect.URI().DBType == schemas.ORACLE { } else if statement.dialect.URI().DBType == schemas.ORACLE {
if _, err := fmt.Fprintf(buf, "SELECT * FROM %s%s WHERE ", tableName, joinStr); err != nil { if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil {
return "", nil, err
}
if err := statement.writeJoin(buf); err != nil {
return "", nil, err
}
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
return "", nil, err return "", nil, err
} }
if statement.Conds().IsValid() { if statement.Conds().IsValid() {
@ -479,7 +463,10 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
return "", nil, err return "", nil, err
} }
} else { } else {
if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s%s", tableName, joinStr); err != nil { if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil {
return "", nil, err
}
if err := statement.writeJoin(buf); err != nil {
return "", nil, err return "", nil, err
} }
if statement.Conds().IsValid() { if statement.Conds().IsValid() {

View File

@ -0,0 +1,137 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/xorm/schemas"
)
// Select replace select
func (statement *Statement) Select(str string) *Statement {
statement.SelectStr = statement.ReplaceQuote(str)
return statement
}
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
// Cols generate "col1, col2" statement
func (statement *Statement) Cols(columns ...string) *Statement {
cols := col2NewCols(columns...)
for _, nc := range cols {
statement.ColumnMap.Add(nc)
}
return statement
}
// ColumnStr returns column string
func (statement *Statement) ColumnStr() string {
return statement.dialect.Quoter().Join(statement.ColumnMap, ", ")
}
// AllCols update use only: update all columns
func (statement *Statement) AllCols() *Statement {
statement.useAllCols = true
return statement
}
// MustCols update use only: must update columns
func (statement *Statement) MustCols(columns ...string) *Statement {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.MustColumnMap[strings.ToLower(nc)] = true
}
return statement
}
// UseBool indicates that use bool fields as update contents and query contiditions
func (statement *Statement) UseBool(columns ...string) *Statement {
if len(columns) > 0 {
statement.MustCols(columns...)
} else {
statement.allUseBool = true
}
return statement
}
// Omit do not use the columns
func (statement *Statement) Omit(columns ...string) {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.OmitColumnMap = append(statement.OmitColumnMap, nc)
}
}
func (statement *Statement) genColumnStr() string {
if statement.RefTable == nil {
return ""
}
var buf strings.Builder
columns := statement.RefTable.Columns()
for _, col := range columns {
if statement.OmitColumnMap.Contain(col.Name) {
continue
}
if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) {
continue
}
if col.MapType == schemas.ONLYTODB {
continue
}
if buf.Len() != 0 {
buf.WriteString(", ")
}
if statement.JoinStr != "" {
if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias)
} else {
buf.WriteString(statement.TableName())
}
buf.WriteString(".")
}
statement.dialect.Quoter().QuoteTo(&buf, col.Name)
}
return buf.String()
}
func (statement *Statement) colName(col *schemas.Column, tableName string) string {
if statement.needTableName() {
nm := tableName
if len(statement.TableAlias) > 0 {
nm = statement.TableAlias
}
return fmt.Sprintf("%s.%s", statement.quote(nm), statement.quote(col.Name))
}
return statement.quote(col.Name)
}
// Distinct generates "DISTINCT col1, col2 " statement
func (statement *Statement) Distinct(columns ...string) *Statement {
statement.IsDistinct = true
statement.Cols(columns...)
return statement
}

View File

@ -102,15 +102,6 @@ func (statement *Statement) GenRawSQL() string {
return statement.ReplaceQuote(statement.RawSQL) return statement.ReplaceQuote(statement.RawSQL)
} }
// GenCondSQL generates condition SQL
func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) {
condSQL, condArgs, err := builder.ToSQL(condOrBuilder)
if err != nil {
return "", nil, err
}
return statement.ReplaceQuote(condSQL), condArgs, nil
}
// ReplaceQuote replace sql key words with quote // ReplaceQuote replace sql key words with quote
func (statement *Statement) ReplaceQuote(sql string) string { func (statement *Statement) ReplaceQuote(sql string) string {
if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL ||
@ -165,21 +156,6 @@ func (statement *Statement) Reset() {
statement.LastError = nil statement.LastError = nil
} }
// SetNoAutoCondition if you do not want convert bean's field as query condition, then use this function
func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement {
statement.NoAutoCondition = true
if len(no) > 0 {
statement.NoAutoCondition = no[0]
}
return statement
}
// Alias set the table alias
func (statement *Statement) Alias(alias string) *Statement {
statement.TableAlias = alias
return statement
}
// SQL adds raw sql statement // SQL adds raw sql statement
func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
switch query.(type) { switch query.(type) {
@ -199,80 +175,10 @@ func (statement *Statement) SQL(query interface{}, args ...interface{}) *Stateme
return statement return statement
} }
// Where add Where statement
func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
return statement.And(query, args...)
}
func (statement *Statement) quote(s string) string { func (statement *Statement) quote(s string) string {
return statement.dialect.Quoter().Quote(s) return statement.dialect.Quoter().Quote(s)
} }
// And add Where & and statement
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
switch qr := query.(type) {
case string:
cond := builder.Expr(qr, args...)
statement.cond = statement.cond.And(cond)
case map[string]interface{}:
cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.And(cond)
case builder.Cond:
statement.cond = statement.cond.And(qr)
for _, v := range args {
if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.And(vv)
}
}
default:
statement.LastError = ErrConditionType
}
return statement
}
// Or add Where & Or statement
func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
switch qr := query.(type) {
case string:
cond := builder.Expr(qr, args...)
statement.cond = statement.cond.Or(cond)
case map[string]interface{}:
cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.Or(cond)
case builder.Cond:
statement.cond = statement.cond.Or(qr)
for _, v := range args {
if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.Or(vv)
}
}
default:
statement.LastError = ErrConditionType
}
return statement
}
// In generate "Where column IN (?) " statement
func (statement *Statement) In(column string, args ...interface{}) *Statement {
in := builder.In(statement.quote(column), args...)
statement.cond = statement.cond.And(in)
return statement
}
// NotIn generate "Where column NOT IN (?) " statement
func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
notIn := builder.NotIn(statement.quote(column), args...)
statement.cond = statement.cond.And(notIn)
return statement
}
// SetRefValue set ref value // SetRefValue set ref value
func (statement *Statement) SetRefValue(v reflect.Value) error { func (statement *Statement) SetRefValue(v reflect.Value) error {
var err error var err error
@ -303,26 +209,6 @@ func (statement *Statement) needTableName() bool {
return len(statement.JoinStr) > 0 return len(statement.JoinStr) > 0
} }
func (statement *Statement) colName(col *schemas.Column, tableName string) string {
if statement.needTableName() {
nm := tableName
if len(statement.TableAlias) > 0 {
nm = statement.TableAlias
}
return fmt.Sprintf("%s.%s", statement.quote(nm), statement.quote(col.Name))
}
return statement.quote(col.Name)
}
// TableName return current tableName
func (statement *Statement) TableName() string {
if statement.AltTableName != "" {
return statement.AltTableName
}
return statement.tableName
}
// Incr Generate "Update ... Set column = column + arg" statement // Incr Generate "Update ... Set column = column + arg" statement
func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
if len(arg) > 0 { if len(arg) > 0 {
@ -353,85 +239,12 @@ func (statement *Statement) SetExpr(column string, expression interface{}) *Stat
return statement return statement
} }
// Distinct generates "DISTINCT col1, col2 " statement
func (statement *Statement) Distinct(columns ...string) *Statement {
statement.IsDistinct = true
statement.Cols(columns...)
return statement
}
// ForUpdate generates "SELECT ... FOR UPDATE" statement // ForUpdate generates "SELECT ... FOR UPDATE" statement
func (statement *Statement) ForUpdate() *Statement { func (statement *Statement) ForUpdate() *Statement {
statement.IsForUpdate = true statement.IsForUpdate = true
return statement return statement
} }
// Select replace select
func (statement *Statement) Select(str string) *Statement {
statement.SelectStr = statement.ReplaceQuote(str)
return statement
}
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
// Cols generate "col1, col2" statement
func (statement *Statement) Cols(columns ...string) *Statement {
cols := col2NewCols(columns...)
for _, nc := range cols {
statement.ColumnMap.Add(nc)
}
return statement
}
// ColumnStr returns column string
func (statement *Statement) ColumnStr() string {
return statement.dialect.Quoter().Join(statement.ColumnMap, ", ")
}
// AllCols update use only: update all columns
func (statement *Statement) AllCols() *Statement {
statement.useAllCols = true
return statement
}
// MustCols update use only: must update columns
func (statement *Statement) MustCols(columns ...string) *Statement {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.MustColumnMap[strings.ToLower(nc)] = true
}
return statement
}
// UseBool indicates that use bool fields as update contents and query contiditions
func (statement *Statement) UseBool(columns ...string) *Statement {
if len(columns) > 0 {
statement.MustCols(columns...)
} else {
statement.allUseBool = true
}
return statement
}
// Omit do not use the columns
func (statement *Statement) Omit(columns ...string) {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.OmitColumnMap = append(statement.OmitColumnMap, nc)
}
}
// Nullable Update use only: update columns to null when value is nullable and zero-value // Nullable Update use only: update columns to null when value is nullable and zero-value
func (statement *Statement) Nullable(columns ...string) { func (statement *Statement) Nullable(columns ...string) {
newColumns := col2NewCols(columns...) newColumns := col2NewCols(columns...)
@ -455,78 +268,6 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement {
return statement return statement
} }
func (statement *Statement) HasOrderBy() bool {
return statement.OrderStr != ""
}
// ResetOrderBy reset ordery conditions
func (statement *Statement) ResetOrderBy() {
statement.OrderStr = ""
statement.orderArgs = nil
}
// WriteOrderBy write order by to writer
func (statement *Statement) WriteOrderBy(w builder.Writer) error {
if len(statement.OrderStr) > 0 {
if _, err := fmt.Fprintf(w, " ORDER BY %s", statement.OrderStr); err != nil {
return err
}
w.Append(statement.orderArgs...)
}
return nil
}
// OrderBy generate "Order By order" statement
func (statement *Statement) OrderBy(order string, args ...interface{}) *Statement {
if len(statement.OrderStr) > 0 {
statement.OrderStr += ", "
}
statement.OrderStr += statement.ReplaceQuote(order)
if len(args) > 0 {
statement.orderArgs = append(statement.orderArgs, args...)
}
return statement
}
// Desc generate `ORDER BY xx DESC`
func (statement *Statement) Desc(colNames ...string) *Statement {
var buf strings.Builder
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, statement.OrderStr, ", ")
}
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
_ = statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " DESC")
}
statement.OrderStr = buf.String()
return statement
}
// Asc provide asc order by query condition, the input parameters are columns.
func (statement *Statement) Asc(colNames ...string) *Statement {
var buf strings.Builder
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, statement.OrderStr, ", ")
}
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
_ = statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " ASC")
}
statement.OrderStr = buf.String()
return statement
}
// Conds returns condtions
func (statement *Statement) Conds() builder.Cond {
return statement.cond
}
// SetTable tempororily set table name, the parameter could be a string or a pointer of struct // SetTable tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) SetTable(tableNameOrBean interface{}) error { func (statement *Statement) SetTable(tableNameOrBean interface{}) error {
v := rValue(tableNameOrBean) v := rValue(tableNameOrBean)
@ -543,59 +284,6 @@ func (statement *Statement) SetTable(tableNameOrBean interface{}) error {
return nil return nil
} }
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf strings.Builder
if len(statement.JoinStr) > 0 {
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
} else {
fmt.Fprintf(&buf, "%v JOIN ", joinOP)
}
switch tp := tablename.(type) {
case builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
if !utils.IsSubQuery(tbName) {
var buf strings.Builder
_ = statement.dialect.Quoter().QuoteTo(&buf, tbName)
tbName = buf.String()
} else {
tbName = statement.ReplaceQuote(tbName)
}
fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition))
}
statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...)
return statement
}
// GroupBy generate "Group By keys" statement // GroupBy generate "Group By keys" statement
func (statement *Statement) GroupBy(keys string) *Statement { func (statement *Statement) GroupBy(keys string) *Statement {
statement.GroupByStr = statement.ReplaceQuote(keys) statement.GroupByStr = statement.ReplaceQuote(keys)
@ -635,47 +323,6 @@ func (statement *Statement) GetUnscoped() bool {
return statement.unscoped return statement.unscoped
} }
func (statement *Statement) genColumnStr() string {
if statement.RefTable == nil {
return ""
}
var buf strings.Builder
columns := statement.RefTable.Columns()
for _, col := range columns {
if statement.OmitColumnMap.Contain(col.Name) {
continue
}
if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) {
continue
}
if col.MapType == schemas.ONLYTODB {
continue
}
if buf.Len() != 0 {
buf.WriteString(", ")
}
if statement.JoinStr != "" {
if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias)
} else {
buf.WriteString(statement.TableName())
}
buf.WriteString(".")
}
statement.dialect.Quoter().QuoteTo(&buf, col.Name)
}
return buf.String()
}
// GenIndexSQL generated create index SQL // GenIndexSQL generated create index SQL
func (statement *Statement) GenIndexSQL() []string { func (statement *Statement) GenIndexSQL() []string {
var sqls []string var sqls []string

View File

@ -0,0 +1,20 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
// TableName return current tableName
func (statement *Statement) TableName() string {
if statement.AltTableName != "" {
return statement.AltTableName
}
return statement.tableName
}
// Alias set the table alias
func (statement *Statement) Alias(alias string) *Statement {
statement.TableAlias = alias
return statement
}

View File

@ -258,10 +258,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp) colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp)
case *builder.Builder: case *builder.Builder:
subQuery, subArgs, err := session.statement.GenCondSQL(tp) subQuery, subArgs, err := builder.ToSQL(tp)
if err != nil { if err != nil {
return 0, err return 0, err
} }
subQuery = session.statement.ReplaceQuote(subQuery)
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")") colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")")
args = append(args, subArgs...) args = append(args, subArgs...)
default: default: