Merge branch 'master' into lunny/mssql_test

This commit is contained in:
Lunny Xiao 2019-07-30 15:15:41 +08:00 committed by GitHub
commit 92bbcdb5e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
75 changed files with 3676 additions and 2211 deletions

61
.circleci/config.yml Normal file
View File

@ -0,0 +1,61 @@
# Golang CircleCI 2.0 configuration file
#
# Check https://circleci.com/docs/2.0/language-go/ for more details
version: 2
jobs:
build:
docker:
# specify the version
- image: circleci/golang:1.10
- image: circleci/mysql:5.7
environment:
MYSQL_ALLOW_EMPTY_PASSWORD: true
MYSQL_DATABASE: xorm_test
MYSQL_HOST: 127.0.0.1
MYSQL_ROOT_HOST: '%'
MYSQL_USER: root
# CircleCI PostgreSQL images available at: https://hub.docker.com/r/circleci/postgres/
- image: circleci/postgres:9.6.2-alpine
environment:
POSTGRES_USER: circleci
POSTGRES_DB: xorm_test
- image: microsoft/mssql-server-linux:latest
environment:
ACCEPT_EULA: Y
SA_PASSWORD: yourStrong(!)Password
MSSQL_PID: Developer
- image: pingcap/tidb:v2.1.2
working_directory: /go/src/github.com/go-xorm/xorm
steps:
- checkout
- run: go get -t -d -v ./...
- run: go get -u xorm.io/core
- run: go get -u xorm.io/builder
- run: GO111MODULE=off go build -v
- run: GO111MODULE=on go build -v
- run: go get -u github.com/wadey/gocovmerge
- run: go test -v -race -db="sqlite3" -conn_str="./test.db" -coverprofile=coverage1-1.txt -covermode=atomic
- run: go test -v -race -db="sqlite3" -conn_str="./test.db" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic
- run: go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -coverprofile=coverage2-1.txt -covermode=atomic
- run: go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -cache=true -coverprofile=coverage2-2.txt -covermode=atomic
- run: go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -coverprofile=coverage3-1.txt -covermode=atomic
- run: go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic
- run: go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic
- run: go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic
- run: go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic
- run: go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic
- run: go test -v -race -db="mssql" -conn_str="server=localhost;user id=sa;password=yourStrong(!)Password;database=xorm_test" -coverprofile=coverage6-1.txt -covermode=atomic
- run: go test -v -race -db="mssql" -conn_str="server=localhost;user id=sa;password=yourStrong(!)Password;database=xorm_test" -cache=true -coverprofile=coverage6-2.txt -covermode=atomic
- run: go test -v -race -db="mysql" -conn_str="root:@tcp(localhost:4000)/xorm_test" -ignore_select_update=true -coverprofile=coverage7-1.txt -covermode=atomic
- run: go test -v -race -db="mysql" -conn_str="root:@tcp(localhost:4000)/xorm_test" -ignore_select_update=true -cache=true -coverprofile=coverage7-2.txt -covermode=atomic
- run: gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt coverage6-1.txt coverage6-2.txt coverage7-1.txt coverage7-2.txt > coverage.txt
- run: bash <(curl -s https://codecov.io/bash)

View File

@ -65,8 +65,8 @@ pipeline:
image: golang:${GO_VERSION} image: golang:${GO_VERSION}
commands: commands:
- go get -t -d -v ./... - go get -t -d -v ./...
- go get -u github.com/go-xorm/core - go get -u xorm.io/core
- go get -u github.com/go-xorm/builder - go get -u xorm.io/builder
- go build -v - go build -v
when: when:
event: [ push, pull_request ] event: [ push, pull_request ]

View File

@ -28,7 +28,7 @@ Xorm is a simple and powerful ORM for Go.
* Optimistic Locking support * Optimistic Locking support
* SQL Builder support via [github.com/go-xorm/builder](https://github.com/go-xorm/builder) * SQL Builder support via [xorm.io/builder](https://xorm.io/builder)
* Automatical Read/Write seperatelly * Automatical Read/Write seperatelly
@ -151,20 +151,20 @@ has, err := engine.Where("name = ?", name).Desc("id").Get(&user)
// SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1 // SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1
var name string var name string
has, err := engine.Where("id = ?", id).Cols("name").Get(&name) has, err := engine.Table(&user).Where("id = ?", id).Cols("name").Get(&name)
// SELECT name FROM user WHERE id = ? // SELECT name FROM user WHERE id = ?
var id int64 var id int64
has, err := engine.Where("name = ?", name).Cols("id").Get(&id) has, err := engine.Table(&user).Where("name = ?", name).Cols("id").Get(&id)
has, err := engine.SQL("select id from user").Get(&id) has, err := engine.SQL("select id from user").Get(&id)
// SELECT id FROM user WHERE name = ? // SELECT id FROM user WHERE name = ?
var valuesMap = make(map[string]string) var valuesMap = make(map[string]string)
has, err := engine.Where("id = ?", id).Get(&valuesMap) has, err := engine.Table(&user).Where("id = ?", id).Get(&valuesMap)
// SELECT * FROM user WHERE id = ? // SELECT * FROM user WHERE id = ?
var valuesSlice = make([]interface{}, len(cols)) var valuesSlice = make([]interface{}, len(cols))
has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice) has, err := engine.Table(&user).Where("id = ?", id).Cols(cols...).Get(&valuesSlice)
// SELECT col1, col2, col3 FROM user WHERE id = ? // SELECT col1, col2, col3 FROM user WHERE id = ?
``` ```
@ -284,6 +284,13 @@ counts, err := engine.Count(&user)
// SELECT count(*) AS total FROM user // SELECT count(*) AS total FROM user
``` ```
* `FindAndCount` combines function `Find` with `Count` which is usually used in query by page
```Go
var users []User
counts, err := engine.FindAndCount(&users)
```
* `Sum` sum functions * `Sum` sum functions
```Go ```Go
@ -363,7 +370,7 @@ return session.Commit()
* Or you can use `Transaction` to replace above codes. * Or you can use `Transaction` to replace above codes.
```Go ```Go
res, err := engine.Transaction(func(sess *xorm.Session) (interface{}, error) { res, err := engine.Transaction(func(session *xorm.Session) (interface{}, error) {
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil { if _, err := session.Insert(&user1); err != nil {
return nil, err return nil, err

View File

@ -153,20 +153,20 @@ has, err := engine.Where("name = ?", name).Desc("id").Get(&user)
// SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1 // SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1
var name string var name string
has, err := engine.Where("id = ?", id).Cols("name").Get(&name) has, err := engine.Table(&user).Where("id = ?", id).Cols("name").Get(&name)
// SELECT name FROM user WHERE id = ? // SELECT name FROM user WHERE id = ?
var id int64 var id int64
has, err := engine.Where("name = ?", name).Cols("id").Get(&id) has, err := engine.Table(&user).Where("name = ?", name).Cols("id").Get(&id)
has, err := engine.SQL("select id from user").Get(&id) has, err := engine.SQL("select id from user").Get(&id)
// SELECT id FROM user WHERE name = ? // SELECT id FROM user WHERE name = ?
var valuesMap = make(map[string]string) var valuesMap = make(map[string]string)
has, err := engine.Where("id = ?", id).Get(&valuesMap) has, err := engine.Table(&user).Where("id = ?", id).Get(&valuesMap)
// SELECT * FROM user WHERE id = ? // SELECT * FROM user WHERE id = ?
var valuesSlice = make([]interface{}, len(cols)) var valuesSlice = make([]interface{}, len(cols))
has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice) has, err := engine.Table(&user).Where("id = ?", id).Cols(cols...).Get(&valuesSlice)
// SELECT col1, col2, col3 FROM user WHERE id = ? // SELECT col1, col2, col3 FROM user WHERE id = ?
``` ```
@ -362,10 +362,10 @@ if _, err := session.Exec("delete from userinfo where username = ?", user2.Usern
return session.Commit() return session.Commit()
``` ```
* 事的简写方法 * 事的简写方法
```Go ```Go
res, err := engine.Transaction(func(sess *xorm.Session) (interface{}, error) { res, err := engine.Transaction(func(session *xorm.Session) (interface{}, error) {
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
if _, err := session.Insert(&user1); err != nil { if _, err := session.Insert(&user1); err != nil {
return nil, err return nil, err
@ -383,7 +383,7 @@ res, err := engine.Transaction(func(sess *xorm.Session) (interface{}, error) {
}) })
``` ```
* Context Cache, if enabled, current query result will be cached on session and be used by next same statement on the same session. * 上下文缓存,如果启用,那么针对单个对象的查询将会被缓存到系统中,可以被下一个查询使用。
```Go ```Go
sess := engine.NewSession() sess := engine.NewSession()

View File

@ -10,7 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
) )
// LRUCacher implments cache object facilities // LRUCacher implments cache object facilities

View File

@ -7,7 +7,7 @@ package xorm
import ( import (
"testing" "testing"
"github.com/go-xorm/core" "xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )

View File

@ -7,7 +7,7 @@ package xorm
import ( import (
"sync" "sync"
"github.com/go-xorm/core" "xorm.io/core"
) )
var _ core.CacheStore = NewMemoryStore() var _ core.CacheStore = NewMemoryStore()

View File

@ -1,41 +0,0 @@
dependencies:
override:
# './...' is a relative pattern which means all subdirectories
- go get -t -d -v ./...
- go get -t -d -v github.com/go-xorm/tests
- go get -u github.com/go-xorm/core
- go get -u github.com/go-xorm/builder
- go build -v
database:
override:
- mysql -u root -e "CREATE DATABASE xorm_test DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci"
- mysql -u root -e "CREATE DATABASE xorm_test1 DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci"
- mysql -u root -e "CREATE DATABASE xorm_test2 DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci"
- mysql -u root -e "CREATE DATABASE xorm_test3 DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci"
- createdb -p 5432 -e -U postgres xorm_test
- createdb -p 5432 -e -U postgres xorm_test1
- createdb -p 5432 -e -U postgres xorm_test2
- createdb -p 5432 -e -U postgres xorm_test3
- psql xorm_test postgres -c "create schema xorm"
test:
override:
# './...' is a relative pattern which means all subdirectories
- go get -u github.com/wadey/gocovmerge
- go test -v -race -db="sqlite3" -conn_str="./test.db" -coverprofile=coverage1-1.txt -covermode=atomic
- go test -v -race -db="sqlite3" -conn_str="./test.db" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic
- go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -coverprofile=coverage2-1.txt -covermode=atomic
- go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -cache=true -coverprofile=coverage2-2.txt -covermode=atomic
- go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -coverprofile=coverage3-1.txt -covermode=atomic
- go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic
- gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt > coverage.txt
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh
post:
- bash <(curl -s https://codecov.io/bash)

View File

@ -7,10 +7,11 @@ package xorm
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/url"
"strconv" "strconv"
"strings" "strings"
"github.com/go-xorm/core" "xorm.io/core"
) )
var ( var (
@ -218,7 +219,7 @@ func (db *mssql) SqlType(c *core.Column) string {
res = core.Bit res = core.Bit
if strings.EqualFold(c.Default, "true") { if strings.EqualFold(c.Default, "true") {
c.Default = "1" c.Default = "1"
} else { } else if strings.EqualFold(c.Default, "false") {
c.Default = "0" c.Default = "0"
} }
case core.Serial: case core.Serial:
@ -285,10 +286,6 @@ func (db *mssql) Quote(name string) string {
return "\"" + name + "\"" return "\"" + name + "\""
} }
func (db *mssql) QuoteStr() string {
return "\""
}
func (db *mssql) SupportEngine() bool { func (db *mssql) SupportEngine() bool {
return false return false
} }
@ -506,7 +503,7 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars
sql = "IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = '" + tableName + "' ) CREATE TABLE " sql = "IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = '" + tableName + "' ) CREATE TABLE "
sql += db.QuoteStr() + tableName + db.QuoteStr() + " (" sql += db.Quote(tableName) + " ("
pkList := table.PrimaryKeys pkList := table.PrimaryKeys
@ -544,8 +541,16 @@ type odbcDriver struct {
} }
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
kv := strings.Split(dataSourceName, ";")
var dbName string var dbName string
if strings.HasPrefix(dataSourceName, "sqlserver://") {
u, err := url.Parse(dataSourceName)
if err != nil {
return nil, err
}
dbName = u.Query().Get("database")
} else {
kv := strings.Split(dataSourceName, ";")
for _, c := range kv { for _, c := range kv {
vv := strings.Split(strings.TrimSpace(c), "=") vv := strings.Split(strings.TrimSpace(c), "=")
if len(vv) == 2 { if len(vv) == 2 {
@ -555,6 +560,7 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error)
} }
} }
} }
}
if dbName == "" { if dbName == "" {
return nil, errors.New("no db name provided") return nil, errors.New("no db name provided")
} }

35
dialect_mssql_test.go Normal file
View File

@ -0,0 +1,35 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"reflect"
"testing"
"xorm.io/core"
)
func TestParseMSSQL(t *testing.T) {
tests := []struct {
in string
expected string
valid bool
}{
{"sqlserver://sa:yourStrong(!)Password@localhost:1433?database=db&connection+timeout=30", "db", true},
{"server=localhost;user id=sa;password=yourStrong(!)Password;database=db", "db", true},
}
driver := core.QueryDriver("mssql")
for _, test := range tests {
uri, err := driver.Parse("mssql", test.in)
if err != nil && test.valid {
t.Errorf("%q got unexpected error: %s", test.in, err)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected)
}
}
}

View File

@ -13,7 +13,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
) )
var ( var (
@ -278,10 +278,6 @@ func (db *mysql) Quote(name string) string {
return "`" + name + "`" return "`" + name + "`"
} }
func (db *mysql) QuoteStr() string {
return "`"
}
func (db *mysql) SupportEngine() bool { func (db *mysql) SupportEngine() bool {
return true return true
} }
@ -393,6 +389,9 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
if colType == "FLOAT UNSIGNED" { if colType == "FLOAT UNSIGNED" {
colType = "FLOAT" colType = "FLOAT"
} }
if colType == "DOUBLE UNSIGNED" {
colType = "DOUBLE"
}
col.Length = len1 col.Length = len1
col.Length2 = len2 col.Length2 = len2
if _, ok := core.SqlTypes[colType]; ok { if _, ok := core.SqlTypes[colType]; ok {
@ -556,8 +555,6 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars
sql += " DEFAULT CHARSET " + charset sql += " DEFAULT CHARSET " + charset
} }
if db.rowFormat != "" { if db.rowFormat != "" {
sql += " ROW_FORMAT=" + db.rowFormat sql += " ROW_FORMAT=" + db.rowFormat
} }

View File

@ -11,7 +11,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/go-xorm/core" "xorm.io/core"
) )
var ( var (
@ -552,11 +552,7 @@ func (db *oracle) IsReserved(name string) bool {
} }
func (db *oracle) Quote(name string) string { func (db *oracle) Quote(name string) string {
return "\"" + name + "\"" return "[" + name + "]"
}
func (db *oracle) QuoteStr() string {
return "\""
} }
func (db *oracle) SupportEngine() bool { func (db *oracle) SupportEngine() bool {

View File

@ -11,7 +11,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/go-xorm/core" "xorm.io/core"
) )
// from http://www.postgresql.org/docs/current/static/sql-keywords-appendix.html // from http://www.postgresql.org/docs/current/static/sql-keywords-appendix.html
@ -822,7 +822,7 @@ func (db *postgres) SqlType(c *core.Column) string {
case core.NVarchar: case core.NVarchar:
res = core.Varchar res = core.Varchar
case core.Uuid: case core.Uuid:
res = core.Uuid return core.Uuid
case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob:
return core.Bytea return core.Bytea
case core.Double: case core.Double:
@ -834,6 +834,10 @@ func (db *postgres) SqlType(c *core.Column) string {
res = t res = t
} }
if strings.EqualFold(res, "bool") {
// for bool, we don't need length information
return res
}
hasLen1 := (c.Length > 0) hasLen1 := (c.Length > 0)
hasLen2 := (c.Length2 > 0) hasLen2 := (c.Length2 > 0)
@ -859,10 +863,6 @@ func (db *postgres) Quote(name string) string {
return "\"" + name + "\"" return "\"" + name + "\""
} }
func (db *postgres) QuoteStr() string {
return "\""
}
func (db *postgres) AutoIncrStr() string { func (db *postgres) AutoIncrStr() string {
return "" return ""
} }
@ -1089,6 +1089,17 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
return tables, nil return tables, nil
} }
func getIndexColName(indexdef string) []string {
var colNames []string
cs := strings.Split(indexdef, "(")
for _, v := range strings.Split(strings.Split(cs[1], ")")[0], ",") {
colNames = append(colNames, strings.Split(strings.TrimLeft(v, " "), " ")[0])
}
return colNames
}
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
@ -1122,8 +1133,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error)
} else { } else {
indexType = core.IndexType indexType = core.IndexType
} }
cs := strings.Split(indexdef, "(") colNames = getIndexColName(indexdef)
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
var isRegular bool var isRegular bool
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
newIdxName := indexName[5+len(tableName):] newIdxName := indexName[5+len(tableName):]

View File

@ -4,8 +4,9 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/go-xorm/core" "xorm.io/core"
"github.com/jackc/pgx/stdlib" "github.com/jackc/pgx/stdlib"
"github.com/stretchr/testify/assert"
) )
func TestParsePostgres(t *testing.T) { func TestParsePostgres(t *testing.T) {
@ -84,3 +85,37 @@ func TestParsePgx(t *testing.T) {
} }
} }
func TestGetIndexColName(t *testing.T) {
t.Run("Index", func(t *testing.T) {
s := "CREATE INDEX test2_mm_idx ON test2 (major);"
colNames := getIndexColName(s)
assert.Equal(t, []string{"major"}, colNames)
})
t.Run("Multicolumn indexes", func(t *testing.T) {
s := "CREATE INDEX test2_mm_idx ON test2 (major, minor);"
colNames := getIndexColName(s)
assert.Equal(t, []string{"major", "minor"}, colNames)
})
t.Run("Indexes and ORDER BY", func(t *testing.T) {
s := "CREATE INDEX test2_mm_idx ON test2 (major NULLS FIRST, minor DESC NULLS LAST);"
colNames := getIndexColName(s)
assert.Equal(t, []string{"major", "minor"}, colNames)
})
t.Run("Combining Multiple Indexes", func(t *testing.T) {
s := "CREATE INDEX test2_mm_cm_idx ON public.test2 USING btree (major, minor) WHERE ((major <> 5) AND (minor <> 6))"
colNames := getIndexColName(s)
assert.Equal(t, []string{"major", "minor"}, colNames)
})
t.Run("unique", func(t *testing.T) {
s := "CREATE UNIQUE INDEX test2_mm_uidx ON test2 (major);"
colNames := getIndexColName(s)
assert.Equal(t, []string{"major"}, colNames)
})
t.Run("Indexes on Expressions", func(t *testing.T) {})
}

View File

@ -11,7 +11,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/go-xorm/core" "xorm.io/core"
) )
var ( var (
@ -202,10 +202,6 @@ func (db *sqlite3) Quote(name string) string {
return "`" + name + "`" return "`" + name + "`"
} }
func (db *sqlite3) QuoteStr() string {
return "`"
}
func (db *sqlite3) AutoIncrStr() string { func (db *sqlite3) AutoIncrStr() string {
return "AUTOINCREMENT" return "AUTOINCREMENT"
} }

View File

@ -7,6 +7,7 @@ package xorm
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"database/sql" "database/sql"
"encoding/gob" "encoding/gob"
"errors" "errors"
@ -19,8 +20,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/go-xorm/builder" "xorm.io/builder"
"github.com/go-xorm/core" "xorm.io/core"
) )
// Engine is the major struct of xorm, it means a database manager. // Engine is the major struct of xorm, it means a database manager.
@ -52,6 +53,8 @@ type Engine struct {
cachers map[string]core.Cacher cachers map[string]core.Cacher
cacherLock sync.RWMutex cacherLock sync.RWMutex
defaultContext context.Context
} }
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
@ -122,6 +125,7 @@ func (engine *Engine) Logger() core.ILogger {
// SetLogger set the new logger // SetLogger set the new logger
func (engine *Engine) SetLogger(logger core.ILogger) { func (engine *Engine) SetLogger(logger core.ILogger) {
engine.logger = logger engine.logger = logger
engine.showSQL = logger.IsShowSQL()
engine.dialect.SetLogger(logger) engine.dialect.SetLogger(logger)
} }
@ -171,12 +175,6 @@ func (engine *Engine) SupportInsertMany() bool {
return engine.dialect.SupportInsertMany() return engine.dialect.SupportInsertMany()
} }
// QuoteStr Engine's database use which character as quote.
// mysql, sqlite use ` and postgres use "
func (engine *Engine) QuoteStr() string {
return engine.dialect.QuoteStr()
}
func (engine *Engine) quoteColumns(columnStr string) string { func (engine *Engine) quoteColumns(columnStr string) string {
columns := strings.Split(columnStr, ",") columns := strings.Split(columnStr, ",")
for i := 0; i < len(columns); i++ { for i := 0; i < len(columns); i++ {
@ -192,13 +190,10 @@ func (engine *Engine) Quote(value string) string {
return value return value
} }
if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' { buf := builder.StringBuilder{}
return value engine.QuoteTo(&buf, value)
}
value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1) return buf.String()
return engine.dialect.QuoteStr() + value + engine.dialect.QuoteStr()
} }
// QuoteTo quotes string and writes into the buffer // QuoteTo quotes string and writes into the buffer
@ -212,20 +207,30 @@ func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) {
return return
} }
if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' { quotePair := engine.dialect.Quote("")
buf.WriteString(value)
if value[0] == '`' || len(quotePair) < 2 || value[0] == quotePair[0] { // no quote
_, _ = buf.WriteString(value)
return return
} else {
prefix, suffix := quotePair[0], quotePair[1]
_ = buf.WriteByte(prefix)
for i := 0; i < len(value); i++ {
if value[i] == '.' {
_ = buf.WriteByte(suffix)
_ = buf.WriteByte('.')
_ = buf.WriteByte(prefix)
} else {
_ = buf.WriteByte(value[i])
}
}
_ = buf.WriteByte(suffix)
} }
value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1)
buf.WriteString(engine.dialect.QuoteStr())
buf.WriteString(value)
buf.WriteString(engine.dialect.QuoteStr())
} }
func (engine *Engine) quote(sql string) string { func (engine *Engine) quote(sql string) string {
return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr() return engine.dialect.Quote(sql)
} }
// SqlType will be deprecated, please use SQLType instead // SqlType will be deprecated, please use SQLType instead
@ -481,7 +486,8 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
} }
cols := table.ColumnsSeq() cols := table.ColumnsSeq()
colNames := dialect.Quote(strings.Join(cols, dialect.Quote(", "))) colNames := engine.dialect.Quote(strings.Join(cols, engine.dialect.Quote(", ")))
destColNames := dialect.Quote(strings.Join(cols, dialect.Quote(", ")))
rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.Quote(table.Name)) rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.Quote(table.Name))
if err != nil { if err != nil {
@ -496,7 +502,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
return err return err
} }
_, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(table.Name)+" ("+colNames+") VALUES (") _, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(table.Name)+" ("+destColNames+") VALUES (")
if err != nil { if err != nil {
return err return err
} }
@ -526,7 +532,11 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
} else if col.SQLType.IsNumeric() { } else if col.SQLType.IsNumeric() {
switch reflect.TypeOf(d).Kind() { switch reflect.TypeOf(d).Kind() {
case reflect.Slice: case reflect.Slice:
if col.SQLType.Name == core.Bool {
temp += fmt.Sprintf(", %v", strconv.FormatBool(d.([]byte)[0] != byte('0')))
} else {
temp += fmt.Sprintf(", %s", string(d.([]byte))) temp += fmt.Sprintf(", %s", string(d.([]byte)))
}
case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int:
if col.SQLType.Name == core.Bool { if col.SQLType.Name == core.Bool {
temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Int() > 0)) temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Int() > 0))
@ -563,7 +573,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
// FIXME: Hack for postgres // FIXME: Hack for postgres
if string(dialect.DBType()) == core.POSTGRES && table.AutoIncrColumn() != nil { if string(dialect.DBType()) == core.POSTGRES && table.AutoIncrColumn() != nil {
_, err = io.WriteString(w, "SELECT setval('table_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") FROM "+dialect.Quote(table.Name)+"), 1), false);\n") _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quote(table.Name)+"), 1), false);\n")
if err != nil { if err != nil {
return err return err
} }
@ -914,7 +924,16 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
engine: engine, engine: engine,
} }
if strings.ToUpper(tags[0]) == "EXTENDS" { if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") {
pStart := strings.Index(tags[0], "(")
if pStart > -1 && strings.HasSuffix(tags[0], ")") {
var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool {
return r == '\'' || r == '"'
})
ctx.params = []string{tagPrefix}
}
if err := ExtendsTagHandler(&ctx); err != nil { if err := ExtendsTagHandler(&ctx); err != nil {
return nil, err return nil, err
} }
@ -1346,31 +1365,31 @@ func (engine *Engine) DropIndexes(bean interface{}) error {
} }
// Exec raw sql // Exec raw sql
func (engine *Engine) Exec(sqlorArgs ...interface{}) (sql.Result, error) { func (engine *Engine) Exec(sqlOrArgs ...interface{}) (sql.Result, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
return session.Exec(sqlorArgs...) return session.Exec(sqlOrArgs...)
} }
// Query a raw sql and return records as []map[string][]byte // Query a raw sql and return records as []map[string][]byte
func (engine *Engine) Query(sqlorArgs ...interface{}) (resultsSlice []map[string][]byte, err error) { func (engine *Engine) Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
return session.Query(sqlorArgs...) return session.Query(sqlOrArgs...)
} }
// QueryString runs a raw sql and return records as []map[string]string // QueryString runs a raw sql and return records as []map[string]string
func (engine *Engine) QueryString(sqlorArgs ...interface{}) ([]map[string]string, error) { func (engine *Engine) QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
return session.QueryString(sqlorArgs...) return session.QueryString(sqlOrArgs...)
} }
// QueryInterface runs a raw sql and return records as []map[string]interface{} // QueryInterface runs a raw sql and return records as []map[string]interface{}
func (engine *Engine) QueryInterface(sqlorArgs ...interface{}) ([]map[string]interface{}, error) { func (engine *Engine) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
return session.QueryInterface(sqlorArgs...) return session.QueryInterface(sqlOrArgs...)
} }
// Insert one or more records // Insert one or more records

View File

@ -6,14 +6,13 @@ package xorm
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"time" "time"
"github.com/go-xorm/builder" "xorm.io/builder"
"github.com/go-xorm/core" "xorm.io/core"
) )
func (engine *Engine) buildConds(table *core.Table, bean interface{}, func (engine *Engine) buildConds(table *core.Table, bean interface{},
@ -147,7 +146,7 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{},
} else { } else {
if col.SQLType.IsJson() { if col.SQLType.IsJson() {
if col.SQLType.IsText() { if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
engine.logger.Error(err) engine.logger.Error(err)
continue continue
@ -156,7 +155,7 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{},
} else if col.SQLType.IsBlob() { } else if col.SQLType.IsBlob() {
var bytes []byte var bytes []byte
var err error var err error
bytes, err = json.Marshal(fieldValue.Interface()) bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
engine.logger.Error(err) engine.logger.Error(err)
continue continue
@ -195,7 +194,7 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{},
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
engine.logger.Error(err) engine.logger.Error(err)
continue continue
@ -212,7 +211,7 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{},
continue continue
} }
} else { } else {
bytes, err = json.Marshal(fieldValue.Interface()) bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
engine.logger.Error(err) engine.logger.Error(err)
continue continue

28
engine_context.go Normal file
View File

@ -0,0 +1,28 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package xorm
import "context"
// Context creates a session with the context
func (engine *Engine) Context(ctx context.Context) *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.Context(ctx)
}
// SetDefaultContext set the default context
func (engine *Engine) SetDefaultContext(ctx context.Context) {
engine.defaultContext = ctx
}
// PingContext tests if database is alive
func (engine *Engine) PingContext(ctx context.Context) error {
session := engine.NewSession()
defer session.Close()
return session.PingContext(ctx)
}

View File

@ -17,9 +17,12 @@ import (
func TestPingContext(t *testing.T) { func TestPingContext(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
ctx, canceled := context.WithTimeout(context.Background(), 10*time.Second) ctx, canceled := context.WithTimeout(context.Background(), time.Nanosecond)
defer canceled() defer canceled()
time.Sleep(time.Nanosecond)
err := testEngine.(*Engine).PingContext(ctx) err := testEngine.(*Engine).PingContext(ctx)
assert.NoError(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "context deadline exceeded")
} }

View File

@ -5,9 +5,10 @@
package xorm package xorm
import ( import (
"context"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
) )
// EngineGroup defines an engine group // EngineGroup defines an engine group
@ -74,6 +75,20 @@ func (eg *EngineGroup) Close() error {
return nil return nil
} }
// Context returned a group session
func (eg *EngineGroup) Context(ctx context.Context) *Session {
sess := eg.NewSession()
sess.isAutoClose = true
return sess.Context(ctx)
}
// NewSession returned a group session
func (eg *EngineGroup) NewSession() *Session {
sess := eg.Engine.NewSession()
sess.sessionType = groupSession
return sess
}
// Master returns the master engine // Master returns the master engine
func (eg *EngineGroup) Master() *Engine { func (eg *EngineGroup) Master() *Engine {
return eg.Engine return eg.Engine

View File

@ -9,10 +9,10 @@ import (
"reflect" "reflect"
"strings" "strings"
"github.com/go-xorm/core" "xorm.io/core"
) )
// TableNameWithSchema will automatically add schema prefix on table name // tbNameWithSchema will automatically add schema prefix on table name
func (engine *Engine) tbNameWithSchema(v string) string { func (engine *Engine) tbNameWithSchema(v string) string {
// Add schema name as prefix of table name. // Add schema name as prefix of table name.
// Only for postgres database. // Only for postgres database.

View File

@ -26,6 +26,8 @@ var (
ErrNotImplemented = errors.New("Not implemented") ErrNotImplemented = errors.New("Not implemented")
// ErrConditionType condition type unsupported // ErrConditionType condition type unsupported
ErrConditionType = errors.New("Unsupported condition type") ErrConditionType = errors.New("Unsupported condition type")
// ErrUnSupportedSQLType parameter of SQL is not supported
ErrUnSupportedSQLType = errors.New("unsupported sql type")
) )
// ErrFieldIsNotExist columns does not exist // ErrFieldIsNotExist columns does not exist

19
go.mod
View File

@ -1,6 +1,19 @@
module "github.com/go-xorm/xorm" module github.com/go-xorm/xorm
require ( require (
"github.com/go-xorm/builder" v0.0.0-20180322150003-a9b7ffcca3f0 github.com/cockroachdb/apd v1.1.0 // indirect
"github.com/go-xorm/core" v0.0.0-20180322150003-0177c08cee88 github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4
github.com/go-sql-driver/mysql v1.4.1
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 // indirect
github.com/jackc/pgx v3.3.0+incompatible
github.com/kr/pretty v0.1.0 // indirect
github.com/lib/pq v1.0.0
github.com/mattn/go-sqlite3 v1.10.0
github.com/pkg/errors v0.8.1 // indirect
github.com/satori/go.uuid v1.2.0 // indirect
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 // indirect
github.com/stretchr/testify v1.3.0
github.com/ziutek/mymysql v1.5.4
xorm.io/builder v0.3.5
xorm.io/core v0.7.0
) )

159
go.sum Normal file
View File

@ -0,0 +1,159 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0 h1:eOI3/cP2VTU6uZLDYAoic+eyzzB9YyGmJ7eIjl8rOPg=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.37.4 h1:glPeL3BQJsbF6aIIYfZizMwc5LTYz250bDMjttbBGAU=
cloud.google.com/go v0.37.4/go.mod h1:NHPJ89PdicEuT9hdPXMROBD91xc5uRDxsMtSB16k7hw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4 h1:YcpmyvADGYw5LqMnHqSkyIELsHCGF6PkrmM31V8rF7o=
github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:9wScpmSP5A3Bk8V3XHWUcJmYTh+ZnlHVyc+A4oZYS3Y=
github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:56xuuqnHyryaerycW3BfssRdxQstACi0Epw/yC5E2xM=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 h1:vr3AYkKovP8uR8AvSGGUK1IDqRa5lAAvEkZG1LKaCRc=
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733/go.mod h1:WrMFNQdiFJ80sQsxDoMokWK1W5TQtxBFNpzWTD84ibQ=
github.com/jackc/pgx v3.3.0+incompatible h1:Wa90/+qsITBAPkAZjiByeIGHFcj3Ztu+VzrrIpHjL90=
github.com/jackc/pgx v3.3.0+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o=
github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs=
github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0=
go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190404172233-64821d5d2107/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
xorm.io/builder v0.3.5 h1:EilU39fvWDxjb1cDaELpYhsF+zziRBhew8xk4pngO+A=
xorm.io/builder v0.3.5/go.mod h1:ZFbByS/KxZI1FKRjL05PyJ4YrK2bcxlUaAxdum5aTR8=
xorm.io/core v0.6.3 h1:n1NhVZt1s2oLw1BZfX2ocIJsHyso259uPgg63BGr37M=
xorm.io/core v0.6.3/go.mod h1:8kz/C6arVW/O9vk3PgCiMJO2hIAm1UcuOL3dSPyZ2qo=

View File

@ -12,7 +12,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/go-xorm/core" "xorm.io/core"
) )
// str2PK convert string value to primary key value according to tp // str2PK convert string value to primary key value according to tp
@ -309,3 +309,24 @@ func sliceEq(left, right []string) bool {
func indexName(tableName, idxName string) string { func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName) return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
} }
func eraseAny(value string, strToErase ...string) string {
if len(strToErase) == 0 {
return value
}
var replaceSeq []string
for _, s := range strToErase {
replaceSeq = append(replaceSeq, s, "")
}
replacer := strings.NewReplacer(replaceSeq...)
return replacer.Replace(value)
}
func quoteColumns(cols []string, quoteFunc func(string) string, sep string) string {
for i := range cols {
cols[i] = quoteFunc(cols[i])
}
return strings.Join(cols, sep+" ")
}

View File

@ -4,7 +4,11 @@
package xorm package xorm
import "testing" import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSplitTag(t *testing.T) { func TestSplitTag(t *testing.T) {
var cases = []struct { var cases = []struct {
@ -24,3 +28,19 @@ func TestSplitTag(t *testing.T) {
} }
} }
} }
func TestEraseAny(t *testing.T) {
raw := "SELECT * FROM `table`.[table_name]"
assert.EqualValues(t, raw, eraseAny(raw))
assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`"))
assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]"))
}
func TestQuoteColumns(t *testing.T) {
cols := []string{"f1", "f2", "f3"}
quoteFunc := func(value string) string {
return "[" + value + "]"
}
assert.EqualValues(t, "[f1], [f2], [f3]", quoteColumns(cols, quoteFunc, ","))
}

View File

@ -5,11 +5,12 @@
package xorm package xorm
import ( import (
"context"
"database/sql" "database/sql"
"reflect" "reflect"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
) )
// Interface defines the interface which Engine, EngineGroup and Session will implementate. // Interface defines the interface which Engine, EngineGroup and Session will implementate.
@ -27,7 +28,7 @@ type Interface interface {
Delete(interface{}) (int64, error) Delete(interface{}) (int64, error)
Distinct(columns ...string) *Session Distinct(columns ...string) *Session
DropIndexes(bean interface{}) error DropIndexes(bean interface{}) error
Exec(sqlOrAgrs ...interface{}) (sql.Result, error) Exec(sqlOrArgs ...interface{}) (sql.Result, error)
Exist(bean ...interface{}) (bool, error) Exist(bean ...interface{}) (bool, error)
Find(interface{}, ...interface{}) error Find(interface{}, ...interface{}) error
FindAndCount(interface{}, ...interface{}) (int64, error) FindAndCount(interface{}, ...interface{}) (int64, error)
@ -49,9 +50,9 @@ type Interface interface {
Omit(columns ...string) *Session Omit(columns ...string) *Session
OrderBy(order string) *Session OrderBy(order string) *Session
Ping() error Ping() error
Query(sqlOrAgrs ...interface{}) (resultsSlice []map[string][]byte, err error) Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error)
QueryInterface(sqlorArgs ...interface{}) ([]map[string]interface{}, error) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error)
QueryString(sqlorArgs ...interface{}) ([]map[string]string, error) QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error)
Rows(bean interface{}) (*Rows, error) Rows(bean interface{}) (*Rows, error)
SetExpr(string, string) *Session SetExpr(string, string) *Session
SQL(interface{}, ...interface{}) *Session SQL(interface{}, ...interface{}) *Session
@ -73,6 +74,7 @@ type EngineInterface interface {
Before(func(interface{})) *Session Before(func(interface{})) *Session
Charset(charset string) *Session Charset(charset string) *Session
ClearCache(...interface{}) error ClearCache(...interface{}) error
Context(context.Context) *Session
CreateTables(...interface{}) error CreateTables(...interface{}) error
DBMetas() ([]*core.Table, error) DBMetas() ([]*core.Table, error)
Dialect() core.Dialect Dialect() core.Dialect

31
json.go Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import "encoding/json"
// JSONInterface represents an interface to handle json data
type JSONInterface interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
}
var (
// DefaultJSONHandler default json handler
DefaultJSONHandler JSONInterface = StdJSON{}
)
// StdJSON implements JSONInterface via encoding/json
type StdJSON struct{}
// Marshal implements JSONInterface
func (StdJSON) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
// Unmarshal implements JSONInterface
func (StdJSON) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}

View File

@ -9,7 +9,7 @@ import (
"io" "io"
"log" "log"
"github.com/go-xorm/core" "xorm.io/core"
) )
// default log options // default log options

View File

@ -8,7 +8,7 @@ import (
"github.com/go-xorm/xorm" "github.com/go-xorm/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"gopkg.in/stretchr/testify.v1/assert" "github.com/stretchr/testify/assert"
) )
type Person struct { type Person struct {

View File

@ -154,58 +154,30 @@ func TestProcessors(t *testing.T) {
} }
_, err = testEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) _, err = testEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.True(t, p.Id > 0, "Inserted ID not set")
panic(err) assert.True(t, p.B4InsertFlag > 0, "B4InsertFlag not set")
} else { assert.True(t, p.AfterInsertedFlag > 0, "B4InsertFlag not set")
if p.B4InsertFlag == 0 { assert.True(t, p.B4InsertViaExt > 0, "B4InsertFlag not set")
t.Error(errors.New("B4InsertFlag not set")) assert.True(t, p.AfterInsertedViaExt > 0, "AfterInsertedViaExt not set")
}
if p.AfterInsertedFlag == 0 {
t.Error(errors.New("B4InsertFlag not set"))
}
if p.B4InsertViaExt == 0 {
t.Error(errors.New("B4InsertFlag not set"))
}
if p.AfterInsertedViaExt == 0 {
t.Error(errors.New("AfterInsertedViaExt not set"))
}
}
p2 := &ProcessorsStruct{} p2 := &ProcessorsStruct{}
_, err = testEngine.ID(p.Id).Get(p2) has, err := testEngine.ID(p.Id).Get(p2)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.True(t, has)
panic(err) assert.True(t, p2.B4InsertFlag > 0, "B4InsertFlag not set")
} else { assert.True(t, p2.AfterInsertedFlag == 0, "AfterInsertedFlag is set")
if p2.B4InsertFlag == 0 { assert.True(t, p2.B4InsertViaExt > 0, "B4InsertViaExt not set")
t.Error(errors.New("B4InsertFlag not set")) assert.True(t, p2.AfterInsertedViaExt == 0, "AfterInsertedViaExt is set")
} assert.True(t, p2.BeforeSetFlag == 9, fmt.Sprintf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag))
if p2.AfterInsertedFlag != 0 { assert.True(t, p2.AfterSetFlag == 9, fmt.Sprintf("AfterSetFlag is %d not 9", p2.BeforeSetFlag))
t.Error(errors.New("AfterInsertedFlag is set"))
}
if p2.B4InsertViaExt == 0 {
t.Error(errors.New("B4InsertViaExt not set"))
}
if p2.AfterInsertedViaExt != 0 {
t.Error(errors.New("AfterInsertedViaExt is set"))
}
if p2.BeforeSetFlag != 9 {
t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag))
}
if p2.AfterSetFlag != 9 {
t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag))
}
}
// -- // --
// test find processors // test find processors
var p2Find []*ProcessorsStruct var p2Find []*ProcessorsStruct
err = testEngine.Find(&p2Find) err = testEngine.Find(&p2Find)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if len(p2Find) != 1 { if len(p2Find) != 1 {
err = errors.New("Should get 1") err = errors.New("Should get 1")
t.Error(err) t.Error(err)
@ -229,16 +201,13 @@ func TestProcessors(t *testing.T) {
if p21.AfterSetFlag != 9 { if p21.AfterSetFlag != 9 {
t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p21.BeforeSetFlag)) t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p21.BeforeSetFlag))
} }
}
// -- // --
// test find map processors // test find map processors
var p2FindMap = make(map[int64]*ProcessorsStruct) var p2FindMap = make(map[int64]*ProcessorsStruct)
err = testEngine.Find(&p2FindMap) err = testEngine.Find(&p2FindMap)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if len(p2FindMap) != 1 { if len(p2FindMap) != 1 {
err = errors.New("Should get 1") err = errors.New("Should get 1")
t.Error(err) t.Error(err)
@ -266,7 +235,6 @@ func TestProcessors(t *testing.T) {
if p22.AfterSetFlag != 9 { if p22.AfterSetFlag != 9 {
t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p22.BeforeSetFlag)) t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p22.BeforeSetFlag))
} }
}
// -- // --
// test update processors // test update processors
@ -289,10 +257,8 @@ func TestProcessors(t *testing.T) {
p = p2 // reset p = p2 // reset
_, err = testEngine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) _, err = testEngine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4UpdateFlag == 0 { if p.B4UpdateFlag == 0 {
t.Error(errors.New("B4UpdateFlag not set")) t.Error(errors.New("B4UpdateFlag not set"))
} }
@ -305,14 +271,12 @@ func TestProcessors(t *testing.T) {
if p.AfterUpdatedViaExt == 0 { if p.AfterUpdatedViaExt == 0 {
t.Error(errors.New("AfterUpdatedViaExt not set")) t.Error(errors.New("AfterUpdatedViaExt not set"))
} }
}
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = testEngine.ID(p.Id).Get(p2) has, err = testEngine.ID(p.Id).Get(p2)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.True(t, has)
panic(err)
} else {
if p2.B4UpdateFlag == 0 { if p2.B4UpdateFlag == 0 {
t.Error(errors.New("B4UpdateFlag not set")) t.Error(errors.New("B4UpdateFlag not set"))
} }
@ -331,7 +295,6 @@ func TestProcessors(t *testing.T) {
if p2.AfterSetFlag != 9 { if p2.AfterSetFlag != 9 {
t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag))
} }
}
// -- // --
// test delete processors // test delete processors
@ -353,10 +316,7 @@ func TestProcessors(t *testing.T) {
p = p2 // reset p = p2 // reset
_, err = testEngine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) _, err = testEngine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4DeleteFlag == 0 { if p.B4DeleteFlag == 0 {
t.Error(errors.New("B4DeleteFlag not set")) t.Error(errors.New("B4DeleteFlag not set"))
} }
@ -369,7 +329,6 @@ func TestProcessors(t *testing.T) {
if p.AfterDeletedViaExt == 0 { if p.AfterDeletedViaExt == 0 {
t.Error(errors.New("AfterDeletedViaExt not set")) t.Error(errors.New("AfterDeletedViaExt not set"))
} }
}
// -- // --
// test insert multi // test insert multi
@ -377,13 +336,9 @@ func TestProcessors(t *testing.T) {
pslice = append(pslice, &ProcessorsStruct{}) pslice = append(pslice, &ProcessorsStruct{})
pslice = append(pslice, &ProcessorsStruct{}) pslice = append(pslice, &ProcessorsStruct{})
cnt, err := testEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice) cnt, err := testEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.EqualValues(t, 2, cnt, "incorrect insert count")
panic(err)
} else {
if cnt != 2 {
t.Error(errors.New("incorrect insert count"))
}
for _, elem := range pslice { for _, elem := range pslice {
if elem.B4InsertFlag == 0 { if elem.B4InsertFlag == 0 {
t.Error(errors.New("B4InsertFlag not set")) t.Error(errors.New("B4InsertFlag not set"))
@ -398,15 +353,12 @@ func TestProcessors(t *testing.T) {
t.Error(errors.New("AfterInsertedViaExt not set")) t.Error(errors.New("AfterInsertedViaExt not set"))
} }
} }
}
for _, elem := range pslice { for _, elem := range pslice {
p = &ProcessorsStruct{} p = &ProcessorsStruct{}
_, err = testEngine.ID(elem.Id).Get(p) _, err = testEngine.ID(elem.Id).Get(p)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p2.B4InsertFlag == 0 { if p2.B4InsertFlag == 0 {
t.Error(errors.New("B4InsertFlag not set")) t.Error(errors.New("B4InsertFlag not set"))
} }
@ -426,7 +378,6 @@ func TestProcessors(t *testing.T) {
t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag))
} }
} }
}
// -- // --
} }
@ -434,24 +385,17 @@ func TestProcessorsTx(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
err := testEngine.DropTables(&ProcessorsStruct{}) err := testEngine.DropTables(&ProcessorsStruct{})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
err = testEngine.CreateTables(&ProcessorsStruct{}) err = testEngine.CreateTables(&ProcessorsStruct{})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
// test insert processors with tx rollback // test insert processors with tx rollback
session := testEngine.NewSession() session := testEngine.NewSession()
defer session.Close()
err = session.Begin() err = session.Begin()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
p := &ProcessorsStruct{} p := &ProcessorsStruct{}
b4InsertFunc := func(bean interface{}) { b4InsertFunc := func(bean interface{}) {
@ -470,10 +414,8 @@ func TestProcessorsTx(t *testing.T) {
} }
} }
_, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) _, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4InsertFlag == 0 { if p.B4InsertFlag == 0 {
t.Error(errors.New("B4InsertFlag not set")) t.Error(errors.New("B4InsertFlag not set"))
} }
@ -486,13 +428,10 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterInsertedViaExt != 0 { if p.AfterInsertedViaExt != 0 {
t.Error(errors.New("AfterInsertedViaExt is set")) t.Error(errors.New("AfterInsertedViaExt is set"))
} }
}
err = session.Rollback() err = session.Rollback()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4InsertFlag == 0 { if p.B4InsertFlag == 0 {
t.Error(errors.New("B4InsertFlag not set")) t.Error(errors.New("B4InsertFlag not set"))
} }
@ -505,36 +444,31 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterInsertedViaExt != 0 { if p.AfterInsertedViaExt != 0 {
t.Error(errors.New("AfterInsertedViaExt is set")) t.Error(errors.New("AfterInsertedViaExt is set"))
} }
}
session.Close() session.Close()
p2 := &ProcessorsStruct{} p2 := &ProcessorsStruct{}
_, err = testEngine.ID(p.Id).Get(p2) _, err = testEngine.ID(p.Id).Get(p2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p2.Id > 0 { if p2.Id > 0 {
err = errors.New("tx got committed upon insert!?") err = errors.New("tx got committed upon insert!?")
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
}
// -- // --
// test insert processors with tx commit // test insert processors with tx commit
session = testEngine.NewSession() session = testEngine.NewSession()
defer session.Close()
err = session.Begin() err = session.Begin()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
p = &ProcessorsStruct{} p = &ProcessorsStruct{}
_, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) _, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4InsertFlag == 0 { if p.B4InsertFlag == 0 {
t.Error(errors.New("B4InsertFlag not set")) t.Error(errors.New("B4InsertFlag not set"))
} }
@ -547,13 +481,10 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterInsertedViaExt != 0 { if p.AfterInsertedViaExt != 0 {
t.Error(errors.New("AfterInsertedViaExt is set")) t.Error(errors.New("AfterInsertedViaExt is set"))
} }
}
err = session.Commit() err = session.Commit()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4InsertFlag == 0 { if p.B4InsertFlag == 0 {
t.Error(errors.New("B4InsertFlag not set")) t.Error(errors.New("B4InsertFlag not set"))
} }
@ -566,14 +497,12 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterInsertedViaExt == 0 { if p.AfterInsertedViaExt == 0 {
t.Error(errors.New("AfterInsertedViaExt not set")) t.Error(errors.New("AfterInsertedViaExt not set"))
} }
}
session.Close() session.Close()
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = testEngine.ID(p.Id).Get(p2) _, err = testEngine.ID(p.Id).Get(p2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p2.B4InsertFlag == 0 { if p2.B4InsertFlag == 0 {
t.Error(errors.New("B4InsertFlag not set")) t.Error(errors.New("B4InsertFlag not set"))
} }
@ -586,17 +515,16 @@ func TestProcessorsTx(t *testing.T) {
if p2.AfterInsertedViaExt != 0 { if p2.AfterInsertedViaExt != 0 {
t.Error(errors.New("AfterInsertedViaExt is set")) t.Error(errors.New("AfterInsertedViaExt is set"))
} }
}
insertedId := p2.Id insertedId := p2.Id
// -- // --
// test update processors with tx rollback // test update processors with tx rollback
session = testEngine.NewSession() session = testEngine.NewSession()
defer session.Close()
err = session.Begin() err = session.Begin()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
b4UpdateFunc := func(bean interface{}) { b4UpdateFunc := func(bean interface{}) {
if v, ok := (bean).(*ProcessorsStruct); ok { if v, ok := (bean).(*ProcessorsStruct); ok {
@ -617,10 +545,8 @@ func TestProcessorsTx(t *testing.T) {
p = p2 // reset p = p2 // reset
_, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) _, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4UpdateFlag == 0 { if p.B4UpdateFlag == 0 {
t.Error(errors.New("B4UpdateFlag not set")) t.Error(errors.New("B4UpdateFlag not set"))
} }
@ -633,12 +559,10 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterUpdatedViaExt != 0 { if p.AfterUpdatedViaExt != 0 {
t.Error(errors.New("AfterUpdatedViaExt is set")) t.Error(errors.New("AfterUpdatedViaExt is set"))
} }
}
err = session.Rollback() err = session.Rollback()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4UpdateFlag == 0 { if p.B4UpdateFlag == 0 {
t.Error(errors.New("B4UpdateFlag not set")) t.Error(errors.New("B4UpdateFlag not set"))
} }
@ -651,16 +575,13 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterUpdatedViaExt != 0 { if p.AfterUpdatedViaExt != 0 {
t.Error(errors.New("AfterUpdatedViaExt is set")) t.Error(errors.New("AfterUpdatedViaExt is set"))
} }
}
session.Close() session.Close()
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = testEngine.ID(insertedId).Get(p2) _, err = testEngine.ID(insertedId).Get(p2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p2.B4UpdateFlag != 0 { if p2.B4UpdateFlag != 0 {
t.Error(errors.New("B4UpdateFlag is set")) t.Error(errors.New("B4UpdateFlag is set"))
} }
@ -673,36 +594,30 @@ func TestProcessorsTx(t *testing.T) {
if p2.AfterUpdatedViaExt != 0 { if p2.AfterUpdatedViaExt != 0 {
t.Error(errors.New("AfterUpdatedViaExt is set")) t.Error(errors.New("AfterUpdatedViaExt is set"))
} }
}
// -- // --
// test update processors with tx rollback // test update processors with tx rollback
session = testEngine.NewSession() session = testEngine.NewSession()
defer session.Close()
err = session.Begin() err = session.Begin()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
p = &ProcessorsStruct{Id: insertedId} p = &ProcessorsStruct{Id: insertedId}
_, err = session.Update(p) _, err = session.Update(p)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4UpdateFlag == 0 { if p.B4UpdateFlag == 0 {
t.Error(errors.New("B4UpdateFlag not set")) t.Error(errors.New("B4UpdateFlag not set"))
} }
if p.AfterUpdatedFlag != 0 { if p.AfterUpdatedFlag != 0 {
t.Error(errors.New("AfterUpdatedFlag is set")) t.Error(errors.New("AfterUpdatedFlag is set"))
} }
}
err = session.Commit() err = session.Commit()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4UpdateFlag == 0 { if p.B4UpdateFlag == 0 {
t.Error(errors.New("B4UpdateFlag not set")) t.Error(errors.New("B4UpdateFlag not set"))
} }
@ -715,25 +630,21 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterInsertedFlag != 0 { if p.AfterInsertedFlag != 0 {
t.Error(errors.New("AfterInsertedFlag set")) t.Error(errors.New("AfterInsertedFlag set"))
} }
}
session.Close() session.Close()
// test update processors with tx commit // test update processors with tx commit
session = testEngine.NewSession() session = testEngine.NewSession()
defer session.Close()
err = session.Begin() err = session.Begin()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
p = &ProcessorsStruct{} p = &ProcessorsStruct{}
_, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) _, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4UpdateFlag == 0 { if p.B4UpdateFlag == 0 {
t.Error(errors.New("B4UpdateFlag not set")) t.Error(errors.New("B4UpdateFlag not set"))
} }
@ -746,12 +657,10 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterUpdatedViaExt != 0 { if p.AfterUpdatedViaExt != 0 {
t.Error(errors.New("AfterUpdatedViaExt is set")) t.Error(errors.New("AfterUpdatedViaExt is set"))
} }
}
err = session.Commit() err = session.Commit()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4UpdateFlag == 0 { if p.B4UpdateFlag == 0 {
t.Error(errors.New("B4UpdateFlag not set")) t.Error(errors.New("B4UpdateFlag not set"))
} }
@ -764,14 +673,12 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterUpdatedViaExt == 0 { if p.AfterUpdatedViaExt == 0 {
t.Error(errors.New("AfterUpdatedViaExt not set")) t.Error(errors.New("AfterUpdatedViaExt not set"))
} }
}
session.Close() session.Close()
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = testEngine.ID(insertedId).Get(p2) _, err = testEngine.ID(insertedId).Get(p2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4UpdateFlag == 0 { if p.B4UpdateFlag == 0 {
t.Error(errors.New("B4UpdateFlag not set")) t.Error(errors.New("B4UpdateFlag not set"))
} }
@ -784,16 +691,14 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterUpdatedViaExt == 0 { if p.AfterUpdatedViaExt == 0 {
t.Error(errors.New("AfterUpdatedViaExt not set")) t.Error(errors.New("AfterUpdatedViaExt not set"))
} }
}
// -- // --
// test delete processors with tx rollback // test delete processors with tx rollback
session = testEngine.NewSession() session = testEngine.NewSession()
defer session.Close()
err = session.Begin() err = session.Begin()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
b4DeleteFunc := func(bean interface{}) { b4DeleteFunc := func(bean interface{}) {
if v, ok := (bean).(*ProcessorsStruct); ok { if v, ok := (bean).(*ProcessorsStruct); ok {
@ -814,10 +719,8 @@ func TestProcessorsTx(t *testing.T) {
p = &ProcessorsStruct{} // reset p = &ProcessorsStruct{} // reset
_, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) _, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4DeleteFlag == 0 { if p.B4DeleteFlag == 0 {
t.Error(errors.New("B4DeleteFlag not set")) t.Error(errors.New("B4DeleteFlag not set"))
} }
@ -830,12 +733,9 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterDeletedViaExt != 0 { if p.AfterDeletedViaExt != 0 {
t.Error(errors.New("AfterDeletedViaExt is set")) t.Error(errors.New("AfterDeletedViaExt is set"))
} }
}
err = session.Rollback() err = session.Rollback()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4DeleteFlag == 0 { if p.B4DeleteFlag == 0 {
t.Error(errors.New("B4DeleteFlag not set")) t.Error(errors.New("B4DeleteFlag not set"))
} }
@ -848,15 +748,13 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterDeletedViaExt != 0 { if p.AfterDeletedViaExt != 0 {
t.Error(errors.New("AfterDeletedViaExt is set")) t.Error(errors.New("AfterDeletedViaExt is set"))
} }
}
session.Close() session.Close()
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = testEngine.ID(insertedId).Get(p2) _, err = testEngine.ID(insertedId).Get(p2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p2.B4DeleteFlag != 0 { if p2.B4DeleteFlag != 0 {
t.Error(errors.New("B4DeleteFlag is set")) t.Error(errors.New("B4DeleteFlag is set"))
} }
@ -869,24 +767,20 @@ func TestProcessorsTx(t *testing.T) {
if p2.AfterDeletedViaExt != 0 { if p2.AfterDeletedViaExt != 0 {
t.Error(errors.New("AfterDeletedViaExt is set")) t.Error(errors.New("AfterDeletedViaExt is set"))
} }
}
// -- // --
// test delete processors with tx commit // test delete processors with tx commit
session = testEngine.NewSession() session = testEngine.NewSession()
defer session.Close()
err = session.Begin() err = session.Begin()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
p = &ProcessorsStruct{} p = &ProcessorsStruct{}
_, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) _, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4DeleteFlag == 0 { if p.B4DeleteFlag == 0 {
t.Error(errors.New("B4DeleteFlag not set")) t.Error(errors.New("B4DeleteFlag not set"))
} }
@ -899,12 +793,10 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterDeletedViaExt != 0 { if p.AfterDeletedViaExt != 0 {
t.Error(errors.New("AfterDeletedViaExt is set")) t.Error(errors.New("AfterDeletedViaExt is set"))
} }
}
err = session.Commit() err = session.Commit()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4DeleteFlag == 0 { if p.B4DeleteFlag == 0 {
t.Error(errors.New("B4DeleteFlag not set")) t.Error(errors.New("B4DeleteFlag not set"))
} }
@ -917,37 +809,30 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterDeletedViaExt == 0 { if p.AfterDeletedViaExt == 0 {
t.Error(errors.New("AfterDeletedViaExt not set")) t.Error(errors.New("AfterDeletedViaExt not set"))
} }
}
session.Close() session.Close()
// test delete processors with tx commit // test delete processors with tx commit
session = testEngine.NewSession() session = testEngine.NewSession()
defer session.Close()
err = session.Begin() err = session.Begin()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
p = &ProcessorsStruct{Id: insertedId} p = &ProcessorsStruct{Id: insertedId}
fmt.Println("delete")
_, err = session.Delete(p) _, err = session.Delete(p)
assert.NoError(t, err)
if err != nil {
t.Error(err)
panic(err)
} else {
if p.B4DeleteFlag == 0 { if p.B4DeleteFlag == 0 {
t.Error(errors.New("B4DeleteFlag not set")) t.Error(errors.New("B4DeleteFlag not set"))
} }
if p.AfterDeletedFlag != 0 { if p.AfterDeletedFlag != 0 {
t.Error(errors.New("AfterDeletedFlag is set")) t.Error(errors.New("AfterDeletedFlag is set"))
} }
}
err = session.Commit() err = session.Commit()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
} else {
if p.B4DeleteFlag == 0 { if p.B4DeleteFlag == 0 {
t.Error(errors.New("B4DeleteFlag not set")) t.Error(errors.New("B4DeleteFlag not set"))
} }
@ -960,7 +845,6 @@ func TestProcessorsTx(t *testing.T) {
if p.AfterUpdatedFlag != 0 { if p.AfterUpdatedFlag != 0 {
t.Error(errors.New("AfterUpdatedFlag set")) t.Error(errors.New("AfterUpdatedFlag set"))
} }
}
session.Close() session.Close()
// -- // --
} }

37
rows.go
View File

@ -9,16 +9,13 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"github.com/go-xorm/core" "xorm.io/core"
) )
// Rows rows wrapper a rows to // Rows rows wrapper a rows to
type Rows struct { type Rows struct {
NoTypeCheck bool
session *Session session *Session
rows *core.Rows rows *core.Rows
fields []string
beanType reflect.Type beanType reflect.Type
lastError error lastError error
} }
@ -57,13 +54,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
return nil, err return nil, err
} }
rows.fields, err = rows.rows.Columns()
if err != nil {
rows.lastError = err
rows.Close()
return nil, err
}
return rows, nil return rows, nil
} }
@ -90,7 +80,7 @@ func (rows *Rows) Scan(bean interface{}) error {
return rows.lastError return rows.lastError
} }
if !rows.NoTypeCheck && reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType {
return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
} }
@ -98,13 +88,18 @@ func (rows *Rows) Scan(bean interface{}) error {
return err return err
} }
scanResults, err := rows.session.row2Slice(rows.rows, rows.fields, bean) fields, err := rows.rows.Columns()
if err != nil {
return err
}
scanResults, err := rows.session.row2Slice(rows.rows, fields, bean)
if err != nil { if err != nil {
return err return err
} }
dataStruct := rValue(bean) dataStruct := rValue(bean)
_, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable) _, err = rows.session.slice2Bean(scanResults, fields, bean, &dataStruct, rows.session.statement.RefTable)
if err != nil { if err != nil {
return err return err
} }
@ -118,17 +113,9 @@ func (rows *Rows) Close() error {
defer rows.session.Close() defer rows.session.Close()
} }
if rows.lastError == nil {
if rows.rows != nil { if rows.rows != nil {
rows.lastError = rows.rows.Close() return rows.rows.Close()
if rows.lastError != nil { }
return rows.lastError
}
}
} else {
if rows.rows != nil {
defer rows.rows.Close()
}
}
return rows.lastError return rows.lastError
} }

View File

@ -38,6 +38,22 @@ func TestRows(t *testing.T) {
cnt++ cnt++
} }
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
assert.False(t, rows.Next())
assert.NoError(t, rows.Close())
rows0, err := testEngine.Where("1>1").Rows(new(UserRows))
assert.NoError(t, err)
defer rows0.Close()
cnt = 0
user0 := new(UserRows)
for rows0.Next() {
err = rows0.Scan(user0)
assert.NoError(t, err)
cnt++
}
assert.EqualValues(t, 0, cnt)
assert.NoError(t, rows0.Close())
sess := testEngine.NewSession() sess := testEngine.NewSession()
defer sess.Close() defer sess.Close()
@ -67,3 +83,68 @@ func TestRows(t *testing.T) {
} }
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
} }
func TestRowsMyTableName(t *testing.T) {
assert.NoError(t, prepareEngine())
type UserRowsMyTable struct {
Id int64
IsMan bool
}
var tableName = "user_rows_my_table_name"
assert.NoError(t, testEngine.Table(tableName).Sync2(new(UserRowsMyTable)))
cnt, err := testEngine.Table(tableName).Insert(&UserRowsMyTable{
IsMan: true,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
rows, err := testEngine.Table(tableName).Rows(new(UserRowsMyTable))
assert.NoError(t, err)
defer rows.Close()
cnt = 0
user := new(UserRowsMyTable)
for rows.Next() {
err = rows.Scan(user)
assert.NoError(t, err)
cnt++
}
assert.EqualValues(t, 1, cnt)
}
type UserRowsSpecTable struct {
Id int64
IsMan bool
}
func (UserRowsSpecTable) TableName() string {
return "user_rows_my_table_name"
}
func TestRowsSpecTableName(t *testing.T) {
assert.NoError(t, prepareEngine())
assert.NoError(t, testEngine.Sync2(new(UserRowsSpecTable)))
cnt, err := testEngine.Insert(&UserRowsSpecTable{
IsMan: true,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
rows, err := testEngine.Rows(new(UserRowsSpecTable))
assert.NoError(t, err)
defer rows.Close()
cnt = 0
user := new(UserRowsSpecTable)
for rows.Next() {
err = rows.Scan(user)
assert.NoError(t, err)
cnt++
}
assert.EqualValues(t, 1, cnt)
}

View File

@ -5,8 +5,8 @@
package xorm package xorm
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"hash/crc32" "hash/crc32"
@ -14,7 +14,14 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
)
type sessionType int
const (
engineSession sessionType = iota
groupSession
) )
// Session keep a pointer to sql.DB and provides all execution of all // Session keep a pointer to sql.DB and provides all execution of all
@ -51,7 +58,8 @@ type Session struct {
lastSQL string lastSQL string
lastSQLArgs []interface{} lastSQLArgs []interface{}
err error ctx context.Context
sessionType sessionType
} }
// Clone copy all the session's content and return a new session // Clone copy all the session's content and return a new session
@ -82,6 +90,8 @@ func (session *Session) Init() {
session.lastSQL = "" session.lastSQL = ""
session.lastSQLArgs = []interface{}{} session.lastSQLArgs = []interface{}{}
session.ctx = session.engine.defaultContext
} }
// Close release the connection from pool // Close release the connection from pool
@ -275,7 +285,7 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt,
var has bool var has bool
stmt, has = session.stmtCache[crc] stmt, has = session.stmtCache[crc]
if !has { if !has {
stmt, err = db.Prepare(sqlStr) stmt, err = db.PrepareContext(session.ctx, sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -480,13 +490,13 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
continue continue
} }
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) err := DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
x := reflect.New(fieldType) x := reflect.New(fieldType)
err := json.Unmarshal(bs, x.Interface()) err := DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -510,13 +520,13 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
hasAssigned = true hasAssigned = true
if len(bs) > 0 { if len(bs) > 0 {
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) err := DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
x := reflect.New(fieldType) x := reflect.New(fieldType)
err := json.Unmarshal(bs, x.Interface()) err := DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -532,7 +542,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
hasAssigned = true hasAssigned = true
if col.SQLType.IsText() { if col.SQLType.IsText() {
x := reflect.New(fieldType) x := reflect.New(fieldType)
err := json.Unmarshal(vv.Bytes(), x.Interface()) err := DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -647,7 +657,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
hasAssigned = true hasAssigned = true
x := reflect.New(fieldType) x := reflect.New(fieldType)
if len([]byte(vv.String())) > 0 { if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), x.Interface()) err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -657,7 +667,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
hasAssigned = true hasAssigned = true
x := reflect.New(fieldType) x := reflect.New(fieldType)
if len(vv.Bytes()) > 0 { if len(vv.Bytes()) > 0 {
err := json.Unmarshal(vv.Bytes(), x.Interface()) err := DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -793,7 +803,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
case core.Complex64Type: case core.Complex64Type:
var x complex64 var x complex64
if len([]byte(vv.String())) > 0 { if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x) err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -803,7 +813,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
case core.Complex128Type: case core.Complex128Type:
var x complex128 var x complex128
if len([]byte(vv.String())) > 0 { if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x) err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -845,3 +855,12 @@ func (session *Session) Unscoped() *Session {
session.statement.Unscoped() session.statement.Unscoped()
return session return session
} }
func (session *Session) incrVersionFieldValue(fieldValue *reflect.Value) {
switch fieldValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
fieldValue.SetInt(fieldValue.Int() + 1)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
fieldValue.SetUint(fieldValue.Uint() + 1)
}
}

View File

@ -9,7 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
) )
type incrParam struct { type incrParam struct {

View File

@ -7,7 +7,7 @@ package xorm
import ( import (
"testing" "testing"
"github.com/go-xorm/core" "xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )

View File

@ -4,7 +4,7 @@
package xorm package xorm
import "github.com/go-xorm/builder" import "xorm.io/builder"
// Sql provides raw sql input parameter. When you have a complex SQL statement // Sql provides raw sql input parameter. When you have a complex SQL statement
// and cannot use Where, Id, In and etc. Methods to describe, you can use SQL. // and cannot use Where, Id, In and etc. Methods to describe, you can use SQL.

View File

@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/go-xorm/builder" "xorm.io/builder"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )

View File

@ -1,18 +1,15 @@
// Copyright 2017 The Xorm Authors. All rights reserved. // Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build go1.8
package xorm package xorm
import "context" import "context"
// PingContext tests if database is alive // Context sets the context on this session
func (engine *Engine) PingContext(ctx context.Context) error { func (session *Session) Context(ctx context.Context) *Session {
session := engine.NewSession() session.ctx = ctx
defer session.Close() return session
return session.PingContext(ctx)
} }
// PingContext test if database is ok // PingContext test if database is ok

36
session_context_test.go Normal file
View File

@ -0,0 +1,36 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestQueryContext(t *testing.T) {
type ContextQueryStruct struct {
Id int64
Name string
}
assert.NoError(t, prepareEngine())
assertSync(t, new(ContextQueryStruct))
_, err := testEngine.Insert(&ContextQueryStruct{Name: "1"})
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
defer cancel()
time.Sleep(time.Nanosecond)
has, err := testEngine.Context(ctx).Exist(&ContextQueryStruct{Name: "1"})
assert.Error(t, err)
assert.Contains(t, err.Error(), "context deadline exceeded")
assert.False(t, has)
}

View File

@ -7,7 +7,6 @@ package xorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@ -15,7 +14,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
) )
func (session *Session) str2Time(col *core.Column, data string) (outTime time.Time, outErr error) { func (session *Session) str2Time(col *core.Column, data string) (outTime time.Time, outErr error) {
@ -103,7 +102,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
x := reflect.New(fieldType) x := reflect.New(fieldType)
if len(data) > 0 { if len(data) > 0 {
err := json.Unmarshal(data, x.Interface()) err := DefaultJSONHandler.Unmarshal(data, x.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
@ -117,7 +116,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
if col.SQLType.IsText() { if col.SQLType.IsText() {
x := reflect.New(fieldType) x := reflect.New(fieldType)
if len(data) > 0 { if len(data) > 0 {
err := json.Unmarshal(data, x.Interface()) err := DefaultJSONHandler.Unmarshal(data, x.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
@ -130,7 +129,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
} else { } else {
x := reflect.New(fieldType) x := reflect.New(fieldType)
if len(data) > 0 { if len(data) > 0 {
err := json.Unmarshal(data, x.Interface()) err := DefaultJSONHandler.Unmarshal(data, x.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
@ -259,7 +258,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
case core.Complex64Type.Kind(): case core.Complex64Type.Kind():
var x complex64 var x complex64
if len(data) > 0 { if len(data) > 0 {
err := json.Unmarshal(data, &x) err := DefaultJSONHandler.Unmarshal(data, &x)
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
@ -270,7 +269,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
case core.Complex128Type.Kind(): case core.Complex128Type.Kind():
var x complex128 var x complex128
if len(data) > 0 { if len(data) > 0 {
err := json.Unmarshal(data, &x) err := DefaultJSONHandler.Unmarshal(data, &x)
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
@ -604,14 +603,14 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
} }
return string(bytes), nil return string(bytes), nil
} else if col.SQLType.IsBlob() { } else if col.SQLType.IsBlob() {
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
@ -620,7 +619,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
} }
return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type()) return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type())
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
@ -632,7 +631,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
@ -641,11 +640,11 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
} else if col.SQLType.IsBlob() { } else if col.SQLType.IsBlob() {
var bytes []byte var bytes []byte
var err error var err error
if (k == reflect.Array || k == reflect.Slice) && if (k == reflect.Slice) &&
(fieldValue.Type().Elem().Kind() == reflect.Uint8) { (fieldValue.Type().Elem().Kind() == reflect.Uint8) {
bytes = fieldValue.Bytes() bytes = fieldValue.Bytes()
} else { } else {
bytes, err = json.Marshal(fieldValue.Interface()) bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err

View File

@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"github.com/go-xorm/core" "xorm.io/core"
) )
func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, args ...interface{}) error { func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, args ...interface{}) error {
@ -79,6 +79,10 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
defer session.Close() defer session.Close()
} }
if session.statement.lastError != nil {
return 0, session.statement.lastError
}
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.setRefBean(bean); err != nil {
return 0, err return 0, err
} }
@ -199,7 +203,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
}) })
} }
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { if cacher := session.engine.getCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache {
session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...)
} }

View File

@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
"xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -21,11 +22,27 @@ func TestDelete(t *testing.T) {
assert.NoError(t, testEngine.Sync2(new(UserinfoDelete))) assert.NoError(t, testEngine.Sync2(new(UserinfoDelete)))
session := testEngine.NewSession()
defer session.Close()
var err error
if testEngine.Dialect().DBType() == core.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT userinfo_delete ON")
assert.NoError(t, err)
}
user := UserinfoDelete{Uid: 1} user := UserinfoDelete{Uid: 1}
cnt, err := testEngine.Insert(&user) cnt, err := session.Insert(&user)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == core.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}
cnt, err = testEngine.Delete(&UserinfoDelete{Uid: user.Uid}) cnt, err = testEngine.Delete(&UserinfoDelete{Uid: user.Uid})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
@ -40,7 +57,7 @@ func TestDelete(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
cnt, err = testEngine.Where("id=?", user.Uid).Delete(&UserinfoDelete{}) cnt, err = testEngine.Where("`id`=?", user.Uid).Delete(&UserinfoDelete{})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)

View File

@ -9,8 +9,8 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"github.com/go-xorm/builder" "xorm.io/builder"
"github.com/go-xorm/core" "xorm.io/core"
) )
// Exist returns true if the record exist otherwise return false // Exist returns true if the record exist otherwise return false
@ -19,6 +19,10 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
defer session.Close() defer session.Close()
} }
if session.statement.lastError != nil {
return false, session.statement.lastError
}
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
var err error var err error
@ -30,6 +34,8 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
return false, ErrTableNotFound return false, ErrTableNotFound
} }
tableName = session.statement.Engine.Quote(tableName)
if session.statement.cond.IsValid() { if session.statement.cond.IsValid() {
condSQL, condArgs, err := builder.ToSQL(session.statement.cond) condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil { if err != nil {
@ -37,14 +43,18 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
} }
if session.engine.dialect.DBType() == core.MSSQL { if session.engine.dialect.DBType() == core.MSSQL {
sqlStr = fmt.Sprintf("SELECT top 1 * FROM %s WHERE %s", tableName, condSQL) sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s WHERE %s", tableName, condSQL)
} else if session.engine.dialect.DBType() == core.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) AND ROWNUM=1", tableName, condSQL)
} else { } else {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL) sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL)
} }
args = condArgs args = condArgs
} else { } else {
if session.engine.dialect.DBType() == core.MSSQL { if session.engine.dialect.DBType() == core.MSSQL {
sqlStr = fmt.Sprintf("SELECT top 1 * FROM %s", tableName) sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s", tableName)
} else if session.engine.dialect.DBType() == core.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE ROWNUM=1", tableName)
} else { } else {
sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName) sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName)
} }

View File

@ -10,8 +10,8 @@ import (
"reflect" "reflect"
"strings" "strings"
"github.com/go-xorm/builder" "xorm.io/builder"
"github.com/go-xorm/core" "xorm.io/core"
) )
const ( const (
@ -63,6 +63,10 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte
} }
func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
if session.statement.lastError != nil {
return session.statement.lastError
}
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
return errors.New("needs a pointer to a slice or a map") return errors.New("needs a pointer to a slice or a map")
@ -176,7 +180,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
if session.canCache() { if session.canCache() {
if cacher := session.engine.getCacher(table.Name); cacher != nil && if cacher := session.engine.getCacher(session.statement.TableName()); cacher != nil &&
!session.statement.IsDistinct && !session.statement.IsDistinct &&
!session.statement.unscoped { !session.statement.unscoped {
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)

View File

@ -10,7 +10,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )

View File

@ -11,7 +11,7 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"github.com/go-xorm/core" "xorm.io/core"
) )
// Get retrieve one record from database, bean's non-empty fields // Get retrieve one record from database, bean's non-empty fields
@ -24,6 +24,10 @@ func (session *Session) Get(bean interface{}) (bool, error) {
} }
func (session *Session) get(bean interface{}) (bool, error) { func (session *Session) get(bean interface{}) (bool, error) {
if session.statement.lastError != nil {
return false, session.statement.lastError
}
beanValue := reflect.ValueOf(bean) beanValue := reflect.ValueOf(bean)
if beanValue.Kind() != reflect.Ptr { if beanValue.Kind() != reflect.Ptr {
return false, errors.New("needs a pointer to a value") return false, errors.New("needs a pointer to a value")
@ -58,7 +62,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
table := session.statement.RefTable table := session.statement.RefTable
if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { if session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
if cacher := session.engine.getCacher(table.Name); cacher != nil && if cacher := session.engine.getCacher(session.statement.TableName()); cacher != nil &&
!session.statement.unscoped { !session.statement.unscoped {
has, err := session.cacheGet(bean, sqlStr, args...) has, err := session.cacheGet(bean, sqlStr, args...)
if err != ErrCacheFailed { if err != ErrCacheFailed {
@ -110,6 +114,114 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea
return true, rows.Scan(&bean) return true, rows.Scan(&bean)
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString: case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString:
return true, rows.Scan(bean) return true, rows.Scan(bean)
case *string:
var res sql.NullString
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*string)) = res.String
}
return true, nil
case *int:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int)) = int(res.Int64)
}
return true, nil
case *int8:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int8)) = int8(res.Int64)
}
return true, nil
case *int16:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int16)) = int16(res.Int64)
}
return true, nil
case *int32:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int32)) = int32(res.Int64)
}
return true, nil
case *int64:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int64)) = int64(res.Int64)
}
return true, nil
case *uint:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint)) = uint(res.Int64)
}
return true, nil
case *uint8:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint8)) = uint8(res.Int64)
}
return true, nil
case *uint16:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint16)) = uint16(res.Int64)
}
return true, nil
case *uint32:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint32)) = uint32(res.Int64)
}
return true, nil
case *uint64:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint64)) = uint64(res.Int64)
}
return true, nil
case *bool:
var res sql.NullBool
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*bool)) = res.Bool
}
return true, nil
} }
switch beanKind { switch beanKind {
@ -138,6 +250,9 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea
err = rows.ScanSlice(bean) err = rows.ScanSlice(bean)
case reflect.Map: case reflect.Map:
err = rows.ScanMap(bean) err = rows.ScanMap(bean)
case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
err = rows.Scan(&bean)
default: default:
err = rows.Scan(bean) err = rows.Scan(bean)
} }

View File

@ -10,8 +10,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-xorm/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core"
) )
func TestGetVar(t *testing.T) { func TestGetVar(t *testing.T) {
@ -47,6 +47,12 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, true, has) assert.Equal(t, true, has)
assert.Equal(t, 28, age) assert.Equal(t, 28, age)
var ageMax int
has, err = testEngine.SQL("SELECT max(age) FROM "+testEngine.TableName("get_var", true)+" WHERE `id` = ?", data.Id).Get(&ageMax)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, 28, ageMax)
var age2 int64 var age2 int64
has, err = testEngine.Table("get_var").Cols("age"). has, err = testEngine.Table("get_var").Cols("age").
Where("age > ?", 20). Where("age > ?", 20).
@ -56,6 +62,69 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, true, has) assert.Equal(t, true, has)
assert.EqualValues(t, 28, age2) assert.EqualValues(t, 28, age2)
var age3 int8
has, err = testEngine.Table("get_var").Cols("age").Get(&age3)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age3)
var age4 int16
has, err = testEngine.Table("get_var").Cols("age").
Where("age > ?", 20).
And("age < ?", 30).
Get(&age4)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age4)
var age5 int32
has, err = testEngine.Table("get_var").Cols("age").
Where("age > ?", 20).
And("age < ?", 30).
Get(&age5)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age5)
var age6 int
has, err = testEngine.Table("get_var").Cols("age").Get(&age6)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age6)
var age7 int64
has, err = testEngine.Table("get_var").Cols("age").
Where("age > ?", 20).
And("age < ?", 30).
Get(&age7)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age7)
var age8 int8
has, err = testEngine.Table("get_var").Cols("age").Get(&age8)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age8)
var age9 int16
has, err = testEngine.Table("get_var").Cols("age").
Where("age > ?", 20).
And("age < ?", 30).
Get(&age9)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age9)
var age10 int32
has, err = testEngine.Table("get_var").Cols("age").
Where("age > ?", 20).
And("age < ?", 30).
Get(&age10)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age10)
var id sql.NullInt64 var id sql.NullInt64
has, err = testEngine.Table("get_var").Cols("id").Get(&id) has, err = testEngine.Table("get_var").Cols("id").Get(&id)
assert.NoError(t, err) assert.NoError(t, err)
@ -84,7 +153,11 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))
var money2 float64 var money2 float64
if testEngine.Dialect().DBType() == core.MSSQL {
has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2)
} else {
has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2) has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2)
}
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2)) assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2))
@ -156,14 +229,23 @@ func TestGetStruct(t *testing.T) {
assert.NoError(t, testEngine.Sync2(new(UserinfoGet))) assert.NoError(t, testEngine.Sync2(new(UserinfoGet)))
session := testEngine.NewSession()
defer session.Close()
var err error var err error
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == core.MSSQL {
_, err = testEngine.Exec("SET IDENTITY_INSERT userinfo_get ON") err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON")
assert.NoError(t, err) assert.NoError(t, err)
} }
cnt, err := testEngine.Insert(&UserinfoGet{Uid: 2}) cnt, err := session.Insert(&UserinfoGet{Uid: 2})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == core.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}
user := UserinfoGet{Uid: 2} user := UserinfoGet{Uid: 2}
has, err := testEngine.Get(&user) has, err := testEngine.Get(&user)
@ -386,3 +468,119 @@ func TestContextGet2(t *testing.T) {
assert.EqualValues(t, 1, c3.Id) assert.EqualValues(t, 1, c3.Id)
assert.EqualValues(t, "1", c3.Name) assert.EqualValues(t, "1", c3.Name)
} }
type GetCustomTableInterface interface {
TableName() string
}
type MyGetCustomTableImpletation struct {
Id int64 `json:"id"`
Name string `json:"name"`
}
const getCustomTableName = "GetCustomTableInterface"
func (m *MyGetCustomTableImpletation) TableName() string {
return getCustomTableName
}
func TestGetCustomTableInterface(t *testing.T) {
assert.NoError(t, prepareEngine())
assert.NoError(t, testEngine.Table(getCustomTableName).Sync2(new(MyGetCustomTableImpletation)))
exist, err := testEngine.IsTableExist(getCustomTableName)
assert.NoError(t, err)
assert.True(t, exist)
_, err = testEngine.Insert(&MyGetCustomTableImpletation{
Name: "xlw",
})
assert.NoError(t, err)
var c GetCustomTableInterface = new(MyGetCustomTableImpletation)
has, err := testEngine.Get(c)
assert.NoError(t, err)
assert.True(t, has)
}
func TestGetNullVar(t *testing.T) {
type TestGetNullVarStruct struct {
Id int64
Name string
Age int
}
assert.NoError(t, prepareEngine())
assertSync(t, new(TestGetNullVarStruct))
affected, err := testEngine.Exec("insert into " + testEngine.TableName(new(TestGetNullVarStruct), true) + " (name,age) values (null,null)")
assert.NoError(t, err)
a, _ := affected.RowsAffected()
assert.EqualValues(t, 1, a)
var name string
has, err := testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("name").Get(&name)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "", name)
var age int
has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 0, age)
var age2 int8
has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 0, age2)
var age3 int16
has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 0, age3)
var age4 int32
has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age4)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 0, age4)
var age5 int64
has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age5)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 0, age5)
var age6 uint
has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age6)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 0, age6)
var age7 uint8
has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age7)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 0, age7)
var age8 int16
has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age8)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 0, age8)
var age9 int32
has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age9)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 0, age9)
var age10 int64
has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age10)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 0, age10)
}

View File

@ -8,10 +8,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"sort"
"strconv" "strconv"
"strings" "strings"
"github.com/go-xorm/core" "xorm.io/core"
) )
// Insert insert one or more beans // Insert insert one or more beans
@ -24,6 +25,40 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
} }
for _, bean := range beans { for _, bean := range beans {
switch bean.(type) {
case map[string]interface{}:
cnt, err := session.insertMapInterface(bean.(map[string]interface{}))
if err != nil {
return affected, err
}
affected += cnt
case []map[string]interface{}:
s := bean.([]map[string]interface{})
session.autoResetStatement = false
for i := 0; i < len(s); i++ {
cnt, err := session.insertMapInterface(s[i])
if err != nil {
return affected, err
}
affected += cnt
}
case map[string]string:
cnt, err := session.insertMapString(bean.(map[string]string))
if err != nil {
return affected, err
}
affected += cnt
case []map[string]string:
s := bean.([]map[string]string)
session.autoResetStatement = false
for i := 0; i < len(s); i++ {
cnt, err := session.insertMapString(s[i])
if err != nil {
return affected, err
}
affected += cnt
}
default:
sliceValue := reflect.Indirect(reflect.ValueOf(bean)) sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice { if sliceValue.Kind() == reflect.Slice {
size := sliceValue.Len() size := sliceValue.Len()
@ -52,6 +87,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
affected += cnt affected += cnt
} }
} }
}
return affected, err return affected, err
} }
@ -206,23 +242,17 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
var sql string var sql string
if session.engine.dialect.DBType() == core.ORACLE { if session.engine.dialect.DBType() == core.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
session.engine.Quote(tableName), session.engine.Quote(tableName),
session.engine.QuoteStr(), quoteColumns(colNames, session.engine.Quote, ","))
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
session.engine.QuoteStr())
sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL",
session.engine.Quote(tableName), session.engine.Quote(tableName),
session.engine.QuoteStr(), quoteColumns(colNames, session.engine.Quote, ","),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
strings.Join(colMultiPlaces, temp)) strings.Join(colMultiPlaces, temp))
} else { } else {
sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
session.engine.Quote(tableName), session.engine.Quote(tableName),
session.engine.QuoteStr(), quoteColumns(colNames, session.engine.Quote, ","),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
strings.Join(colMultiPlaces, "),(")) strings.Join(colMultiPlaces, "),("))
} }
res, err := session.exec(sql, args...) res, err := session.exec(sql, args...)
@ -337,21 +367,28 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
var sqlStr string var sqlStr string
var tableName = session.statement.TableName() var tableName = session.statement.TableName()
var output string
if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
}
if len(colPlaces) > 0 { if len(colPlaces) > 0 {
sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
session.engine.Quote(tableName), session.engine.Quote(tableName),
session.engine.QuoteStr(), quoteColumns(colNames, session.engine.Quote, ","),
strings.Join(colNames, session.engine.Quote(", ")), output,
session.engine.QuoteStr(),
colPlaces) colPlaces)
} else { } else {
if session.engine.dialect.DBType() == core.MYSQL { if session.engine.dialect.DBType() == core.MYSQL {
sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName)) sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName))
} else { } else {
sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.engine.Quote(tableName)) sqlStr = fmt.Sprintf("INSERT INTO %s%s DEFAULT VALUES", session.engine.Quote(tableName), output)
} }
} }
if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES {
sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement)
}
handleAfterInsertProcessorFunc := func(bean interface{}) { handleAfterInsertProcessorFunc := func(bean interface{}) {
if session.isAutoCommit { if session.isAutoCommit {
for _, closure := range session.afterClosures { for _, closure := range session.afterClosures {
@ -397,7 +434,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
} else if verValue.IsValid() && verValue.CanSet() { } else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1) session.incrVersionFieldValue(verValue)
} }
} }
@ -423,9 +460,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue.Set(int64ToIntValue(id, aiValue.Type())) aiValue.Set(int64ToIntValue(id, aiValue.Type()))
return 1, nil return 1, nil
} else if session.engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 { } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == core.POSTGRES || session.engine.dialect.DBType() == core.MSSQL) {
//assert table.AutoIncrement != ""
sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement)
res, err := session.queryBytes(sqlStr, args...) res, err := session.queryBytes(sqlStr, args...)
if err != nil { if err != nil {
@ -440,12 +475,12 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
} else if verValue.IsValid() && verValue.CanSet() { } else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1) session.incrVersionFieldValue(verValue)
} }
} }
if len(res) < 1 { if len(res) < 1 {
return 0, errors.New("insert no error but not returned id") return 0, errors.New("insert successfully but not returned id")
} }
idByte := res[0][table.AutoIncrement] idByte := res[0][table.AutoIncrement]
@ -481,7 +516,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
} else if verValue.IsValid() && verValue.CanSet() { } else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1) session.incrVersionFieldValue(verValue)
} }
} }
@ -622,3 +657,83 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
} }
return colNames, args, nil return colNames, args, nil
} }
func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
}
var columns = make([]string, 0, len(m))
for k := range m {
columns = append(columns, k)
}
sort.Strings(columns)
qm := strings.Repeat("?,", len(columns))
qm = "(" + qm[:len(qm)-1] + ")"
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return affected, nil
}
func (session *Session) insertMapString(m map[string]string) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
}
var columns = make([]string, 0, len(m))
for k := range m {
columns = append(columns, k)
}
sort.Strings(columns)
qm := strings.Repeat("?,", len(columns))
qm = "(" + qm[:len(qm)-1] + ")"
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return affected, nil
}

View File

@ -145,41 +145,22 @@ func TestInsert(t *testing.T) {
user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(), user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(),
Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true} Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true}
cnt, err := testEngine.Insert(&user) cnt, err := testEngine.Insert(&user)
fmt.Println(user.Uid) assert.NoError(t, err)
if err != nil { assert.EqualValues(t, 1, cnt, "insert not returned 1")
t.Error(err) assert.True(t, user.Uid > 0, "not return id error")
panic(err)
}
if cnt != 1 {
err = errors.New("insert not returned 1")
t.Error(err)
panic(err)
}
if user.Uid <= 0 {
err = errors.New("not return id error")
t.Error(err)
panic(err)
}
user.Uid = 0 user.Uid = 0
cnt, err = testEngine.Insert(&user) cnt, err = testEngine.Insert(&user)
// Username is unique, so this should return error
assert.Error(t, err, "insert should fail but no error returned")
assert.EqualValues(t, 0, cnt, "insert not returned 1")
if err == nil { if err == nil {
err = errors.New("insert failed but no return error") panic("should return err")
t.Error(err)
panic(err)
}
if cnt != 0 {
err = errors.New("insert not returned 1")
t.Error(err)
panic(err)
return
} }
} }
func TestInsertAutoIncr(t *testing.T) { func TestInsertAutoIncr(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
assertSync(t, new(Userinfo)) assertSync(t, new(Userinfo))
// auto increment insert // auto increment insert
@ -214,20 +195,14 @@ func TestInsertDefault(t *testing.T) {
di := new(DefaultInsert) di := new(DefaultInsert)
err := testEngine.Sync2(di) err := testEngine.Sync2(di)
if err != nil { assert.NoError(t, err)
t.Error(err)
}
var di2 = DefaultInsert{Name: "test"} var di2 = DefaultInsert{Name: "test"}
_, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("Status")).Insert(&di2) _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("Status")).Insert(&di2)
if err != nil { assert.NoError(t, err)
t.Error(err)
}
has, err := testEngine.Desc("(id)").Get(di) has, err := testEngine.Desc("(id)").Get(di)
if err != nil { assert.NoError(t, err)
t.Error(err)
}
if !has { if !has {
err = errors.New("error with no data") err = errors.New("error with no data")
t.Error(err) t.Error(err)
@ -780,3 +755,82 @@ func TestAnonymousStruct(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestInsertMap(t *testing.T) {
type InsertMap struct {
Id int64
Width uint32
Height uint32
Name string
}
assert.NoError(t, prepareEngine())
assertSync(t, new(InsertMap))
cnt, err := testEngine.Table(new(InsertMap)).Insert(map[string]interface{}{
"width": 20,
"height": 10,
"name": "lunny",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var im InsertMap
has, err := testEngine.Get(&im)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 20, im.Width)
assert.EqualValues(t, 10, im.Height)
assert.EqualValues(t, "lunny", im.Name)
cnt, err = testEngine.Table("insert_map").Insert(map[string]interface{}{
"width": 30,
"height": 10,
"name": "lunny",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var ims []InsertMap
err = testEngine.Find(&ims)
assert.NoError(t, err)
assert.EqualValues(t, 2, len(ims))
assert.EqualValues(t, 20, ims[0].Width)
assert.EqualValues(t, 10, ims[0].Height)
assert.EqualValues(t, "lunny", ims[0].Name)
assert.EqualValues(t, 30, ims[1].Width)
assert.EqualValues(t, 10, ims[1].Height)
assert.EqualValues(t, "lunny", ims[1].Name)
cnt, err = testEngine.Table("insert_map").Insert([]map[string]interface{}{
{
"width": 40,
"height": 10,
"name": "lunny",
},
{
"width": 50,
"height": 10,
"name": "lunny",
},
})
assert.NoError(t, err)
assert.EqualValues(t, 2, cnt)
ims = make([]InsertMap, 0, 4)
err = testEngine.Find(&ims)
assert.NoError(t, err)
assert.EqualValues(t, 4, len(ims))
assert.EqualValues(t, 20, ims[0].Width)
assert.EqualValues(t, 10, ims[0].Height)
assert.EqualValues(t, "lunny", ims[1].Name)
assert.EqualValues(t, 30, ims[1].Width)
assert.EqualValues(t, 10, ims[1].Height)
assert.EqualValues(t, "lunny", ims[1].Name)
assert.EqualValues(t, 40, ims[2].Width)
assert.EqualValues(t, 10, ims[2].Height)
assert.EqualValues(t, "lunny", ims[2].Name)
assert.EqualValues(t, 50, ims[3].Width)
assert.EqualValues(t, 10, ims[3].Height)
assert.EqualValues(t, "lunny", ims[3].Name)
}

View File

@ -23,6 +23,10 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error {
defer session.Close() defer session.Close()
} }
if session.statement.lastError != nil {
return session.statement.lastError
}
if session.statement.bufferSize > 0 { if session.statement.bufferSize > 0 {
return session.bufferIterate(bean, fun) return session.bufferIterate(bean, fun)
} }

View File

@ -9,7 +9,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )

View File

@ -11,13 +11,13 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-xorm/builder" "xorm.io/builder"
"github.com/go-xorm/core" "xorm.io/core"
) )
func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interface{}, error) { func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) {
if len(sqlorArgs) > 0 { if len(sqlOrArgs) > 0 {
return convertSQLOrArgs(sqlorArgs...) return convertSQLOrArgs(sqlOrArgs...)
} }
if session.statement.RawSQL != "" { if session.statement.RawSQL != "" {
@ -78,12 +78,12 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
} }
// Query runs a raw sql and return records as []map[string][]byte // Query runs a raw sql and return records as []map[string][]byte
func (session *Session) Query(sqlorArgs ...interface{}) ([]map[string][]byte, error) { func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
sqlStr, args, err := session.genQuerySQL(sqlorArgs...) sqlStr, args, err := session.genQuerySQL(sqlOrArgs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -227,12 +227,12 @@ func rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) {
} }
// QueryString runs a raw sql and return records as []map[string]string // QueryString runs a raw sql and return records as []map[string]string
func (session *Session) QueryString(sqlorArgs ...interface{}) ([]map[string]string, error) { func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
sqlStr, args, err := session.genQuerySQL(sqlorArgs...) sqlStr, args, err := session.genQuerySQL(sqlOrArgs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -247,12 +247,12 @@ func (session *Session) QueryString(sqlorArgs ...interface{}) ([]map[string]stri
} }
// QuerySliceString runs a raw sql and return records as [][]string // QuerySliceString runs a raw sql and return records as [][]string
func (session *Session) QuerySliceString(sqlorArgs ...interface{}) ([][]string, error) { func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
sqlStr, args, err := session.genQuerySQL(sqlorArgs...) sqlStr, args, err := session.genQuerySQL(sqlOrArgs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -300,12 +300,12 @@ func rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, er
} }
// QueryInterface runs a raw sql and return records as []map[string]interface{} // QueryInterface runs a raw sql and return records as []map[string]interface{}
func (session *Session) QueryInterface(sqlorArgs ...interface{}) ([]map[string]interface{}, error) { func (session *Session) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
sqlStr, args, err := session.genQuerySQL(sqlorArgs...) sqlStr, args, err := session.genQuerySQL(sqlOrArgs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -10,8 +10,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-xorm/builder" "xorm.io/builder"
"github.com/go-xorm/core" "xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -207,7 +207,7 @@ func TestQueryStringNoParam(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0]["id"]) assert.EqualValues(t, "1", records[0]["id"])
if testEngine.Dialect().URI().DbType == core.POSTGRES { if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL {
assert.EqualValues(t, "false", records[0]["msg"]) assert.EqualValues(t, "false", records[0]["msg"])
} else { } else {
assert.EqualValues(t, "0", records[0]["msg"]) assert.EqualValues(t, "0", records[0]["msg"])
@ -217,7 +217,7 @@ func TestQueryStringNoParam(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0]["id"]) assert.EqualValues(t, "1", records[0]["id"])
if testEngine.Dialect().URI().DbType == core.POSTGRES { if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL {
assert.EqualValues(t, "false", records[0]["msg"]) assert.EqualValues(t, "false", records[0]["msg"])
} else { } else {
assert.EqualValues(t, "0", records[0]["msg"]) assert.EqualValues(t, "0", records[0]["msg"])
@ -244,7 +244,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0][0]) assert.EqualValues(t, "1", records[0][0])
if testEngine.Dialect().URI().DbType == core.POSTGRES { if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL {
assert.EqualValues(t, "false", records[0][1]) assert.EqualValues(t, "false", records[0][1])
} else { } else {
assert.EqualValues(t, "0", records[0][1]) assert.EqualValues(t, "0", records[0][1])
@ -254,7 +254,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0][0]) assert.EqualValues(t, "1", records[0][0])
if testEngine.Dialect().URI().DbType == core.POSTGRES { if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL {
assert.EqualValues(t, "false", records[0][1]) assert.EqualValues(t, "false", records[0][1])
} else { } else {
assert.EqualValues(t, "0", records[0][1]) assert.EqualValues(t, "0", records[0][1])
@ -334,3 +334,47 @@ func TestQueryWithBuilder(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assertResult(t, results) assertResult(t, results)
} }
func TestJoinWithSubQuery(t *testing.T) {
assert.NoError(t, prepareEngine())
type JoinWithSubQuery1 struct {
Id int64 `xorm:"autoincr pk"`
Msg string `xorm:"varchar(255)"`
DepartId int64
Money float32
}
type JoinWithSubQueryDepart struct {
Id int64 `xorm:"autoincr pk"`
Name string
}
testEngine.ShowSQL(true)
assert.NoError(t, testEngine.Sync2(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart)))
var depart = JoinWithSubQueryDepart{
Name: "depart1",
}
cnt, err := testEngine.Insert(&depart)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var q = JoinWithSubQuery1{
Msg: "message",
DepartId: depart.Id,
Money: 3000,
}
cnt, err = testEngine.Insert(&q)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var querys []JoinWithSubQuery1
err = testEngine.Join("INNER", builder.Select("id").From(testEngine.Quote(testEngine.TableName("join_with_sub_query_depart", true))),
"join_with_sub_query_depart.id = join_with_sub_query1.depart_id").Find(&querys)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(querys))
assert.EqualValues(t, q, querys[0])
}

View File

@ -9,8 +9,8 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/go-xorm/builder" "xorm.io/builder"
"github.com/go-xorm/core" "xorm.io/core"
) )
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
@ -49,7 +49,7 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row
if session.isAutoCommit { if session.isAutoCommit {
var db *core.DB var db *core.DB
if session.engine.engineGroup != nil { if session.sessionType == groupSession {
db = session.engine.engineGroup.Slave().DB() db = session.engine.engineGroup.Slave().DB()
} else { } else {
db = session.DB() db = session.DB()
@ -62,21 +62,21 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row
return nil, err return nil, err
} }
rows, err := stmt.Query(args...) rows, err := stmt.QueryContext(session.ctx, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return rows, nil return rows, nil
} }
rows, err := db.Query(sqlStr, args...) rows, err := db.QueryContext(session.ctx, sqlStr, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return rows, nil return rows, nil
} }
rows, err := session.tx.Query(sqlStr, args...) rows, err := session.tx.QueryContext(session.ctx, sqlStr, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -175,7 +175,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
} }
if !session.isAutoCommit { if !session.isAutoCommit {
return session.tx.Exec(sqlStr, args...) return session.tx.ExecContext(session.ctx, sqlStr, args...)
} }
if session.prepareStmt { if session.prepareStmt {
@ -184,24 +184,24 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
return nil, err return nil, err
} }
res, err := stmt.Exec(args...) res, err := stmt.ExecContext(session.ctx, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return res, nil return res, nil
} }
return session.DB().Exec(sqlStr, args...) return session.DB().ExecContext(session.ctx, sqlStr, args...)
} }
func convertSQLOrArgs(sqlorArgs ...interface{}) (string, []interface{}, error) { func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) {
switch sqlorArgs[0].(type) { switch sqlOrArgs[0].(type) {
case string: case string:
return sqlorArgs[0].(string), sqlorArgs[1:], nil return sqlOrArgs[0].(string), sqlOrArgs[1:], nil
case *builder.Builder: case *builder.Builder:
return sqlorArgs[0].(*builder.Builder).ToSQL() return sqlOrArgs[0].(*builder.Builder).ToSQL()
case builder.Builder: case builder.Builder:
bd := sqlorArgs[0].(builder.Builder) bd := sqlOrArgs[0].(builder.Builder)
return bd.ToSQL() return bd.ToSQL()
} }
@ -209,16 +209,16 @@ func convertSQLOrArgs(sqlorArgs ...interface{}) (string, []interface{}, error) {
} }
// Exec raw sql // Exec raw sql
func (session *Session) Exec(sqlorArgs ...interface{}) (sql.Result, error) { func (session *Session) Exec(sqlOrArgs ...interface{}) (sql.Result, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
if len(sqlorArgs) == 0 { if len(sqlOrArgs) == 0 {
return nil, ErrUnSupportedType return nil, ErrUnSupportedType
} }
sqlStr, args, err := convertSQLOrArgs(sqlorArgs...) sqlStr, args, err := convertSQLOrArgs(sqlOrArgs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/go-xorm/core" "xorm.io/core"
) )
// Ping test if database is ok // Ping test if database is ok
@ -19,7 +19,7 @@ func (session *Session) Ping() error {
} }
session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
return session.DB().Ping() return session.DB().PingContext(session.ctx)
} }
// CreateTable create a table according a bean // CreateTable create a table according a bean

View File

@ -9,7 +9,7 @@ import (
"strconv" "strconv"
"testing" "testing"
"github.com/go-xorm/builder" "xorm.io/builder"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -18,13 +18,14 @@ func isFloatEq(i, j float64, precision int) bool {
} }
func TestSum(t *testing.T) { func TestSum(t *testing.T) {
assert.NoError(t, prepareEngine())
type SumStruct struct { type SumStruct struct {
Int int Int int
Float float32 Float float32
} }
assert.NoError(t, prepareEngine())
assert.NoError(t, testEngine.Sync2(new(SumStruct)))
var ( var (
cases = []SumStruct{ cases = []SumStruct{
{1, 6.2}, {1, 6.2},
@ -40,8 +41,6 @@ func TestSum(t *testing.T) {
f += v.Float f += v.Float
} }
assert.NoError(t, testEngine.Sync2(new(SumStruct)))
cnt, err := testEngine.Insert(cases) cnt, err := testEngine.Insert(cases)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 3, cnt) assert.EqualValues(t, 3, cnt)
@ -73,6 +72,65 @@ func TestSum(t *testing.T) {
assert.EqualValues(t, i, int(sumsInt[0])) assert.EqualValues(t, i, int(sumsInt[0]))
} }
type SumStructWithTableName struct {
Int int
Float float32
}
func (s SumStructWithTableName) TableName() string {
return "sum_struct_with_table_name_1"
}
func TestSumWithTableName(t *testing.T) {
assert.NoError(t, prepareEngine())
assert.NoError(t, testEngine.Sync2(new(SumStructWithTableName)))
var (
cases = []SumStructWithTableName{
{1, 6.2},
{2, 5.3},
{92, -0.2},
}
)
var i int
var f float32
for _, v := range cases {
i += v.Int
f += v.Float
}
cnt, err := testEngine.Insert(cases)
assert.NoError(t, err)
assert.EqualValues(t, 3, cnt)
colInt := testEngine.GetColumnMapper().Obj2Table("Int")
colFloat := testEngine.GetColumnMapper().Obj2Table("Float")
sumInt, err := testEngine.Sum(new(SumStructWithTableName), colInt)
assert.NoError(t, err)
assert.EqualValues(t, int(sumInt), i)
sumFloat, err := testEngine.Sum(new(SumStructWithTableName), colFloat)
assert.NoError(t, err)
assert.Condition(t, func() bool {
return isFloatEq(sumFloat, float64(f), 2)
})
sums, err := testEngine.Sums(new(SumStructWithTableName), colInt, colFloat)
assert.NoError(t, err)
assert.EqualValues(t, 2, len(sums))
assert.EqualValues(t, i, int(sums[0]))
assert.Condition(t, func() bool {
return isFloatEq(sums[1], float64(f), 2)
})
sumsInt, err := testEngine.SumsInt(new(SumStructWithTableName), colInt)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(sumsInt))
assert.EqualValues(t, i, int(sumsInt[0]))
}
func TestSumCustomColumn(t *testing.T) { func TestSumCustomColumn(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
@ -183,3 +241,36 @@ func TestCountWithOthers(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, total) assert.EqualValues(t, 2, total)
} }
type CountWithTableName struct {
Id int64
Name string
}
func (CountWithTableName) TableName() string {
return "count_with_table_name1"
}
func TestWithTableName(t *testing.T) {
assert.NoError(t, prepareEngine())
assertSync(t, new(CountWithTableName))
_, err := testEngine.Insert(&CountWithTableName{
Name: "orderby",
})
assert.NoError(t, err)
_, err = testEngine.Insert(CountWithTableName{
Name: "limit",
})
assert.NoError(t, err)
total, err := testEngine.OrderBy("id desc").Count(new(CountWithTableName))
assert.NoError(t, err)
assert.EqualValues(t, 2, total)
total, err = testEngine.OrderBy("id desc").Count(CountWithTableName{})
assert.NoError(t, err)
assert.EqualValues(t, 2, total)
}

View File

@ -7,7 +7,7 @@ package xorm
// Begin a transaction // Begin a transaction
func (session *Session) Begin() error { func (session *Session) Begin() error {
if session.isAutoCommit { if session.isAutoCommit {
tx, err := session.DB().Begin() tx, err := session.DB().BeginTx(session.ctx, nil)
if err != nil { if err != nil {
return err return err
} }

View File

@ -9,7 +9,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )

View File

@ -11,8 +11,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/go-xorm/builder" "xorm.io/builder"
"github.com/go-xorm/core" "xorm.io/core"
) )
func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, args ...interface{}) error { func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, args ...interface{}) error {
@ -96,14 +96,15 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
return ErrCacheFailed return ErrCacheFailed
} }
kvs := strings.Split(strings.TrimSpace(sqls[1]), ",") kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")
for idx, kv := range kvs { for idx, kv := range kvs {
sps := strings.SplitN(kv, "=", 2) sps := strings.SplitN(kv, "=", 2)
sps2 := strings.Split(sps[0], ".") sps2 := strings.Split(sps[0], ".")
colName := sps2[len(sps2)-1] colName := sps2[len(sps2)-1]
if strings.Contains(colName, "`") { // treat quote prefix, suffix and '`' as quotes
colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1)) quotes := append(strings.Split(session.engine.Quote(""), ""), "`")
} else if strings.Contains(colName, session.engine.QuoteStr()) { if strings.ContainsAny(colName, strings.Join(quotes, "")) {
colName = strings.TrimSpace(strings.Replace(colName, session.engine.QuoteStr(), "", -1)) colName = strings.TrimSpace(eraseAny(colName, quotes...))
} else { } else {
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName) session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
return ErrCacheFailed return ErrCacheFailed
@ -116,7 +117,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
} else { } else {
session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
if col.IsVersion && session.statement.checkVersion { if col.IsVersion && session.statement.checkVersion {
fieldValue.SetInt(fieldValue.Int() + 1) session.incrVersionFieldValue(fieldValue)
} else { } else {
fieldValue.Set(reflect.ValueOf(args[idx])) fieldValue.Set(reflect.ValueOf(args[idx]))
} }
@ -147,6 +148,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
defer session.Close() defer session.Close()
} }
if session.statement.lastError != nil {
return 0, session.statement.lastError
}
v := rValue(bean) v := rValue(bean)
t := v.Type() t := v.Type()
@ -240,7 +245,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
var autoCond builder.Cond var autoCond builder.Cond
if !session.statement.noAutoCondition && len(condiBean) > 0 { if !session.statement.noAutoCondition {
condBeanIsStruct := false
if len(condiBean) > 0 {
if c, ok := condiBean[0].(map[string]interface{}); ok { if c, ok := condiBean[0].(map[string]interface{}); ok {
autoCond = builder.Eq(c) autoCond = builder.Eq(c)
} else { } else {
@ -255,12 +262,26 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if err != nil { if err != nil {
return 0, err return 0, err
} }
condBeanIsStruct = true
} else { } else {
return 0, ErrConditionType return 0, ErrConditionType
} }
} }
} }
if !condBeanIsStruct && table != nil {
if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled
autoCond1 := session.engine.CondDeleted(session.engine.Quote(col.Name))
if autoCond == nil {
autoCond = autoCond1
} else {
autoCond = autoCond.And(autoCond1)
}
}
}
}
st := &session.statement st := &session.statement
var sqlStr string var sqlStr string
@ -357,7 +378,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, err return 0, err
} else if doIncVer { } else if doIncVer {
if verValue != nil && verValue.IsValid() && verValue.CanSet() { if verValue != nil && verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(verValue.Int() + 1) session.incrVersionFieldValue(verValue)
} }
} }

View File

@ -11,8 +11,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-xorm/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core"
) )
func TestUpdateMap(t *testing.T) { func TestUpdateMap(t *testing.T) {
@ -110,7 +110,7 @@ func setupForUpdate(engine EngineInterface) error {
} }
func TestForUpdate(t *testing.T) { func TestForUpdate(t *testing.T) {
if testEngine.Dialect().DriverName() != "mysql" && testEngine.Dialect().DriverName() != "mymysql" { if *ignoreSelectUpdate {
return return
} }
@ -1331,3 +1331,62 @@ func TestUpdateCondiBean(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
} }
func TestWhereCondErrorWhenUpdate(t *testing.T) {
type AuthRequestError struct {
ChallengeToken string
RequestToken string
}
assert.NoError(t, prepareEngine())
assertSync(t, new(AuthRequestError))
_, err := testEngine.Cols("challenge_token", "request_token", "challenge_agent", "status").
Where(&AuthRequestError{ChallengeToken: "1"}).
Update(&AuthRequestError{
ChallengeToken: "2",
})
assert.Error(t, err)
assert.EqualValues(t, ErrConditionType, err)
}
func TestUpdateDeleted(t *testing.T) {
assert.NoError(t, prepareEngine())
type UpdateDeletedStruct struct {
Id int64
Name string
DeletedAt time.Time `xorm:"deleted"`
}
assertSync(t, new(UpdateDeletedStruct))
var s = UpdateDeletedStruct{
Name: "test",
}
cnt, err := testEngine.Insert(&s)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
cnt, err = testEngine.ID(s.Id).Delete(&UpdateDeletedStruct{})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
cnt, err = testEngine.ID(s.Id).Update(&UpdateDeletedStruct{
Name: "test1",
})
assert.NoError(t, err)
assert.EqualValues(t, 0, cnt)
cnt, err = testEngine.Table(&UpdateDeletedStruct{}).ID(s.Id).Update(map[string]interface{}{
"name": "test1",
})
assert.NoError(t, err)
assert.EqualValues(t, 0, cnt)
cnt, err = testEngine.ID(s.Id).Unscoped().Update(&UpdateDeletedStruct{
Name: "test1",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
}

View File

@ -6,15 +6,13 @@ package xorm
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"time" "time"
"github.com/go-xorm/builder" "xorm.io/builder"
"github.com/go-xorm/core" "xorm.io/core"
) )
// Statement save all the sql info for executing SQL // Statement save all the sql info for executing SQL
@ -60,6 +58,7 @@ type Statement struct {
cond builder.Cond cond builder.Cond
bufferSize int bufferSize int
context ContextCache context ContextCache
lastError error
} }
// Init reset all the statement's fields // Init reset all the statement's fields
@ -101,6 +100,7 @@ func (statement *Statement) Init() {
statement.cond = builder.NewCond() statement.cond = builder.NewCond()
statement.bufferSize = 0 statement.bufferSize = 0
statement.context = nil statement.context = nil
statement.lastError = nil
} }
// NoAutoCondition if you do not want convert bean's field as query condition, then use this function // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
@ -125,13 +125,13 @@ func (statement *Statement) SQL(query interface{}, args ...interface{}) *Stateme
var err error var err error
statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL() statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
if err != nil { if err != nil {
statement.Engine.logger.Error(err) statement.lastError = err
} }
case string: case string:
statement.RawSQL = query.(string) statement.RawSQL = query.(string)
statement.RawParams = args statement.RawParams = args
default: default:
statement.Engine.logger.Error("unsupported sql type") statement.lastError = ErrUnSupportedSQLType
} }
return statement return statement
@ -160,7 +160,7 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
} }
} }
default: default:
// TODO: not support condition type statement.lastError = ErrConditionType
} }
return statement return statement
@ -406,7 +406,7 @@ func (statement *Statement) buildUpdates(bean interface{},
} else { } else {
// Blank struct could not be as update data // Blank struct could not be as update data
if requiredField || !isStructZero(fieldValue) { if requiredField || !isStructZero(fieldValue) {
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface())) panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
} }
@ -435,7 +435,7 @@ func (statement *Statement) buildUpdates(bean interface{},
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
engine.logger.Error(err) engine.logger.Error(err)
continue continue
@ -455,7 +455,7 @@ func (statement *Statement) buildUpdates(bean interface{},
fieldType.Elem().Kind() == reflect.Uint8 { fieldType.Elem().Kind() == reflect.Uint8 {
val = fieldValue.Slice(0, 0).Interface() val = fieldValue.Slice(0, 0).Interface()
} else { } else {
bytes, err = json.Marshal(fieldValue.Interface()) bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
engine.logger.Error(err) engine.logger.Error(err)
continue continue
@ -578,21 +578,9 @@ func (statement *Statement) getExpr() map[string]exprParam {
func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
newColumns := make([]string, 0) newColumns := make([]string, 0)
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
for _, col := range columns { for _, col := range columns {
col = strings.Replace(col, "`", "", -1) newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
fields := strings.Split(strings.TrimSpace(c), ".")
if len(fields) == 1 {
newColumns = append(newColumns, statement.Engine.quote(fields[0]))
} else if len(fields) == 2 {
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
statement.Engine.quote(fields[1]))
} else {
panic(errors.New("unwanted colnames"))
}
}
} }
return newColumns return newColumns
} }
@ -755,9 +743,36 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt.Fprintf(&buf, "%v JOIN ", joinOP) fmt.Fprintf(&buf, "%v JOIN ", joinOP)
} }
tbName := statement.Engine.TableName(tablename, true) switch tp := tablename.(type) {
case builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.lastError = err
return statement
}
tbs := strings.Split(tp.TableName(), ".")
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.lastError = err
return statement
}
tbs := strings.Split(tp.TableName(), ".")
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
tbName := statement.Engine.TableName(tablename, true)
fmt.Fprintf(&buf, "%s ON %v", tbName, condition) fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
}
statement.JoinStr = buf.String() statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...) statement.joinArgs = append(statement.joinArgs, args...)
return statement return statement
@ -1133,8 +1148,12 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
if statement.Start != 0 || statement.LimitN != 0 { if statement.Start != 0 || statement.LimitN != 0 {
oldString := buf.String() oldString := buf.String()
buf.Reset() buf.Reset()
rawColStr := columnStr
if rawColStr == "*" {
rawColStr = "at.*"
}
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, columnStr, oldString, statement.Start+statement.LimitN, statement.Start) columnStr, rawColStr, oldString, statement.Start+statement.LimitN, statement.Start)
} }
} }
} }

View File

@ -9,8 +9,8 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/go-xorm/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core"
) )
var colStrTests = []struct { var colStrTests = []struct {
@ -237,3 +237,12 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
testEngine.Update(record) testEngine.Update(record)
assertGetRecord() assertGetRecord()
} }
func TestCol2NewColsWithQuote(t *testing.T) {
cols := []string{"f1", "f2", "t3.f3"}
statement := createTestStatement()
quotedCols := statement.col2NewColsWithQuote(cols...)
assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols)
}

View File

@ -10,7 +10,7 @@ import (
"fmt" "fmt"
"log/syslog" "log/syslog"
"github.com/go-xorm/core" "xorm.io/core"
) )
var _ core.ILogger = &SyslogLogger{} var _ core.ILogger = &SyslogLogger{}

22
tag.go
View File

@ -11,7 +11,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
) )
type tagContext struct { type tagContext struct {
@ -244,6 +244,7 @@ func SQLTypeTagHandler(ctx *tagContext) error {
// ExtendsTagHandler describes extends tag handler // ExtendsTagHandler describes extends tag handler
func ExtendsTagHandler(ctx *tagContext) error { func ExtendsTagHandler(ctx *tagContext) error {
var fieldValue = ctx.fieldValue var fieldValue = ctx.fieldValue
var isPtr = false
switch fieldValue.Kind() { switch fieldValue.Kind() {
case reflect.Ptr: case reflect.Ptr:
f := fieldValue.Type().Elem() f := fieldValue.Type().Elem()
@ -254,6 +255,7 @@ func ExtendsTagHandler(ctx *tagContext) error {
fieldValue = reflect.New(f).Elem() fieldValue = reflect.New(f).Elem()
} }
} }
isPtr = true
fallthrough fallthrough
case reflect.Struct: case reflect.Struct:
parentTable, err := ctx.engine.mapType(fieldValue) parentTable, err := ctx.engine.mapType(fieldValue)
@ -262,6 +264,24 @@ func ExtendsTagHandler(ctx *tagContext) error {
} }
for _, col := range parentTable.Columns() { for _, col := range parentTable.Columns() {
col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName)
var tagPrefix = ctx.col.FieldName
if len(ctx.params) > 0 {
col.Nullable = isPtr
tagPrefix = ctx.params[0]
if col.IsPrimaryKey {
col.Name = ctx.col.FieldName
col.IsPrimaryKey = false
} else {
col.Name = fmt.Sprintf("%v%v", tagPrefix, col.Name)
}
}
if col.Nullable {
col.IsAutoIncrement = false
col.IsPrimaryKey = false
}
ctx.table.AddColumn(col) ctx.table.AddColumn(col)
for indexName, indexType := range col.Indexes { for indexName, indexType := range col.Indexes {
addIndex(indexName, ctx.table, col, indexType) addIndex(indexName, ctx.table, col, indexType)

View File

@ -10,7 +10,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -60,63 +60,37 @@ func TestExtends(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
err := testEngine.DropTables(&tempUser2{}) err := testEngine.DropTables(&tempUser2{})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
err = testEngine.CreateTables(&tempUser2{}) err = testEngine.CreateTables(&tempUser2{})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
tu := &tempUser2{tempUser{0, "extends"}, "dev depart"} tu := &tempUser2{tempUser{0, "extends"}, "dev depart"}
_, err = testEngine.Insert(tu) _, err = testEngine.Insert(tu)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
tu2 := &tempUser2{} tu2 := &tempUser2{}
_, err = testEngine.Get(tu2) _, err = testEngine.Get(tu2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
tu3 := &tempUser2{tempUser{0, "extends update"}, ""} tu3 := &tempUser2{tempUser{0, "extends update"}, ""}
_, err = testEngine.ID(tu2.TempUser.Id).Update(tu3) _, err = testEngine.ID(tu2.TempUser.Id).Update(tu3)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
err = testEngine.DropTables(&tempUser4{}) err = testEngine.DropTables(&tempUser4{})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
err = testEngine.CreateTables(&tempUser4{}) err = testEngine.CreateTables(&tempUser4{})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
tu8 := &tempUser4{tempUser2{tempUser{0, "extends"}, "dev depart"}} tu8 := &tempUser4{tempUser2{tempUser{0, "extends"}, "dev depart"}}
_, err = testEngine.Insert(tu8) _, err = testEngine.Insert(tu8)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
tu9 := &tempUser4{} tu9 := &tempUser4{}
_, err = testEngine.Get(tu9) _, err = testEngine.Get(tu9)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
if tu9.TempUser2.TempUser.Username != tu8.TempUser2.TempUser.Username || tu9.TempUser2.Departname != tu8.TempUser2.Departname { if tu9.TempUser2.TempUser.Username != tu8.TempUser2.TempUser.Username || tu9.TempUser2.Departname != tu8.TempUser2.Departname {
err = errors.New(fmt.Sprintln("not equal for", tu8, tu9)) err = errors.New(fmt.Sprintln("not equal for", tu8, tu9))
t.Error(err) t.Error(err)
@ -125,36 +99,22 @@ func TestExtends(t *testing.T) {
tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}} tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}}
_, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10) _, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
err = testEngine.DropTables(&tempUser3{}) err = testEngine.DropTables(&tempUser3{})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
err = testEngine.CreateTables(&tempUser3{}) err = testEngine.CreateTables(&tempUser3{})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
tu4 := &tempUser3{&tempUser{0, "extends"}, "dev depart"} tu4 := &tempUser3{&tempUser{0, "extends"}, "dev depart"}
_, err = testEngine.Insert(tu4) _, err = testEngine.Insert(tu4)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
tu5 := &tempUser3{} tu5 := &tempUser3{}
_, err = testEngine.Get(tu5) _, err = testEngine.Get(tu5)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
if tu5.Temp == nil { if tu5.Temp == nil {
err = errors.New("error get data extends") err = errors.New("error get data extends")
t.Error(err) t.Error(err)
@ -169,22 +129,12 @@ func TestExtends(t *testing.T) {
tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} tu6 := &tempUser3{&tempUser{0, "extends update"}, ""}
_, err = testEngine.ID(tu5.Temp.Id).Update(tu6) _, err = testEngine.ID(tu5.Temp.Id).Update(tu6)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
users := make([]tempUser3, 0) users := make([]tempUser3, 0)
err = testEngine.Find(&users) err = testEngine.Find(&users)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.EqualValues(t, 1, len(users), "error get data not 1")
panic(err)
}
if len(users) != 1 {
err = errors.New("error get data not 1")
t.Error(err)
panic(err)
}
assertSync(t, new(Userinfo), new(Userdetail)) assertSync(t, new(Userinfo), new(Userdetail))
@ -249,10 +199,7 @@ func TestExtends(t *testing.T) {
Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)).
NoCascade(). NoCascade().
Find(&infos2) Find(&infos2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
fmt.Println(infos2) fmt.Println(infos2)
} }
@ -297,25 +244,16 @@ func TestExtends2(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
var sender = MessageUser{Name: "sender"} var sender = MessageUser{Name: "sender"}
var receiver = MessageUser{Name: "receiver"} var receiver = MessageUser{Name: "receiver"}
var msgtype = MessageType{Name: "type"} var msgtype = MessageType{Name: "type"}
_, err = testEngine.Insert(&sender, &receiver, &msgtype) _, err = testEngine.Insert(&sender, &receiver, &msgtype)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
msg := Message{ msg := Message{
MessageBase: MessageBase{ MessageBase: MessageBase{
@ -326,15 +264,24 @@ func TestExtends2(t *testing.T) {
Uid: sender.Id, Uid: sender.Id,
ToUid: receiver.Id, ToUid: receiver.Id,
} }
session := testEngine.NewSession()
defer session.Close()
// MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == core.MSSQL {
_, err = testEngine.Exec("SET IDENTITY_INSERT message ON") err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON")
assert.NoError(t, err) assert.NoError(t, err)
} }
cnt, err := session.Insert(&msg)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
_, err = testEngine.Insert(&msg) if testEngine.Dialect().DBType() == core.MSSQL {
if err != nil { err = session.Commit()
t.Error(err) assert.NoError(t, err)
panic(err)
} }
var mapper = testEngine.GetTableMapper().Obj2Table var mapper = testEngine.GetTableMapper().Obj2Table
@ -344,23 +291,14 @@ func TestExtends2(t *testing.T) {
msgTableName := quote(testEngine.TableName(mapper("Message"), true)) msgTableName := quote(testEngine.TableName(mapper("Message"), true))
list := make([]Message, 0) list := make([]Message, 0)
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`").
Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
Find(&list) Find(&list)
assert.NoError(t, err) assert.NoError(t, err)
if len(list) != 1 { assert.EqualValues(t, 1, len(list), fmt.Sprintln("should have 1 message, got", len(list)))
err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) assert.EqualValues(t, msg.Id, list[0].Id, fmt.Sprintln("should message equal", list[0], msg))
t.Error(err)
panic(err)
}
if list[0].Id != msg.Id {
err = errors.New(fmt.Sprintln("should message equal", list[0], msg))
t.Error(err)
panic(err)
}
} }
func TestExtends3(t *testing.T) { func TestExtends3(t *testing.T) {
@ -396,13 +334,25 @@ func TestExtends3(t *testing.T) {
Uid: sender.Id, Uid: sender.Id,
ToUid: receiver.Id, ToUid: receiver.Id,
} }
session := testEngine.NewSession()
defer session.Close()
// MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == core.MSSQL {
_, err = testEngine.Exec("SET IDENTITY_INSERT message ON") err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON")
assert.NoError(t, err) assert.NoError(t, err)
} }
_, err = testEngine.Insert(&msg) _, err = session.Insert(&msg)
assert.NoError(t, err) assert.NoError(t, err)
if testEngine.Dialect().DBType() == core.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}
var mapper = testEngine.GetTableMapper().Obj2Table var mapper = testEngine.GetTableMapper().Obj2Table
var quote = testEngine.Quote var quote = testEngine.Quote
userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
@ -410,7 +360,7 @@ func TestExtends3(t *testing.T) {
msgTableName := quote(testEngine.TableName(mapper("Message"), true)) msgTableName := quote(testEngine.TableName(mapper("Message"), true))
list := make([]MessageExtend3, 0) list := make([]MessageExtend3, 0)
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`").
Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
Find(&list) Find(&list)
@ -478,14 +428,23 @@ func TestExtends4(t *testing.T) {
Content: "test", Content: "test",
Uid: sender.Id, Uid: sender.Id,
} }
session := testEngine.NewSession()
defer session.Close()
// MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == core.MSSQL {
_, err = testEngine.Exec("SET IDENTITY_INSERT message ON") err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON")
assert.NoError(t, err) assert.NoError(t, err)
} }
_, err = testEngine.Insert(&msg) _, err = session.Insert(&msg)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err) if testEngine.Dialect().DBType() == core.MSSQL {
err = session.Commit()
assert.NoError(t, err)
} }
var mapper = testEngine.GetTableMapper().Obj2Table var mapper = testEngine.GetTableMapper().Obj2Table
@ -495,7 +454,7 @@ func TestExtends4(t *testing.T) {
msgTableName := quote(testEngine.TableName(mapper("Message"), true)) msgTableName := quote(testEngine.TableName(mapper("Message"), true))
list := make([]MessageExtend4, 0) list := make([]MessageExtend4, 0)
err = testEngine.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). err = session.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
Find(&list) Find(&list)
if err != nil { if err != nil {
@ -527,3 +486,123 @@ func TestExtends4(t *testing.T) {
panic(err) panic(err)
} }
} }
type Size struct {
ID int64 `xorm:"int(4) 'id' pk autoincr"`
Width float32 `json:"width" xorm:"float 'Width'"`
Height float32 `json:"height" xorm:"float 'Height'"`
}
type Book struct {
ID int64 `xorm:"int(4) 'id' pk autoincr"`
SizeOpen *Size `xorm:"extends('Open')"`
SizeClosed *Size `xorm:"extends('Closed')"`
Size *Size `xorm:"extends('')"`
}
func TestExtends5(t *testing.T) {
assert.NoError(t, prepareEngine())
err := testEngine.DropTables(&Book{}, &Size{})
if err != nil {
t.Error(err)
panic(err)
}
err = testEngine.CreateTables(&Size{}, &Book{})
if err != nil {
t.Error(err)
panic(err)
}
var sc = Size{Width: 0.2, Height: 0.4}
var so = Size{Width: 0.2, Height: 0.8}
var s = Size{Width: 0.15, Height: 1.5}
var bk1 = Book{
SizeOpen: &so,
SizeClosed: &sc,
Size: &s,
}
var bk2 = Book{
SizeOpen: &so,
}
var bk3 = Book{
SizeClosed: &sc,
Size: &s,
}
var bk4 = Book{}
var bk5 = Book{Size: &s}
_, err = testEngine.Insert(&sc, &so, &s, &bk1, &bk2, &bk3, &bk4, &bk5)
if err != nil {
t.Fatal(err)
}
var books = map[int64]Book{
bk1.ID: bk1,
bk2.ID: bk2,
bk3.ID: bk3,
bk4.ID: bk4,
bk5.ID: bk5,
}
session := testEngine.NewSession()
defer session.Close()
var mapper = testEngine.GetTableMapper().Obj2Table
var quote = testEngine.Quote
bookTableName := quote(testEngine.TableName(mapper("Book"), true))
sizeTableName := quote(testEngine.TableName(mapper("Size"), true))
list := make([]Book, 0)
err = session.
Select(fmt.Sprintf(
"%s.%s, sc.%s AS %s, sc.%s AS %s, s.%s, s.%s",
quote(bookTableName),
quote("id"),
quote("Width"),
quote("ClosedWidth"),
quote("Height"),
quote("ClosedHeight"),
quote("Width"),
quote("Height"),
)).
Table(bookTableName).
Join(
"LEFT",
sizeTableName+" AS `sc`",
bookTableName+".`SizeClosed`=sc.`id`",
).
Join(
"LEFT",
sizeTableName+" AS `s`",
bookTableName+".`Size`=s.`id`",
).
Find(&list)
if err != nil {
t.Error(err)
panic(err)
}
for _, book := range list {
if ok := assert.Equal(t, books[book.ID].SizeClosed.Width, book.SizeClosed.Width); !ok {
t.Error("Not bounded size closed")
panic("Not bounded size closed")
}
if ok := assert.Equal(t, books[book.ID].SizeClosed.Height, book.SizeClosed.Height); !ok {
t.Error("Not bounded size closed")
panic("Not bounded size closed")
}
if books[book.ID].Size != nil || book.Size != nil {
if ok := assert.Equal(t, books[book.ID].Size.Width, book.Size.Width); !ok {
t.Error("Not bounded size")
panic("Not bounded size")
}
if ok := assert.Equal(t, books[book.ID].Size.Height, book.Size.Height); !ok {
t.Error("Not bounded size")
panic("Not bounded size")
}
}
}
}

View File

@ -7,7 +7,7 @@ package xorm
import ( import (
"testing" "testing"
"github.com/go-xorm/core" "xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )

View File

@ -11,8 +11,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-xorm/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core"
) )
type UserCU struct { type UserCU struct {

View File

@ -85,7 +85,7 @@ func TestVersion1(t *testing.T) {
} }
fmt.Println(newVer) fmt.Println(newVer)
if newVer.Ver != 2 { if newVer.Ver != 2 {
err = errors.New("insert error") err = errors.New("update error")
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
@ -126,3 +126,117 @@ func TestVersion2(t *testing.T) {
} }
} }
} }
type VersionUintS struct {
Id int64
Name string
Ver uint `xorm:"version"`
Created time.Time `xorm:"created"`
}
func TestVersion3(t *testing.T) {
assert.NoError(t, prepareEngine())
err := testEngine.DropTables(new(VersionUintS))
if err != nil {
t.Error(err)
panic(err)
}
err = testEngine.CreateTables(new(VersionUintS))
if err != nil {
t.Error(err)
panic(err)
}
ver := &VersionUintS{Name: "sfsfdsfds"}
_, err = testEngine.Insert(ver)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(ver)
if ver.Ver != 1 {
err = errors.New("insert error")
t.Error(err)
panic(err)
}
newVer := new(VersionUintS)
has, err := testEngine.ID(ver.Id).Get(newVer)
if err != nil {
t.Error(err)
panic(err)
}
if !has {
t.Error(errors.New(fmt.Sprintf("no version id is %v", ver.Id)))
panic(err)
}
fmt.Println(newVer)
if newVer.Ver != 1 {
err = errors.New("insert error")
t.Error(err)
panic(err)
}
newVer.Name = "-------"
_, err = testEngine.ID(ver.Id).Update(newVer)
if err != nil {
t.Error(err)
panic(err)
}
if newVer.Ver != 2 {
err = errors.New("update should set version back to struct")
t.Error(err)
}
newVer = new(VersionUintS)
has, err = testEngine.ID(ver.Id).Get(newVer)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(newVer)
if newVer.Ver != 2 {
err = errors.New("update error")
t.Error(err)
panic(err)
}
}
func TestVersion4(t *testing.T) {
assert.NoError(t, prepareEngine())
err := testEngine.DropTables(new(VersionUintS))
if err != nil {
t.Error(err)
panic(err)
}
err = testEngine.CreateTables(new(VersionUintS))
if err != nil {
t.Error(err)
panic(err)
}
var vers = []VersionUintS{
{Name: "sfsfdsfds"},
{Name: "xxxxx"},
}
_, err = testEngine.Insert(vers)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(vers)
for _, v := range vers {
if v.Ver != 1 {
err := errors.New("version should be 1")
t.Error(err)
panic(err)
}
}
}

View File

@ -1 +1 @@
go test -db=mssql -conn_str="server=192.168.1.58;user id=sa;password=123456;database=xorm_test" go test -db=mssql -conn_str="server=localhost;user id=sa;password=yourStrong(!)Password;database=xorm_test"

1
test_tidb.sh Executable file
View File

@ -0,0 +1 @@
go test -db=mysql -conn_str="root:@tcp(localhost:4000)/xorm_test" -ignore_select_update=true

View File

@ -1,9 +1,13 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm package xorm
import ( import (
"reflect" "reflect"
"github.com/go-xorm/core" "xorm.io/core"
) )
var ( var (

View File

@ -5,12 +5,11 @@
package xorm package xorm
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"testing" "testing"
"github.com/go-xorm/core" "xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -117,21 +116,21 @@ type ConvConfig struct {
} }
func (s *ConvConfig) FromDB(data []byte) error { func (s *ConvConfig) FromDB(data []byte) error {
return json.Unmarshal(data, s) return DefaultJSONHandler.Unmarshal(data, s)
} }
func (s *ConvConfig) ToDB() ([]byte, error) { func (s *ConvConfig) ToDB() ([]byte, error) {
return json.Marshal(s) return DefaultJSONHandler.Marshal(s)
} }
type SliceType []*ConvConfig type SliceType []*ConvConfig
func (s *SliceType) FromDB(data []byte) error { func (s *SliceType) FromDB(data []byte) error {
return json.Unmarshal(data, s) return DefaultJSONHandler.Unmarshal(data, s)
} }
func (s *SliceType) ToDB() ([]byte, error) { func (s *SliceType) ToDB() ([]byte, error) {
return json.Marshal(s) return DefaultJSONHandler.Marshal(s)
} }
type ConvStruct struct { type ConvStruct struct {
@ -309,16 +308,24 @@ func TestCustomType2(t *testing.T) {
_, err = testEngine.Exec("delete from " + testEngine.Quote(tableName)) _, err = testEngine.Exec("delete from " + testEngine.Quote(tableName))
assert.NoError(t, err) assert.NoError(t, err)
session := testEngine.NewSession()
defer session.Close()
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == core.MSSQL {
return err = session.Begin()
/*_, err = engine.Exec("set IDENTITY_INSERT " + tableName + " on") assert.NoError(t, err)
if err != nil { _, err = session.Exec("set IDENTITY_INSERT " + tableName + " on")
t.Fatal(err) assert.NoError(t, err)
}*/
} }
_, err = testEngine.Insert(&UserCus{1, "xlw", Registed}) cnt, err := session.Insert(&UserCus{1, "xlw", Registed})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == core.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}
user := UserCus{} user := UserCus{}
exist, err := testEngine.ID(1).Get(&user) exist, err := testEngine.ID(1).Get(&user)

View File

@ -7,6 +7,7 @@
package xorm package xorm
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
@ -14,7 +15,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/go-xorm/core" "xorm.io/core"
) )
const ( const (
@ -93,6 +94,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
TZLocation: time.Local, TZLocation: time.Local,
tagHandlers: defaultTagHandlers, tagHandlers: defaultTagHandlers,
cachers: make(map[string]core.Cacher), cachers: make(map[string]core.Cacher),
defaultContext: context.Background(),
} }
if uri.DbType == core.SQLITE { if uri.DbType == core.SQLITE {

View File

@ -1,15 +1,21 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm package xorm
import ( import (
"database/sql"
"flag" "flag"
"fmt" "fmt"
"log"
"os" "os"
"strings" "strings"
"testing" "testing"
_ "github.com/denisenkom/go-mssqldb" _ "github.com/denisenkom/go-mssqldb"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/go-xorm/core" "xorm.io/core"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
@ -28,6 +34,7 @@ var (
cluster = flag.Bool("cluster", false, "if this is a cluster") cluster = flag.Bool("cluster", false, "if this is a cluster")
splitter = flag.String("splitter", ";", "the splitter on connstr for cluster") splitter = flag.String("splitter", ";", "the splitter on connstr for cluster")
schema = flag.String("schema", "", "specify the schema") schema = flag.String("schema", "", "specify the schema")
ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb")
) )
func createEngine(dbType, connStr string) error { func createEngine(dbType, connStr string) error {
@ -35,9 +42,59 @@ func createEngine(dbType, connStr string) error {
var err error var err error
if !*cluster { if !*cluster {
switch strings.ToLower(dbType) {
case core.MSSQL:
db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1))
if err != nil {
return err
}
if _, err = db.Exec("If(db_id(N'xorm_test') IS NULL) BEGIN CREATE DATABASE xorm_test; END;"); err != nil {
return fmt.Errorf("db.Exec: %v", err)
}
db.Close()
*ignoreSelectUpdate = true
case core.POSTGRES:
db, err := sql.Open(dbType, connStr)
if err != nil {
return err
}
rows, err := db.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = 'xorm_test'"))
if err != nil {
return fmt.Errorf("db.Query: %v", err)
}
defer rows.Close()
if !rows.Next() {
if _, err = db.Exec("CREATE DATABASE xorm_test"); err != nil {
return fmt.Errorf("CREATE DATABASE: %v", err)
}
}
if *schema != "" {
if _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS " + *schema); err != nil {
return fmt.Errorf("CREATE SCHEMA: %v", err)
}
}
db.Close()
*ignoreSelectUpdate = true
case core.MYSQL:
db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "mysql", -1))
if err != nil {
return err
}
if _, err = db.Exec("CREATE DATABASE IF NOT EXISTS xorm_test"); err != nil {
return fmt.Errorf("db.Exec: %v", err)
}
db.Close()
default:
*ignoreSelectUpdate = true
}
testEngine, err = NewEngine(dbType, connStr) testEngine, err = NewEngine(dbType, connStr)
} else { } else {
testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter)) testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter))
if dbType != "mysql" && dbType != "mymysql" {
*ignoreSelectUpdate = true
}
} }
if err != nil { if err != nil {
return err return err
@ -95,7 +152,7 @@ func TestMain(m *testing.M) {
} }
} else { } else {
if ptrConnStr == nil { if ptrConnStr == nil {
fmt.Println("you should indicate conn string") log.Fatal("you should indicate conn string")
return return
} }
connString = *ptrConnStr connString = *ptrConnStr
@ -112,7 +169,7 @@ func TestMain(m *testing.M) {
fmt.Println("testing", dbType, connString) fmt.Println("testing", dbType, connString)
if err := prepareEngine(); err != nil { if err := prepareEngine(); err != nil {
fmt.Println(err) log.Fatal(err)
return return
} }