From fd1f887a984e3ea99c4bbdb2bd61a7315a8e3b2f Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 11 Jul 2023 12:36:19 +0800 Subject: [PATCH] more refactor --- internal/statements/join.go | 116 +++++++++++++++---------------- internal/statements/query.go | 16 ++--- internal/statements/select.go | 4 +- internal/statements/statement.go | 23 +++--- rows.go | 11 ++- session.go | 2 +- session_find.go | 7 +- session_insert.go | 28 ++++---- 8 files changed, 104 insertions(+), 103 deletions(-) diff --git a/internal/statements/join.go b/internal/statements/join.go index adf349e7..9cf13b59 100644 --- a/internal/statements/join.go +++ b/internal/statements/join.go @@ -11,64 +11,46 @@ import ( "xorm.io/builder" "xorm.io/xorm/dialects" "xorm.io/xorm/internal/utils" - "xorm.io/xorm/schemas" ) // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (statement *Statement) Join(joinOP string, tablename interface{}, condition interface{}, args ...interface{}) *Statement { - var buf strings.Builder - if len(statement.JoinStr) > 0 { - fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) - } else { - fmt.Fprintf(&buf, "%v JOIN ", joinOP) - } +func (statement *Statement) Join(joinOP string, joinTable interface{}, condition interface{}, args ...interface{}) *Statement { + statement.joins = append(statement.joins, join{ + op: joinOP, + table: joinTable, + condition: condition, + args: args, + }) + return statement +} - condStr := "" - condArgs := []interface{}{} - switch condTp := condition.(type) { - case string: - condStr = condTp - case builder.Cond: - var err error - condStr, condArgs, err = builder.ToSQL(condTp) - if err != nil { - statement.LastError = err - return statement +func (statement *Statement) writeJoins(w builder.Writer) error { + for _, join := range statement.joins { + if err := statement.writeJoin(w, join); err != nil { + return err } - default: - statement.LastError = fmt.Errorf("unsupported join condition type: %v", condTp) - return statement + } + return nil +} + +func (statement *Statement) writeJoin(buf builder.Writer, join join) error { + // write join operator + if _, err := fmt.Fprintf(buf, " %v JOIN ", join.op); err != nil { + return err } - switch tp := tablename.(type) { + // write table or sub query + switch tp := join.table.(type) { case builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.LastError = err - return statement + if err := tp.WriteTo(buf); err != nil { + return err } - - fields := strings.Split(tp.TableName(), ".") - aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) - aliasName = schemas.CommonQuoter.Trim(aliasName) - - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr)) - statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...) case *builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.LastError = err - return statement + if err := tp.WriteTo(buf); err != nil { + return err } - - fields := strings.Split(tp.TableName(), ".") - aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) - aliasName = schemas.CommonQuoter.Trim(aliasName) - - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr)) - statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...) default: - tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) + tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), join.table, true) if !utils.IsSubQuery(tbName) { var buf strings.Builder _ = statement.dialect.Quoter().QuoteTo(&buf, tbName) @@ -76,21 +58,37 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition } else { tbName = statement.ReplaceQuote(tbName) } - fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condStr)) - statement.joinArgs = append(statement.joinArgs, condArgs...) - } - - statement.JoinStr = buf.String() - statement.joinArgs = append(statement.joinArgs, args...) - return statement -} - -func (statement *Statement) writeJoin(w builder.Writer) error { - if statement.JoinStr != "" { - if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil { + if _, err := fmt.Fprint(buf, tbName); err != nil { return err } - w.Append(statement.joinArgs...) } + + // write alias FIXME + /*fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + if _, err := fmt.Fprint(buf, " ", statement.quote(aliasName)); err != nil { + return err + }*/ + + // write condition + if _, err := fmt.Fprint(buf, " ON "); err != nil { + return err + } + + switch condTp := join.condition.(type) { + case string: + if _, err := fmt.Fprint(buf, condTp); err != nil { + return err + } + buf.Append(join.args...) + case builder.Cond: + if err := condTp.WriteTo(buf); err != nil { + return err + } + default: + return fmt.Errorf("unsupported join condition type: %v", condTp) + } + return nil } diff --git a/internal/statements/query.go b/internal/statements/query.go index 8c8a0b27..ec1bacd0 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -34,7 +34,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int if len(statement.SelectStr) > 0 { columnStr = statement.SelectStr } else { - if statement.JoinStr == "" { + if len(statement.joins) == 0 { if columnStr == "" { if statement.GroupByStr != "" { columnStr = statement.quoteColumnStr(statement.GroupByStr) @@ -109,7 +109,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, columnStr = statement.SelectStr } else { // TODO: always generate column names, not use * even if join - if len(statement.JoinStr) == 0 { + if len(statement.joins) == 0 { if len(columnStr) == 0 { if len(statement.GroupByStr) > 0 { columnStr = statement.quoteColumnStr(statement.GroupByStr) @@ -199,7 +199,7 @@ func (statement *Statement) writeFrom(w builder.Writer) error { if err := statement.writeAlias(w); err != nil { return err } - return statement.writeJoin(w) + return statement.writeJoins(w) } func (statement *Statement) writeLimitOffset(w builder.Writer) error { @@ -296,7 +296,7 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err } else { column = statement.RefTable.PKColumns()[0].Name } - if statement.needTableName() { + if statement.NeedTableName() { if len(statement.TableAlias) > 0 { column = fmt.Sprintf("%s.%s", statement.TableAlias, column) } else { @@ -442,7 +442,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil { return "", nil, err } - if err := statement.writeJoin(buf); err != nil { + if err := statement.writeJoins(buf); err != nil { return "", nil, err } if statement.Conds().IsValid() { @@ -457,7 +457,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil { return "", nil, err } - if err := statement.writeJoin(buf); err != nil { + if err := statement.writeJoins(buf); err != nil { return "", nil, err } if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { @@ -478,7 +478,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil { return "", nil, err } - if err := statement.writeJoin(buf); err != nil { + if err := statement.writeJoins(buf); err != nil { return "", nil, err } if statement.Conds().IsValid() { @@ -503,7 +503,7 @@ func (statement *Statement) genSelect() string { } columnStr := statement.ColumnStr() - if statement.JoinStr == "" { + if len(statement.joins) == 0 { if columnStr == "" { if statement.GroupByStr != "" { columnStr = statement.quoteColumnStr(statement.GroupByStr) diff --git a/internal/statements/select.go b/internal/statements/select.go index 2bd2e94d..59161d76 100644 --- a/internal/statements/select.go +++ b/internal/statements/select.go @@ -102,7 +102,7 @@ func (statement *Statement) genColumnStr() string { buf.WriteString(", ") } - if statement.JoinStr != "" { + if len(statement.joins) > 0 { if statement.TableAlias != "" { buf.WriteString(statement.TableAlias) } else { @@ -119,7 +119,7 @@ func (statement *Statement) genColumnStr() string { } func (statement *Statement) colName(col *schemas.Column, tableName string) string { - if statement.needTableName() { + if statement.NeedTableName() { nm := tableName if len(statement.TableAlias) > 0 { nm = statement.TableAlias diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 17f5ae1f..00cafb6b 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -34,6 +34,13 @@ var ( ErrTableNotFound = errors.New("Table not found") ) +type join struct { + op string + table interface{} + condition interface{} + args []interface{} +} + // Statement save all the sql info for executing SQL type Statement struct { RefTable *schemas.Table @@ -45,8 +52,7 @@ type Statement struct { idParam schemas.PK orderStr string orderArgs []interface{} - JoinStr string - joinArgs []interface{} + joins []join GroupByStr string HavingStr string SelectStr string @@ -123,8 +129,7 @@ func (statement *Statement) Reset() { statement.LimitN = nil statement.ResetOrderBy() statement.UseCascade = true - statement.JoinStr = "" - statement.joinArgs = make([]interface{}, 0) + statement.joins = nil statement.GroupByStr = "" statement.HavingStr = "" statement.ColumnMap = columnMap{} @@ -205,8 +210,9 @@ func (statement *Statement) SetRefBean(bean interface{}) error { return nil } -func (statement *Statement) needTableName() bool { - return len(statement.JoinStr) > 0 +// NeedTableName returns true if need table name before column names +func (statement *Statement) NeedTableName() bool { + return len(statement.joins) > 0 } // Incr Generate "Update ... Set column = column + arg" statement @@ -605,8 +611,7 @@ func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, i // MergeConds merge conditions from bean and id func (statement *Statement) MergeConds(bean interface{}) error { if !statement.NoAutoCondition && statement.RefTable != nil { - addedTableName := (len(statement.JoinStr) > 0) - autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) + autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, statement.NeedTableName()) if err != nil { return err } @@ -673,7 +678,7 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName // CondDeleted returns the conditions whether a record is soft deleted. func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { colName := statement.quote(col.Name) - if statement.JoinStr != "" { + if len(statement.joins) > 0 { var prefix string if statement.TableAlias != "" { prefix = statement.TableAlias diff --git a/rows.go b/rows.go index 4801c300..8d613763 100644 --- a/rows.go +++ b/rows.go @@ -46,12 +46,11 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { if rows.session.statement.RawSQL == "" { var autoCond builder.Cond - var addedTableName = (len(session.statement.JoinStr) > 0) - var table = rows.session.statement.RefTable + table := rows.session.statement.RefTable if !session.statement.NoAutoCondition { var err error - autoCond, err = session.statement.BuildConds(table, bean, true, true, false, true, addedTableName) + autoCond, err = session.statement.BuildConds(table, bean, true, true, false, true, session.statement.NeedTableName()) if err != nil { return nil, err } @@ -103,12 +102,12 @@ func (rows *Rows) Scan(beans ...interface{}) error { return rows.Err() } - var bean = beans[0] - var tp = reflect.TypeOf(bean) + bean := beans[0] + tp := reflect.TypeOf(bean) if tp.Kind() == reflect.Ptr { tp = tp.Elem() } - var beanKind = tp.Kind() + beanKind := tp.Kind() if len(beans) == 1 { if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { diff --git a/session.go b/session.go index e1a16e5b..af6e4921 100644 --- a/session.go +++ b/session.go @@ -354,7 +354,7 @@ func (session *Session) DB() *core.DB { func (session *Session) canCache() bool { if session.statement.RefTable == nil || - session.statement.JoinStr != "" || + session.statement.NeedTableName() || session.statement.RawSQL != "" || !session.statement.UseCache || session.statement.IsForUpdate || diff --git a/session_find.go b/session_find.go index 3341eafe..40890764 100644 --- a/session_find.go +++ b/session_find.go @@ -113,9 +113,8 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } var ( - table = session.statement.RefTable - addedTableName = (len(session.statement.JoinStr) > 0) - autoCond builder.Cond + table = session.statement.RefTable + autoCond builder.Cond ) if tp == tpStruct { if !session.statement.NoAutoCondition && len(condiBean) > 0 { @@ -123,7 +122,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) if err != nil { return err } - autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, addedTableName) + autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, session.statement.NeedTableName()) if err != nil { return err } diff --git a/session_insert.go b/session_insert.go index cfa26d39..3c6a9406 100644 --- a/session_insert.go +++ b/session_insert.go @@ -156,14 +156,14 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, e } args = append(args, val) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t) }) } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnInt(bean, col, 1) @@ -276,7 +276,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { processor.BeforeInsert() } - var tableName = session.statement.TableName() + tableName := session.statement.TableName() table := session.statement.RefTable colNames, args, err := session.genInsertColumns(bean) @@ -517,7 +517,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } args = append(args, val) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t) @@ -547,7 +547,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err return 0, ErrTableNotFound } - var columns = make([]string, 0, len(m)) + columns := make([]string, 0, len(m)) exprs := session.statement.ExprColumns for k := range m { if !exprs.IsColExist(k) { @@ -556,7 +556,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err } sort.Strings(columns) - var args = make([]interface{}, 0, len(m)) + args := make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } @@ -574,7 +574,7 @@ func (session *Session) insertMultipleMapInterface(maps []map[string]interface{} return 0, ErrTableNotFound } - var columns = make([]string, 0, len(maps[0])) + columns := make([]string, 0, len(maps[0])) exprs := session.statement.ExprColumns for k := range maps[0] { if !exprs.IsColExist(k) { @@ -583,9 +583,9 @@ func (session *Session) insertMultipleMapInterface(maps []map[string]interface{} } sort.Strings(columns) - var argss = make([][]interface{}, 0, len(maps)) + argss := make([][]interface{}, 0, len(maps)) for _, m := range maps { - var args = make([]interface{}, 0, len(m)) + args := make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } @@ -605,7 +605,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { return 0, ErrTableNotFound } - var columns = make([]string, 0, len(m)) + columns := make([]string, 0, len(m)) exprs := session.statement.ExprColumns for k := range m { if !exprs.IsColExist(k) { @@ -615,7 +615,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { sort.Strings(columns) - var args = make([]interface{}, 0, len(m)) + args := make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } @@ -633,7 +633,7 @@ func (session *Session) insertMultipleMapString(maps []map[string]string) (int64 return 0, ErrTableNotFound } - var columns = make([]string, 0, len(maps[0])) + columns := make([]string, 0, len(maps[0])) exprs := session.statement.ExprColumns for k := range maps[0] { if !exprs.IsColExist(k) { @@ -642,9 +642,9 @@ func (session *Session) insertMultipleMapString(maps []map[string]string) (int64 } sort.Strings(columns) - var argss = make([][]interface{}, 0, len(maps)) + argss := make([][]interface{}, 0, len(maps)) for _, m := range maps { - var args = make([]interface{}, 0, len(m)) + args := make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) }