From 3ce08ca6878f1f16be5bfeba288236afb5fa06af Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 3 Nov 2020 16:00:55 +0800 Subject: [PATCH] refactor exprParam --- internal/statements/expr_param.go | 63 ++++++++++++++++--------------- internal/statements/insert.go | 12 +++--- session_update.go | 24 ++++++------ 3 files changed, 50 insertions(+), 49 deletions(-) diff --git a/internal/statements/expr_param.go b/internal/statements/expr_param.go index d0c355d3..51fc34fb 100644 --- a/internal/statements/expr_param.go +++ b/internal/statements/expr_param.go @@ -21,46 +21,47 @@ func (err ErrUnsupportedExprType) Error() string { return fmt.Sprintf("Unsupported expression type: %v", err.tp) } -type exprParam struct { - colName string - arg interface{} +// Expr represents an SQL express +type Expr struct { + ColName string + Arg interface{} } -type exprParams struct { - ColNames []string - Args []interface{} +type exprParams []Expr + +func (exprs exprParams) ColNames() []string { + var cols = make([]string, 0, len(exprs)) + for _, expr := range exprs { + cols = append(cols, expr.ColName) + } + return cols } -func (exprs *exprParams) Len() int { - return len(exprs.ColNames) +func (exprs exprParams) addParam(colName string, arg interface{}) { + exprs = append(exprs, Expr{colName, arg}) } -func (exprs *exprParams) addParam(colName string, arg interface{}) { - exprs.ColNames = append(exprs.ColNames, colName) - exprs.Args = append(exprs.Args, arg) -} - -func (exprs *exprParams) IsColExist(colName string) bool { - for _, name := range exprs.ColNames { - if strings.EqualFold(schemas.CommonQuoter.Trim(name), schemas.CommonQuoter.Trim(colName)) { +func (exprs exprParams) IsColExist(colName string) bool { + for _, expr := range exprs { + if strings.EqualFold(schemas.CommonQuoter.Trim(expr.ColName), schemas.CommonQuoter.Trim(colName)) { return true } } return false } -func (exprs *exprParams) getByName(colName string) (exprParam, bool) { - for i, name := range exprs.ColNames { - if strings.EqualFold(name, colName) { - return exprParam{name, exprs.Args[i]}, true +func (exprs exprParams) getByName(colName string) (Expr, bool) { + for _, expr := range exprs { + if strings.EqualFold(expr.ColName, colName) { + return expr, true } } - return exprParam{}, false + return Expr{}, false } -func (exprs *exprParams) WriteArgs(w *builder.BytesWriter) error { - for i, expr := range exprs.Args { - switch arg := expr.(type) { +func (exprs exprParams) WriteArgs(w *builder.BytesWriter) error { + for i, expr := range exprs { + switch arg := expr.Arg.(type) { case *builder.Builder: if _, err := w.WriteString("("); err != nil { return err @@ -84,7 +85,7 @@ func (exprs *exprParams) WriteArgs(w *builder.BytesWriter) error { } w.Append(arg) } - if i != len(exprs.Args)-1 { + if i != len(exprs)-1 { if _, err := w.WriteString(","); err != nil { return err } @@ -93,16 +94,16 @@ func (exprs *exprParams) WriteArgs(w *builder.BytesWriter) error { return nil } -func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { - for i, colName := range exprs.ColNames { - if _, err := w.WriteString(colName); err != nil { +func (exprs exprParams) writeNameArgs(w *builder.BytesWriter) error { + for i, expr := range exprs { + if _, err := w.WriteString(expr.ColName); err != nil { return err } if _, err := w.WriteString("="); err != nil { return err } - switch arg := exprs.Args[i].(type) { + switch arg := expr.Arg.(type) { case *builder.Builder: if _, err := w.WriteString("("); err != nil { return err @@ -114,10 +115,10 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { return err } default: - w.Append(exprs.Args[i]) + w.Append(expr.Arg) } - if i+1 != len(exprs.ColNames) { + if i+1 != len(exprs) { if _, err := w.WriteString(","); err != nil { return err } diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 6cbbbeda..367dbdc9 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -59,7 +59,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames...), ","); err != nil { + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames()...), ","); err != nil { return "", nil, err } @@ -79,7 +79,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if len(exprs.Args) > 0 { + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err } @@ -112,7 +112,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if len(exprs.Args) > 0 { + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err } @@ -152,7 +152,7 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{} return "", nil, err } - if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames...), ","); err != nil { + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil { return "", nil, err } @@ -166,7 +166,7 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{} return "", nil, err } - if len(exprs.Args) > 0 { + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err } @@ -190,7 +190,7 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{} return "", nil, err } - if len(exprs.Args) > 0 { + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err } diff --git a/session_update.go b/session_update.go index 0adac25e..9e4cddb1 100644 --- a/session_update.go +++ b/session_update.go @@ -224,35 +224,35 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 // for update action to like "column = column + ?" incColumns := session.statement.IncrColumns - for i, colName := range incColumns.ColNames { - colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?") - args = append(args, incColumns.Args[i]) + for _, expr := range incColumns { + colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?") + args = append(args, expr.Arg) } // for update action to like "column = column - ?" decColumns := session.statement.DecrColumns - for i, colName := range decColumns.ColNames { - colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?") - args = append(args, decColumns.Args[i]) + for _, expr := range decColumns { + colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?") + args = append(args, expr.Arg) } // for update action to like "column = expression" exprColumns := session.statement.ExprColumns - for i, colName := range exprColumns.ColNames { - switch tp := exprColumns.Args[i].(type) { + for _, expr := range exprColumns { + switch tp := expr.Arg.(type) { case string: if len(tp) == 0 { tp = "''" } - colNames = append(colNames, session.engine.Quote(colName)+"="+tp) + colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp) case *builder.Builder: subQuery, subArgs, err := session.statement.GenCondSQL(tp) if err != nil { return 0, err } - colNames = append(colNames, session.engine.Quote(colName)+"=("+subQuery+")") + colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")") args = append(args, subArgs...) default: - colNames = append(colNames, session.engine.Quote(colName)+"=?") - args = append(args, exprColumns.Args[i]) + colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?") + args = append(args, expr.Arg) } }