fix tests

This commit is contained in:
Lunny Xiao 2019-09-22 20:50:00 +08:00
parent 18ce386663
commit e7221d4aff
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
3 changed files with 22 additions and 27 deletions

View File

@ -7,7 +7,6 @@ package xorm
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"reflect" "reflect"
"sort" "sort"
"strconv" "strconv"
@ -355,22 +354,22 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} }
var buf = builder.NewWriter() var buf = builder.NewWriter()
if _, err := io.WriteString(buf, fmt.Sprintf("INSERT INTO %s", session.engine.Quote(tableName))); err != nil { if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s", session.engine.Quote(tableName))); err != nil {
return 0, err return 0, err
} }
if len(colPlaces) <= 0 { if len(colPlaces) <= 0 {
if session.engine.dialect.DBType() == core.MYSQL { if session.engine.dialect.DBType() == core.MYSQL {
if _, err := io.WriteString(buf, " VALUES ()"); err != nil { if _, err := buf.WriteString(" VALUES ()"); err != nil {
return 0, err return 0, err
} }
} else { } else {
if _, err := io.WriteString(buf, fmt.Sprintf("%s DEFAULT VALUES", output)); err != nil { if _, err := buf.WriteString(fmt.Sprintf("%s DEFAULT VALUES", output)); err != nil {
return 0, err return 0, err
} }
} }
} else { } else {
if _, err := io.WriteString(buf, " ("); err != nil { if _, err := buf.WriteString(" ("); err != nil {
return 0, err return 0, err
} }
@ -379,9 +378,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} }
if session.statement.cond.IsValid() { if session.statement.cond.IsValid() {
if _, err := io.WriteString(buf, fmt.Sprintf(")%s SELECT ", if _, err := buf.WriteString(fmt.Sprintf(")%s SELECT ", output)); err != nil {
output,
)); err != nil {
return 0, err return 0, err
} }
@ -398,7 +395,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 0, err return 0, err
} }
if _, err := io.WriteString(buf, fmt.Sprintf(" FROM %v WHERE ", session.engine.Quote(tableName))); err != nil { if _, err := buf.WriteString(fmt.Sprintf(" FROM %v WHERE ", session.engine.Quote(tableName))); err != nil {
return 0, err return 0, err
} }
@ -408,7 +405,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} else { } else {
buf.Append(args...) buf.Append(args...)
if _, err := io.WriteString(buf, fmt.Sprintf(")%s VALUES (%v", if _, err := buf.WriteString(fmt.Sprintf(")%s VALUES (%v",
output, output,
colPlaces)); err != nil { colPlaces)); err != nil {
return 0, err return 0, err
@ -418,21 +415,21 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 0, err return 0, err
} }
if _, err := io.WriteString(buf, ")"); err != nil { if _, err := buf.WriteString(")"); err != nil {
return 0, err return 0, err
} }
} }
} }
if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES {
if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil {
return 0, err
}
}
sqlStr := buf.String() sqlStr := buf.String()
args = buf.Args() args = buf.Args()
if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES {
if _, err := io.WriteString(buf, " RETURNING "+session.engine.Quote(table.AutoIncrement)); err != nil {
return 0, err
}
}
handleAfterInsertProcessorFunc := func(bean interface{}) { handleAfterInsertProcessorFunc := func(bean interface{}) {
if session.isAutoCommit { if session.isAutoCommit {
for _, closure := range session.afterClosures { for _, closure := range session.afterClosures {

View File

@ -6,7 +6,6 @@ package xorm
import ( import (
"fmt" "fmt"
"io"
"xorm.io/builder" "xorm.io/builder"
) )
@ -47,20 +46,20 @@ func writeArgs(w *builder.BytesWriter, args []interface{}) error {
func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error { func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error {
for i, colName := range cols { for i, colName := range cols {
if len(leftQuote) > 0 && colName[0] != '`' { if len(leftQuote) > 0 && colName[0] != '`' {
if _, err := io.WriteString(w, leftQuote); err != nil { if _, err := w.WriteString(leftQuote); err != nil {
return err return err
} }
} }
if _, err := io.WriteString(w, colName); err != nil { if _, err := w.WriteString(colName); err != nil {
return err return err
} }
if len(rightQuote) > 0 && colName[len(colName)-1] != '`' { if len(rightQuote) > 0 && colName[len(colName)-1] != '`' {
if _, err := io.WriteString(w, rightQuote); err != nil { if _, err := w.WriteString(rightQuote); err != nil {
return err return err
} }
} }
if i+1 != len(cols) { if i+1 != len(cols) {
if _, err := io.WriteString(w, ","); err != nil { if _, err := w.WriteString(","); err != nil {
return err return err
} }
} }

View File

@ -6,7 +6,6 @@ package xorm
import ( import (
"fmt" "fmt"
"io"
"strings" "strings"
"xorm.io/builder" "xorm.io/builder"
@ -57,7 +56,7 @@ func (exprs *exprParams) getByName(colName string) (exprParam, bool) {
return exprParam{}, false return exprParam{}, false
} }
func (exprs *exprParams) writeArgs(w builder.Writer) error { func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
for _, expr := range exprs.args { for _, expr := range exprs.args {
switch arg := expr.(type) { switch arg := expr.(type) {
case *builder.Builder: case *builder.Builder:
@ -65,7 +64,7 @@ func (exprs *exprParams) writeArgs(w builder.Writer) error {
return err return err
} }
default: default:
if _, err := io.WriteString(w, fmt.Sprintf("%v", arg)); err != nil { if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil {
return err return err
} }
} }
@ -73,12 +72,12 @@ func (exprs *exprParams) writeArgs(w builder.Writer) error {
return nil return nil
} }
func (exprs *exprParams) writeNameArgs(w builder.Writer) error { func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
for i, colName := range exprs.colNames { for i, colName := range exprs.colNames {
if _, err := io.WriteString(w, colName); err != nil { if _, err := w.WriteString(colName); err != nil {
return err return err
} }
if _, err := io.WriteString(w, "="); err != nil { if _, err := w.WriteString("="); err != nil {
return err return err
} }
@ -92,7 +91,7 @@ func (exprs *exprParams) writeNameArgs(w builder.Writer) error {
} }
if i+1 != len(exprs.colNames) { if i+1 != len(exprs.colNames) {
if _, err := io.WriteString(w, ","); err != nil { if _, err := w.WriteString(","); err != nil {
return err return err
} }
} }