From 2ac051f07535be24618a7ba3f0f968e6092cf359 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 27 Mar 2020 03:13:25 +0000 Subject: [PATCH] Improve insert map generating SQL (#1634) Fix writeArg Improve insert map generating SQL Reviewed-on: https://gitea.com/xorm/xorm/pulls/1634 --- internal/statements/insert.go | 108 ++++++++++++++++++++------ internal/statements/statement_args.go | 32 +++----- session_insert.go | 70 +---------------- 3 files changed, 98 insertions(+), 112 deletions(-) diff --git a/internal/statements/insert.go b/internal/statements/insert.go index db2fc91c..6cbbbeda 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -5,6 +5,7 @@ package statements import ( + "fmt" "strings" "xorm.io/builder" @@ -23,18 +24,15 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem return nil } +// GenInsertSQL generates insert beans SQL func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (string, []interface{}, error) { var ( + buf = builder.NewWriter() + exprs = statement.ExprColumns table = statement.RefTable tableName = statement.TableName() - exprs = statement.ExprColumns - colPlaces = strings.Repeat("?, ", len(colNames)) ) - if exprs.Len() <= 0 && len(colPlaces) > 0 { - colPlaces = colPlaces[0 : len(colPlaces)-2] - } - var buf = builder.NewWriter() if _, err := buf.WriteString("INSERT INTO "); err != nil { return "", nil, err } @@ -43,7 +41,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if len(colPlaces) <= 0 { + if len(colNames) <= 0 { if statement.dialect.URI().DBType == schemas.MYSQL { if _, err := buf.WriteString(" VALUES ()"); err != nil { return "", nil, err @@ -65,13 +63,14 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } + if _, err := buf.WriteString(")"); err != nil { + return "", nil, err + } + if err := statement.writeInsertOutput(buf.Builder, table); err != nil { + return "", nil, err + } + if statement.Conds().IsValid() { - if _, err := buf.WriteString(")"); err != nil { - return "", nil, err - } - if err := statement.writeInsertOutput(buf.Builder, table); err != nil { - return "", nil, err - } if _, err := buf.WriteString(" SELECT "); err != nil { return "", nil, err } @@ -105,21 +104,20 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } } else { - buf.Append(args...) - - if _, err := buf.WriteString(")"); err != nil { - return "", nil, err - } - if err := statement.writeInsertOutput(buf.Builder, table); err != nil { - return "", nil, err - } if _, err := buf.WriteString(" VALUES ("); err != nil { return "", nil, err } - if _, err := buf.WriteString(colPlaces); err != nil { + + if err := statement.WriteArgs(buf, args); err != nil { return "", nil, err } + if len(exprs.Args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + } + if err := exprs.WriteArgs(buf); err != nil { return "", nil, err } @@ -141,3 +139,69 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return buf.String(), buf.Args(), nil } + +// GenInsertMapSQL generates insert map SQL +func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{}) (string, []interface{}, error) { + var ( + buf = builder.NewWriter() + exprs = statement.ExprColumns + tableName = statement.TableName() + ) + + if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s (", statement.quote(tableName))); err != nil { + return "", nil, err + } + + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames...), ","); err != nil { + return "", nil, err + } + + // if insert where + if statement.Conds().IsValid() { + if _, err := buf.WriteString(") SELECT "); err != nil { + return "", nil, err + } + + if err := statement.WriteArgs(buf, args); err != nil { + return "", nil, err + } + + if len(exprs.Args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } + } + + if _, err := buf.WriteString(fmt.Sprintf(" FROM %s WHERE ", statement.quote(tableName))); err != nil { + return "", nil, err + } + + if err := statement.Conds().WriteTo(buf); err != nil { + return "", nil, err + } + } else { + if _, err := buf.WriteString(") VALUES ("); err != nil { + return "", nil, err + } + if err := statement.WriteArgs(buf, args); err != nil { + return "", nil, err + } + + if len(exprs.Args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } + } + if _, err := buf.WriteString(")"); err != nil { + return "", nil, err + } + } + + return buf.String(), buf.Args(), nil +} diff --git a/internal/statements/statement_args.go b/internal/statements/statement_args.go index 7d1ef9eb..dc14467d 100644 --- a/internal/statements/statement_args.go +++ b/internal/statements/statement_args.go @@ -79,28 +79,6 @@ const insertSelectPlaceHolder = true func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error { switch argv := arg.(type) { - case bool: - if statement.dialect.URI().DBType == schemas.MSSQL { - if argv { - if _, err := w.WriteString("1"); err != nil { - return err - } - } else { - if _, err := w.WriteString("0"); err != nil { - return err - } - } - } else { - if argv { - if _, err := w.WriteString("true"); err != nil { - return err - } - } else { - if _, err := w.WriteString("false"); err != nil { - return err - } - } - } case *builder.Builder: if _, err := w.WriteString("("); err != nil { return err @@ -116,7 +94,15 @@ func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) er if err := w.WriteByte('?'); err != nil { return err } - w.Append(arg) + if v, ok := arg.(bool); ok && statement.dialect.URI().DBType == schemas.MSSQL { + if v { + w.Append(1) + } else { + w.Append(0) + } + } else { + w.Append(arg) + } } else { var convertFunc = convertStringSingleQuote if statement.dialect.URI().DBType == schemas.MYSQL { diff --git a/session_insert.go b/session_insert.go index 1270d5db..2c46a59b 100644 --- a/session_insert.go +++ b/session_insert.go @@ -12,7 +12,6 @@ import ( "strconv" "strings" - "xorm.io/builder" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -623,74 +622,11 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, return 0, ErrTableNotFound } - exprs := session.statement.ExprColumns - w := builder.NewWriter() - // if insert where - if session.statement.Conds().IsValid() { - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil { - return 0, err - } - - if _, err := w.WriteString(") SELECT "); err != nil { - return 0, err - } - - if err := session.statement.WriteArgs(w, args); err != nil { - return 0, err - } - - if len(exprs.Args) > 0 { - if _, err := w.WriteString(","); err != nil { - return 0, err - } - if err := exprs.WriteArgs(w); err != nil { - return 0, err - } - } - - if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := session.statement.Conds().WriteTo(w); err != nil { - return 0, err - } - } else { - qm := strings.Repeat("?,", len(columns)) - qm = qm[:len(qm)-1] - - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil { - return 0, err - } - if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil { - return 0, err - } - - w.Append(args...) - if len(exprs.Args) > 0 { - if _, err := w.WriteString(","); err != nil { - return 0, err - } - if err := exprs.WriteArgs(w); err != nil { - return 0, err - } - } - if _, err := w.WriteString(")"); err != nil { - return 0, err - } + sql, args, err := session.statement.GenInsertMapSQL(columns, args) + if err != nil { + return 0, err } - sql := w.String() - args = w.Args() - if err := session.cacheInsert(tableName); err != nil { return 0, err }