From d6963b7d423db56a4a07b8e0737af98317028bb0 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 28 Sep 2019 18:02:19 +0800 Subject: [PATCH] fix arg conversion (#1441) * fix arg conversion * fix bugs * fix bug on postgres * use traditional positional parameters on insert into select * remove unnecessary tests --- go.mod | 2 +- go.sum | 5 +++++ session_insert_test.go | 24 ++++++++++++++++++++++++ statement_args.go | 38 ++++++++++++++++++++++++++++++++++---- statement_exprparam.go | 7 ++++++- 5 files changed, 70 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index c30d6df5..1ab39831 100644 --- a/go.mod +++ b/go.mod @@ -16,5 +16,5 @@ require ( github.com/stretchr/testify v1.4.0 github.com/ziutek/mymysql v1.5.4 xorm.io/builder v0.3.6 - xorm.io/core v0.7.1 + xorm.io/core v0.7.2-0.20190928055935-90aeac8d08eb ) diff --git a/go.sum b/go.sum index 9ca37897..cf637a8e 100644 --- a/go.sum +++ b/go.sum @@ -89,12 +89,14 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs= github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmVhcMLU6v5fEb/ok4wyqtNU= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -155,6 +157,7 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= @@ -165,3 +168,5 @@ xorm.io/core v0.7.0 h1:hKxuOKWZNeiFQsSuGet/KV8HZ788hclvAl+7azx3tkM= xorm.io/core v0.7.0/go.mod h1:TuOJjIVa7e3w/rN8tDcAvuLBMtwzdHPbyOzE6Gk1EUI= xorm.io/core v0.7.1 h1:I6x6Q6dYb67aDEoYFWr2t8UcKIYjJPyCHS+aXuj5V0Y= xorm.io/core v0.7.1/go.mod h1:jJfd0UAEzZ4t87nbQYtVjmqpIODugN6PD2D9E+dJvdM= +xorm.io/core v0.7.2-0.20190928055935-90aeac8d08eb h1:msX3zG3BPl8Ti+LDzP33/9K7BzO/WqFXk610K1kYKfo= +xorm.io/core v0.7.2-0.20190928055935-90aeac8d08eb/go.mod h1:jJfd0UAEzZ4t87nbQYtVjmqpIODugN6PD2D9E+dJvdM= diff --git a/session_insert_test.go b/session_insert_test.go index 3dcc87d0..2785401d 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -896,6 +896,7 @@ func TestInsertWhere(t *testing.T) { inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). + SetExpr("repo_id", "1"). Insert(map[string]string{ "name": "trest3", }) @@ -917,6 +918,29 @@ func TestInsertWhere(t *testing.T) { }) assert.NoError(t, err) assert.EqualValues(t, 1, inserted) + + var j4 InsertWhere + has, err = testEngine.ID(4).Get(&j4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "10';delete * from insert_where; --", j4.Name) + assert.EqualValues(t, 4, j4.Index) + + inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). + Insert(map[string]interface{}{ + "repo_id": 1, + "name": "10\\';delete * from insert_where; --", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + + var j5 InsertWhere + has, err = testEngine.ID(5).Get(&j5) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "10\\';delete * from insert_where; --", j5.Name) + assert.EqualValues(t, 5, j5.Index) } type NightlyRate struct { diff --git a/statement_args.go b/statement_args.go index 23496443..310f24d6 100644 --- a/statement_args.go +++ b/statement_args.go @@ -49,15 +49,34 @@ func quoteNeeded(a interface{}) bool { return true } -func convertArg(arg interface{}) string { +func convertStringSingleQuote(arg string) string { + return "'" + strings.Replace(arg, "'", "''", -1) + "'" +} + +func convertString(arg string) string { + var buf strings.Builder + buf.WriteRune('\'') + for _, c := range arg { + if c == '\\' || c == '\'' { + buf.WriteRune('\\') + } + buf.WriteRune(c) + } + buf.WriteRune('\'') + return buf.String() +} + +func convertArg(arg interface{}, convertFunc func(string) string) string { if quoteNeeded(arg) { argv := fmt.Sprintf("%v", arg) - return "'" + strings.Replace(argv, "'", "''", -1) + "'" + return convertFunc(argv) } return fmt.Sprintf("%v", arg) } +const insertSelectPlaceHolder = true + func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error { switch argv := arg.(type) { case bool: @@ -93,8 +112,19 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er return err } default: - if _, err := w.WriteString(convertArg(arg)); err != nil { - return err + if insertSelectPlaceHolder { + if err := w.WriteByte('?'); err != nil { + return err + } + w.Append(arg) + } else { + var convertFunc = convertStringSingleQuote + if statement.Engine.dialect.DBType() == core.MYSQL { + convertFunc = convertString + } + if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil { + return err + } } } return nil diff --git a/statement_exprparam.go b/statement_exprparam.go index 0cddca02..4da4f1ea 100644 --- a/statement_exprparam.go +++ b/statement_exprparam.go @@ -57,7 +57,7 @@ func (exprs *exprParams) getByName(colName string) (exprParam, bool) { } func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { - for _, expr := range exprs.args { + for i, expr := range exprs.args { switch arg := expr.(type) { case *builder.Builder: if _, err := w.WriteString("("); err != nil { @@ -74,6 +74,11 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { return err } } + if i != len(exprs.args)-1 { + if _, err := w.WriteString(","); err != nil { + return err + } + } } return nil }