diff --git a/integrations/session_find_test.go b/integrations/session_find_test.go index 5c2a4c68..ff197ebb 100644 --- a/integrations/session_find_test.go +++ b/integrations/session_find_test.go @@ -12,6 +12,7 @@ import ( "xorm.io/xorm" "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" + "xorm.io/xorm/schemas" "github.com/stretchr/testify/assert" ) @@ -1196,3 +1197,41 @@ func TestUpdateFindDate(t *testing.T) { assert.EqualValues(t, 1, len(tufs)) assert.EqualValues(t, tuf.Tm.Format("2006-01-02"), tufs[0].Tm.Format("2006-01-02")) } + +func TestBuilderDialect(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestBuilderDialect struct { + Id int64 + Name string `xorm:"index"` + Age2 int + } + + type TestBuilderDialectFoo struct { + Id int64 + DialectId int64 `xorm:"index"` + Age int + } + + assertSync(t, new(TestBuilderDialect), new(TestBuilderDialectFoo)) + + session := testEngine.NewSession() + defer session.Close() + + var dialect string + switch testEngine.Dialect().URI().DBType { + case schemas.MYSQL: + dialect = builder.MYSQL + case schemas.MSSQL: + dialect = builder.MSSQL + case schemas.POSTGRES: + dialect = builder.POSTGRES + case schemas.SQLITE: + dialect = builder.SQLITE + } + + inner := builder.Dialect(dialect).Select("*").From("test_builder_dialect_foo").Where(builder.Eq{"age": 20}) + result := make([]*TestBuilderDialect, 0, 10) + err := testEngine.Table("test_builder_dialect").Where(builder.Eq{"age2": 2}).Join("INNER", inner, "test_builder_dialect_foo.dialect_id = test_builder_dialect.id").Find(&result) + assert.NoError(t, err) +} diff --git a/internal/statements/join.go b/internal/statements/join.go index adf349e7..79ae2b5f 100644 --- a/internal/statements/join.go +++ b/internal/statements/join.go @@ -15,82 +15,85 @@ import ( ) // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (statement *Statement) Join(joinOP string, tablename interface{}, condition interface{}, args ...interface{}) *Statement { - var buf strings.Builder - if len(statement.JoinStr) > 0 { - fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) - } else { - fmt.Fprintf(&buf, "%v JOIN ", joinOP) - } - - condStr := "" - condArgs := []interface{}{} - switch condTp := condition.(type) { - case string: - condStr = condTp - case builder.Cond: - var err error - condStr, condArgs, err = builder.ToSQL(condTp) - if err != nil { - statement.LastError = err - return statement - } - default: - statement.LastError = fmt.Errorf("unsupported join condition type: %v", condTp) - return statement - } - - switch tp := tablename.(type) { - case builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.LastError = err - return statement - } - - fields := strings.Split(tp.TableName(), ".") - aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) - aliasName = schemas.CommonQuoter.Trim(aliasName) - - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr)) - statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...) - case *builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.LastError = err - return statement - } - - fields := strings.Split(tp.TableName(), ".") - aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) - aliasName = schemas.CommonQuoter.Trim(aliasName) - - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr)) - statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...) - default: - tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) - if !utils.IsSubQuery(tbName) { - var buf strings.Builder - _ = statement.dialect.Quoter().QuoteTo(&buf, tbName) - tbName = buf.String() - } else { - tbName = statement.ReplaceQuote(tbName) - } - fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condStr)) - statement.joinArgs = append(statement.joinArgs, condArgs...) - } - - statement.JoinStr = buf.String() - statement.joinArgs = append(statement.joinArgs, args...) +func (statement *Statement) Join(joinOP string, joinTable interface{}, condition interface{}, args ...interface{}) *Statement { + statement.joins = append(statement.joins, join{ + op: joinOP, + table: joinTable, + condition: condition, + args: args, + }) return statement } -func (statement *Statement) writeJoin(w builder.Writer) error { - if statement.JoinStr != "" { - if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil { +func (statement *Statement) writeJoins(w builder.Writer) error { + for _, join := range statement.joins { + if err := statement.writeJoin(w, join); err != nil { return err } - w.Append(statement.joinArgs...) } return nil } + +func (statement *Statement) writeJoin(buf builder.Writer, join join) error { + // write join operator + if _, err := fmt.Fprintf(buf, " %v JOIN ", join.op); err != nil { + return err + } + + // write join table or subquery + switch tp := join.table.(type) { + case builder.Builder: + fmt.Fprintf(buf, "(") + // statement.ReplaceQuote(subSQL), + if err := tp.WriteTo(buf); err != nil { + return err + } + + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + + fmt.Fprintf(buf, ") %s", statement.quote(aliasName)) + case *builder.Builder: + fmt.Fprintf(buf, "(") + // statement.ReplaceQuote(subSQL), + if err := tp.WriteTo(buf); err != nil { + return err + } + + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + + fmt.Fprintf(buf, ") %s", statement.quote(aliasName)) + default: + tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), join.table, true) + if !utils.IsSubQuery(tbName) { + var sb strings.Builder + _ = statement.dialect.Quoter().QuoteTo(&sb, tbName) + tbName = sb.String() + } else { + tbName = statement.ReplaceQuote(tbName) + } + fmt.Fprintf(buf, tbName) + } + + // write on condition + if _, err := fmt.Fprint(buf, " ON "); err != nil { + return err + } + + switch condTp := join.condition.(type) { + case string: + fmt.Fprint(buf, statement.ReplaceQuote(condTp)) + case builder.Cond: + if err := condTp.WriteTo(buf); err != nil { + return err + } + default: + return fmt.Errorf("unsupported join condition type: %v", condTp) + } + buf.Append(join.args...) + + return nil +} diff --git a/internal/statements/query.go b/internal/statements/query.go index f72c8602..54cb5a40 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -33,7 +33,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int if len(statement.SelectStr) > 0 { columnStr = statement.SelectStr } else { - if statement.JoinStr == "" { + if len(statement.joins) == 0 { if columnStr == "" { if statement.GroupByStr != "" { columnStr = statement.quoteColumnStr(statement.GroupByStr) @@ -108,7 +108,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, columnStr = statement.SelectStr } else { // TODO: always generate column names, not use * even if join - if len(statement.JoinStr) == 0 { + if len(statement.joins) == 0 { if len(columnStr) == 0 { if len(statement.GroupByStr) > 0 { columnStr = statement.quoteColumnStr(statement.GroupByStr) @@ -198,7 +198,7 @@ func (statement *Statement) writeFrom(w builder.Writer) error { if err := statement.writeAlias(w); err != nil { return err } - return statement.writeJoin(w) + return statement.writeJoins(w) } func (statement *Statement) writeLimitOffset(w builder.Writer) error { @@ -263,7 +263,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB } else { column = statement.RefTable.PKColumns()[0].Name } - if statement.needTableName() { + if statement.NeedTableName() { if len(statement.TableAlias) > 0 { column = fmt.Sprintf("%s.%s", statement.TableAlias, column) } else { @@ -402,7 +402,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil { return "", nil, err } - if err := statement.writeJoin(buf); err != nil { + if err := statement.writeJoins(buf); err != nil { return "", nil, err } if statement.Conds().IsValid() { @@ -417,7 +417,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil { return "", nil, err } - if err := statement.writeJoin(buf); err != nil { + if err := statement.writeJoins(buf); err != nil { return "", nil, err } if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { @@ -438,7 +438,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil { return "", nil, err } - if err := statement.writeJoin(buf); err != nil { + if err := statement.writeJoins(buf); err != nil { return "", nil, err } if statement.Conds().IsValid() { @@ -471,7 +471,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa if len(statement.SelectStr) > 0 { columnStr = statement.SelectStr } else { - if statement.JoinStr == "" { + if len(statement.joins) == 0 { if columnStr == "" { if statement.GroupByStr != "" { columnStr = statement.quoteColumnStr(statement.GroupByStr) diff --git a/internal/statements/select.go b/internal/statements/select.go index 2bd2e94d..59161d76 100644 --- a/internal/statements/select.go +++ b/internal/statements/select.go @@ -102,7 +102,7 @@ func (statement *Statement) genColumnStr() string { buf.WriteString(", ") } - if statement.JoinStr != "" { + if len(statement.joins) > 0 { if statement.TableAlias != "" { buf.WriteString(statement.TableAlias) } else { @@ -119,7 +119,7 @@ func (statement *Statement) genColumnStr() string { } func (statement *Statement) colName(col *schemas.Column, tableName string) string { - if statement.needTableName() { + if statement.NeedTableName() { nm := tableName if len(statement.TableAlias) > 0 { nm = statement.TableAlias diff --git a/internal/statements/statement.go b/internal/statements/statement.go index a8fe34fa..516adff1 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -34,6 +34,13 @@ var ( ErrTableNotFound = errors.New("Table not found") ) +type join struct { + op string + table interface{} + condition interface{} + args []interface{} +} + // Statement save all the sql info for executing SQL type Statement struct { RefTable *schemas.Table @@ -45,8 +52,7 @@ type Statement struct { idParam schemas.PK orderStr string orderArgs []interface{} - JoinStr string - joinArgs []interface{} + joins []join GroupByStr string HavingStr string SelectStr string @@ -123,8 +129,7 @@ func (statement *Statement) Reset() { statement.LimitN = nil statement.ResetOrderBy() statement.UseCascade = true - statement.JoinStr = "" - statement.joinArgs = make([]interface{}, 0) + statement.joins = nil statement.GroupByStr = "" statement.HavingStr = "" statement.ColumnMap = columnMap{} @@ -205,8 +210,8 @@ func (statement *Statement) SetRefBean(bean interface{}) error { return nil } -func (statement *Statement) needTableName() bool { - return len(statement.JoinStr) > 0 +func (statement *Statement) NeedTableName() bool { + return len(statement.joins) > 0 } // Incr Generate "Update ... Set column = column + arg" statement @@ -605,7 +610,7 @@ func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, i // MergeConds merge conditions from bean and id func (statement *Statement) MergeConds(bean interface{}) error { if !statement.NoAutoCondition && statement.RefTable != nil { - addedTableName := (len(statement.JoinStr) > 0) + addedTableName := (len(statement.joins) > 0) autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) if err != nil { return err @@ -673,7 +678,7 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName // CondDeleted returns the conditions whether a record is soft deleted. func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { colName := statement.quote(col.Name) - if statement.JoinStr != "" { + if len(statement.joins) > 0 { var prefix string if statement.TableAlias != "" { prefix = statement.TableAlias diff --git a/rows.go b/rows.go index 4801c300..ef3e42b6 100644 --- a/rows.go +++ b/rows.go @@ -46,8 +46,8 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { if rows.session.statement.RawSQL == "" { var autoCond builder.Cond - var addedTableName = (len(session.statement.JoinStr) > 0) - var table = rows.session.statement.RefTable + addedTableName := session.statement.NeedTableName() + table := rows.session.statement.RefTable if !session.statement.NoAutoCondition { var err error @@ -103,12 +103,12 @@ func (rows *Rows) Scan(beans ...interface{}) error { return rows.Err() } - var bean = beans[0] - var tp = reflect.TypeOf(bean) + bean := beans[0] + tp := reflect.TypeOf(bean) if tp.Kind() == reflect.Ptr { tp = tp.Elem() } - var beanKind = tp.Kind() + beanKind := tp.Kind() if len(beans) == 1 { if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { diff --git a/session.go b/session.go index e1a16e5b..af6e4921 100644 --- a/session.go +++ b/session.go @@ -354,7 +354,7 @@ func (session *Session) DB() *core.DB { func (session *Session) canCache() bool { if session.statement.RefTable == nil || - session.statement.JoinStr != "" || + session.statement.NeedTableName() || session.statement.RawSQL != "" || !session.statement.UseCache || session.statement.IsForUpdate || diff --git a/session_find.go b/session_find.go index 3341eafe..d9444aee 100644 --- a/session_find.go +++ b/session_find.go @@ -114,7 +114,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) var ( table = session.statement.RefTable - addedTableName = (len(session.statement.JoinStr) > 0) + addedTableName = session.statement.NeedTableName() autoCond builder.Cond ) if tp == tpStruct {