diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 00000000..bf8d9288 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,61 @@ +# Golang CircleCI 2.0 configuration file +# +# Check https://circleci.com/docs/2.0/language-go/ for more details +version: 2 +jobs: + build: + docker: + # specify the version + - image: circleci/golang:1.10 + + - image: circleci/mysql:5.7 + environment: + MYSQL_ALLOW_EMPTY_PASSWORD: true + MYSQL_DATABASE: xorm_test + MYSQL_HOST: 127.0.0.1 + MYSQL_ROOT_HOST: '%' + MYSQL_USER: root + + # CircleCI PostgreSQL images available at: https://hub.docker.com/r/circleci/postgres/ + - image: circleci/postgres:9.6.2-alpine + environment: + POSTGRES_USER: circleci + POSTGRES_DB: xorm_test + + - image: microsoft/mssql-server-linux:latest + environment: + ACCEPT_EULA: Y + SA_PASSWORD: yourStrong(!)Password + MSSQL_PID: Developer + + - image: pingcap/tidb:v2.1.2 + + working_directory: /go/src/github.com/go-xorm/xorm + steps: + - checkout + + - run: go get -t -d -v ./... + - run: go get -u github.com/go-xorm/core + - run: go get -u github.com/go-xorm/builder + - run: GO111MODULE=off go build -v + - run: GO111MODULE=on go build -v + + - run: go get -u github.com/wadey/gocovmerge + + - run: go test -v -race -db="sqlite3" -conn_str="./test.db" -coverprofile=coverage1-1.txt -covermode=atomic + - run: go test -v -race -db="sqlite3" -conn_str="./test.db" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic + - run: go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -coverprofile=coverage2-1.txt -covermode=atomic + - run: go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -cache=true -coverprofile=coverage2-2.txt -covermode=atomic + - run: go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -coverprofile=coverage3-1.txt -covermode=atomic + - run: go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic + - run: go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic + - run: go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic + - run: go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic + - run: go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic + - run: go test -v -race -db="mssql" -conn_str="server=localhost;user id=sa;password=yourStrong(!)Password;database=xorm_test" -coverprofile=coverage6-1.txt -covermode=atomic + - run: go test -v -race -db="mssql" -conn_str="server=localhost;user id=sa;password=yourStrong(!)Password;database=xorm_test" -cache=true -coverprofile=coverage6-2.txt -covermode=atomic + - run: go test -v -race -db="mysql" -conn_str="root:@tcp(localhost:4000)/xorm_test" -ignore_select_update=true -coverprofile=coverage7-1.txt -covermode=atomic + - run: go test -v -race -db="mysql" -conn_str="root:@tcp(localhost:4000)/xorm_test" -ignore_select_update=true -cache=true -coverprofile=coverage7-2.txt -covermode=atomic + - run: gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt coverage6-1.txt coverage6-2.txt coverage7-1.txt coverage7-2.txt > coverage.txt + + - run: bash <(curl -s https://codecov.io/bash) \ No newline at end of file diff --git a/.drone.yml b/.drone.yml new file mode 100644 index 00000000..0a79ed02 --- /dev/null +++ b/.drone.yml @@ -0,0 +1,125 @@ +workspace: + base: /go + path: src/github.com/go-xorm/xorm + +clone: + git: + image: plugins/git:next + depth: 50 + tags: true + +services: + mysql: + image: mysql:5.7 + environment: + - MYSQL_DATABASE=xorm_test + - MYSQL_ALLOW_EMPTY_PASSWORD=yes + when: + event: [ push, tag, pull_request ] + + pgsql: + image: postgres:9.5 + environment: + - POSTGRES_USER=postgres + - POSTGRES_DB=xorm_test + when: + event: [ push, tag, pull_request ] + + #mssql: + # image: microsoft/mssql-server-linux:2017-CU11 + # environment: + # - ACCEPT_EULA=Y + # - SA_PASSWORD=yourStrong(!)Password + # - MSSQL_PID=Developer + # commands: + # - echo 'CREATE DATABASE xorm_test' > create.sql + # - /opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P yourStrong(!)Password -i "create.sql" + +matrix: + GO_VERSION: + - 1.8 + - 1.9 + - 1.10 + - 1.11 + +pipeline: + init_postgres: + image: postgres:9.5 + commands: + # wait for postgres service to become available + - | + until psql -U postgres -d xorm_test -h pgsql \ + -c "SELECT 1;" >/dev/null 2>&1; do sleep 1; done + # query the database + - | + psql -U postgres -d xorm_test -h pgsql \ + -c "create schema xorm;" + + build: + image: golang:${GO_VERSION} + commands: + - go get -t -d -v ./... + - go get -u github.com/go-xorm/core + - go get -u github.com/go-xorm/builder + - go build -v + when: + event: [ push, pull_request ] + + test-sqlite: + image: golang:${GO_VERSION} + commands: + - go get -u github.com/wadey/gocovmerge + - go test -v -race -db="sqlite3" -conn_str="./test.db" -coverprofile=coverage1-1.txt -covermode=atomic + - go test -v -race -db="sqlite3" -conn_str="./test.db" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic + when: + event: [ push, pull_request ] + + test-mysql: + image: golang:${GO_VERSION} + commands: + - go test -v -race -db="mysql" -conn_str="root:@tcp(mysql)/xorm_test" -coverprofile=coverage2-1.txt -covermode=atomic + - go test -v -race -db="mysql" -conn_str="root:@tcp(mysql)/xorm_test" -cache=true -coverprofile=coverage2-2.txt -covermode=atomic + when: + event: [ push, pull_request ] + + test-mysql-utf8mb4: + image: golang:${GO_VERSION} + commands: + - go test -v -race -db="mysql" -conn_str="root:@tcp(mysql)/xorm_test?charset=utf8mb4" -coverprofile=coverage2.1-1.txt -covermode=atomic + - go test -v -race -db="mysql" -conn_str="root:@tcp(mysql)/xorm_test?charset=utf8mb4" -cache=true -coverprofile=coverage2.1-2.txt -covermode=atomic + when: + event: [ push, pull_request ] + + test-mymysql: + image: golang:${GO_VERSION} + commands: + - go test -v -race -db="mymysql" -conn_str="tcp:mysql:3306*xorm_test/root/" -coverprofile=coverage3-1.txt -covermode=atomic + - go test -v -race -db="mymysql" -conn_str="tcp:mysql:3306*xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic + when: + event: [ push, pull_request ] + + test-postgres: + image: golang:${GO_VERSION} + commands: + - go test -v -race -db="postgres" -conn_str="postgres://postgres:@pgsql/xorm_test?sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic + - go test -v -race -db="postgres" -conn_str="postgres://postgres:@pgsql/xorm_test?sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic + when: + event: [ push, pull_request ] + + test-postgres-schema: + image: golang:${GO_VERSION} + commands: + - go test -v -race -db="postgres" -conn_str="postgres://postgres:@pgsql/xorm_test?sslmode=disable" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic + - go test -v -race -db="postgres" -conn_str="postgres://postgres:@pgsql/xorm_test?sslmode=disable" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic + - gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage2.1-1.txt coverage2.1-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt > coverage.txt + when: + event: [ push, pull_request ] + + #coverage: + # image: robertstettner/drone-codecov + # secrets: [ codecov_token ] + # files: + # - coverage.txt + # when: + # event: [ push, pull_request ] + # branch: [ master ] \ No newline at end of file diff --git a/README.md b/README.md index 2443e4ef..5c64e776 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,8 @@ Xorm is a simple and powerful ORM for Go. * Postgres schema support +* Context Cache support + ## Drivers Support Drivers for Go's sql package which currently support database/sql includes: @@ -358,6 +360,56 @@ if _, err := session.Exec("delete from userinfo where username = ?", user2.Usern return session.Commit() ``` +* Or you can use `Transaction` to replace above codes. + +```Go +res, err := engine.Transaction(func(session *xorm.Session) (interface{}, error) { + user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} + if _, err := session.Insert(&user1); err != nil { + return nil, err + } + + user2 := Userinfo{Username: "yyy"} + if _, err := session.Where("id = ?", 2).Update(&user2); err != nil { + return nil, err + } + + if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil { + return nil, err + } + return nil, nil +}) +``` + +* Context Cache, if enabled, current query result will be cached on session and be used by next same statement on the same session. + +```Go + sess := engine.NewSession() + defer sess.Close() + + var context = xorm.NewMemoryContextCache() + + var c2 ContextGetStruct + has, err := sess.ID(1).ContextCache(context).Get(&c2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, c2.Id) + assert.EqualValues(t, "1", c2.Name) + sql, args := sess.LastSQL() + assert.True(t, len(sql) > 0) + assert.True(t, len(args) > 0) + + var c3 ContextGetStruct + has, err = sess.ID(1).ContextCache(context).Get(&c3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, c3.Id) + assert.EqualValues(t, "1", c3.Name) + sql, args = sess.LastSQL() + assert.True(t, len(sql) == 0) + assert.True(t, len(args) == 0) +``` + ## Contributing If you want to pull request, please see [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md). And we also provide [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm) to discuss. @@ -441,4 +493,4 @@ Support this project by becoming a sponsor. Your logo will show up here with a l ## LICENSE -BSD License [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) \ No newline at end of file +BSD License [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) diff --git a/README_CN.md b/README_CN.md index c51cec05..e2ed95b6 100644 --- a/README_CN.md +++ b/README_CN.md @@ -32,6 +32,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 * 内置SQL Builder支持 +* 上下文缓存支持 + ## 驱动支持 目前支持的Go数据库驱动和对应的数据库如下: @@ -360,9 +362,61 @@ if _, err := session.Exec("delete from userinfo where username = ?", user2.Usern return session.Commit() ``` +* 事物的简写方法 + +```Go +res, err := engine.Transaction(func(session *xorm.Session) (interface{}, error) { + user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} + if _, err := session.Insert(&user1); err != nil { + return nil, err + } + + user2 := Userinfo{Username: "yyy"} + if _, err := session.Where("id = ?", 2).Update(&user2); err != nil { + return nil, err + } + + if _, err := session.Exec("delete from userinfo where username = ?", user2.Username); err != nil { + return nil, err + } + return nil, nil +}) +``` + +* 上下文缓存,如果启用,那么针对单个对象的查询将会被缓存到系统中,可以被下一个查询使用。 + +```Go + sess := engine.NewSession() + defer sess.Close() + + var context = xorm.NewMemoryContextCache() + + var c2 ContextGetStruct + has, err := sess.ID(1).ContextCache(context).Get(&c2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, c2.Id) + assert.EqualValues(t, "1", c2.Name) + sql, args := sess.LastSQL() + assert.True(t, len(sql) > 0) + assert.True(t, len(args) > 0) + + var c3 ContextGetStruct + has, err = sess.ID(1).ContextCache(context).Get(&c3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, c3.Id) + assert.EqualValues(t, "1", c3.Name) + sql, args = sess.LastSQL() + assert.True(t, len(sql) == 0) + assert.True(t, len(args) == 0) +``` + ## 贡献 -如果您也想为Xorm贡献您的力量,请查看 [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md)。您也可以加入QQ群 280360085 技术帮助和讨论。 +如果您也想为Xorm贡献您的力量,请查看 [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md)。您也可以加入QQ群 技术帮助和讨论。 +群一:280360085 (已满) +群二:795010183 ## Credits diff --git a/circle.yml b/circle.yml deleted file mode 100644 index 8fde3169..00000000 --- a/circle.yml +++ /dev/null @@ -1,41 +0,0 @@ -dependencies: - override: - # './...' is a relative pattern which means all subdirectories - - go get -t -d -v ./... - - go get -t -d -v github.com/go-xorm/tests - - go get -u github.com/go-xorm/core - - go get -u github.com/go-xorm/builder - - go build -v - -database: - override: - - mysql -u root -e "CREATE DATABASE xorm_test DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci" - - mysql -u root -e "CREATE DATABASE xorm_test1 DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci" - - mysql -u root -e "CREATE DATABASE xorm_test2 DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci" - - mysql -u root -e "CREATE DATABASE xorm_test3 DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci" - - createdb -p 5432 -e -U postgres xorm_test - - createdb -p 5432 -e -U postgres xorm_test1 - - createdb -p 5432 -e -U postgres xorm_test2 - - createdb -p 5432 -e -U postgres xorm_test3 - - psql xorm_test postgres -c "create schema xorm" - -test: - override: - # './...' is a relative pattern which means all subdirectories - - go get -u github.com/wadey/gocovmerge - - go test -v -race -db="sqlite3" -conn_str="./test.db" -coverprofile=coverage1-1.txt -covermode=atomic - - go test -v -race -db="sqlite3" -conn_str="./test.db" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic - - go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -coverprofile=coverage2-1.txt -covermode=atomic - - go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -cache=true -coverprofile=coverage2-2.txt -covermode=atomic - - go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -coverprofile=coverage3-1.txt -covermode=atomic - - go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic - - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic - - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic - - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic - - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic - - gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt > coverage.txt - - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh - - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh - - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh - post: - - bash <(curl -s https://codecov.io/bash) \ No newline at end of file diff --git a/context_cache.go b/context_cache.go new file mode 100644 index 00000000..1bc22884 --- /dev/null +++ b/context_cache.go @@ -0,0 +1,30 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +// ContextCache is the interface that operates the cache data. +type ContextCache interface { + // Put puts value into cache with key. + Put(key string, val interface{}) + // Get gets cached value by given key. + Get(key string) interface{} +} + +type memoryContextCache map[string]interface{} + +// NewMemoryContextCache return memoryContextCache +func NewMemoryContextCache() memoryContextCache { + return make(map[string]interface{}) +} + +// Put puts value into cache with key. +func (m memoryContextCache) Put(key string, val interface{}) { + m[key] = val +} + +// Get gets cached value by given key. +func (m memoryContextCache) Get(key string) interface{} { + return m[key] +} diff --git a/dialect_mssql.go b/dialect_mssql.go index 6d2291dc..ea543825 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -7,6 +7,7 @@ package xorm import ( "errors" "fmt" + "net/url" "strconv" "strings" @@ -218,7 +219,7 @@ func (db *mssql) SqlType(c *core.Column) string { res = core.Bit if strings.EqualFold(c.Default, "true") { c.Default = "1" - } else { + } else if strings.EqualFold(c.Default, "false") { c.Default = "0" } case core.Serial: @@ -544,14 +545,23 @@ type odbcDriver struct { } func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - kv := strings.Split(dataSourceName, ";") var dbName string - for _, c := range kv { - vv := strings.Split(strings.TrimSpace(c), "=") - if len(vv) == 2 { - switch strings.ToLower(vv[0]) { - case "database": - dbName = vv[1] + + if strings.HasPrefix(dataSourceName, "sqlserver://") { + u, err := url.Parse(dataSourceName) + if err != nil { + return nil, err + } + dbName = u.Query().Get("database") + } else { + kv := strings.Split(dataSourceName, ";") + for _, c := range kv { + vv := strings.Split(strings.TrimSpace(c), "=") + if len(vv) == 2 { + switch strings.ToLower(vv[0]) { + case "database": + dbName = vv[1] + } } } } diff --git a/dialect_mssql_test.go b/dialect_mssql_test.go new file mode 100644 index 00000000..f0673016 --- /dev/null +++ b/dialect_mssql_test.go @@ -0,0 +1,35 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "reflect" + "testing" + + "github.com/go-xorm/core" +) + +func TestParseMSSQL(t *testing.T) { + tests := []struct { + in string + expected string + valid bool + }{ + {"sqlserver://sa:yourStrong(!)Password@localhost:1433?database=db&connection+timeout=30", "db", true}, + {"server=localhost;user id=sa;password=yourStrong(!)Password;database=db", "db", true}, + } + + driver := core.QueryDriver("mssql") + + for _, test := range tests { + uri, err := driver.Parse("mssql", test.in) + + if err != nil && test.valid { + t.Errorf("%q got unexpected error: %s", test.in, err) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) + } + } +} diff --git a/dialect_postgres.go b/dialect_postgres.go index 1f74bd31..d6f40368 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -822,7 +822,7 @@ func (db *postgres) SqlType(c *core.Column) string { case core.NVarchar: res = core.Varchar case core.Uuid: - res = core.Uuid + return core.Uuid case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: return core.Bytea case core.Double: @@ -834,6 +834,10 @@ func (db *postgres) SqlType(c *core.Column) string { res = t } + if strings.EqualFold(res, "bool") { + // for bool, we don't need length information + return res + } hasLen1 := (c.Length > 0) hasLen2 := (c.Length2 > 0) @@ -1089,6 +1093,19 @@ func (db *postgres) GetTables() ([]*core.Table, error) { return tables, nil } + +func getIndexColName(indexdef string) []string { + var colNames []string + + cs := strings.Split(indexdef, "(") + for _, v := range strings.Split(strings.Split(cs[1], ")")[0], ",") { + colNames = append(colNames, strings.Split(strings.TrimLeft(v, " "), " ")[0]) + } + + return colNames +} + + func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { args := []interface{}{tableName} s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") @@ -1122,8 +1139,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) } else { indexType = core.IndexType } - cs := strings.Split(indexdef, "(") - colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") + colNames = getIndexColName(indexdef) var isRegular bool if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { newIdxName := indexName[5+len(tableName):] diff --git a/dialect_postgres_test.go b/dialect_postgres_test.go index 6e6c44bb..06c79ca4 100644 --- a/dialect_postgres_test.go +++ b/dialect_postgres_test.go @@ -6,6 +6,7 @@ import ( "github.com/go-xorm/core" "github.com/jackc/pgx/stdlib" + "github.com/stretchr/testify/assert" ) func TestParsePostgres(t *testing.T) { @@ -84,3 +85,37 @@ func TestParsePgx(t *testing.T) { } } + +func TestGetIndexColName(t *testing.T) { + t.Run("Index", func(t *testing.T) { + s := "CREATE INDEX test2_mm_idx ON test2 (major);" + colNames := getIndexColName(s) + assert.Equal(t, []string{"major"}, colNames) + }) + + t.Run("Multicolumn indexes", func(t *testing.T) { + s := "CREATE INDEX test2_mm_idx ON test2 (major, minor);" + colNames := getIndexColName(s) + assert.Equal(t, []string{"major", "minor"}, colNames) + }) + + t.Run("Indexes and ORDER BY", func(t *testing.T) { + s := "CREATE INDEX test2_mm_idx ON test2 (major NULLS FIRST, minor DESC NULLS LAST);" + colNames := getIndexColName(s) + assert.Equal(t, []string{"major", "minor"}, colNames) + }) + + t.Run("Combining Multiple Indexes", func(t *testing.T) { + s := "CREATE INDEX test2_mm_cm_idx ON public.test2 USING btree (major, minor) WHERE ((major <> 5) AND (minor <> 6))" + colNames := getIndexColName(s) + assert.Equal(t, []string{"major", "minor"}, colNames) + }) + + t.Run("unique", func(t *testing.T) { + s := "CREATE UNIQUE INDEX test2_mm_uidx ON test2 (major);" + colNames := getIndexColName(s) + assert.Equal(t, []string{"major"}, colNames) + }) + + t.Run("Indexes on Expressions", func(t *testing.T) {}) +} diff --git a/engine.go b/engine.go index 89a96d9f..07649df7 100644 --- a/engine.go +++ b/engine.go @@ -7,6 +7,7 @@ package xorm import ( "bufio" "bytes" + "context" "database/sql" "encoding/gob" "errors" @@ -52,6 +53,8 @@ type Engine struct { cachers map[string]core.Cacher cacherLock sync.RWMutex + + defaultContext context.Context } func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { @@ -481,7 +484,8 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D } cols := table.ColumnsSeq() - colNames := dialect.Quote(strings.Join(cols, dialect.Quote(", "))) + colNames := engine.dialect.Quote(strings.Join(cols, engine.dialect.Quote(", "))) + destColNames := dialect.Quote(strings.Join(cols, dialect.Quote(", "))) rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.Quote(table.Name)) if err != nil { @@ -496,7 +500,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D return err } - _, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(table.Name)+" ("+colNames+") VALUES (") + _, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(table.Name)+" ("+destColNames+") VALUES (") if err != nil { return err } @@ -526,7 +530,11 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D } else if col.SQLType.IsNumeric() { switch reflect.TypeOf(d).Kind() { case reflect.Slice: - temp += fmt.Sprintf(", %s", string(d.([]byte))) + if col.SQLType.Name == core.Bool { + temp += fmt.Sprintf(", %v", strconv.FormatBool(d.([]byte)[0] != byte('0'))) + } else { + temp += fmt.Sprintf(", %s", string(d.([]byte))) + } case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: if col.SQLType.Name == core.Bool { temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Int() > 0)) @@ -563,7 +571,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D // FIXME: Hack for postgres if string(dialect.DBType()) == core.POSTGRES && table.AutoIncrColumn() != nil { - _, err = io.WriteString(w, "SELECT setval('table_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") FROM "+dialect.Quote(table.Name)+"), 1), false);\n") + _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quote(table.Name)+"), 1), false);\n") if err != nil { return err } diff --git a/engine_context.go b/engine_context.go new file mode 100644 index 00000000..c6cbb76c --- /dev/null +++ b/engine_context.go @@ -0,0 +1,28 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.8 + +package xorm + +import "context" + +// Context creates a session with the context +func (engine *Engine) Context(ctx context.Context) *Session { + session := engine.NewSession() + session.isAutoClose = true + return session.Context(ctx) +} + +// SetDefaultContext set the default context +func (engine *Engine) SetDefaultContext(ctx context.Context) { + engine.defaultContext = ctx +} + +// PingContext tests if database is alive +func (engine *Engine) PingContext(ctx context.Context) error { + session := engine.NewSession() + defer session.Close() + return session.PingContext(ctx) +} diff --git a/context_test.go b/engine_context_test.go similarity index 57% rename from context_test.go rename to engine_context_test.go index 29a6786b..cc564694 100644 --- a/context_test.go +++ b/engine_context_test.go @@ -7,7 +7,9 @@ package xorm import ( + "context" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -15,10 +17,10 @@ import ( func TestPingContext(t *testing.T) { assert.NoError(t, prepareEngine()) - // TODO: Since EngineInterface should be compitable with old Go version, PingContext is not supported. - /* - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) - err := testEngine.PingContext(ctx) - assert.NoError(t, err) - */ + ctx, canceled := context.WithTimeout(context.Background(), time.Nanosecond) + defer canceled() + + err := testEngine.(*Engine).PingContext(ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") } diff --git a/engine_group.go b/engine_group.go index 5eee3e61..6796075e 100644 --- a/engine_group.go +++ b/engine_group.go @@ -74,6 +74,13 @@ func (eg *EngineGroup) Close() error { return nil } +// NewSession returned a group session +func (eg *EngineGroup) NewSession() *Session { + sess := eg.Engine.NewSession() + sess.sessionType = groupSession + return sess +} + // Master returns the master engine func (eg *EngineGroup) Master() *Engine { return eg.Engine diff --git a/engine_table.go b/engine_table.go index 94871a4b..4b672a6f 100644 --- a/engine_table.go +++ b/engine_table.go @@ -12,7 +12,7 @@ import ( "github.com/go-xorm/core" ) -// TableNameWithSchema will automatically add schema prefix on table name +// tbNameWithSchema will automatically add schema prefix on table name func (engine *Engine) tbNameWithSchema(v string) string { // Add schema name as prefix of table name. // Only for postgres database. diff --git a/error.go b/error.go index a223fc4a..a67527ac 100644 --- a/error.go +++ b/error.go @@ -26,6 +26,8 @@ var ( ErrNotImplemented = errors.New("Not implemented") // ErrConditionType condition type unsupported ErrConditionType = errors.New("Unsupported condition type") + // ErrUnSupportedSQLType parameter of SQL is not supported + ErrUnSupportedSQLType = errors.New("unsupported sql type") ) // ErrFieldIsNotExist columns does not exist diff --git a/go.mod b/go.mod index 4510e93c..f6a036f2 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,29 @@ -module "github.com/go-xorm/xorm" +module github.com/go-xorm/xorm require ( - "github.com/go-xorm/builder" v0.0.0-20180322150003-a9b7ffcca3f0 - "github.com/go-xorm/core" v0.0.0-20180322150003-0177c08cee88 + cloud.google.com/go v0.34.0 // indirect + github.com/cockroachdb/apd v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/denisenkom/go-mssqldb v0.0.0-20190121005146-b04fd42d9952 + github.com/go-sql-driver/mysql v1.4.1 + github.com/go-xorm/builder v0.3.3 + github.com/go-xorm/core v0.6.2 + github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a // indirect + github.com/google/go-cmp v0.2.0 // indirect + github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 // indirect + github.com/jackc/pgx v3.3.0+incompatible + github.com/kr/pretty v0.1.0 // indirect + github.com/kr/pty v1.1.3 // indirect + github.com/lib/pq v1.0.0 + github.com/mattn/go-sqlite3 v1.10.0 + github.com/pkg/errors v0.8.1 // indirect + github.com/satori/go.uuid v1.2.0 // indirect + github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 // indirect + github.com/stretchr/objx v0.1.1 // indirect + github.com/stretchr/testify v1.3.0 + github.com/ziutek/mymysql v1.5.4 + golang.org/x/crypto v0.0.0-20190122013713-64072686203f // indirect + golang.org/x/net v0.0.0-20190110200230-915654e7eabc // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/stretchr/testify.v1 v1.2.2 ) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..cdfb06e1 --- /dev/null +++ b/go.sum @@ -0,0 +1,60 @@ +cloud.google.com/go v0.34.0 h1:eOI3/cP2VTU6uZLDYAoic+eyzzB9YyGmJ7eIjl8rOPg= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/denisenkom/go-mssqldb v0.0.0-20190121005146-b04fd42d9952 h1:b5OnbZD49x9g+/FcYbs/vukEt8C/jUbGhCJ3uduQmu8= +github.com/denisenkom/go-mssqldb v0.0.0-20190121005146-b04fd42d9952/go.mod h1:xN/JuLBIz4bjkxNmByTiV1IbhfnYb6oo99phBn4Eqhc= +github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= +github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-xorm/builder v0.3.3 h1:v8grgrwOGv/iHXIEhIvOwHZIPLrpxRKSX8yWSMLFn/4= +github.com/go-xorm/builder v0.3.3/go.mod h1:v8mE3MFBgtL+RGFNfUnAMUqqfk/Y4W5KuwCFQIEpQLk= +github.com/go-xorm/core v0.6.2 h1:EJLcSxf336POJr670wKB55Mah9f93xzvGYzNRgnT8/Y= +github.com/go-xorm/core v0.6.2/go.mod h1:bwPIfLdm/FzWgVUH8WPVlr+uJhscvNGFcaZKXsI3n2c= +github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:9wScpmSP5A3Bk8V3XHWUcJmYTh+ZnlHVyc+A4oZYS3Y= +github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:56xuuqnHyryaerycW3BfssRdxQstACi0Epw/yC5E2xM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 h1:vr3AYkKovP8uR8AvSGGUK1IDqRa5lAAvEkZG1LKaCRc= +github.com/jackc/fake v0.0.0-20150926172116-812a484cc733/go.mod h1:WrMFNQdiFJ80sQsxDoMokWK1W5TQtxBFNpzWTD84ibQ= +github.com/jackc/pgx v3.3.0+incompatible h1:Wa90/+qsITBAPkAZjiByeIGHFcj3Ztu+VzrrIpHjL90= +github.com/jackc/pgx v3.3.0+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= +github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs= +github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= +golang.org/x/crypto v0.0.0-20190122013713-64072686203f h1:u1CmMhe3a44hy8VIgpInORnI01UVaUYheqR7x9BxT3c= +golang.org/x/crypto v0.0.0-20190122013713-64072686203f/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190110200230-915654e7eabc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/stretchr/testify.v1 v1.2.2 h1:yhQC6Uy5CqibAIlk1wlusa/MJ3iAN49/BsR/dCCKz3M= +gopkg.in/stretchr/testify.v1 v1.2.2/go.mod h1:QI5V/q6UbPmuhtm10CaFZxED9NreB8PnFYN9JcR6TxU= diff --git a/interface.go b/interface.go index 33d2078e..4f084421 100644 --- a/interface.go +++ b/interface.go @@ -5,6 +5,7 @@ package xorm import ( + "context" "database/sql" "reflect" "time" @@ -73,6 +74,7 @@ type EngineInterface interface { Before(func(interface{})) *Session Charset(charset string) *Session ClearCache(...interface{}) error + Context(context.Context) *Session CreateTables(...interface{}) error DBMetas() ([]*core.Table, error) Dialect() core.Dialect diff --git a/processors_test.go b/processors_test.go index e8c27e89..d1efc047 100644 --- a/processors_test.go +++ b/processors_test.go @@ -154,118 +154,86 @@ func TestProcessors(t *testing.T) { } _, err = testEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.AfterInsertedFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.AfterInsertedViaExt == 0 { - t.Error(errors.New("AfterInsertedViaExt not set")) - } - } + assert.NoError(t, err) + assert.True(t, p.Id > 0, "Inserted ID not set") + assert.True(t, p.B4InsertFlag > 0, "B4InsertFlag not set") + assert.True(t, p.AfterInsertedFlag > 0, "B4InsertFlag not set") + assert.True(t, p.B4InsertViaExt > 0, "B4InsertFlag not set") + assert.True(t, p.AfterInsertedViaExt > 0, "AfterInsertedViaExt not set") p2 := &ProcessorsStruct{} - _, err = testEngine.ID(p.Id).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p2.AfterInsertedFlag != 0 { - t.Error(errors.New("AfterInsertedFlag is set")) - } - if p2.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p2.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } - if p2.BeforeSetFlag != 9 { - t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) - } - if p2.AfterSetFlag != 9 { - t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) - } - } + has, err := testEngine.ID(p.Id).Get(p2) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, p2.B4InsertFlag > 0, "B4InsertFlag not set") + assert.True(t, p2.AfterInsertedFlag == 0, "AfterInsertedFlag is set") + assert.True(t, p2.B4InsertViaExt > 0, "B4InsertViaExt not set") + assert.True(t, p2.AfterInsertedViaExt == 0, "AfterInsertedViaExt is set") + assert.True(t, p2.BeforeSetFlag == 9, fmt.Sprintf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) + assert.True(t, p2.AfterSetFlag == 9, fmt.Sprintf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) // -- // test find processors var p2Find []*ProcessorsStruct err = testEngine.Find(&p2Find) - if err != nil { + assert.NoError(t, err) + + if len(p2Find) != 1 { + err = errors.New("Should get 1") t.Error(err) - panic(err) - } else { - if len(p2Find) != 1 { - err = errors.New("Should get 1") - t.Error(err) - } - p21 := p2Find[0] - if p21.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p21.AfterInsertedFlag != 0 { - t.Error(errors.New("AfterInsertedFlag is set")) - } - if p21.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p21.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } - if p21.BeforeSetFlag != 9 { - t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p21.BeforeSetFlag)) - } - if p21.AfterSetFlag != 9 { - t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p21.BeforeSetFlag)) - } + } + p21 := p2Find[0] + if p21.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p21.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p21.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p21.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + if p21.BeforeSetFlag != 9 { + t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p21.BeforeSetFlag)) + } + if p21.AfterSetFlag != 9 { + t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p21.BeforeSetFlag)) } // -- // test find map processors var p2FindMap = make(map[int64]*ProcessorsStruct) err = testEngine.Find(&p2FindMap) - if err != nil { - t.Error(err) - panic(err) - } else { - if len(p2FindMap) != 1 { - err = errors.New("Should get 1") - t.Error(err) - } - var p22 *ProcessorsStruct - for _, v := range p2FindMap { - p22 = v - } + assert.NoError(t, err) - if p22.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p22.AfterInsertedFlag != 0 { - t.Error(errors.New("AfterInsertedFlag is set")) - } - if p22.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p22.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } - if p22.BeforeSetFlag != 9 { - t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p22.BeforeSetFlag)) - } - if p22.AfterSetFlag != 9 { - t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p22.BeforeSetFlag)) - } + if len(p2FindMap) != 1 { + err = errors.New("Should get 1") + t.Error(err) + } + var p22 *ProcessorsStruct + for _, v := range p2FindMap { + p22 = v + } + + if p22.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p22.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p22.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p22.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + if p22.BeforeSetFlag != 9 { + t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p22.BeforeSetFlag)) + } + if p22.AfterSetFlag != 9 { + t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p22.BeforeSetFlag)) } // -- @@ -289,48 +257,43 @@ func TestProcessors(t *testing.T) { p = p2 // reset _, err = testEngine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag == 0 { - t.Error(errors.New("AfterUpdatedFlag not set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt == 0 { - t.Error(errors.New("AfterUpdatedViaExt not set")) - } + assert.NoError(t, err) + + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt == 0 { + t.Error(errors.New("AfterUpdatedViaExt not set")) } p2 = &ProcessorsStruct{} - _, err = testEngine.ID(p.Id).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p2.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set: " + string(p.AfterUpdatedFlag))) - } - if p2.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p2.AfterUpdatedViaExt != 0 { - t.Error(errors.New("AfterUpdatedViaExt is set: " + string(p.AfterUpdatedViaExt))) - } - if p2.BeforeSetFlag != 9 { - t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) - } - if p2.AfterSetFlag != 9 { - t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) - } + has, err = testEngine.ID(p.Id).Get(p2) + assert.NoError(t, err) + assert.True(t, has) + + if p2.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p2.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set: " + string(p.AfterUpdatedFlag))) + } + if p2.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p2.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set: " + string(p.AfterUpdatedViaExt))) + } + if p2.BeforeSetFlag != 9 { + t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) + } + if p2.AfterSetFlag != 9 { + t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) } // -- @@ -353,22 +316,18 @@ func TestProcessors(t *testing.T) { p = p2 // reset _, err = testEngine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4DeleteFlag == 0 { - t.Error(errors.New("B4DeleteFlag not set")) - } - if p.AfterDeletedFlag == 0 { - t.Error(errors.New("AfterDeletedFlag not set")) - } - if p.B4DeleteViaExt == 0 { - t.Error(errors.New("B4DeleteViaExt not set")) - } - if p.AfterDeletedViaExt == 0 { - t.Error(errors.New("AfterDeletedViaExt not set")) - } + assert.NoError(t, err) + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag == 0 { + t.Error(errors.New("AfterDeletedFlag not set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt == 0 { + t.Error(errors.New("AfterDeletedViaExt not set")) } // -- @@ -377,54 +336,46 @@ func TestProcessors(t *testing.T) { pslice = append(pslice, &ProcessorsStruct{}) pslice = append(pslice, &ProcessorsStruct{}) cnt, err := testEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice) - if err != nil { - t.Error(err) - panic(err) - } else { - if cnt != 2 { - t.Error(errors.New("incorrect insert count")) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt, "incorrect insert count") + + for _, elem := range pslice { + if elem.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) } - for _, elem := range pslice { - if elem.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if elem.AfterInsertedFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if elem.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if elem.AfterInsertedViaExt == 0 { - t.Error(errors.New("AfterInsertedViaExt not set")) - } + if elem.AfterInsertedFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if elem.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if elem.AfterInsertedViaExt == 0 { + t.Error(errors.New("AfterInsertedViaExt not set")) } } for _, elem := range pslice { p = &ProcessorsStruct{} _, err = testEngine.ID(elem.Id).Get(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p2.AfterInsertedFlag != 0 { - t.Error(errors.New("AfterInsertedFlag is set")) - } - if p2.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p2.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } - if p2.BeforeSetFlag != 9 { - t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) - } - if p2.AfterSetFlag != 9 { - t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) - } + assert.NoError(t, err) + + if p2.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p2.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p2.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p2.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + if p2.BeforeSetFlag != 9 { + t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) + } + if p2.AfterSetFlag != 9 { + t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) } } // -- @@ -434,24 +385,17 @@ func TestProcessorsTx(t *testing.T) { assert.NoError(t, prepareEngine()) err := testEngine.DropTables(&ProcessorsStruct{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) err = testEngine.CreateTables(&ProcessorsStruct{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) // test insert processors with tx rollback session := testEngine.NewSession() + defer session.Close() + err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) p := &ProcessorsStruct{} b4InsertFunc := func(bean interface{}) { @@ -470,133 +414,117 @@ func TestProcessorsTx(t *testing.T) { } } _, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.AfterInsertedFlag != 0 { - t.Error(errors.New("B4InsertFlag is set")) - } - if p.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } + assert.NoError(t, err) + + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("B4InsertFlag is set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) } err = session.Rollback() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.AfterInsertedFlag != 0 { - t.Error(errors.New("B4InsertFlag is set")) - } - if p.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } + assert.NoError(t, err) + + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("B4InsertFlag is set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + session.Close() + p2 := &ProcessorsStruct{} _, err = testEngine.ID(p.Id).Get(p2) - if err != nil { + assert.NoError(t, err) + + if p2.Id > 0 { + err = errors.New("tx got committed upon insert!?") t.Error(err) panic(err) - } else { - if p2.Id > 0 { - err = errors.New("tx got committed upon insert!?") - t.Error(err) - panic(err) - } } // -- // test insert processors with tx commit session = testEngine.NewSession() + defer session.Close() + err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) p = &ProcessorsStruct{} _, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.AfterInsertedFlag != 0 { - t.Error(errors.New("AfterInsertedFlag is set")) - } - if p.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } + assert.NoError(t, err) + + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) } err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.AfterInsertedFlag == 0 { - t.Error(errors.New("AfterInsertedFlag not set")) - } - if p.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p.AfterInsertedViaExt == 0 { - t.Error(errors.New("AfterInsertedViaExt not set")) - } + assert.NoError(t, err) + + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) } + if p.AfterInsertedFlag == 0 { + t.Error(errors.New("AfterInsertedFlag not set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt == 0 { + t.Error(errors.New("AfterInsertedViaExt not set")) + } + session.Close() p2 = &ProcessorsStruct{} _, err = testEngine.ID(p.Id).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p2.AfterInsertedFlag != 0 { - t.Error(errors.New("AfterInsertedFlag is set")) - } - if p2.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p2.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } + assert.NoError(t, err) + + if p2.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) } + if p2.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p2.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p2.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + insertedId := p2.Id // -- // test update processors with tx rollback session = testEngine.NewSession() + defer session.Close() + err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) b4UpdateFunc := func(bean interface{}) { if v, ok := (bean).(*ProcessorsStruct); ok { @@ -617,183 +545,160 @@ func TestProcessorsTx(t *testing.T) { p = p2 // reset _, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt != 0 { - t.Error(errors.New("AfterUpdatedViaExt is set")) - } + assert.NoError(t, err) + + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + err = session.Rollback() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt != 0 { - t.Error(errors.New("AfterUpdatedViaExt is set")) - } + assert.NoError(t, err) + + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) } session.Close() p2 = &ProcessorsStruct{} _, err = testEngine.ID(insertedId).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.B4UpdateFlag != 0 { - t.Error(errors.New("B4UpdateFlag is set")) - } - if p2.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set")) - } - if p2.B4UpdateViaExt != 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p2.AfterUpdatedViaExt != 0 { - t.Error(errors.New("AfterUpdatedViaExt is set")) - } + assert.NoError(t, err) + + if p2.B4UpdateFlag != 0 { + t.Error(errors.New("B4UpdateFlag is set")) + } + if p2.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p2.B4UpdateViaExt != 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p2.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) } // -- // test update processors with tx rollback session = testEngine.NewSession() + defer session.Close() + err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) p = &ProcessorsStruct{Id: insertedId} _, err = session.Update(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set")) - } + assert.NoError(t, err) + + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag == 0 { - t.Error(errors.New("AfterUpdatedFlag not set")) - } - if p.AfterDeletedFlag != 0 { - t.Error(errors.New("AfterDeletedFlag set")) - } - if p.AfterInsertedFlag != 0 { - t.Error(errors.New("AfterInsertedFlag set")) - } + assert.NoError(t, err) + + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag set")) } session.Close() // test update processors with tx commit session = testEngine.NewSession() + defer session.Close() + err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) p = &ProcessorsStruct{} _, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt != 0 { - t.Error(errors.New("AfterUpdatedViaExt is set")) - } + assert.NoError(t, err) + + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag == 0 { - t.Error(errors.New("AfterUpdatedFlag not set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt == 0 { - t.Error(errors.New("AfterUpdatedViaExt not set")) - } + assert.NoError(t, err) + + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt == 0 { + t.Error(errors.New("AfterUpdatedViaExt not set")) + } + session.Close() p2 = &ProcessorsStruct{} _, err = testEngine.ID(insertedId).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag == 0 { - t.Error(errors.New("AfterUpdatedFlag not set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt == 0 { - t.Error(errors.New("AfterUpdatedViaExt not set")) - } + assert.NoError(t, err) + + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt == 0 { + t.Error(errors.New("AfterUpdatedViaExt not set")) } // -- // test delete processors with tx rollback session = testEngine.NewSession() + defer session.Close() + err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) b4DeleteFunc := func(bean interface{}) { if v, ok := (bean).(*ProcessorsStruct); ok { @@ -814,152 +719,131 @@ func TestProcessorsTx(t *testing.T) { p = &ProcessorsStruct{} // reset _, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4DeleteFlag == 0 { - t.Error(errors.New("B4DeleteFlag not set")) - } - if p.AfterDeletedFlag != 0 { - t.Error(errors.New("AfterDeletedFlag is set")) - } - if p.B4DeleteViaExt == 0 { - t.Error(errors.New("B4DeleteViaExt not set")) - } - if p.AfterDeletedViaExt != 0 { - t.Error(errors.New("AfterDeletedViaExt is set")) - } + assert.NoError(t, err) + + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + err = session.Rollback() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4DeleteFlag == 0 { - t.Error(errors.New("B4DeleteFlag not set")) - } - if p.AfterDeletedFlag != 0 { - t.Error(errors.New("AfterDeletedFlag is set")) - } - if p.B4DeleteViaExt == 0 { - t.Error(errors.New("B4DeleteViaExt not set")) - } - if p.AfterDeletedViaExt != 0 { - t.Error(errors.New("AfterDeletedViaExt is set")) - } + assert.NoError(t, err) + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + session.Close() p2 = &ProcessorsStruct{} _, err = testEngine.ID(insertedId).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.B4DeleteFlag != 0 { - t.Error(errors.New("B4DeleteFlag is set")) - } - if p2.AfterDeletedFlag != 0 { - t.Error(errors.New("AfterDeletedFlag is set")) - } - if p2.B4DeleteViaExt != 0 { - t.Error(errors.New("B4DeleteViaExt is set")) - } - if p2.AfterDeletedViaExt != 0 { - t.Error(errors.New("AfterDeletedViaExt is set")) - } + assert.NoError(t, err) + + if p2.B4DeleteFlag != 0 { + t.Error(errors.New("B4DeleteFlag is set")) + } + if p2.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p2.B4DeleteViaExt != 0 { + t.Error(errors.New("B4DeleteViaExt is set")) + } + if p2.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) } // -- // test delete processors with tx commit session = testEngine.NewSession() + defer session.Close() + err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) p = &ProcessorsStruct{} _, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4DeleteFlag == 0 { - t.Error(errors.New("B4DeleteFlag not set")) - } - if p.AfterDeletedFlag != 0 { - t.Error(errors.New("AfterDeletedFlag is set")) - } - if p.B4DeleteViaExt == 0 { - t.Error(errors.New("B4DeleteViaExt not set")) - } - if p.AfterDeletedViaExt != 0 { - t.Error(errors.New("AfterDeletedViaExt is set")) - } + assert.NoError(t, err) + + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4DeleteFlag == 0 { - t.Error(errors.New("B4DeleteFlag not set")) - } - if p.AfterDeletedFlag == 0 { - t.Error(errors.New("AfterDeletedFlag not set")) - } - if p.B4DeleteViaExt == 0 { - t.Error(errors.New("B4DeleteViaExt not set")) - } - if p.AfterDeletedViaExt == 0 { - t.Error(errors.New("AfterDeletedViaExt not set")) - } + assert.NoError(t, err) + + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) } + if p.AfterDeletedFlag == 0 { + t.Error(errors.New("AfterDeletedFlag not set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt == 0 { + t.Error(errors.New("AfterDeletedViaExt not set")) + } + session.Close() // test delete processors with tx commit session = testEngine.NewSession() + defer session.Close() + err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) p = &ProcessorsStruct{Id: insertedId} - fmt.Println("delete") _, err = session.Delete(p) + assert.NoError(t, err) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4DeleteFlag == 0 { - t.Error(errors.New("B4DeleteFlag not set")) - } - if p.AfterDeletedFlag != 0 { - t.Error(errors.New("AfterDeletedFlag is set")) - } + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4DeleteFlag == 0 { - t.Error(errors.New("B4DeleteFlag not set")) - } - if p.AfterDeletedFlag == 0 { - t.Error(errors.New("AfterDeletedFlag not set")) - } - if p.AfterInsertedFlag != 0 { - t.Error(errors.New("AfterInsertedFlag set")) - } - if p.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag set")) - } + assert.NoError(t, err) + + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag == 0 { + t.Error(errors.New("AfterDeletedFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag set")) } session.Close() // -- diff --git a/rows.go b/rows.go index 54ec7f37..ab5607e5 100644 --- a/rows.go +++ b/rows.go @@ -14,11 +14,8 @@ import ( // Rows rows wrapper a rows to type Rows struct { - NoTypeCheck bool - session *Session rows *core.Rows - fields []string beanType reflect.Type lastError error } @@ -57,13 +54,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { return nil, err } - rows.fields, err = rows.rows.Columns() - if err != nil { - rows.lastError = err - rows.Close() - return nil, err - } - return rows, nil } @@ -90,7 +80,7 @@ func (rows *Rows) Scan(bean interface{}) error { return rows.lastError } - if !rows.NoTypeCheck && reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { + if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) } @@ -98,13 +88,18 @@ func (rows *Rows) Scan(bean interface{}) error { return err } - scanResults, err := rows.session.row2Slice(rows.rows, rows.fields, bean) + fields, err := rows.rows.Columns() + if err != nil { + return err + } + + scanResults, err := rows.session.row2Slice(rows.rows, fields, bean) if err != nil { return err } dataStruct := rValue(bean) - _, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable) + _, err = rows.session.slice2Bean(scanResults, fields, bean, &dataStruct, rows.session.statement.RefTable) if err != nil { return err } diff --git a/rows_test.go b/rows_test.go index ee121c5e..c5b44279 100644 --- a/rows_test.go +++ b/rows_test.go @@ -67,3 +67,68 @@ func TestRows(t *testing.T) { } assert.EqualValues(t, 1, cnt) } + +func TestRowsMyTableName(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserRowsMyTable struct { + Id int64 + IsMan bool + } + + var tableName = "user_rows_my_table_name" + + assert.NoError(t, testEngine.Table(tableName).Sync2(new(UserRowsMyTable))) + + cnt, err := testEngine.Table(tableName).Insert(&UserRowsMyTable{ + IsMan: true, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + rows, err := testEngine.Table(tableName).Rows(new(UserRowsMyTable)) + assert.NoError(t, err) + defer rows.Close() + + cnt = 0 + user := new(UserRowsMyTable) + for rows.Next() { + err = rows.Scan(user) + assert.NoError(t, err) + cnt++ + } + assert.EqualValues(t, 1, cnt) +} + +type UserRowsSpecTable struct { + Id int64 + IsMan bool +} + +func (UserRowsSpecTable) TableName() string { + return "user_rows_my_table_name" +} + +func TestRowsSpecTableName(t *testing.T) { + assert.NoError(t, prepareEngine()) + assert.NoError(t, testEngine.Sync2(new(UserRowsSpecTable))) + + cnt, err := testEngine.Insert(&UserRowsSpecTable{ + IsMan: true, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + rows, err := testEngine.Rows(new(UserRowsSpecTable)) + assert.NoError(t, err) + defer rows.Close() + + cnt = 0 + user := new(UserRowsSpecTable) + for rows.Next() { + err = rows.Scan(user) + assert.NoError(t, err) + cnt++ + } + assert.EqualValues(t, 1, cnt) +} diff --git a/session.go b/session.go index 3775eb01..2307a414 100644 --- a/session.go +++ b/session.go @@ -5,6 +5,7 @@ package xorm import ( + "context" "database/sql" "encoding/json" "errors" @@ -17,6 +18,13 @@ import ( "github.com/go-xorm/core" ) +type sessionType int + +const ( + engineSession sessionType = iota + groupSession +) + // Session keep a pointer to sql.DB and provides all execution of all // kind of database operations. type Session struct { @@ -51,7 +59,8 @@ type Session struct { lastSQL string lastSQLArgs []interface{} - err error + ctx context.Context + sessionType sessionType } // Clone copy all the session's content and return a new session @@ -82,6 +91,8 @@ func (session *Session) Init() { session.lastSQL = "" session.lastSQLArgs = []interface{}{} + + session.ctx = session.engine.defaultContext } // Close release the connection from pool @@ -102,6 +113,12 @@ func (session *Session) Close() { } } +// ContextCache enable context cache or not +func (session *Session) ContextCache(context ContextCache) *Session { + session.statement.context = context + return session +} + // IsClosed returns if session is closed func (session *Session) IsClosed() bool { return session.db == nil @@ -269,7 +286,7 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, var has bool stmt, has = session.stmtCache[crc] if !has { - stmt, err = db.Prepare(sqlStr) + stmt, err = db.PrepareContext(session.ctx, sqlStr) if err != nil { return nil, err } @@ -839,3 +856,12 @@ func (session *Session) Unscoped() *Session { session.statement.Unscoped() return session } + +func (session *Session) incrVersionFieldValue(fieldValue *reflect.Value) { + switch fieldValue.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fieldValue.SetInt(fieldValue.Int() + 1) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + fieldValue.SetUint(fieldValue.Uint() + 1) + } +} diff --git a/context.go b/session_context.go similarity index 60% rename from context.go rename to session_context.go index 074ba35a..915f0568 100644 --- a/context.go +++ b/session_context.go @@ -1,18 +1,15 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. +// Copyright 2019 The Xorm Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build go1.8 - package xorm import "context" -// PingContext tests if database is alive -func (engine *Engine) PingContext(ctx context.Context) error { - session := engine.NewSession() - defer session.Close() - return session.PingContext(ctx) +// Context sets the context on this session +func (session *Session) Context(ctx context.Context) *Session { + session.ctx = ctx + return session } // PingContext test if database is ok diff --git a/session_context_test.go b/session_context_test.go new file mode 100644 index 00000000..b6fd1f6e --- /dev/null +++ b/session_context_test.go @@ -0,0 +1,33 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestQueryContext(t *testing.T) { + type ContextQueryStruct struct { + Id int64 + Name string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(ContextQueryStruct)) + + _, err := testEngine.Insert(&ContextQueryStruct{Name: "1"}) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + has, err := testEngine.Context(ctx).Exist(&ContextQueryStruct{Name: "1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") + assert.False(t, has) +} diff --git a/session_delete.go b/session_delete.go index d9cf3ea9..26782f69 100644 --- a/session_delete.go +++ b/session_delete.go @@ -79,6 +79,10 @@ func (session *Session) Delete(bean interface{}) (int64, error) { defer session.Close() } + if session.statement.lastError != nil { + return 0, session.statement.lastError + } + if err := session.statement.setRefBean(bean); err != nil { return 0, err } @@ -199,7 +203,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { }) } - if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { + if cacher := session.engine.getCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) } diff --git a/session_delete_test.go b/session_delete_test.go index 916dab46..66032afe 100644 --- a/session_delete_test.go +++ b/session_delete_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/go-xorm/core" "github.com/stretchr/testify/assert" ) @@ -21,11 +22,27 @@ func TestDelete(t *testing.T) { assert.NoError(t, testEngine.Sync2(new(UserinfoDelete))) + session := testEngine.NewSession() + defer session.Close() + + var err error + if testEngine.Dialect().DBType() == core.MSSQL { + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT userinfo_delete ON") + assert.NoError(t, err) + } + user := UserinfoDelete{Uid: 1} - cnt, err := testEngine.Insert(&user) + cnt, err := session.Insert(&user) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + if testEngine.Dialect().DBType() == core.MSSQL { + err = session.Commit() + assert.NoError(t, err) + } + cnt, err = testEngine.Delete(&UserinfoDelete{Uid: user.Uid}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) @@ -40,7 +57,7 @@ func TestDelete(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - cnt, err = testEngine.Where("id=?", user.Uid).Delete(&UserinfoDelete{}) + cnt, err = testEngine.Where("`id`=?", user.Uid).Delete(&UserinfoDelete{}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) diff --git a/session_exist.go b/session_exist.go index 74a660e8..6aa154aa 100644 --- a/session_exist.go +++ b/session_exist.go @@ -19,6 +19,10 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { defer session.Close() } + if session.statement.lastError != nil { + return false, session.statement.lastError + } + var sqlStr string var args []interface{} var err error @@ -37,14 +41,18 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { } if session.engine.dialect.DBType() == core.MSSQL { - sqlStr = fmt.Sprintf("SELECT top 1 * FROM %s WHERE %s", tableName, condSQL) + sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s WHERE %s", tableName, condSQL) + } else if session.engine.dialect.DBType() == core.ORACLE { + sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) AND ROWNUM=1", tableName, condSQL) } else { sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL) } args = condArgs } else { if session.engine.dialect.DBType() == core.MSSQL { - sqlStr = fmt.Sprintf("SELECT top 1 * FROM %s", tableName) + sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s", tableName) + } else if session.engine.dialect.DBType() == core.ORACLE { + sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE ROWNUM=1", tableName) } else { sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName) } diff --git a/session_find.go b/session_find.go index b75f8347..48ee3209 100644 --- a/session_find.go +++ b/session_find.go @@ -63,6 +63,10 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte } func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { + if session.statement.lastError != nil { + return session.statement.lastError + } + sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { return errors.New("needs a pointer to a slice or a map") @@ -176,7 +180,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } if session.canCache() { - if cacher := session.engine.getCacher(table.Name); cacher != nil && + if cacher := session.engine.getCacher(session.statement.TableName()); cacher != nil && !session.statement.IsDistinct && !session.statement.unscoped { err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) diff --git a/session_get.go b/session_get.go index 69194a23..5ecf2f37 100644 --- a/session_get.go +++ b/session_get.go @@ -7,6 +7,7 @@ package xorm import ( "database/sql" "errors" + "fmt" "reflect" "strconv" @@ -23,6 +24,10 @@ func (session *Session) Get(bean interface{}) (bool, error) { } func (session *Session) get(bean interface{}) (bool, error) { + if session.statement.lastError != nil { + return false, session.statement.lastError + } + beanValue := reflect.ValueOf(bean) if beanValue.Kind() != reflect.Ptr { return false, errors.New("needs a pointer to a value") @@ -57,7 +62,7 @@ func (session *Session) get(bean interface{}) (bool, error) { table := session.statement.RefTable if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { - if cacher := session.engine.getCacher(table.Name); cacher != nil && + if cacher := session.engine.getCacher(session.statement.TableName()); cacher != nil && !session.statement.unscoped { has, err := session.cacheGet(bean, sqlStr, args...) if err != ErrCacheFailed { @@ -66,7 +71,28 @@ func (session *Session) get(bean interface{}) (bool, error) { } } - return session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...) + context := session.statement.context + if context != nil { + res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args)) + if res != nil { + structValue := reflect.Indirect(reflect.ValueOf(bean)) + structValue.Set(reflect.Indirect(reflect.ValueOf(res))) + session.lastSQL = "" + session.lastSQLArgs = nil + return true, nil + } + } + + has, err := session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...) + if err != nil || !has { + return has, err + } + + if context != nil { + context.Put(fmt.Sprintf("%v-%v", sqlStr, args), bean) + } + + return true, nil } func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { diff --git a/session_get_test.go b/session_get_test.go index 4ec7cf02..025a747a 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -84,7 +84,11 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) var money2 float64 - has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2) + if testEngine.Dialect().DBType() == core.MSSQL { + has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2) + } else { + has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2) + } assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2)) @@ -156,14 +160,23 @@ func TestGetStruct(t *testing.T) { assert.NoError(t, testEngine.Sync2(new(UserinfoGet))) + session := testEngine.NewSession() + defer session.Close() + var err error if testEngine.Dialect().DBType() == core.MSSQL { - _, err = testEngine.Exec("SET IDENTITY_INSERT userinfo_get ON") + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON") assert.NoError(t, err) } - cnt, err := testEngine.Insert(&UserinfoGet{Uid: 2}) + cnt, err := session.Insert(&UserinfoGet{Uid: 2}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + if testEngine.Dialect().DBType() == core.MSSQL { + err = session.Commit() + assert.NoError(t, err) + } user := UserinfoGet{Uid: 2} has, err := testEngine.Get(&user) @@ -319,3 +332,104 @@ func TestGetStructId(t *testing.T) { assert.True(t, has) assert.EqualValues(t, 2, maxid.Id) } + +func TestContextGet(t *testing.T) { + type ContextGetStruct struct { + Id int64 + Name string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(ContextGetStruct)) + + _, err := testEngine.Insert(&ContextGetStruct{Name: "1"}) + assert.NoError(t, err) + + sess := testEngine.NewSession() + defer sess.Close() + + context := NewMemoryContextCache() + + var c2 ContextGetStruct + has, err := sess.ID(1).NoCache().ContextCache(context).Get(&c2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, c2.Id) + assert.EqualValues(t, "1", c2.Name) + sql, args := sess.LastSQL() + assert.True(t, len(sql) > 0) + assert.True(t, len(args) > 0) + + var c3 ContextGetStruct + has, err = sess.ID(1).NoCache().ContextCache(context).Get(&c3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, c3.Id) + assert.EqualValues(t, "1", c3.Name) + sql, args = sess.LastSQL() + assert.True(t, len(sql) == 0) + assert.True(t, len(args) == 0) +} + +func TestContextGet2(t *testing.T) { + type ContextGetStruct2 struct { + Id int64 + Name string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(ContextGetStruct2)) + + _, err := testEngine.Insert(&ContextGetStruct2{Name: "1"}) + assert.NoError(t, err) + + context := NewMemoryContextCache() + + var c2 ContextGetStruct2 + has, err := testEngine.ID(1).NoCache().ContextCache(context).Get(&c2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, c2.Id) + assert.EqualValues(t, "1", c2.Name) + + var c3 ContextGetStruct2 + has, err = testEngine.ID(1).NoCache().ContextCache(context).Get(&c3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, c3.Id) + assert.EqualValues(t, "1", c3.Name) +} + +type GetCustomTableInterface interface { + TableName() string +} + +type MyGetCustomTableImpletation struct { + Id int64 `json:"id"` + Name string `json:"name"` +} + +const getCustomTableName = "GetCustomTableInterface" + +func (m *MyGetCustomTableImpletation) TableName() string { + return getCustomTableName +} + +func TestGetCustomTableInterface(t *testing.T) { + assert.NoError(t, prepareEngine()) + assert.NoError(t, testEngine.Table(getCustomTableName).Sync2(new(MyGetCustomTableImpletation))) + + exist, err := testEngine.IsTableExist(getCustomTableName) + assert.NoError(t, err) + assert.True(t, exist) + + _, err = testEngine.Insert(&MyGetCustomTableImpletation{ + Name: "xlw", + }) + assert.NoError(t, err) + + var c GetCustomTableInterface = new(MyGetCustomTableImpletation) + has, err := testEngine.Get(c) + assert.NoError(t, err) + assert.True(t, has) +} diff --git a/session_insert.go b/session_insert.go index 2ea58fda..aa2a432b 100644 --- a/session_insert.go +++ b/session_insert.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "reflect" + "sort" "strconv" "strings" @@ -24,32 +25,67 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { } for _, bean := range beans { - sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - if sliceValue.Kind() == reflect.Slice { - size := sliceValue.Len() - if size > 0 { - if session.engine.SupportInsertMany() { - cnt, err := session.innerInsertMulti(bean) - if err != nil { - return affected, err - } - affected += cnt - } else { - for i := 0; i < size; i++ { - cnt, err := session.innerInsert(sliceValue.Index(i).Interface()) - if err != nil { - return affected, err - } - affected += cnt - } - } - } - } else { - cnt, err := session.innerInsert(bean) + switch bean.(type) { + case map[string]interface{}: + cnt, err := session.insertMapInterface(bean.(map[string]interface{})) if err != nil { return affected, err } affected += cnt + case []map[string]interface{}: + s := bean.([]map[string]interface{}) + session.autoResetStatement = false + for i := 0; i < len(s); i++ { + cnt, err := session.insertMapInterface(s[i]) + if err != nil { + return affected, err + } + affected += cnt + } + case map[string]string: + cnt, err := session.insertMapString(bean.(map[string]string)) + if err != nil { + return affected, err + } + affected += cnt + case []map[string]string: + s := bean.([]map[string]string) + session.autoResetStatement = false + for i := 0; i < len(s); i++ { + cnt, err := session.insertMapString(s[i]) + if err != nil { + return affected, err + } + affected += cnt + } + default: + sliceValue := reflect.Indirect(reflect.ValueOf(bean)) + if sliceValue.Kind() == reflect.Slice { + size := sliceValue.Len() + if size > 0 { + if session.engine.SupportInsertMany() { + cnt, err := session.innerInsertMulti(bean) + if err != nil { + return affected, err + } + affected += cnt + } else { + for i := 0; i < size; i++ { + cnt, err := session.innerInsert(sliceValue.Index(i).Interface()) + if err != nil { + return affected, err + } + affected += cnt + } + } + } + } else { + cnt, err := session.innerInsert(bean) + if err != nil { + return affected, err + } + affected += cnt + } } } @@ -337,21 +373,30 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { var sqlStr string var tableName = session.statement.TableName() + var output string + if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 { + output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) + } if len(colPlaces) > 0 { - sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", + sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v)%s VALUES (%v)", session.engine.Quote(tableName), session.engine.QuoteStr(), strings.Join(colNames, session.engine.Quote(", ")), session.engine.QuoteStr(), + output, colPlaces) } else { if session.engine.dialect.DBType() == core.MYSQL { sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName)) } else { - sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.engine.Quote(tableName)) + sqlStr = fmt.Sprintf("INSERT INTO %s%s DEFAULT VALUES", session.engine.Quote(tableName), output) } } + if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES { + sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement) + } + handleAfterInsertProcessorFunc := func(bean interface{}) { if session.isAutoCommit { for _, closure := range session.afterClosures { @@ -397,7 +442,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if err != nil { session.engine.logger.Error(err) } else if verValue.IsValid() && verValue.CanSet() { - verValue.SetInt(1) + session.incrVersionFieldValue(verValue) } } @@ -423,9 +468,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue.Set(int64ToIntValue(id, aiValue.Type())) return 1, nil - } else if session.engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 { - //assert table.AutoIncrement != "" - sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement) + } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == core.POSTGRES || session.engine.dialect.DBType() == core.MSSQL) { res, err := session.queryBytes(sqlStr, args...) if err != nil { @@ -440,12 +483,12 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if err != nil { session.engine.logger.Error(err) } else if verValue.IsValid() && verValue.CanSet() { - verValue.SetInt(1) + session.incrVersionFieldValue(verValue) } } if len(res) < 1 { - return 0, errors.New("insert no error but not returned id") + return 0, errors.New("insert successfully but not returned id") } idByte := res[0][table.AutoIncrement] @@ -481,7 +524,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if err != nil { session.engine.logger.Error(err) } else if verValue.IsValid() && verValue.CanSet() { - verValue.SetInt(1) + session.incrVersionFieldValue(verValue) } } @@ -622,3 +665,83 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } return colNames, args, nil } + +func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) { + if len(m) == 0 { + return 0, ErrParamsType + } + + var columns = make([]string, 0, len(m)) + for k := range m { + columns = append(columns, k) + } + sort.Strings(columns) + + qm := strings.Repeat("?,", len(columns)) + qm = "(" + qm[:len(qm)-1] + ")" + + tableName := session.statement.TableName() + if len(tableName) <= 0 { + return 0, ErrTableNotFound + } + + var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm) + var args = make([]interface{}, 0, len(m)) + for _, colName := range columns { + args = append(args, m[colName]) + } + + if err := session.cacheInsert(tableName); err != nil { + return 0, err + } + + res, err := session.exec(sql, args...) + if err != nil { + return 0, err + } + affected, err := res.RowsAffected() + if err != nil { + return 0, err + } + return affected, nil +} + +func (session *Session) insertMapString(m map[string]string) (int64, error) { + if len(m) == 0 { + return 0, ErrParamsType + } + + var columns = make([]string, 0, len(m)) + for k := range m { + columns = append(columns, k) + } + sort.Strings(columns) + + qm := strings.Repeat("?,", len(columns)) + qm = "(" + qm[:len(qm)-1] + ")" + + tableName := session.statement.TableName() + if len(tableName) <= 0 { + return 0, ErrTableNotFound + } + + var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm) + var args = make([]interface{}, 0, len(m)) + for _, colName := range columns { + args = append(args, m[colName]) + } + + if err := session.cacheInsert(tableName); err != nil { + return 0, err + } + + res, err := session.exec(sql, args...) + if err != nil { + return 0, err + } + affected, err := res.RowsAffected() + if err != nil { + return 0, err + } + return affected, nil +} diff --git a/session_insert_test.go b/session_insert_test.go index 50943032..8e7ffa99 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -145,41 +145,22 @@ func TestInsert(t *testing.T) { user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(), Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true} cnt, err := testEngine.Insert(&user) - fmt.Println(user.Uid) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - } - - if user.Uid <= 0 { - err = errors.New("not return id error") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt, "insert not returned 1") + assert.True(t, user.Uid > 0, "not return id error") user.Uid = 0 cnt, err = testEngine.Insert(&user) + // Username is unique, so this should return error + assert.Error(t, err, "insert should fail but no error returned") + assert.EqualValues(t, 0, cnt, "insert not returned 1") if err == nil { - err = errors.New("insert failed but no return error") - t.Error(err) - panic(err) - } - if cnt != 0 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return + panic("should return err") } } func TestInsertAutoIncr(t *testing.T) { assert.NoError(t, prepareEngine()) - assertSync(t, new(Userinfo)) // auto increment insert @@ -214,20 +195,14 @@ func TestInsertDefault(t *testing.T) { di := new(DefaultInsert) err := testEngine.Sync2(di) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) var di2 = DefaultInsert{Name: "test"} _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("Status")).Insert(&di2) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) has, err := testEngine.Desc("(id)").Get(di) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) if !has { err = errors.New("error with no data") t.Error(err) @@ -780,3 +755,82 @@ func TestAnonymousStruct(t *testing.T) { }) assert.NoError(t, err) } + +func TestInsertMap(t *testing.T) { + type InsertMap struct { + Id int64 + Width uint32 + Height uint32 + Name string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(InsertMap)) + + cnt, err := testEngine.Table(new(InsertMap)).Insert(map[string]interface{}{ + "width": 20, + "height": 10, + "name": "lunny", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var im InsertMap + has, err := testEngine.Get(&im) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 20, im.Width) + assert.EqualValues(t, 10, im.Height) + assert.EqualValues(t, "lunny", im.Name) + + cnt, err = testEngine.Table("insert_map").Insert(map[string]interface{}{ + "width": 30, + "height": 10, + "name": "lunny", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var ims []InsertMap + err = testEngine.Find(&ims) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(ims)) + assert.EqualValues(t, 20, ims[0].Width) + assert.EqualValues(t, 10, ims[0].Height) + assert.EqualValues(t, "lunny", ims[0].Name) + assert.EqualValues(t, 30, ims[1].Width) + assert.EqualValues(t, 10, ims[1].Height) + assert.EqualValues(t, "lunny", ims[1].Name) + + cnt, err = testEngine.Table("insert_map").Insert([]map[string]interface{}{ + { + "width": 40, + "height": 10, + "name": "lunny", + }, + { + "width": 50, + "height": 10, + "name": "lunny", + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + ims = make([]InsertMap, 0, 4) + err = testEngine.Find(&ims) + assert.NoError(t, err) + assert.EqualValues(t, 4, len(ims)) + assert.EqualValues(t, 20, ims[0].Width) + assert.EqualValues(t, 10, ims[0].Height) + assert.EqualValues(t, "lunny", ims[1].Name) + assert.EqualValues(t, 30, ims[1].Width) + assert.EqualValues(t, 10, ims[1].Height) + assert.EqualValues(t, "lunny", ims[1].Name) + assert.EqualValues(t, 40, ims[2].Width) + assert.EqualValues(t, 10, ims[2].Height) + assert.EqualValues(t, "lunny", ims[2].Name) + assert.EqualValues(t, 50, ims[3].Width) + assert.EqualValues(t, 10, ims[3].Height) + assert.EqualValues(t, "lunny", ims[3].Name) +} diff --git a/session_iterate.go b/session_iterate.go index 071fce49..ca996c28 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -23,6 +23,10 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { defer session.Close() } + if session.statement.lastError != nil { + return session.statement.lastError + } + if session.statement.bufferSize > 0 { return session.bufferIterate(bean, fun) } diff --git a/session_query.go b/session_query.go index 1d0b156b..6d597cc4 100644 --- a/session_query.go +++ b/session_query.go @@ -166,6 +166,34 @@ func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, return result, nil } +func row2sliceStr(rows *core.Rows, fields []string) (results []string, err error) { + result := make([]string, 0, len(fields)) + scanResultContainers := make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + var scanResultContainer interface{} + scanResultContainers[i] = &scanResultContainer + } + if err := rows.Scan(scanResultContainers...); err != nil { + return nil, err + } + + for i := 0; i < len(fields); i++ { + rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[i])) + // if row is null then as empty string + if rawValue.Interface() == nil { + result = append(result, "") + continue + } + + if data, err := value2String(&rawValue); err == nil { + result = append(result, data) + } else { + return nil, err + } + } + return result, nil +} + func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { fields, err := rows.Columns() if err != nil { @@ -182,6 +210,22 @@ func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) return resultsSlice, nil } +func rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + for rows.Next() { + record, err := row2sliceStr(rows, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, record) + } + + return resultsSlice, nil +} + // QueryString runs a raw sql and return records as []map[string]string func (session *Session) QueryString(sqlorArgs ...interface{}) ([]map[string]string, error) { if session.isAutoClose { @@ -202,6 +246,26 @@ func (session *Session) QueryString(sqlorArgs ...interface{}) ([]map[string]stri return rows2Strings(rows) } +// QuerySliceString runs a raw sql and return records as [][]string +func (session *Session) QuerySliceString(sqlorArgs ...interface{}) ([][]string, error) { + if session.isAutoClose { + defer session.Close() + } + + sqlStr, args, err := session.genQuerySQL(sqlorArgs...) + if err != nil { + return nil, err + } + + rows, err := session.queryRows(sqlStr, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return rows2SliceString(rows) +} + func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) { resultsMap = make(map[string]interface{}, len(fields)) scanResultContainers := make([]interface{}, len(fields)) diff --git a/session_query_test.go b/session_query_test.go index 2c1fb617..233929af 100644 --- a/session_query_test.go +++ b/session_query_test.go @@ -207,7 +207,7 @@ func TestQueryStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0]["id"]) - if testEngine.Dialect().URI().DbType == core.POSTGRES { + if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { assert.EqualValues(t, "false", records[0]["msg"]) } else { assert.EqualValues(t, "0", records[0]["msg"]) @@ -217,13 +217,50 @@ func TestQueryStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0]["id"]) - if testEngine.Dialect().URI().DbType == core.POSTGRES { + if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { assert.EqualValues(t, "false", records[0]["msg"]) } else { assert.EqualValues(t, "0", records[0]["msg"]) } } +func TestQuerySliceStringNoParam(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type GetVar6 struct { + Id int64 `xorm:"autoincr pk"` + Msg bool `xorm:"bit"` + } + + assert.NoError(t, testEngine.Sync2(new(GetVar6))) + + var data = GetVar6{ + Msg: false, + } + _, err := testEngine.Insert(data) + assert.NoError(t, err) + + records, err := testEngine.Table("get_var6").Limit(1).QuerySliceString() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(records)) + assert.EqualValues(t, "1", records[0][0]) + if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { + assert.EqualValues(t, "false", records[0][1]) + } else { + assert.EqualValues(t, "0", records[0][1]) + } + + records, err = testEngine.Table("get_var6").Where(builder.Eq{"id": 1}).QuerySliceString() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(records)) + assert.EqualValues(t, "1", records[0][0]) + if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { + assert.EqualValues(t, "false", records[0][1]) + } else { + assert.EqualValues(t, "0", records[0][1]) + } +} + func TestQueryInterfaceNoParam(t *testing.T) { assert.NoError(t, prepareEngine()) @@ -297,3 +334,47 @@ func TestQueryWithBuilder(t *testing.T) { assert.NoError(t, err) assertResult(t, results) } + +func TestJoinWithSubQuery(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type JoinWithSubQuery1 struct { + Id int64 `xorm:"autoincr pk"` + Msg string `xorm:"varchar(255)"` + DepartId int64 + Money float32 + } + + type JoinWithSubQueryDepart struct { + Id int64 `xorm:"autoincr pk"` + Name string + } + + testEngine.ShowSQL(true) + + assert.NoError(t, testEngine.Sync2(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart))) + + var depart = JoinWithSubQueryDepart{ + Name: "depart1", + } + cnt, err := testEngine.Insert(&depart) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var q = JoinWithSubQuery1{ + Msg: "message", + DepartId: depart.Id, + Money: 3000, + } + + cnt, err = testEngine.Insert(&q) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var querys []JoinWithSubQuery1 + err = testEngine.Join("INNER", builder.Select("id").From(testEngine.Quote(testEngine.TableName("join_with_sub_query_depart", true))), + "join_with_sub_query_depart.id = join_with_sub_query1.depart_id").Find(&querys) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(querys)) + assert.EqualValues(t, q, querys[0]) +} diff --git a/session_raw.go b/session_raw.go index 47823d67..c2556365 100644 --- a/session_raw.go +++ b/session_raw.go @@ -49,7 +49,7 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row if session.isAutoCommit { var db *core.DB - if session.engine.engineGroup != nil { + if session.sessionType == groupSession { db = session.engine.engineGroup.Slave().DB() } else { db = session.DB() @@ -62,21 +62,21 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row return nil, err } - rows, err := stmt.Query(args...) + rows, err := stmt.QueryContext(session.ctx, args...) if err != nil { return nil, err } return rows, nil } - rows, err := db.Query(sqlStr, args...) + rows, err := db.QueryContext(session.ctx, sqlStr, args...) if err != nil { return nil, err } return rows, nil } - rows, err := session.tx.Query(sqlStr, args...) + rows, err := session.tx.QueryContext(session.ctx, sqlStr, args...) if err != nil { return nil, err } @@ -175,7 +175,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er } if !session.isAutoCommit { - return session.tx.Exec(sqlStr, args...) + return session.tx.ExecContext(session.ctx, sqlStr, args...) } if session.prepareStmt { @@ -184,14 +184,14 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er return nil, err } - res, err := stmt.Exec(args...) + res, err := stmt.ExecContext(session.ctx, args...) if err != nil { return nil, err } return res, nil } - return session.DB().Exec(sqlStr, args...) + return session.DB().ExecContext(session.ctx, sqlStr, args...) } func convertSQLOrArgs(sqlorArgs ...interface{}) (string, []interface{}, error) { diff --git a/session_schema.go b/session_schema.go index 369ec72a..7629906f 100644 --- a/session_schema.go +++ b/session_schema.go @@ -19,7 +19,7 @@ func (session *Session) Ping() error { } session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) - return session.DB().Ping() + return session.DB().PingContext(session.ctx) } // CreateTable create a table according a bean diff --git a/session_stats_test.go b/session_stats_test.go index b66a84b4..61839576 100644 --- a/session_stats_test.go +++ b/session_stats_test.go @@ -18,13 +18,14 @@ func isFloatEq(i, j float64, precision int) bool { } func TestSum(t *testing.T) { - assert.NoError(t, prepareEngine()) - type SumStruct struct { Int int Float float32 } + assert.NoError(t, prepareEngine()) + assert.NoError(t, testEngine.Sync2(new(SumStruct))) + var ( cases = []SumStruct{ {1, 6.2}, @@ -40,8 +41,6 @@ func TestSum(t *testing.T) { f += v.Float } - assert.NoError(t, testEngine.Sync2(new(SumStruct))) - cnt, err := testEngine.Insert(cases) assert.NoError(t, err) assert.EqualValues(t, 3, cnt) @@ -73,6 +72,65 @@ func TestSum(t *testing.T) { assert.EqualValues(t, i, int(sumsInt[0])) } +type SumStructWithTableName struct { + Int int + Float float32 +} + +func (s SumStructWithTableName) TableName() string { + return "sum_struct_with_table_name_1" +} + +func TestSumWithTableName(t *testing.T) { + assert.NoError(t, prepareEngine()) + assert.NoError(t, testEngine.Sync2(new(SumStructWithTableName))) + + var ( + cases = []SumStructWithTableName{ + {1, 6.2}, + {2, 5.3}, + {92, -0.2}, + } + ) + + var i int + var f float32 + for _, v := range cases { + i += v.Int + f += v.Float + } + + cnt, err := testEngine.Insert(cases) + assert.NoError(t, err) + assert.EqualValues(t, 3, cnt) + + colInt := testEngine.GetColumnMapper().Obj2Table("Int") + colFloat := testEngine.GetColumnMapper().Obj2Table("Float") + + sumInt, err := testEngine.Sum(new(SumStructWithTableName), colInt) + assert.NoError(t, err) + assert.EqualValues(t, int(sumInt), i) + + sumFloat, err := testEngine.Sum(new(SumStructWithTableName), colFloat) + assert.NoError(t, err) + assert.Condition(t, func() bool { + return isFloatEq(sumFloat, float64(f), 2) + }) + + sums, err := testEngine.Sums(new(SumStructWithTableName), colInt, colFloat) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(sums)) + assert.EqualValues(t, i, int(sums[0])) + assert.Condition(t, func() bool { + return isFloatEq(sums[1], float64(f), 2) + }) + + sumsInt, err := testEngine.SumsInt(new(SumStructWithTableName), colInt) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(sumsInt)) + assert.EqualValues(t, i, int(sumsInt[0])) +} + func TestSumCustomColumn(t *testing.T) { assert.NoError(t, prepareEngine()) @@ -183,3 +241,36 @@ func TestCountWithOthers(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, total) } + +type CountWithTableName struct { + Id int64 + Name string +} + +func (CountWithTableName) TableName() string { + return "count_with_table_name1" +} + +func TestWithTableName(t *testing.T) { + assert.NoError(t, prepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "orderby", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "limit", + }) + assert.NoError(t, err) + + total, err := testEngine.OrderBy("id desc").Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) + + total, err = testEngine.OrderBy("id desc").Count(CountWithTableName{}) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) +} diff --git a/session_tx.go b/session_tx.go index c8d759a3..ee3d473f 100644 --- a/session_tx.go +++ b/session_tx.go @@ -7,7 +7,7 @@ package xorm // Begin a transaction func (session *Session) Begin() error { if session.isAutoCommit { - tx, err := session.DB().Begin() + tx, err := session.DB().BeginTx(session.ctx, nil) if err != nil { return err } diff --git a/session_update.go b/session_update.go index 42dfaacd..6bd16aaf 100644 --- a/session_update.go +++ b/session_update.go @@ -116,7 +116,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, } else { session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) if col.IsVersion && session.statement.checkVersion { - fieldValue.SetInt(fieldValue.Int() + 1) + session.incrVersionFieldValue(fieldValue) } else { fieldValue.Set(reflect.ValueOf(args[idx])) } @@ -147,6 +147,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 defer session.Close() } + if session.statement.lastError != nil { + return 0, session.statement.lastError + } + v := rValue(bean) t := v.Type() @@ -357,7 +361,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, err } else if doIncVer { if verValue != nil && verValue.IsValid() && verValue.CanSet() { - verValue.SetInt(verValue.Int() + 1) + session.incrVersionFieldValue(verValue) } } diff --git a/session_update_test.go b/session_update_test.go index 2a7005ee..480fc5fc 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -110,7 +110,7 @@ func setupForUpdate(engine EngineInterface) error { } func TestForUpdate(t *testing.T) { - if testEngine.Dialect().DriverName() != "mysql" && testEngine.Dialect().DriverName() != "mymysql" { + if *ignoreSelectUpdate { return } @@ -1331,3 +1331,21 @@ func TestUpdateCondiBean(t *testing.T) { assert.NoError(t, err) assert.True(t, has) } + +func TestWhereCondErrorWhenUpdate(t *testing.T) { + type AuthRequestError struct { + ChallengeToken string + RequestToken string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(AuthRequestError)) + + _, err := testEngine.Cols("challenge_token", "request_token", "challenge_agent", "status"). + Where(&AuthRequestError{ChallengeToken: "1"}). + Update(&AuthRequestError{ + ChallengeToken: "2", + }) + assert.Error(t, err) + assert.EqualValues(t, ErrConditionType, err) +} diff --git a/statement.go b/statement.go index 56644036..03ac107a 100644 --- a/statement.go +++ b/statement.go @@ -59,6 +59,8 @@ type Statement struct { exprColumns map[string]exprParam cond builder.Cond bufferSize int + context ContextCache + lastError error } // Init reset all the statement's fields @@ -99,6 +101,8 @@ func (statement *Statement) Init() { statement.exprColumns = make(map[string]exprParam) statement.cond = builder.NewCond() statement.bufferSize = 0 + statement.context = nil + statement.lastError = nil } // NoAutoCondition if you do not want convert bean's field as query condition, then use this function @@ -123,13 +127,13 @@ func (statement *Statement) SQL(query interface{}, args ...interface{}) *Stateme var err error statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL() if err != nil { - statement.Engine.logger.Error(err) + statement.lastError = err } case string: statement.RawSQL = query.(string) statement.RawParams = args default: - statement.Engine.logger.Error("unsupported sql type") + statement.lastError = ErrUnSupportedSQLType } return statement @@ -158,7 +162,7 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme } } default: - // TODO: not support condition type + statement.lastError = ErrConditionType } return statement @@ -753,9 +757,32 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, "%v JOIN ", joinOP) } - tbName := statement.Engine.TableName(tablename, true) + switch tp := tablename.(type) { + case builder.Builder: + subSQL, subQueryArgs, err := tp.ToSQL() + if err != nil { + statement.lastError = err + return statement + } + tbs := strings.Split(tp.TableName(), ".") + var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr()) + fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) + statement.joinArgs = append(statement.joinArgs, subQueryArgs...) + case *builder.Builder: + subSQL, subQueryArgs, err := tp.ToSQL() + if err != nil { + statement.lastError = err + return statement + } + tbs := strings.Split(tp.TableName(), ".") + var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr()) + fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) + statement.joinArgs = append(statement.joinArgs, subQueryArgs...) + default: + tbName := statement.Engine.TableName(tablename, true) + fmt.Fprintf(&buf, "%s ON %v", tbName, condition) + } - fmt.Fprintf(&buf, "%s ON %v", tbName, condition) statement.JoinStr = buf.String() statement.joinArgs = append(statement.joinArgs, args...) return statement @@ -1062,7 +1089,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n if dialect.DBType() == core.MSSQL { if statement.LimitN > 0 { - top = fmt.Sprintf(" TOP %d ", statement.LimitN) + top = fmt.Sprintf("TOP %d ", statement.LimitN) } if statement.Start > 0 { var column string diff --git a/tag_extends_test.go b/tag_extends_test.go index 4a4150ba..aec30da9 100644 --- a/tag_extends_test.go +++ b/tag_extends_test.go @@ -60,63 +60,37 @@ func TestExtends(t *testing.T) { assert.NoError(t, prepareEngine()) err := testEngine.DropTables(&tempUser2{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) err = testEngine.CreateTables(&tempUser2{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) tu := &tempUser2{tempUser{0, "extends"}, "dev depart"} _, err = testEngine.Insert(tu) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) tu2 := &tempUser2{} _, err = testEngine.Get(tu2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) tu3 := &tempUser2{tempUser{0, "extends update"}, ""} _, err = testEngine.ID(tu2.TempUser.Id).Update(tu3) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) err = testEngine.DropTables(&tempUser4{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) err = testEngine.CreateTables(&tempUser4{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) tu8 := &tempUser4{tempUser2{tempUser{0, "extends"}, "dev depart"}} _, err = testEngine.Insert(tu8) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) tu9 := &tempUser4{} _, err = testEngine.Get(tu9) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) + if tu9.TempUser2.TempUser.Username != tu8.TempUser2.TempUser.Username || tu9.TempUser2.Departname != tu8.TempUser2.Departname { err = errors.New(fmt.Sprintln("not equal for", tu8, tu9)) t.Error(err) @@ -125,36 +99,22 @@ func TestExtends(t *testing.T) { tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}} _, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) err = testEngine.DropTables(&tempUser3{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) err = testEngine.CreateTables(&tempUser3{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) tu4 := &tempUser3{&tempUser{0, "extends"}, "dev depart"} _, err = testEngine.Insert(tu4) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) tu5 := &tempUser3{} _, err = testEngine.Get(tu5) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) + if tu5.Temp == nil { err = errors.New("error get data extends") t.Error(err) @@ -169,22 +129,12 @@ func TestExtends(t *testing.T) { tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} _, err = testEngine.ID(tu5.Temp.Id).Update(tu6) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) users := make([]tempUser3, 0) err = testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - if len(users) != 1 { - err = errors.New("error get data not 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users), "error get data not 1") assertSync(t, new(Userinfo), new(Userdetail)) @@ -249,10 +199,7 @@ func TestExtends(t *testing.T) { Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). NoCascade(). Find(&infos2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(infos2) } @@ -297,25 +244,16 @@ func TestExtends2(t *testing.T) { assert.NoError(t, prepareEngine()) err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) var sender = MessageUser{Name: "sender"} var receiver = MessageUser{Name: "receiver"} var msgtype = MessageType{Name: "type"} _, err = testEngine.Insert(&sender, &receiver, &msgtype) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) msg := Message{ MessageBase: MessageBase{ @@ -326,15 +264,24 @@ func TestExtends2(t *testing.T) { Uid: sender.Id, ToUid: receiver.Id, } + + session := testEngine.NewSession() + defer session.Close() + + // MSSQL deny insert identity column excep declare as below if testEngine.Dialect().DBType() == core.MSSQL { - _, err = testEngine.Exec("SET IDENTITY_INSERT message ON") + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT message ON") assert.NoError(t, err) } + cnt, err := session.Insert(&msg) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) - _, err = testEngine.Insert(&msg) - if err != nil { - t.Error(err) - panic(err) + if testEngine.Dialect().DBType() == core.MSSQL { + err = session.Commit() + assert.NoError(t, err) } var mapper = testEngine.GetTableMapper().Obj2Table @@ -344,23 +291,14 @@ func TestExtends2(t *testing.T) { msgTableName := quote(testEngine.TableName(mapper("Message"), true)) list := make([]Message, 0) - err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). Find(&list) assert.NoError(t, err) - if len(list) != 1 { - err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) - t.Error(err) - panic(err) - } - - if list[0].Id != msg.Id { - err = errors.New(fmt.Sprintln("should message equal", list[0], msg)) - t.Error(err) - panic(err) - } + assert.EqualValues(t, 1, len(list), fmt.Sprintln("should have 1 message, got", len(list))) + assert.EqualValues(t, msg.Id, list[0].Id, fmt.Sprintln("should message equal", list[0], msg)) } func TestExtends3(t *testing.T) { @@ -396,13 +334,25 @@ func TestExtends3(t *testing.T) { Uid: sender.Id, ToUid: receiver.Id, } + + session := testEngine.NewSession() + defer session.Close() + + // MSSQL deny insert identity column excep declare as below if testEngine.Dialect().DBType() == core.MSSQL { - _, err = testEngine.Exec("SET IDENTITY_INSERT message ON") + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT message ON") assert.NoError(t, err) } - _, err = testEngine.Insert(&msg) + _, err = session.Insert(&msg) assert.NoError(t, err) + if testEngine.Dialect().DBType() == core.MSSQL { + err = session.Commit() + assert.NoError(t, err) + } + var mapper = testEngine.GetTableMapper().Obj2Table var quote = testEngine.Quote userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) @@ -410,7 +360,7 @@ func TestExtends3(t *testing.T) { msgTableName := quote(testEngine.TableName(mapper("Message"), true)) list := make([]MessageExtend3, 0) - err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). Find(&list) @@ -478,14 +428,23 @@ func TestExtends4(t *testing.T) { Content: "test", Uid: sender.Id, } + + session := testEngine.NewSession() + defer session.Close() + + // MSSQL deny insert identity column excep declare as below if testEngine.Dialect().DBType() == core.MSSQL { - _, err = testEngine.Exec("SET IDENTITY_INSERT message ON") + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT message ON") assert.NoError(t, err) } - _, err = testEngine.Insert(&msg) - if err != nil { - t.Error(err) - panic(err) + _, err = session.Insert(&msg) + assert.NoError(t, err) + + if testEngine.Dialect().DBType() == core.MSSQL { + err = session.Commit() + assert.NoError(t, err) } var mapper = testEngine.GetTableMapper().Obj2Table @@ -495,7 +454,7 @@ func TestExtends4(t *testing.T) { msgTableName := quote(testEngine.TableName(mapper("Message"), true)) list := make([]MessageExtend4, 0) - err = testEngine.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + err = session.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). Find(&list) if err != nil { diff --git a/tag_version_test.go b/tag_version_test.go index 570a6754..cd6dc935 100644 --- a/tag_version_test.go +++ b/tag_version_test.go @@ -85,7 +85,7 @@ func TestVersion1(t *testing.T) { } fmt.Println(newVer) if newVer.Ver != 2 { - err = errors.New("insert error") + err = errors.New("update error") t.Error(err) panic(err) } @@ -126,3 +126,117 @@ func TestVersion2(t *testing.T) { } } } + +type VersionUintS struct { + Id int64 + Name string + Ver uint `xorm:"version"` + Created time.Time `xorm:"created"` +} + +func TestVersion3(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(new(VersionUintS)) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(new(VersionUintS)) + if err != nil { + t.Error(err) + panic(err) + } + + ver := &VersionUintS{Name: "sfsfdsfds"} + _, err = testEngine.Insert(ver) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(ver) + if ver.Ver != 1 { + err = errors.New("insert error") + t.Error(err) + panic(err) + } + + newVer := new(VersionUintS) + has, err := testEngine.ID(ver.Id).Get(newVer) + if err != nil { + t.Error(err) + panic(err) + } + + if !has { + t.Error(errors.New(fmt.Sprintf("no version id is %v", ver.Id))) + panic(err) + } + fmt.Println(newVer) + if newVer.Ver != 1 { + err = errors.New("insert error") + t.Error(err) + panic(err) + } + + newVer.Name = "-------" + _, err = testEngine.ID(ver.Id).Update(newVer) + if err != nil { + t.Error(err) + panic(err) + } + if newVer.Ver != 2 { + err = errors.New("update should set version back to struct") + t.Error(err) + } + + newVer = new(VersionUintS) + has, err = testEngine.ID(ver.Id).Get(newVer) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(newVer) + if newVer.Ver != 2 { + err = errors.New("update error") + t.Error(err) + panic(err) + } +} + +func TestVersion4(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(new(VersionUintS)) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(new(VersionUintS)) + if err != nil { + t.Error(err) + panic(err) + } + + var vers = []VersionUintS{ + {Name: "sfsfdsfds"}, + {Name: "xxxxx"}, + } + _, err = testEngine.Insert(vers) + if err != nil { + t.Error(err) + panic(err) + } + + fmt.Println(vers) + + for _, v := range vers { + if v.Ver != 1 { + err := errors.New("version should be 1") + t.Error(err) + panic(err) + } + } +} diff --git a/test_mssql.sh b/test_mssql.sh index 6f9cf729..7f060cff 100755 --- a/test_mssql.sh +++ b/test_mssql.sh @@ -1 +1 @@ -go test -db=mssql -conn_str="server=192.168.1.58;user id=sa;password=123456;database=xorm_test" \ No newline at end of file +go test -db=mssql -conn_str="server=localhost;user id=sa;password=yourStrong(!)Password;database=xorm_test" \ No newline at end of file diff --git a/test_tidb.sh b/test_tidb.sh new file mode 100755 index 00000000..03d2d6cd --- /dev/null +++ b/test_tidb.sh @@ -0,0 +1 @@ +go test -db=mysql -conn_str="root:@tcp(localhost:4000)/xorm_test" -ignore_select_update=true \ No newline at end of file diff --git a/types.go b/types.go index 99d761c2..25c007d7 100644 --- a/types.go +++ b/types.go @@ -1,3 +1,7 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import ( diff --git a/types_test.go b/types_test.go index 20511407..b863671d 100644 --- a/types_test.go +++ b/types_test.go @@ -309,16 +309,24 @@ func TestCustomType2(t *testing.T) { _, err = testEngine.Exec("delete from " + testEngine.Quote(tableName)) assert.NoError(t, err) + session := testEngine.NewSession() + defer session.Close() + if testEngine.Dialect().DBType() == core.MSSQL { - return - /*_, err = engine.Exec("set IDENTITY_INSERT " + tableName + " on") - if err != nil { - t.Fatal(err) - }*/ + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("set IDENTITY_INSERT " + tableName + " on") + assert.NoError(t, err) } - _, err = testEngine.Insert(&UserCus{1, "xlw", Registed}) + cnt, err := session.Insert(&UserCus{1, "xlw", Registed}) assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + if testEngine.Dialect().DBType() == core.MSSQL { + err = session.Commit() + assert.NoError(t, err) + } user := UserCus{} exist, err := testEngine.ID(1).Get(&user) diff --git a/xorm.go b/xorm.go index 739de8d4..157c9d34 100644 --- a/xorm.go +++ b/xorm.go @@ -7,6 +7,7 @@ package xorm import ( + "context" "fmt" "os" "reflect" @@ -85,14 +86,15 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { } engine := &Engine{ - db: db, - dialect: dialect, - Tables: make(map[reflect.Type]*core.Table), - mutex: &sync.RWMutex{}, - TagIdentifier: "xorm", - TZLocation: time.Local, - tagHandlers: defaultTagHandlers, - cachers: make(map[string]core.Cacher), + db: db, + dialect: dialect, + Tables: make(map[reflect.Type]*core.Table), + mutex: &sync.RWMutex{}, + TagIdentifier: "xorm", + TZLocation: time.Local, + tagHandlers: defaultTagHandlers, + cachers: make(map[string]core.Cacher), + defaultContext: context.Background(), } if uri.DbType == core.SQLITE { diff --git a/xorm_test.go b/xorm_test.go index 4e88dc40..a35f0743 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -1,8 +1,14 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import ( + "database/sql" "flag" "fmt" + "log" "os" "strings" "testing" @@ -20,14 +26,15 @@ var ( dbType string connString string - db = flag.String("db", "sqlite3", "the tested database") - showSQL = flag.Bool("show_sql", true, "show generated SQLs") - ptrConnStr = flag.String("conn_str", "./test.db?cache=shared&mode=rwc", "test database connection string") - mapType = flag.String("map_type", "snake", "indicate the name mapping") - cache = flag.Bool("cache", false, "if enable cache") - cluster = flag.Bool("cluster", false, "if this is a cluster") - splitter = flag.String("splitter", ";", "the splitter on connstr for cluster") - schema = flag.String("schema", "", "specify the schema") + db = flag.String("db", "sqlite3", "the tested database") + showSQL = flag.Bool("show_sql", true, "show generated SQLs") + ptrConnStr = flag.String("conn_str", "./test.db?cache=shared&mode=rwc", "test database connection string") + mapType = flag.String("map_type", "snake", "indicate the name mapping") + cache = flag.Bool("cache", false, "if enable cache") + cluster = flag.Bool("cluster", false, "if this is a cluster") + splitter = flag.String("splitter", ";", "the splitter on connstr for cluster") + schema = flag.String("schema", "", "specify the schema") + ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb") ) func createEngine(dbType, connStr string) error { @@ -35,9 +42,59 @@ func createEngine(dbType, connStr string) error { var err error if !*cluster { + switch strings.ToLower(dbType) { + case core.MSSQL: + db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1)) + if err != nil { + return err + } + if _, err = db.Exec("If(db_id(N'xorm_test') IS NULL) BEGIN CREATE DATABASE xorm_test; END;"); err != nil { + return fmt.Errorf("db.Exec: %v", err) + } + db.Close() + *ignoreSelectUpdate = true + case core.POSTGRES: + db, err := sql.Open(dbType, connStr) + if err != nil { + return err + } + rows, err := db.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = 'xorm_test'")) + if err != nil { + return fmt.Errorf("db.Query: %v", err) + } + defer rows.Close() + + if !rows.Next() { + if _, err = db.Exec("CREATE DATABASE xorm_test"); err != nil { + return fmt.Errorf("CREATE DATABASE: %v", err) + } + } + if *schema != "" { + if _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS " + *schema); err != nil { + return fmt.Errorf("CREATE SCHEMA: %v", err) + } + } + db.Close() + *ignoreSelectUpdate = true + case core.MYSQL: + db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "mysql", -1)) + if err != nil { + return err + } + if _, err = db.Exec("CREATE DATABASE IF NOT EXISTS xorm_test"); err != nil { + return fmt.Errorf("db.Exec: %v", err) + } + db.Close() + default: + *ignoreSelectUpdate = true + } + testEngine, err = NewEngine(dbType, connStr) } else { testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter)) + if dbType != "mysql" && dbType != "mymysql" { + *ignoreSelectUpdate = true + } } if err != nil { return err @@ -95,7 +152,7 @@ func TestMain(m *testing.M) { } } else { if ptrConnStr == nil { - fmt.Println("you should indicate conn string") + log.Fatal("you should indicate conn string") return } connString = *ptrConnStr @@ -112,7 +169,7 @@ func TestMain(m *testing.M) { fmt.Println("testing", dbType, connString) if err := prepareEngine(); err != nil { - fmt.Println(err) + log.Fatal(err) return }