diff --git a/.drone.yml b/.drone.yml index 6001ec59..bd682e5f 100644 --- a/.drone.yml +++ b/.drone.yml @@ -537,6 +537,202 @@ services: - tag - pull_request +- name: mssql + pull: default + image: microsoft/mssql-server-linux:latest + environment: + ACCEPT_EULA: Y + SA_PASSWORD: yourStrong(!)Password + MSSQL_PID: Developer + when: + event: + - push + - tag + - pull_request + +--- +kind: pipeline +name: go1.13 + +platform: + os: linux + arch: amd64 + +clone: + disable: true + +workspace: + base: /go + path: src/github.com/go-xorm/xorm + +steps: +- name: git + pull: default + image: plugins/git:next + settings: + depth: 50 + tags: true + +- name: init_postgres + pull: default + image: postgres:9.5 + commands: + - "until psql -U postgres -d xorm_test -h pgsql \\\n -c \"SELECT 1;\" >/dev/null 2>&1; do sleep 1; done\n" + - "psql -U postgres -d xorm_test -h pgsql \\\n -c \"create schema xorm;\"\n" + +- name: build + pull: default + image: golang:1.13 + environment: + GO111MODULE: "off" + commands: + - go get -t -d -v ./... + - go get -u xorm.io/core + - go get -u xorm.io/builder + - go build -v + when: + event: + - push + - pull_request + +- name: build-gomod + pull: default + image: golang:1.13 + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" + commands: + - go build -v + when: + event: + - push + - pull_request + +- name: test-sqlite + pull: default + image: golang:1.13 + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" + commands: + - "go test -v -race -db=\"sqlite3\" -conn_str=\"./test.db\" -coverprofile=coverage1-1.txt -covermode=atomic" + - "go test -v -race -db=\"sqlite3\" -conn_str=\"./test.db\" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic" + when: + event: + - push + - pull_request + +- name: test-mysql + pull: default + image: golang:1.13 + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" + commands: + - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(mysql)/xorm_test\" -coverprofile=coverage2-1.txt -covermode=atomic" + - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(mysql)/xorm_test\" -cache=true -coverprofile=coverage2-2.txt -covermode=atomic" + when: + event: + - push + - pull_request + +- name: test-mysql-utf8mb4 + pull: default + image: golang:1.13 + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" + commands: + - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(mysql)/xorm_test?charset=utf8mb4\" -coverprofile=coverage2.1-1.txt -covermode=atomic" + - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(mysql)/xorm_test?charset=utf8mb4\" -cache=true -coverprofile=coverage2.1-2.txt -covermode=atomic" + when: + event: + - push + - pull_request + +- name: test-mymysql + pull: default + image: golang:1.13 + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" + commands: + - "go test -v -race -db=\"mymysql\" -conn_str=\"tcp:mysql:3306*xorm_test/root/\" -coverprofile=coverage3-1.txt -covermode=atomic" + - "go test -v -race -db=\"mymysql\" -conn_str=\"tcp:mysql:3306*xorm_test/root/\" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic" + when: + event: + - push + - pull_request + +- name: test-postgres + pull: default + image: golang:1.13 + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" + commands: + - "go test -v -race -db=\"postgres\" -conn_str=\"postgres://postgres:@pgsql/xorm_test?sslmode=disable\" -coverprofile=coverage4-1.txt -covermode=atomic" + - "go test -v -race -db=\"postgres\" -conn_str=\"postgres://postgres:@pgsql/xorm_test?sslmode=disable\" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic" + when: + event: + - push + - pull_request + +- name: test-postgres-schema + pull: default + image: golang:1.13 + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" + commands: + - "go test -v -race -db=\"postgres\" -conn_str=\"postgres://postgres:@pgsql/xorm_test?sslmode=disable\" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic" + - "go test -v -race -db=\"postgres\" -conn_str=\"postgres://postgres:@pgsql/xorm_test?sslmode=disable\" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic" + when: + event: + - push + - pull_request + +- name: test-mssql + pull: default + image: golang:1.13 + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" + commands: + - "go test -v -race -db=\"mssql\" -conn_str=\"server=mssql;user id=sa;password=yourStrong(!)Password;database=xorm_test\" -coverprofile=coverage6-1.txt -covermode=atomic" + - "go test -v -race -db=\"mssql\" -conn_str=\"server=mssql;user id=sa;password=yourStrong(!)Password;database=xorm_test\" -cache=true -coverprofile=coverage6-2.txt -covermode=atomic" + - go get -u github.com/wadey/gocovmerge + - gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage2.1-1.txt coverage2.1-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt coverage6-1.txt coverage6-2.txt > coverage.txt + when: + event: + - push + - pull_request + +services: +- name: mysql + pull: default + image: mysql:5.7 + environment: + MYSQL_ALLOW_EMPTY_PASSWORD: yes + MYSQL_DATABASE: xorm_test + when: + event: + - push + - tag + - pull_request + +- name: pgsql + pull: default + image: postgres:9.5 + environment: + POSTGRES_DB: xorm_test + POSTGRES_USER: postgres + when: + event: + - push + - tag + - pull_request + - name: mssql pull: default image: microsoft/mssql-server-linux:latest diff --git a/dialect_mssql.go b/dialect_mssql.go index 61061cb2..ce4dd00c 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -254,6 +254,9 @@ func (db *mssql) SqlType(c *core.Column) string { case core.TinyInt: res = core.TinyInt c.Length = 0 + case core.BigInt: + res = core.BigInt + c.Length = 0 default: res = t } diff --git a/dialect_postgres_test.go b/dialect_postgres_test.go index ebc079e1..f2afdefc 100644 --- a/dialect_postgres_test.go +++ b/dialect_postgres_test.go @@ -4,9 +4,8 @@ import ( "reflect" "testing" - "xorm.io/core" - "github.com/jackc/pgx/stdlib" "github.com/stretchr/testify/assert" + "xorm.io/core" ) func TestParsePostgres(t *testing.T) { @@ -72,10 +71,7 @@ func TestParsePgx(t *testing.T) { } // Register DriverConfig - drvierConfig := stdlib.DriverConfig{} - stdlib.RegisterDriverConfig(&drvierConfig) - uri, err = driver.Parse("pgx", - drvierConfig.ConnectionString(test.in)) + uri, err = driver.Parse("pgx", test.in) if err != nil && test.valid { t.Errorf("%q got unexpected error: %s", test.in, err) } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { diff --git a/engine.go b/engine.go index 4cbb7dd3..29a091c2 100644 --- a/engine.go +++ b/engine.go @@ -739,7 +739,7 @@ func (engine *Engine) Decr(column string, arg ...interface{}) *Session { } // SetExpr provides a update string like "column = {expression}" -func (engine *Engine) SetExpr(column string, expression string) *Session { +func (engine *Engine) SetExpr(column string, expression interface{}) *Session { session := engine.NewSession() session.isAutoClose = true return session.SetExpr(column, expression) diff --git a/go.mod b/go.mod index a3e78cae..1ab39831 100644 --- a/go.mod +++ b/go.mod @@ -6,16 +6,15 @@ require ( github.com/cockroachdb/apd v1.1.0 // indirect github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4 github.com/go-sql-driver/mysql v1.4.1 + github.com/gofrs/uuid v3.2.0+incompatible // indirect github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 // indirect - github.com/jackc/pgx v3.3.0+incompatible - github.com/kr/pretty v0.1.0 // indirect + github.com/jackc/pgx v3.6.0+incompatible github.com/lib/pq v1.0.0 github.com/mattn/go-sqlite3 v1.10.0 github.com/pkg/errors v0.8.1 // indirect - github.com/satori/go.uuid v1.2.0 // indirect github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 // indirect - github.com/stretchr/testify v1.3.0 + github.com/stretchr/testify v1.4.0 github.com/ziutek/mymysql v1.5.4 - xorm.io/builder v0.3.6-0.20190906062455-b937eb46ecfb - xorm.io/core v0.7.0 + xorm.io/builder v0.3.6 + xorm.io/core v0.7.2-0.20190928055935-90aeac8d08eb ) diff --git a/go.sum b/go.sum index 0f2baf17..cf637a8e 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:9wScpmSP5A3Bk8V3XHWUcJmYTh+ZnlHVyc+A4oZYS3Y= github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:56xuuqnHyryaerycW3BfssRdxQstACi0Epw/yC5E2xM= +github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -48,18 +50,13 @@ github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 h1:vr3AYkKovP8uR8AvSGGUK1IDqRa5lAAvEkZG1LKaCRc= github.com/jackc/fake v0.0.0-20150926172116-812a484cc733/go.mod h1:WrMFNQdiFJ80sQsxDoMokWK1W5TQtxBFNpzWTD84ibQ= -github.com/jackc/pgx v3.3.0+incompatible h1:Wa90/+qsITBAPkAZjiByeIGHFcj3Ztu+VzrrIpHjL90= -github.com/jackc/pgx v3.3.0+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I= +github.com/jackc/pgx v3.6.0+incompatible h1:bJeo4JdVbDAW8KB2m8XkFeo8CPipREoG37BwEoKGz+Q= +github.com/jackc/pgx v3.6.0+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= @@ -84,8 +81,6 @@ github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y8 github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= -github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= -github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= @@ -94,11 +89,14 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs= github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmVhcMLU6v5fEb/ok4wyqtNU= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -133,6 +131,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190602015325-4c4f7f33c9ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -158,11 +157,16 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -xorm.io/builder v0.3.6-0.20190906062455-b937eb46ecfb h1:2idZcp79ldX5qLeQ6WKCdS7aEFNOMvQc9wrtt5hSRwM= -xorm.io/builder v0.3.6-0.20190906062455-b937eb46ecfb/go.mod h1:LEFAPISnRzG+zxaxj2vPicRwz67BdhFreKg8yv8/TgU= +xorm.io/builder v0.3.6 h1:ha28mQ2M+TFx96Hxo+iq6tQgnkC9IZkM6D8w9sKHHF8= +xorm.io/builder v0.3.6/go.mod h1:LEFAPISnRzG+zxaxj2vPicRwz67BdhFreKg8yv8/TgU= xorm.io/core v0.7.0 h1:hKxuOKWZNeiFQsSuGet/KV8HZ788hclvAl+7azx3tkM= xorm.io/core v0.7.0/go.mod h1:TuOJjIVa7e3w/rN8tDcAvuLBMtwzdHPbyOzE6Gk1EUI= +xorm.io/core v0.7.1 h1:I6x6Q6dYb67aDEoYFWr2t8UcKIYjJPyCHS+aXuj5V0Y= +xorm.io/core v0.7.1/go.mod h1:jJfd0UAEzZ4t87nbQYtVjmqpIODugN6PD2D9E+dJvdM= +xorm.io/core v0.7.2-0.20190928055935-90aeac8d08eb h1:msX3zG3BPl8Ti+LDzP33/9K7BzO/WqFXk610K1kYKfo= +xorm.io/core v0.7.2-0.20190928055935-90aeac8d08eb/go.mod h1:jJfd0UAEzZ4t87nbQYtVjmqpIODugN6PD2D9E+dJvdM= diff --git a/interface.go b/interface.go index 47240928..685ed287 100644 --- a/interface.go +++ b/interface.go @@ -54,7 +54,7 @@ type Interface interface { QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error) Rows(bean interface{}) (*Rows, error) - SetExpr(string, string) *Session + SetExpr(string, interface{}) *Session SQL(interface{}, ...interface{}) *Session Sum(bean interface{}, colName string) (float64, error) SumInt(bean interface{}, colName string) (int64, error) diff --git a/session_cols.go b/session_cols.go index dc3befcf..1558074f 100644 --- a/session_cols.go +++ b/session_cols.go @@ -12,49 +12,6 @@ import ( "xorm.io/core" ) -type incrParam struct { - colName string - arg interface{} -} - -type decrParam struct { - colName string - arg interface{} -} - -type exprParam struct { - colName string - expr string -} - -type columnMap []string - -func (m columnMap) contain(colName string) bool { - if len(m) == 0 { - return false - } - - n := len(colName) - for _, mk := range m { - if len(mk) != n { - continue - } - if strings.EqualFold(mk, colName) { - return true - } - } - - return false -} - -func (m *columnMap) add(colName string) bool { - if m.contain(colName) { - return false - } - *m = append(*m, colName) - return true -} - func setColumnInt(bean interface{}, col *core.Column, t int64) { v, err := col.ValueOf(bean) if err != nil { @@ -132,7 +89,7 @@ func (session *Session) Decr(column string, arg ...interface{}) *Session { } // SetExpr provides a query string like "column = {expression}" -func (session *Session) SetExpr(column string, expression string) *Session { +func (session *Session) SetExpr(column string, expression interface{}) *Session { session.statement.SetExpr(column, expression) return session } diff --git a/session_cols_test.go b/session_cols_test.go index 5f5954c7..96cb1620 100644 --- a/session_cols_test.go +++ b/session_cols_test.go @@ -7,21 +7,38 @@ package xorm import ( "testing" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/builder" + "xorm.io/core" ) func TestSetExpr(t *testing.T) { assert.NoError(t, prepareEngine()) + type UserExprIssue struct { + Id int64 + Title string + } + + assert.NoError(t, testEngine.Sync2(new(UserExprIssue))) + + var issue = UserExprIssue{ + Title: "my issue", + } + cnt, err := testEngine.Insert(&issue) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.EqualValues(t, 1, issue.Id) + type UserExpr struct { - Id int64 - Show bool + Id int64 + IssueId int64 `xorm:"index"` + Show bool } assert.NoError(t, testEngine.Sync2(new(UserExpr))) - cnt, err := testEngine.Insert(&UserExpr{ + cnt, err = testEngine.Insert(&UserExpr{ Show: true, }) assert.NoError(t, err) @@ -34,6 +51,16 @@ func TestSetExpr(t *testing.T) { cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr)) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + + tableName := testEngine.TableName(new(UserExprIssue), true) + cnt, err = testEngine.SetExpr("issue_id", + builder.Select("id"). + From(tableName). + Where(builder.Eq{"id": issue.Id})). + ID(1). + Update(new(UserExpr)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) } func TestCols(t *testing.T) { diff --git a/session_insert.go b/session_insert.go index 24b32831..1e19ce7a 100644 --- a/session_insert.go +++ b/session_insert.go @@ -25,6 +25,12 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { defer session.Close() } + session.autoResetStatement = false + defer func() { + session.autoResetStatement = true + session.resetStatement() + }() + for _, bean := range beans { switch bean.(type) { case map[string]interface{}: @@ -35,7 +41,6 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { affected += cnt case []map[string]interface{}: s := bean.([]map[string]interface{}) - session.autoResetStatement = false for i := 0; i < len(s); i++ { cnt, err := session.insertMapInterface(s[i]) if err != nil { @@ -51,7 +56,6 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { affected += cnt case []map[string]string: s := bean.([]map[string]string) - session.autoResetStatement = false for i := 0; i < len(s); i++ { cnt, err := session.insertMapString(s[i]) if err != nil { @@ -340,74 +344,96 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if err != nil { return 0, err } - // insert expr columns, override if exists - exprColumns := session.statement.getExpr() - exprColVals := make([]string, 0, len(exprColumns)) - for _, v := range exprColumns { - // remove the expr columns - for i, colName := range colNames { - if colName == strings.Trim(v.colName, "`") { - colNames = append(colNames[:i], colNames[i+1:]...) - args = append(args[:i], args[i+1:]...) - } - } - // append expr column to the end - colNames = append(colNames, v.colName) - exprColVals = append(exprColVals, v.expr) + exprs := session.statement.exprColumns + colPlaces := strings.Repeat("?, ", len(colNames)) + if exprs.Len() <= 0 && len(colPlaces) > 0 { + colPlaces = colPlaces[0 : len(colPlaces)-2] } - colPlaces := strings.Repeat("?, ", len(colNames)-len(exprColumns)) - if len(exprColVals) > 0 { - colPlaces = colPlaces + strings.Join(exprColVals, ", ") - } else { - if len(colPlaces) > 0 { - colPlaces = colPlaces[0 : len(colPlaces)-2] - } - } - - var sqlStr string var tableName = session.statement.TableName() var output string if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 { output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) } - if len(colPlaces) > 0 { + var buf = builder.NewWriter() + if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s", session.engine.Quote(tableName))); err != nil { + return 0, err + } + + if len(colPlaces) <= 0 { + if session.engine.dialect.DBType() == core.MYSQL { + if _, err := buf.WriteString(" VALUES ()"); err != nil { + return 0, err + } + } else { + if _, err := buf.WriteString(fmt.Sprintf("%s DEFAULT VALUES", output)); err != nil { + return 0, err + } + } + } else { + if _, err := buf.WriteString(" ("); err != nil { + return 0, err + } + + if err := writeStrings(buf, append(colNames, exprs.colNames...), "`", "`"); err != nil { + return 0, err + } + if session.statement.cond.IsValid() { - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { + if _, err := buf.WriteString(fmt.Sprintf(")%s SELECT ", output)); err != nil { return 0, err } - sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s SELECT %v FROM %v WHERE %v", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), - output, - colPlaces, - session.engine.Quote(tableName), - condSQL, - ) - args = append(args, condArgs...) + if err := session.statement.writeArgs(buf, args); err != nil { + return 0, err + } + + if len(exprs.args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return 0, err + } + } + if err := exprs.writeArgs(buf); err != nil { + return 0, err + } + + if _, err := buf.WriteString(fmt.Sprintf(" FROM %v WHERE ", session.engine.Quote(tableName))); err != nil { + return 0, err + } + + if err := session.statement.cond.WriteTo(buf); err != nil { + return 0, err + } } else { - sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), + buf.Append(args...) + + if _, err := buf.WriteString(fmt.Sprintf(")%s VALUES (%v", output, - colPlaces) - } - } else { - if session.engine.dialect.DBType() == core.MYSQL { - sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName)) - } else { - sqlStr = fmt.Sprintf("INSERT INTO %s%s DEFAULT VALUES", session.engine.Quote(tableName), output) + colPlaces)); err != nil { + return 0, err + } + + if err := exprs.writeArgs(buf); err != nil { + return 0, err + } + + if _, err := buf.WriteString(")"); err != nil { + return 0, err + } } } if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES { - sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement) + if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil { + return 0, err + } } + sqlStr := buf.String() + args = buf.Args() + handleAfterInsertProcessorFunc := func(bean interface{}) { if session.isAutoCommit { for _, closure := range session.afterClosures { @@ -611,9 +637,11 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac continue } - if _, ok := session.statement.incrColumns[col.Name]; ok { + if session.statement.incrColumns.isColExist(col.Name) { continue - } else if _, ok := session.statement.decrColumns[col.Name]; ok { + } else if session.statement.decrColumns.isColExist(col.Name) { + continue + } else if session.statement.exprColumns.isColExist(col.Name) { continue } @@ -688,46 +716,66 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err } var columns = make([]string, 0, len(m)) + exprs := session.statement.exprColumns for k := range m { - columns = append(columns, k) + if !exprs.isColExist(k) { + columns = append(columns, k) + } } sort.Strings(columns) - qm := strings.Repeat("?,", len(columns)) - var args = make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } - // insert expr columns, override if exists - exprColumns := session.statement.getExpr() - for _, col := range exprColumns { - columns = append(columns, strings.Trim(col.colName, "`")) - qm = qm + col.expr + "," - } - - qm = qm[:len(qm)-1] - - var sql string - + w := builder.NewWriter() if session.statement.cond.IsValid() { - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { + if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { + return 0, err + } + + if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { + return 0, err + } + + if _, err := w.WriteString(") SELECT "); err != nil { + return 0, err + } + + if err := session.statement.writeArgs(w, args); err != nil { + return 0, err + } + + if len(exprs.args) > 0 { + if _, err := w.WriteString(","); err != nil { + return 0, err + } + if err := exprs.writeArgs(w); err != nil { + return 0, err + } + } + + if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil { + return 0, err + } + + if err := session.statement.cond.WriteTo(w); err != nil { return 0, err } - sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s", - session.engine.Quote(tableName), - strings.Join(columns, "`,`"), - qm, - session.engine.Quote(tableName), - condSQL, - ) - args = append(args, condArgs...) } else { - sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm) + qm := strings.Repeat("?,", len(columns)) + qm = qm[:len(qm)-1] + + if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil { + return 0, err + } + w.Append(args...) } + sql := w.String() + args = w.Args() + if err := session.cacheInsert(tableName); err != nil { return 0, err } @@ -754,8 +802,11 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { } var columns = make([]string, 0, len(m)) + exprs := session.statement.exprColumns for k := range m { - columns = append(columns, k) + if !exprs.isColExist(k) { + columns = append(columns, k) + } } sort.Strings(columns) @@ -764,37 +815,53 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { args = append(args, m[colName]) } - qm := strings.Repeat("?,", len(columns)) - - // insert expr columns, override if exists - exprColumns := session.statement.getExpr() - for _, col := range exprColumns { - columns = append(columns, strings.Trim(col.colName, "`")) - qm = qm + col.expr + "," - } - - qm = qm[:len(qm)-1] - - var sql string - + w := builder.NewWriter() if session.statement.cond.IsValid() { - qm = "(" + qm[:len(qm)-1] + ")" - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { + if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { + return 0, err + } + + if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { + return 0, err + } + + if _, err := w.WriteString(") SELECT "); err != nil { + return 0, err + } + + if err := session.statement.writeArgs(w, args); err != nil { + return 0, err + } + + if len(exprs.args) > 0 { + if _, err := w.WriteString(","); err != nil { + return 0, err + } + if err := exprs.writeArgs(w); err != nil { + return 0, err + } + } + + if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil { + return 0, err + } + + if err := session.statement.cond.WriteTo(w); err != nil { return 0, err } - sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s", - session.engine.Quote(tableName), - strings.Join(columns, "`,`"), - qm, - session.engine.Quote(tableName), - condSQL, - ) - args = append(args, condArgs...) } else { - sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm) + qm := strings.Repeat("?,", len(columns)) + qm = qm[:len(qm)-1] + + if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil { + return 0, err + } + w.Append(args...) } + sql := w.String() + args = w.Args() + if err := session.cacheInsert(tableName); err != nil { return 0, err } diff --git a/session_insert_test.go b/session_insert_test.go index daf08e7f..2785401d 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -846,6 +846,7 @@ func TestInsertWhere(t *testing.T) { Width uint32 Height uint32 Name string + IsTrue bool } assert.NoError(t, prepareEngine()) @@ -892,4 +893,134 @@ func TestInsertWhere(t *testing.T) { assert.EqualValues(t, 40, j2.Height) assert.EqualValues(t, "trest2", j2.Name) assert.EqualValues(t, 2, j2.Index) + + inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). + SetExpr("repo_id", "1"). + Insert(map[string]string{ + "name": "trest3", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + + var j3 InsertWhere + has, err = testEngine.ID(3).Get(&j3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "trest3", j3.Name) + assert.EqualValues(t, 3, j3.Index) + + inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). + Insert(map[string]interface{}{ + "repo_id": 1, + "name": "10';delete * from insert_where; --", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + + var j4 InsertWhere + has, err = testEngine.ID(4).Get(&j4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "10';delete * from insert_where; --", j4.Name) + assert.EqualValues(t, 4, j4.Index) + + inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). + Insert(map[string]interface{}{ + "repo_id": 1, + "name": "10\\';delete * from insert_where; --", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + + var j5 InsertWhere + has, err = testEngine.ID(5).Get(&j5) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "10\\';delete * from insert_where; --", j5.Name) + assert.EqualValues(t, 5, j5.Index) +} + +type NightlyRate struct { + ID int64 `xorm:"'id' not null pk BIGINT(20)" json:"id"` +} + +func (NightlyRate) TableName() string { + return "prd_nightly_rate" +} + +func TestMultipleInsertTableName(t *testing.T) { + assert.NoError(t, prepareEngine()) + + tableName := `prd_nightly_rate_16` + assert.NoError(t, testEngine.Table(tableName).Sync2(new(NightlyRate))) + + trans := testEngine.NewSession() + defer trans.Close() + err := trans.Begin() + assert.NoError(t, err) + + rtArr := []interface{}{ + []*NightlyRate{ + {ID: 1}, + {ID: 2}, + }, + []*NightlyRate{ + {ID: 3}, + {ID: 4}, + }, + []*NightlyRate{ + {ID: 5}, + }, + } + + _, err = trans.Table(tableName).Insert(rtArr...) + assert.NoError(t, err) + + assert.NoError(t, trans.Commit()) +} + +func TestInsertMultiWithOmit(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type TestMultiOmit struct { + Id int64 `xorm:"int(11) pk"` + Name string `xorm:"varchar(255)"` + Omitted string `xorm:"varchar(255) 'omitted'"` + } + + assert.NoError(t, testEngine.Sync2(new(TestMultiOmit))) + + l := []interface{}{ + TestMultiOmit{Id: 1, Name: "1", Omitted: "1"}, + TestMultiOmit{Id: 2, Name: "1", Omitted: "2"}, + TestMultiOmit{Id: 3, Name: "1", Omitted: "3"}, + } + + check := func() { + var ls []TestMultiOmit + err := testEngine.Find(&ls) + assert.NoError(t, err) + assert.EqualValues(t, 3, len(ls)) + + for e := range ls { + assert.EqualValues(t, "", ls[e].Omitted) + } + } + + num, err := testEngine.Omit("omitted").Insert(l...) + assert.NoError(t, err) + assert.EqualValues(t, 3, num) + check() + + num, err = testEngine.Delete(TestMultiOmit{Name: "1"}) + assert.NoError(t, err) + assert.EqualValues(t, 3, num) + + num, err = testEngine.Omit("omitted").Insert(l) + assert.NoError(t, err) + assert.EqualValues(t, 3, num) + check() } diff --git a/session_update.go b/session_update.go index 85b0bb0b..c5c65a45 100644 --- a/session_update.go +++ b/session_update.go @@ -223,21 +223,31 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } // for update action to like "column = column + ?" - incColumns := session.statement.getInc() - for _, v := range incColumns { - colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?") - args = append(args, v.arg) + incColumns := session.statement.incrColumns + for i, colName := range incColumns.colNames { + colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?") + args = append(args, incColumns.args[i]) } // for update action to like "column = column - ?" - decColumns := session.statement.getDec() - for _, v := range decColumns { - colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?") - args = append(args, v.arg) + decColumns := session.statement.decrColumns + for i, colName := range decColumns.colNames { + colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?") + args = append(args, decColumns.args[i]) } // for update action to like "column = expression" - exprColumns := session.statement.getExpr() - for _, v := range exprColumns { - colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr) + exprColumns := session.statement.exprColumns + for i, colName := range exprColumns.colNames { + switch tp := exprColumns.args[i].(type) { + case string: + colNames = append(colNames, session.engine.Quote(colName)+" = "+tp) + case *builder.Builder: + subQuery, subArgs, err := builder.ToSQL(tp) + if err != nil { + return 0, err + } + colNames = append(colNames, session.engine.Quote(colName)+" = ("+subQuery+")") + args = append(args, subArgs...) + } } if err = session.statement.processIDParam(); err != nil { @@ -468,14 +478,17 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac continue } - if len(session.statement.columnMap) > 0 { - if !session.statement.columnMap.contain(col.Name) { - continue - } else if _, ok := session.statement.incrColumns[col.Name]; ok { - continue - } else if _, ok := session.statement.decrColumns[col.Name]; ok { - continue - } + // if only update specify columns + if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + continue + } + + if session.statement.incrColumns.isColExist(col.Name) { + continue + } else if session.statement.decrColumns.isColExist(col.Name) { + continue + } else if session.statement.exprColumns.isColExist(col.Name) { + continue } // !evalphobia! set fieldValue as nil when column is nullable and zero-value diff --git a/statement.go b/statement.go index 6cdbad7d..3cc0831e 100644 --- a/statement.go +++ b/statement.go @@ -52,9 +52,9 @@ type Statement struct { omitColumnMap columnMap mustColumnMap map[string]bool nullableMap map[string]bool - incrColumns map[string]incrParam - decrColumns map[string]decrParam - exprColumns map[string]exprParam + incrColumns exprParams + decrColumns exprParams + exprColumns exprParams cond builder.Cond bufferSize int context ContextCache @@ -94,9 +94,9 @@ func (statement *Statement) Init() { statement.nullableMap = make(map[string]bool) statement.checkVersion = true statement.unscoped = false - statement.incrColumns = make(map[string]incrParam) - statement.decrColumns = make(map[string]decrParam) - statement.exprColumns = make(map[string]exprParam) + statement.incrColumns = exprParams{} + statement.decrColumns = exprParams{} + statement.exprColumns = exprParams{} statement.cond = builder.NewCond() statement.bufferSize = 0 statement.context = nil @@ -534,48 +534,30 @@ func (statement *Statement) ID(id interface{}) *Statement { // Incr Generate "Update ... Set column = column + arg" statement func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { - k := strings.ToLower(column) if len(arg) > 0 { - statement.incrColumns[k] = incrParam{column, arg[0]} + statement.incrColumns.addParam(column, arg[0]) } else { - statement.incrColumns[k] = incrParam{column, 1} + statement.incrColumns.addParam(column, 1) } return statement } // Decr Generate "Update ... Set column = column - arg" statement func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { - k := strings.ToLower(column) if len(arg) > 0 { - statement.decrColumns[k] = decrParam{column, arg[0]} + statement.decrColumns.addParam(column, arg[0]) } else { - statement.decrColumns[k] = decrParam{column, 1} + statement.decrColumns.addParam(column, 1) } return statement } // SetExpr Generate "Update ... Set column = {expression}" statement -func (statement *Statement) SetExpr(column string, expression string) *Statement { - k := strings.ToLower(column) - statement.exprColumns[k] = exprParam{column, expression} +func (statement *Statement) SetExpr(column string, expression interface{}) *Statement { + statement.exprColumns.addParam(column, expression) return statement } -// Generate "Update ... Set column = column + arg" statement -func (statement *Statement) getInc() map[string]incrParam { - return statement.incrColumns -} - -// Generate "Update ... Set column = column - arg" statement -func (statement *Statement) getDec() map[string]decrParam { - return statement.decrColumns -} - -// Generate "Update ... Set column = {expression}" statement -func (statement *Statement) getExpr() map[string]exprParam { - return statement.exprColumns -} - func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { newColumns := make([]string, 0) quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") diff --git a/statement_args.go b/statement_args.go new file mode 100644 index 00000000..310f24d6 --- /dev/null +++ b/statement_args.go @@ -0,0 +1,170 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "reflect" + "strings" + "time" + + "xorm.io/builder" + "xorm.io/core" +) + +func quoteNeeded(a interface{}) bool { + switch a.(type) { + case int, int8, int16, int32, int64: + return false + case uint, uint8, uint16, uint32, uint64: + return false + case float32, float64: + return false + case bool: + return false + case string: + return true + case time.Time, *time.Time: + return true + case builder.Builder, *builder.Builder: + return false + } + + t := reflect.TypeOf(a) + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return false + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return false + case reflect.Float32, reflect.Float64: + return false + case reflect.Bool: + return false + case reflect.String: + return true + } + + return true +} + +func convertStringSingleQuote(arg string) string { + return "'" + strings.Replace(arg, "'", "''", -1) + "'" +} + +func convertString(arg string) string { + var buf strings.Builder + buf.WriteRune('\'') + for _, c := range arg { + if c == '\\' || c == '\'' { + buf.WriteRune('\\') + } + buf.WriteRune(c) + } + buf.WriteRune('\'') + return buf.String() +} + +func convertArg(arg interface{}, convertFunc func(string) string) string { + if quoteNeeded(arg) { + argv := fmt.Sprintf("%v", arg) + return convertFunc(argv) + } + + return fmt.Sprintf("%v", arg) +} + +const insertSelectPlaceHolder = true + +func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error { + switch argv := arg.(type) { + case bool: + if statement.Engine.dialect.DBType() == core.MSSQL { + if argv { + if _, err := w.WriteString("1"); err != nil { + return err + } + } else { + if _, err := w.WriteString("0"); err != nil { + return err + } + } + } else { + if argv { + if _, err := w.WriteString("true"); err != nil { + return err + } + } else { + if _, err := w.WriteString("false"); err != nil { + return err + } + } + } + case *builder.Builder: + if _, err := w.WriteString("("); err != nil { + return err + } + if err := argv.WriteTo(w); err != nil { + return err + } + if _, err := w.WriteString(")"); err != nil { + return err + } + default: + if insertSelectPlaceHolder { + if err := w.WriteByte('?'); err != nil { + return err + } + w.Append(arg) + } else { + var convertFunc = convertStringSingleQuote + if statement.Engine.dialect.DBType() == core.MYSQL { + convertFunc = convertString + } + if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil { + return err + } + } + } + return nil +} + +func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{}) error { + for i, arg := range args { + if err := statement.writeArg(w, arg); err != nil { + return err + } + + if i+1 != len(args) { + if _, err := w.WriteString(","); err != nil { + return err + } + } + } + return nil +} + +func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error { + for i, colName := range cols { + if len(leftQuote) > 0 && colName[0] != '`' { + if _, err := w.WriteString(leftQuote); err != nil { + return err + } + } + if _, err := w.WriteString(colName); err != nil { + return err + } + if len(rightQuote) > 0 && colName[len(colName)-1] != '`' { + if _, err := w.WriteString(rightQuote); err != nil { + return err + } + } + if i+1 != len(cols) { + if _, err := w.WriteString(","); err != nil { + return err + } + } + } + return nil +} diff --git a/statement_columnmap.go b/statement_columnmap.go new file mode 100644 index 00000000..b6523b1e --- /dev/null +++ b/statement_columnmap.go @@ -0,0 +1,35 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import "strings" + +type columnMap []string + +func (m columnMap) contain(colName string) bool { + if len(m) == 0 { + return false + } + + n := len(colName) + for _, mk := range m { + if len(mk) != n { + continue + } + if strings.EqualFold(mk, colName) { + return true + } + } + + return false +} + +func (m *columnMap) add(colName string) bool { + if m.contain(colName) { + return false + } + *m = append(*m, colName) + return true +} diff --git a/statement_exprparam.go b/statement_exprparam.go new file mode 100644 index 00000000..4da4f1ea --- /dev/null +++ b/statement_exprparam.go @@ -0,0 +1,117 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "strings" + + "xorm.io/builder" +) + +type ErrUnsupportedExprType struct { + tp string +} + +func (err ErrUnsupportedExprType) Error() string { + return fmt.Sprintf("Unsupported expression type: %v", err.tp) +} + +type exprParam struct { + colName string + arg interface{} +} + +type exprParams struct { + colNames []string + args []interface{} +} + +func (exprs *exprParams) Len() int { + return len(exprs.colNames) +} + +func (exprs *exprParams) addParam(colName string, arg interface{}) { + exprs.colNames = append(exprs.colNames, colName) + exprs.args = append(exprs.args, arg) +} + +func (exprs *exprParams) isColExist(colName string) bool { + for _, name := range exprs.colNames { + if strings.EqualFold(trimQuote(name), trimQuote(colName)) { + return true + } + } + return false +} + +func (exprs *exprParams) getByName(colName string) (exprParam, bool) { + for i, name := range exprs.colNames { + if strings.EqualFold(name, colName) { + return exprParam{name, exprs.args[i]}, true + } + } + return exprParam{}, false +} + +func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { + for i, expr := range exprs.args { + switch arg := expr.(type) { + case *builder.Builder: + if _, err := w.WriteString("("); err != nil { + return err + } + if err := arg.WriteTo(w); err != nil { + return err + } + if _, err := w.WriteString(")"); err != nil { + return err + } + default: + if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil { + return err + } + } + if i != len(exprs.args)-1 { + if _, err := w.WriteString(","); err != nil { + return err + } + } + } + return nil +} + +func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { + for i, colName := range exprs.colNames { + if _, err := w.WriteString(colName); err != nil { + return err + } + if _, err := w.WriteString("="); err != nil { + return err + } + + switch arg := exprs.args[i].(type) { + case *builder.Builder: + if _, err := w.WriteString("("); err != nil { + return err + } + if err := arg.WriteTo(w); err != nil { + return err + } + if _, err := w.WriteString("("); err != nil { + return err + } + default: + w.Append(exprs.args[i]) + } + + if i+1 != len(exprs.colNames) { + if _, err := w.WriteString(","); err != nil { + return err + } + } + } + return nil +} diff --git a/statement_quote.go b/statement_quote.go new file mode 100644 index 00000000..e22e0d14 --- /dev/null +++ b/statement_quote.go @@ -0,0 +1,19 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +func trimQuote(s string) string { + if len(s) == 0 { + return s + } + + if s[0] == '`' { + s = s[1:] + } + if len(s) > 0 && s[len(s)-1] == '`' { + return s[:len(s)-1] + } + return s +} diff --git a/test_mssql.sh b/test_mssql.sh index 7f060cff..e26e1641 100755 --- a/test_mssql.sh +++ b/test_mssql.sh @@ -1 +1 @@ -go test -db=mssql -conn_str="server=localhost;user id=sa;password=yourStrong(!)Password;database=xorm_test" \ No newline at end of file +go test -db=mssql -conn_str="server=localhost;user id=sa;password=MwantsaSecurePassword1;database=xorm_test" \ No newline at end of file