From bc715dd7aefa4d630698019af25eb87a09442400 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 22 Jul 2023 21:05:03 +0800 Subject: [PATCH] refactor write insert sql --- internal/statements/insert.go | 96 +++++++++++++++++++++++++++++++++++ internal/statements/query.go | 2 +- session_insert.go | 51 ++++++++----------- 3 files changed, 117 insertions(+), 32 deletions(-) diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 187b94a3..9370c984 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -293,3 +293,99 @@ func (statement *Statement) GenInsertMultipleMapSQL(columns []string, argss [][] return buf.String(), buf.Args(), nil } + +func (statement *Statement) writeColumns(w *builder.BytesWriter, slice []string) error { + for i, s := range slice { + if i > 0 { + if _, err := fmt.Fprint(w, ","); err != nil { + return err + } + } + if err := statement.dialect.Quoter().QuoteTo(w.Builder, s); err != nil { + return err + } + } + return nil +} + +func (statement *Statement) writeQuestions(w *builder.BytesWriter, length int) error { + for i := 0; i < length; i++ { + if i > 0 { + if _, err := fmt.Fprint(w, ","); err != nil { + return err + } + } + if _, err := fmt.Fprint(w, "?"); err != nil { + return err + } + } + return nil +} + +func (statement *Statement) oracleWriteInsertMultiple(w *builder.BytesWriter, tableName string, colNames []string, colMultiPlaces []string) error { + if _, err := fmt.Fprint(w, "INSERT ALL"); err != nil { + return err + } + + for _, cols := range colMultiPlaces { + if _, err := fmt.Fprint(w, " INTO "); err != nil { + return err + } + if err := statement.dialect.Quoter().QuoteTo(w.Builder, tableName); err != nil { + return err + } + if _, err := fmt.Fprint(w, " ("); err != nil { + return err + } + if err := statement.writeColumns(w, colNames); err != nil { + return err + } + if _, err := fmt.Fprint(w, ") VALUES ("); err != nil { + return err + } + if _, err := fmt.Fprintf(w, cols, ")"); err != nil { + return err + } + } + + if _, err := fmt.Fprint(w, " SELECT 1 FROM DUAL"); err != nil { + return err + } + return nil +} + +func (statement *Statement) WriteInsertMultiple(w *builder.BytesWriter, tableName string, colNames []string, colMultiPlaces []string) error { + if statement.dialect.URI().DBType == schemas.ORACLE { + return statement.oracleWriteInsertMultiple(w, tableName, colNames, colMultiPlaces) + } + return statement.plainWriteInsertMultiple(w, tableName, colNames, colMultiPlaces) +} + +func (statement *Statement) plainWriteInsertMultiple(w *builder.BytesWriter, tableName string, colNames []string, colMultiPlaces []string) error { + if _, err := fmt.Fprint(w, "INSERT INTO "); err != nil { + return err + } + if err := statement.dialect.Quoter().QuoteTo(w.Builder, tableName); err != nil { + return err + } + if _, err := fmt.Fprint(w, " ("); err != nil { + return err + } + if err := statement.writeColumns(w, colNames); err != nil { + return err + } + if _, err := fmt.Fprint(w, ") VALUES ("); err != nil { + return err + } + for i, cols := range colMultiPlaces { + if _, err := fmt.Fprint(w, cols, ")"); err != nil { + return err + } + if i < len(colMultiPlaces)-1 { + if _, err := fmt.Fprint(w, ",("); err != nil { + return err + } + } + } + return nil +} diff --git a/internal/statements/query.go b/internal/statements/query.go index 63e079e7..211ba268 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -230,7 +230,7 @@ func (statement *Statement) writeDistinct(w builder.Writer) error { } func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr string) error { - if _, err := fmt.Fprintf(w, "SELECT "); err != nil { + if _, err := fmt.Fprintf(w, "SELECT"); err != nil { return err } if err := statement.writeDistinct(w); err != nil { diff --git a/session_insert.go b/session_insert.go index cfa26d39..7003e0f7 100644 --- a/session_insert.go +++ b/session_insert.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "xorm.io/builder" "xorm.io/xorm/convert" "xorm.io/xorm/dialects" "xorm.io/xorm/internal/utils" @@ -156,14 +157,14 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, e } args = append(args, val) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t) }) } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnInt(bean, col, 1) @@ -186,24 +187,12 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, e } cleanupProcessorsClosures(&session.beforeClosures) - quoter := session.engine.dialect.Quoter() - var sql string - colStr := quoter.Join(colNames, ",") - if session.engine.dialect.URI().DBType == schemas.ORACLE { - temp := fmt.Sprintf(") INTO %s (%v) VALUES (", - quoter.Quote(tableName), - colStr) - sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL", - quoter.Quote(tableName), - colStr, - strings.Join(colMultiPlaces, temp)) - } else { - sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)", - quoter.Quote(tableName), - colStr, - strings.Join(colMultiPlaces, "),(")) + w := builder.NewWriter() + if err := session.statement.WriteInsertMultiple(w, tableName, colNames, colMultiPlaces); err != nil { + return 0, err } - res, err := session.exec(sql, args...) + + res, err := session.exec(w.String(), args...) if err != nil { return 0, err } @@ -276,7 +265,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { processor.BeforeInsert() } - var tableName = session.statement.TableName() + tableName := session.statement.TableName() table := session.statement.RefTable colNames, args, err := session.genInsertColumns(bean) @@ -517,7 +506,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } args = append(args, val) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t) @@ -547,7 +536,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err return 0, ErrTableNotFound } - var columns = make([]string, 0, len(m)) + columns := make([]string, 0, len(m)) exprs := session.statement.ExprColumns for k := range m { if !exprs.IsColExist(k) { @@ -556,7 +545,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err } sort.Strings(columns) - var args = make([]interface{}, 0, len(m)) + args := make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } @@ -574,7 +563,7 @@ func (session *Session) insertMultipleMapInterface(maps []map[string]interface{} return 0, ErrTableNotFound } - var columns = make([]string, 0, len(maps[0])) + columns := make([]string, 0, len(maps[0])) exprs := session.statement.ExprColumns for k := range maps[0] { if !exprs.IsColExist(k) { @@ -583,9 +572,9 @@ func (session *Session) insertMultipleMapInterface(maps []map[string]interface{} } sort.Strings(columns) - var argss = make([][]interface{}, 0, len(maps)) + argss := make([][]interface{}, 0, len(maps)) for _, m := range maps { - var args = make([]interface{}, 0, len(m)) + args := make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } @@ -605,7 +594,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { return 0, ErrTableNotFound } - var columns = make([]string, 0, len(m)) + columns := make([]string, 0, len(m)) exprs := session.statement.ExprColumns for k := range m { if !exprs.IsColExist(k) { @@ -615,7 +604,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { sort.Strings(columns) - var args = make([]interface{}, 0, len(m)) + args := make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } @@ -633,7 +622,7 @@ func (session *Session) insertMultipleMapString(maps []map[string]string) (int64 return 0, ErrTableNotFound } - var columns = make([]string, 0, len(maps[0])) + columns := make([]string, 0, len(maps[0])) exprs := session.statement.ExprColumns for k := range maps[0] { if !exprs.IsColExist(k) { @@ -642,9 +631,9 @@ func (session *Session) insertMultipleMapString(maps []map[string]string) (int64 } sort.Strings(columns) - var argss = make([][]interface{}, 0, len(maps)) + argss := make([][]interface{}, 0, len(maps)) for _, m := range maps { - var args = make([]interface{}, 0, len(m)) + args := make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) }