refactor write insert sql (#2302)

Reviewed-on: https://gitea.com/xorm/xorm/pulls/2302
This commit is contained in:
Lunny Xiao 2023-07-22 15:24:19 +00:00
parent 9988dac44d
commit 6c29ab378e
4 changed files with 118 additions and 33 deletions

View File

@ -293,3 +293,99 @@ func (statement *Statement) GenInsertMultipleMapSQL(columns []string, argss [][]
return buf.String(), buf.Args(), nil
}
func (statement *Statement) writeColumns(w *builder.BytesWriter, slice []string) error {
for i, s := range slice {
if i > 0 {
if _, err := fmt.Fprint(w, ","); err != nil {
return err
}
}
if err := statement.dialect.Quoter().QuoteTo(w.Builder, s); err != nil {
return err
}
}
return nil
}
func (statement *Statement) writeQuestions(w *builder.BytesWriter, length int) error {
for i := 0; i < length; i++ {
if i > 0 {
if _, err := fmt.Fprint(w, ","); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, "?"); err != nil {
return err
}
}
return nil
}
func (statement *Statement) oracleWriteInsertMultiple(w *builder.BytesWriter, tableName string, colNames []string, colMultiPlaces []string) error {
if _, err := fmt.Fprint(w, "INSERT ALL"); err != nil {
return err
}
for _, cols := range colMultiPlaces {
if _, err := fmt.Fprint(w, " INTO "); err != nil {
return err
}
if err := statement.dialect.Quoter().QuoteTo(w.Builder, tableName); err != nil {
return err
}
if _, err := fmt.Fprint(w, " ("); err != nil {
return err
}
if err := statement.writeColumns(w, colNames); err != nil {
return err
}
if _, err := fmt.Fprint(w, ") VALUES ("); err != nil {
return err
}
if _, err := fmt.Fprintf(w, cols, ")"); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, " SELECT 1 FROM DUAL"); err != nil {
return err
}
return nil
}
func (statement *Statement) WriteInsertMultiple(w *builder.BytesWriter, tableName string, colNames []string, colMultiPlaces []string) error {
if statement.dialect.URI().DBType == schemas.ORACLE {
return statement.oracleWriteInsertMultiple(w, tableName, colNames, colMultiPlaces)
}
return statement.plainWriteInsertMultiple(w, tableName, colNames, colMultiPlaces)
}
func (statement *Statement) plainWriteInsertMultiple(w *builder.BytesWriter, tableName string, colNames []string, colMultiPlaces []string) error {
if _, err := fmt.Fprint(w, "INSERT INTO "); err != nil {
return err
}
if err := statement.dialect.Quoter().QuoteTo(w.Builder, tableName); err != nil {
return err
}
if _, err := fmt.Fprint(w, " ("); err != nil {
return err
}
if err := statement.writeColumns(w, colNames); err != nil {
return err
}
if _, err := fmt.Fprint(w, ") VALUES ("); err != nil {
return err
}
for i, cols := range colMultiPlaces {
if _, err := fmt.Fprint(w, cols, ")"); err != nil {
return err
}
if i < len(colMultiPlaces)-1 {
if _, err := fmt.Fprint(w, ",("); err != nil {
return err
}
}
}
return nil
}

View File

@ -36,7 +36,7 @@ func (statement *Statement) writeJoins(w *builder.BytesWriter) error {
func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error {
// write join operator
if _, err := fmt.Fprintf(buf, " %v JOIN", join.op); err != nil {
if _, err := fmt.Fprint(buf, " ", join.op, " JOIN"); err != nil {
return err
}

View File

@ -12,6 +12,7 @@ import (
"strings"
"time"
"xorm.io/builder"
"xorm.io/xorm/convert"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/utils"
@ -156,14 +157,14 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, e
}
args = append(args, val)
var colName = col.Name
colName := col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1)
var colName = col.Name
colName := col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnInt(bean, col, 1)
@ -186,24 +187,12 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, e
}
cleanupProcessorsClosures(&session.beforeClosures)
quoter := session.engine.dialect.Quoter()
var sql string
colStr := quoter.Join(colNames, ",")
if session.engine.dialect.URI().DBType == schemas.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
quoter.Quote(tableName),
colStr)
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
quoter.Quote(tableName),
colStr,
strings.Join(colMultiPlaces, temp))
} else {
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
quoter.Quote(tableName),
colStr,
strings.Join(colMultiPlaces, "),("))
w := builder.NewWriter()
if err := session.statement.WriteInsertMultiple(w, tableName, colNames, colMultiPlaces); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
res, err := session.exec(w.String(), args...)
if err != nil {
return 0, err
}
@ -276,7 +265,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
processor.BeforeInsert()
}
var tableName = session.statement.TableName()
tableName := session.statement.TableName()
table := session.statement.RefTable
colNames, args, err := session.genInsertColumns(bean)
@ -517,7 +506,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
}
args = append(args, val)
var colName = col.Name
colName := col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
@ -547,7 +536,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(m))
columns := make([]string, 0, len(m))
exprs := session.statement.ExprColumns
for k := range m {
if !exprs.IsColExist(k) {
@ -556,7 +545,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
}
sort.Strings(columns)
var args = make([]interface{}, 0, len(m))
args := make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
@ -574,7 +563,7 @@ func (session *Session) insertMultipleMapInterface(maps []map[string]interface{}
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(maps[0]))
columns := make([]string, 0, len(maps[0]))
exprs := session.statement.ExprColumns
for k := range maps[0] {
if !exprs.IsColExist(k) {
@ -583,9 +572,9 @@ func (session *Session) insertMultipleMapInterface(maps []map[string]interface{}
}
sort.Strings(columns)
var argss = make([][]interface{}, 0, len(maps))
argss := make([][]interface{}, 0, len(maps))
for _, m := range maps {
var args = make([]interface{}, 0, len(m))
args := make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
@ -605,7 +594,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(m))
columns := make([]string, 0, len(m))
exprs := session.statement.ExprColumns
for k := range m {
if !exprs.IsColExist(k) {
@ -615,7 +604,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
sort.Strings(columns)
var args = make([]interface{}, 0, len(m))
args := make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
@ -633,7 +622,7 @@ func (session *Session) insertMultipleMapString(maps []map[string]string) (int64
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(maps[0]))
columns := make([]string, 0, len(maps[0]))
exprs := session.statement.ExprColumns
for k := range maps[0] {
if !exprs.IsColExist(k) {
@ -642,9 +631,9 @@ func (session *Session) insertMultipleMapString(maps []map[string]string) (int64
}
sort.Strings(columns)
var argss = make([][]interface{}, 0, len(maps))
argss := make([][]interface{}, 0, len(maps))
for _, m := range maps {
var args = make([]interface{}, 0, len(m))
args := make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}