diff --git a/internal/statements/cond.go b/internal/statements/cond.go new file mode 100644 index 00000000..30eb5b4f --- /dev/null +++ b/internal/statements/cond.go @@ -0,0 +1,91 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import "xorm.io/builder" + +// Where add Where statement +func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement { + return statement.And(query, args...) +} + +// And add Where & and statement +func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { + switch qr := query.(type) { + case string: + cond := builder.Expr(qr, args...) + statement.cond = statement.cond.And(cond) + case map[string]interface{}: + cond := make(builder.Eq) + for k, v := range qr { + cond[statement.quote(k)] = v + } + statement.cond = statement.cond.And(cond) + case builder.Cond: + statement.cond = statement.cond.And(qr) + for _, v := range args { + if vv, ok := v.(builder.Cond); ok { + statement.cond = statement.cond.And(vv) + } + } + default: + statement.LastError = ErrConditionType + } + + return statement +} + +// Or add Where & Or statement +func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { + switch qr := query.(type) { + case string: + cond := builder.Expr(qr, args...) + statement.cond = statement.cond.Or(cond) + case map[string]interface{}: + cond := make(builder.Eq) + for k, v := range qr { + cond[statement.quote(k)] = v + } + statement.cond = statement.cond.Or(cond) + case builder.Cond: + statement.cond = statement.cond.Or(qr) + for _, v := range args { + if vv, ok := v.(builder.Cond); ok { + statement.cond = statement.cond.Or(vv) + } + } + default: + statement.LastError = ErrConditionType + } + return statement +} + +// In generate "Where column IN (?) " statement +func (statement *Statement) In(column string, args ...interface{}) *Statement { + in := builder.In(statement.quote(column), args...) + statement.cond = statement.cond.And(in) + return statement +} + +// NotIn generate "Where column NOT IN (?) " statement +func (statement *Statement) NotIn(column string, args ...interface{}) *Statement { + notIn := builder.NotIn(statement.quote(column), args...) + statement.cond = statement.cond.And(notIn) + return statement +} + +// SetNoAutoCondition if you do not want convert bean's field as query condition, then use this function +func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement { + statement.NoAutoCondition = true + if len(no) > 0 { + statement.NoAutoCondition = no[0] + } + return statement +} + +// Conds returns condtions +func (statement *Statement) Conds() builder.Cond { + return statement.cond +} diff --git a/internal/statements/join.go b/internal/statements/join.go new file mode 100644 index 00000000..45fc2441 --- /dev/null +++ b/internal/statements/join.go @@ -0,0 +1,78 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "xorm.io/builder" + "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN +func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { + var buf strings.Builder + if len(statement.JoinStr) > 0 { + fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) + } else { + fmt.Fprintf(&buf, "%v JOIN ", joinOP) + } + + switch tp := tablename.(type) { + case builder.Builder: + subSQL, subQueryArgs, err := tp.ToSQL() + if err != nil { + statement.LastError = err + return statement + } + + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition)) + statement.joinArgs = append(statement.joinArgs, subQueryArgs...) + case *builder.Builder: + subSQL, subQueryArgs, err := tp.ToSQL() + if err != nil { + statement.LastError = err + return statement + } + + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition)) + statement.joinArgs = append(statement.joinArgs, subQueryArgs...) + default: + tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) + if !utils.IsSubQuery(tbName) { + var buf strings.Builder + _ = statement.dialect.Quoter().QuoteTo(&buf, tbName) + tbName = buf.String() + } else { + tbName = statement.ReplaceQuote(tbName) + } + fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition)) + } + + statement.JoinStr = buf.String() + statement.joinArgs = append(statement.joinArgs, args...) + return statement +} + +func (statement *Statement) writeJoin(w builder.Writer) error { + if statement.JoinStr != "" { + if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil { + return err + } + w.Append(statement.joinArgs...) + } + return nil +} diff --git a/internal/statements/order_by.go b/internal/statements/order_by.go new file mode 100644 index 00000000..8a5cc1ef --- /dev/null +++ b/internal/statements/order_by.go @@ -0,0 +1,79 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "xorm.io/builder" +) + +func (statement *Statement) HasOrderBy() bool { + return statement.OrderStr != "" +} + +// ResetOrderBy reset ordery conditions +func (statement *Statement) ResetOrderBy() { + statement.OrderStr = "" + statement.orderArgs = nil +} + +// WriteOrderBy write order by to writer +func (statement *Statement) WriteOrderBy(w builder.Writer) error { + if len(statement.OrderStr) > 0 { + if _, err := fmt.Fprintf(w, " ORDER BY %s", statement.OrderStr); err != nil { + return err + } + w.Append(statement.orderArgs...) + } + return nil +} + +// OrderBy generate "Order By order" statement +func (statement *Statement) OrderBy(order string, args ...interface{}) *Statement { + if len(statement.OrderStr) > 0 { + statement.OrderStr += ", " + } + statement.OrderStr += statement.ReplaceQuote(order) + if len(args) > 0 { + statement.orderArgs = append(statement.orderArgs, args...) + } + return statement +} + +// Desc generate `ORDER BY xx DESC` +func (statement *Statement) Desc(colNames ...string) *Statement { + var buf strings.Builder + if len(statement.OrderStr) > 0 { + fmt.Fprint(&buf, statement.OrderStr, ", ") + } + for i, col := range colNames { + if i > 0 { + fmt.Fprint(&buf, ", ") + } + _ = statement.dialect.Quoter().QuoteTo(&buf, col) + fmt.Fprint(&buf, " DESC") + } + statement.OrderStr = buf.String() + return statement +} + +// Asc provide asc order by query condition, the input parameters are columns. +func (statement *Statement) Asc(colNames ...string) *Statement { + var buf strings.Builder + if len(statement.OrderStr) > 0 { + fmt.Fprint(&buf, statement.OrderStr, ", ") + } + for i, col := range colNames { + if i > 0 { + fmt.Fprint(&buf, ", ") + } + _ = statement.dialect.Quoter().QuoteTo(&buf, col) + fmt.Fprint(&buf, " ASC") + } + statement.OrderStr = buf.String() + return statement +} diff --git a/internal/statements/query.go b/internal/statements/query.go index 9d16c891..152e297d 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -59,18 +59,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int return "", nil, err } - sqlStr, args, err := statement.genSelectSQL(columnStr, true, true) - if err != nil { - return "", nil, err - } - - // for mssql and use limit - qs := strings.Count(sqlStr, "?") - if len(args)*2 == qs { - args = append(args, args...) - } - - return sqlStr, args, nil + return statement.genSelectSQL(columnStr, true, true) } // GenSumSQL generates sum SQL @@ -199,16 +188,6 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa return sqlStr, condArgs, nil } -func (statement *Statement) writeJoin(w builder.Writer) error { - if statement.JoinStr != "" { - if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil { - return err - } - w.Append(statement.joinArgs...) - } - return nil -} - func (statement *Statement) writeAlias(w builder.Writer) error { if statement.TableAlias != "" { if statement.dialect.URI().DBType == schemas.ORACLE { @@ -422,7 +401,6 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac return statement.GenRawSQL(), statement.RawParams, nil } - var joinStr string var b interface{} if len(bean) > 0 { b = bean[0] @@ -446,13 +424,13 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac } tableName = statement.quote(tableName) - if len(statement.JoinStr) > 0 { - joinStr = " " + statement.JoinStr + " " - } buf := builder.NewWriter() if statement.dialect.URI().DBType == schemas.MSSQL { - if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s%s", tableName, joinStr); err != nil { + if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil { + return "", nil, err + } + if err := statement.writeJoin(buf); err != nil { return "", nil, err } if statement.Conds().IsValid() { @@ -464,7 +442,13 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac } } } else if statement.dialect.URI().DBType == schemas.ORACLE { - if _, err := fmt.Fprintf(buf, "SELECT * FROM %s%s WHERE ", tableName, joinStr); err != nil { + if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil { + return "", nil, err + } + if err := statement.writeJoin(buf); err != nil { + return "", nil, err + } + if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { return "", nil, err } if statement.Conds().IsValid() { @@ -479,7 +463,10 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac return "", nil, err } } else { - if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s%s", tableName, joinStr); err != nil { + if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil { + return "", nil, err + } + if err := statement.writeJoin(buf); err != nil { return "", nil, err } if statement.Conds().IsValid() { diff --git a/internal/statements/select.go b/internal/statements/select.go new file mode 100644 index 00000000..2bd2e94d --- /dev/null +++ b/internal/statements/select.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "xorm.io/xorm/schemas" +) + +// Select replace select +func (statement *Statement) Select(str string) *Statement { + statement.SelectStr = statement.ReplaceQuote(str) + return statement +} + +func col2NewCols(columns ...string) []string { + newColumns := make([]string, 0, len(columns)) + for _, col := range columns { + col = strings.Replace(col, "`", "", -1) + col = strings.Replace(col, `"`, "", -1) + ccols := strings.Split(col, ",") + for _, c := range ccols { + newColumns = append(newColumns, strings.TrimSpace(c)) + } + } + return newColumns +} + +// Cols generate "col1, col2" statement +func (statement *Statement) Cols(columns ...string) *Statement { + cols := col2NewCols(columns...) + for _, nc := range cols { + statement.ColumnMap.Add(nc) + } + return statement +} + +// ColumnStr returns column string +func (statement *Statement) ColumnStr() string { + return statement.dialect.Quoter().Join(statement.ColumnMap, ", ") +} + +// AllCols update use only: update all columns +func (statement *Statement) AllCols() *Statement { + statement.useAllCols = true + return statement +} + +// MustCols update use only: must update columns +func (statement *Statement) MustCols(columns ...string) *Statement { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.MustColumnMap[strings.ToLower(nc)] = true + } + return statement +} + +// UseBool indicates that use bool fields as update contents and query contiditions +func (statement *Statement) UseBool(columns ...string) *Statement { + if len(columns) > 0 { + statement.MustCols(columns...) + } else { + statement.allUseBool = true + } + return statement +} + +// Omit do not use the columns +func (statement *Statement) Omit(columns ...string) { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.OmitColumnMap = append(statement.OmitColumnMap, nc) + } +} + +func (statement *Statement) genColumnStr() string { + if statement.RefTable == nil { + return "" + } + + var buf strings.Builder + columns := statement.RefTable.Columns() + + for _, col := range columns { + if statement.OmitColumnMap.Contain(col.Name) { + continue + } + + if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) { + continue + } + + if col.MapType == schemas.ONLYTODB { + continue + } + + if buf.Len() != 0 { + buf.WriteString(", ") + } + + if statement.JoinStr != "" { + if statement.TableAlias != "" { + buf.WriteString(statement.TableAlias) + } else { + buf.WriteString(statement.TableName()) + } + + buf.WriteString(".") + } + + statement.dialect.Quoter().QuoteTo(&buf, col.Name) + } + + return buf.String() +} + +func (statement *Statement) colName(col *schemas.Column, tableName string) string { + if statement.needTableName() { + nm := tableName + if len(statement.TableAlias) > 0 { + nm = statement.TableAlias + } + return fmt.Sprintf("%s.%s", statement.quote(nm), statement.quote(col.Name)) + } + return statement.quote(col.Name) +} + +// Distinct generates "DISTINCT col1, col2 " statement +func (statement *Statement) Distinct(columns ...string) *Statement { + statement.IsDistinct = true + statement.Cols(columns...) + return statement +} diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 605d4bd7..1974da1e 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -102,15 +102,6 @@ func (statement *Statement) GenRawSQL() string { return statement.ReplaceQuote(statement.RawSQL) } -// GenCondSQL generates condition SQL -func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) { - condSQL, condArgs, err := builder.ToSQL(condOrBuilder) - if err != nil { - return "", nil, err - } - return statement.ReplaceQuote(condSQL), condArgs, nil -} - // ReplaceQuote replace sql key words with quote func (statement *Statement) ReplaceQuote(sql string) string { if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || @@ -165,21 +156,6 @@ func (statement *Statement) Reset() { statement.LastError = nil } -// SetNoAutoCondition if you do not want convert bean's field as query condition, then use this function -func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement { - statement.NoAutoCondition = true - if len(no) > 0 { - statement.NoAutoCondition = no[0] - } - return statement -} - -// Alias set the table alias -func (statement *Statement) Alias(alias string) *Statement { - statement.TableAlias = alias - return statement -} - // SQL adds raw sql statement func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { switch query.(type) { @@ -199,80 +175,10 @@ func (statement *Statement) SQL(query interface{}, args ...interface{}) *Stateme return statement } -// Where add Where statement -func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement { - return statement.And(query, args...) -} - func (statement *Statement) quote(s string) string { return statement.dialect.Quoter().Quote(s) } -// And add Where & and statement -func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { - switch qr := query.(type) { - case string: - cond := builder.Expr(qr, args...) - statement.cond = statement.cond.And(cond) - case map[string]interface{}: - cond := make(builder.Eq) - for k, v := range qr { - cond[statement.quote(k)] = v - } - statement.cond = statement.cond.And(cond) - case builder.Cond: - statement.cond = statement.cond.And(qr) - for _, v := range args { - if vv, ok := v.(builder.Cond); ok { - statement.cond = statement.cond.And(vv) - } - } - default: - statement.LastError = ErrConditionType - } - - return statement -} - -// Or add Where & Or statement -func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { - switch qr := query.(type) { - case string: - cond := builder.Expr(qr, args...) - statement.cond = statement.cond.Or(cond) - case map[string]interface{}: - cond := make(builder.Eq) - for k, v := range qr { - cond[statement.quote(k)] = v - } - statement.cond = statement.cond.Or(cond) - case builder.Cond: - statement.cond = statement.cond.Or(qr) - for _, v := range args { - if vv, ok := v.(builder.Cond); ok { - statement.cond = statement.cond.Or(vv) - } - } - default: - statement.LastError = ErrConditionType - } - return statement -} - -// In generate "Where column IN (?) " statement -func (statement *Statement) In(column string, args ...interface{}) *Statement { - in := builder.In(statement.quote(column), args...) - statement.cond = statement.cond.And(in) - return statement -} - -// NotIn generate "Where column NOT IN (?) " statement -func (statement *Statement) NotIn(column string, args ...interface{}) *Statement { - notIn := builder.NotIn(statement.quote(column), args...) - statement.cond = statement.cond.And(notIn) - return statement -} - // SetRefValue set ref value func (statement *Statement) SetRefValue(v reflect.Value) error { var err error @@ -303,26 +209,6 @@ func (statement *Statement) needTableName() bool { return len(statement.JoinStr) > 0 } -func (statement *Statement) colName(col *schemas.Column, tableName string) string { - if statement.needTableName() { - nm := tableName - if len(statement.TableAlias) > 0 { - nm = statement.TableAlias - } - return fmt.Sprintf("%s.%s", statement.quote(nm), statement.quote(col.Name)) - } - return statement.quote(col.Name) -} - -// TableName return current tableName -func (statement *Statement) TableName() string { - if statement.AltTableName != "" { - return statement.AltTableName - } - - return statement.tableName -} - // Incr Generate "Update ... Set column = column + arg" statement func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { if len(arg) > 0 { @@ -353,85 +239,12 @@ func (statement *Statement) SetExpr(column string, expression interface{}) *Stat return statement } -// Distinct generates "DISTINCT col1, col2 " statement -func (statement *Statement) Distinct(columns ...string) *Statement { - statement.IsDistinct = true - statement.Cols(columns...) - return statement -} - // ForUpdate generates "SELECT ... FOR UPDATE" statement func (statement *Statement) ForUpdate() *Statement { statement.IsForUpdate = true return statement } -// Select replace select -func (statement *Statement) Select(str string) *Statement { - statement.SelectStr = statement.ReplaceQuote(str) - return statement -} - -func col2NewCols(columns ...string) []string { - newColumns := make([]string, 0, len(columns)) - for _, col := range columns { - col = strings.Replace(col, "`", "", -1) - col = strings.Replace(col, `"`, "", -1) - ccols := strings.Split(col, ",") - for _, c := range ccols { - newColumns = append(newColumns, strings.TrimSpace(c)) - } - } - return newColumns -} - -// Cols generate "col1, col2" statement -func (statement *Statement) Cols(columns ...string) *Statement { - cols := col2NewCols(columns...) - for _, nc := range cols { - statement.ColumnMap.Add(nc) - } - return statement -} - -// ColumnStr returns column string -func (statement *Statement) ColumnStr() string { - return statement.dialect.Quoter().Join(statement.ColumnMap, ", ") -} - -// AllCols update use only: update all columns -func (statement *Statement) AllCols() *Statement { - statement.useAllCols = true - return statement -} - -// MustCols update use only: must update columns -func (statement *Statement) MustCols(columns ...string) *Statement { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.MustColumnMap[strings.ToLower(nc)] = true - } - return statement -} - -// UseBool indicates that use bool fields as update contents and query contiditions -func (statement *Statement) UseBool(columns ...string) *Statement { - if len(columns) > 0 { - statement.MustCols(columns...) - } else { - statement.allUseBool = true - } - return statement -} - -// Omit do not use the columns -func (statement *Statement) Omit(columns ...string) { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.OmitColumnMap = append(statement.OmitColumnMap, nc) - } -} - // Nullable Update use only: update columns to null when value is nullable and zero-value func (statement *Statement) Nullable(columns ...string) { newColumns := col2NewCols(columns...) @@ -455,78 +268,6 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement { return statement } -func (statement *Statement) HasOrderBy() bool { - return statement.OrderStr != "" -} - -// ResetOrderBy reset ordery conditions -func (statement *Statement) ResetOrderBy() { - statement.OrderStr = "" - statement.orderArgs = nil -} - -// WriteOrderBy write order by to writer -func (statement *Statement) WriteOrderBy(w builder.Writer) error { - if len(statement.OrderStr) > 0 { - if _, err := fmt.Fprintf(w, " ORDER BY %s", statement.OrderStr); err != nil { - return err - } - w.Append(statement.orderArgs...) - } - return nil -} - -// OrderBy generate "Order By order" statement -func (statement *Statement) OrderBy(order string, args ...interface{}) *Statement { - if len(statement.OrderStr) > 0 { - statement.OrderStr += ", " - } - statement.OrderStr += statement.ReplaceQuote(order) - if len(args) > 0 { - statement.orderArgs = append(statement.orderArgs, args...) - } - return statement -} - -// Desc generate `ORDER BY xx DESC` -func (statement *Statement) Desc(colNames ...string) *Statement { - var buf strings.Builder - if len(statement.OrderStr) > 0 { - fmt.Fprint(&buf, statement.OrderStr, ", ") - } - for i, col := range colNames { - if i > 0 { - fmt.Fprint(&buf, ", ") - } - _ = statement.dialect.Quoter().QuoteTo(&buf, col) - fmt.Fprint(&buf, " DESC") - } - statement.OrderStr = buf.String() - return statement -} - -// Asc provide asc order by query condition, the input parameters are columns. -func (statement *Statement) Asc(colNames ...string) *Statement { - var buf strings.Builder - if len(statement.OrderStr) > 0 { - fmt.Fprint(&buf, statement.OrderStr, ", ") - } - for i, col := range colNames { - if i > 0 { - fmt.Fprint(&buf, ", ") - } - _ = statement.dialect.Quoter().QuoteTo(&buf, col) - fmt.Fprint(&buf, " ASC") - } - statement.OrderStr = buf.String() - return statement -} - -// Conds returns condtions -func (statement *Statement) Conds() builder.Cond { - return statement.cond -} - // SetTable tempororily set table name, the parameter could be a string or a pointer of struct func (statement *Statement) SetTable(tableNameOrBean interface{}) error { v := rValue(tableNameOrBean) @@ -543,59 +284,6 @@ func (statement *Statement) SetTable(tableNameOrBean interface{}) error { return nil } -// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { - var buf strings.Builder - if len(statement.JoinStr) > 0 { - fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) - } else { - fmt.Fprintf(&buf, "%v JOIN ", joinOP) - } - - switch tp := tablename.(type) { - case builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.LastError = err - return statement - } - - fields := strings.Split(tp.TableName(), ".") - aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) - aliasName = schemas.CommonQuoter.Trim(aliasName) - - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition)) - statement.joinArgs = append(statement.joinArgs, subQueryArgs...) - case *builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.LastError = err - return statement - } - - fields := strings.Split(tp.TableName(), ".") - aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) - aliasName = schemas.CommonQuoter.Trim(aliasName) - - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition)) - statement.joinArgs = append(statement.joinArgs, subQueryArgs...) - default: - tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) - if !utils.IsSubQuery(tbName) { - var buf strings.Builder - _ = statement.dialect.Quoter().QuoteTo(&buf, tbName) - tbName = buf.String() - } else { - tbName = statement.ReplaceQuote(tbName) - } - fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition)) - } - - statement.JoinStr = buf.String() - statement.joinArgs = append(statement.joinArgs, args...) - return statement -} - // GroupBy generate "Group By keys" statement func (statement *Statement) GroupBy(keys string) *Statement { statement.GroupByStr = statement.ReplaceQuote(keys) @@ -635,47 +323,6 @@ func (statement *Statement) GetUnscoped() bool { return statement.unscoped } -func (statement *Statement) genColumnStr() string { - if statement.RefTable == nil { - return "" - } - - var buf strings.Builder - columns := statement.RefTable.Columns() - - for _, col := range columns { - if statement.OmitColumnMap.Contain(col.Name) { - continue - } - - if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) { - continue - } - - if col.MapType == schemas.ONLYTODB { - continue - } - - if buf.Len() != 0 { - buf.WriteString(", ") - } - - if statement.JoinStr != "" { - if statement.TableAlias != "" { - buf.WriteString(statement.TableAlias) - } else { - buf.WriteString(statement.TableName()) - } - - buf.WriteString(".") - } - - statement.dialect.Quoter().QuoteTo(&buf, col.Name) - } - - return buf.String() -} - // GenIndexSQL generated create index SQL func (statement *Statement) GenIndexSQL() []string { var sqls []string diff --git a/internal/statements/table_name.go b/internal/statements/table_name.go new file mode 100644 index 00000000..c41157ca --- /dev/null +++ b/internal/statements/table_name.go @@ -0,0 +1,20 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +// TableName return current tableName +func (statement *Statement) TableName() string { + if statement.AltTableName != "" { + return statement.AltTableName + } + + return statement.tableName +} + +// Alias set the table alias +func (statement *Statement) Alias(alias string) *Statement { + statement.TableAlias = alias + return statement +} diff --git a/session_update.go b/session_update.go index 66e5d980..b7d0d7ae 100644 --- a/session_update.go +++ b/session_update.go @@ -258,10 +258,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp) case *builder.Builder: - subQuery, subArgs, err := session.statement.GenCondSQL(tp) + subQuery, subArgs, err := builder.ToSQL(tp) if err != nil { return 0, err } + subQuery = session.statement.ReplaceQuote(subQuery) colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")") args = append(args, subArgs...) default: