Merge branch 'master' into optimize-slice2Bean

This commit is contained in:
Lunny Xiao 2023-07-12 08:09:58 +00:00
commit 40a3e27af2
48 changed files with 1642 additions and 1095 deletions

View File

@ -1,437 +0,0 @@
---
kind: pipeline
name: test-mysql
environment:
GO111MODULE: "on"
GOPROXY: "https://goproxy.io"
CGO_ENABLED: 1
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-vet
image: golang:1.17
pull: always
volumes:
- name: cache
path: /go/pkg/mod
commands:
- make vet
- name: test-sqlite3
image: golang:1.17
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-vet
commands:
- make fmt-check
- make test
- make test-sqlite3
- TEST_CACHE_ENABLE=true make test-sqlite3
- name: test-sqlite
image: golang:1.17
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-vet
commands:
- make test-sqlite
- TEST_QUOTE_POLICY=reserved make test-sqlite
- name: test-mysql
image: golang:1.17
pull: never
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-vet
environment:
TEST_MYSQL_HOST: mysql
TEST_MYSQL_CHARSET: utf8
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
commands:
- TEST_CACHE_ENABLE=true make test-mysql
- name: test-mysql-utf8mb4
image: golang:1.17
pull: never
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-mysql
environment:
TEST_MYSQL_HOST: mysql
TEST_MYSQL_CHARSET: utf8mb4
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
commands:
- make test-mysql
- TEST_QUOTE_POLICY=reserved make test-mysql-tls
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: mysql
image: mysql:5.7
environment:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: xorm_test
---
kind: pipeline
name: test-mysql8
depends_on:
- test-mysql
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-mysql8
image: golang:1.17
pull: never
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_MYSQL_HOST: mysql8
TEST_MYSQL_CHARSET: utf8mb4
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
commands:
- make test-mysql
- TEST_CACHE_ENABLE=true make test-mysql
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: mysql8
image: mysql:8.0
environment:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: xorm_test
---
kind: pipeline
name: test-mariadb
depends_on:
- test-mysql8
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-mariadb
image: golang:1.17
pull: never
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_MYSQL_HOST: mariadb
TEST_MYSQL_CHARSET: utf8mb4
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
commands:
- make test-mysql
- TEST_QUOTE_POLICY=reserved make test-mysql
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: mariadb
image: mariadb:10.4
environment:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: xorm_test
---
kind: pipeline
name: test-postgres
depends_on:
- test-mariadb
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-postgres
pull: never
image: golang:1.17
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
commands:
- make test-postgres
- TEST_CACHE_ENABLE=true make test-postgres
- name: test-postgres-schema
pull: never
image: golang:1.17
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-postgres
environment:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_SCHEMA: xorm
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
commands:
- TEST_QUOTE_POLICY=reserved make test-postgres
- name: test-pgx
pull: never
image: golang:1.17
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-postgres-schema
environment:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
commands:
- make test-pgx
- TEST_CACHE_ENABLE=true make test-pgx
- TEST_QUOTE_POLICY=reserved make test-pgx
- name: test-pgx-schema
pull: never
image: golang:1.17
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-pgx
environment:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_SCHEMA: xorm
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
commands:
- make test-pgx
- TEST_CACHE_ENABLE=true make test-pgx
- TEST_QUOTE_POLICY=reserved make test-pgx
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: pgsql
image: postgres:9.5
environment:
POSTGRES_DB: xorm_test
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
---
kind: pipeline
name: test-mssql
depends_on:
- test-postgres
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-mssql
pull: never
image: golang:1.17
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_MSSQL_HOST: mssql
TEST_MSSQL_DBNAME: xorm_test
TEST_MSSQL_USERNAME: sa
TEST_MSSQL_PASSWORD: "yourStrong(!)Password"
commands:
- make test-mssql
- TEST_MSSQL_DEFAULT_VARCHAR=NVARCHAR TEST_MSSQL_DEFAULT_CHAR=NCHAR make test-mssql
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: mssql
pull: always
image: mcr.microsoft.com/mssql/server:latest
environment:
ACCEPT_EULA: Y
SA_PASSWORD: yourStrong(!)Password
MSSQL_PID: Standard
---
kind: pipeline
name: test-tidb
depends_on:
- test-mssql
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-tidb
pull: never
image: golang:1.17
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_TIDB_HOST: "tidb:4000"
TEST_TIDB_DBNAME: xorm_test
TEST_TIDB_USERNAME: root
TEST_TIDB_PASSWORD:
commands:
- make test-tidb
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: tidb
image: pingcap/tidb:v3.0.3
---
kind: pipeline
name: test-cockroach
depends_on:
- test-tidb
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-cockroach
pull: never
image: golang:1.17
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_COCKROACH_HOST: "cockroach:26257"
TEST_COCKROACH_DBNAME: xorm_test
TEST_COCKROACH_USERNAME: root
TEST_COCKROACH_PASSWORD:
commands:
- sleep 10
- make test-cockroach
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: cockroach
image: cockroachdb/cockroach:v19.2.4
commands:
- /cockroach/cockroach start --insecure
# ---
# kind: pipeline
# name: test-dameng
# depends_on:
# - test-cockroach
# trigger:
# ref:
# - refs/heads/master
# - refs/pull/*/head
# steps:
# - name: test-dameng
# pull: never
# image: golang:1.17
# volumes:
# - name: cache
# path: /go/pkg/mod
# environment:
# TEST_DAMENG_HOST: "dameng:5236"
# TEST_DAMENG_USERNAME: SYSDBA
# TEST_DAMENG_PASSWORD: SYSDBA
# commands:
# - sleep 30
# - make test-dameng
# volumes:
# - name: cache
# host:
# path: /tmp/cache
# services:
# - name: dameng
# image: lunny/dm:v1.0
# commands:
# - /bin/bash /startDm.sh
---
kind: pipeline
name: merge_coverage
depends_on:
- test-mysql
- test-mysql8
- test-mariadb
- test-postgres
- test-mssql
- test-tidb
- test-cockroach
#- test-dameng
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: merge_coverage
image: golang:1.17
commands:
- make coverage
---
kind: pipeline
name: release-tag
trigger:
event:
- tag
steps:
- name: release-tag-gitea
pull: always
image: plugins/gitea-release:latest
settings:
base_url: https://gitea.com
title: '${DRONE_TAG} is released'
api_key:
from_secret: gitea_token

View File

@ -0,0 +1,23 @@
name: release
on:
push:
tags:
- '*'
jobs:
release:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: setup go
uses: https://github.com/actions/setup-go@v4
with:
go-version: '>=1.20.1'
- name: Use Go Action
id: use-go-action
uses: actions/release-action@main
with:
api_key: '${{secrets.RELEASE_TOKEN}}'

View File

@ -0,0 +1,55 @@
name: test cockroach
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
test-cockroach:
name: test cockroach
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test cockroach
env:
TEST_COCKROACH_HOST: "cockroach:26257"
TEST_COCKROACH_DBNAME: xorm_test
TEST_COCKROACH_USERNAME: root
TEST_COCKROACH_PASSWORD:
run: sleep 20 && make test-cockroach
services:
cockroach:
image: cockroachdb/cockroach:v19.2.4
ports:
- 26257:26257
cmd:
- 'start'
- '--insecure'

View File

@ -0,0 +1,56 @@
name: test mariadb
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
lint:
name: test mariadb
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test mariadb
env:
TEST_MYSQL_HOST: mariadb
TEST_MYSQL_CHARSET: utf8mb4
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
run: TEST_QUOTE_POLICY=reserved make test-mysql
services:
mariadb:
image: mariadb:10.4
env:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: xorm_test
ports:
- 3306:3306

View File

@ -0,0 +1,56 @@
name: test mssql
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
test-mssql:
name: test mssql
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test mssql
env:
TEST_MSSQL_HOST: mssql
TEST_MSSQL_DBNAME: xorm_test
TEST_MSSQL_USERNAME: sa
TEST_MSSQL_PASSWORD: "yourStrong(!)Password"
run: TEST_MSSQL_DEFAULT_VARCHAR=NVARCHAR TEST_MSSQL_DEFAULT_CHAR=NCHAR make test-mssql
services:
mssql:
image: mcr.microsoft.com/mssql/server:latest
env:
ACCEPT_EULA: Y
SA_PASSWORD: yourStrong(!)Password
MSSQL_PID: Standard
ports:
- 1433:1433

View File

@ -0,0 +1,56 @@
name: test mysql
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
test-mysql:
name: test mysql
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test mysql utf8mb4
env:
TEST_MYSQL_HOST: mysql
TEST_MYSQL_CHARSET: utf8mb4
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
run: TEST_QUOTE_POLICY=reserved make test-mysql-tls
services:
mysql:
image: mysql:5.7
env:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: xorm_test
ports:
- 3306:3306

View File

@ -0,0 +1,56 @@
name: test mysql8
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
lint:
name: test mysql8
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test mysql8
env:
TEST_MYSQL_HOST: mysql8
TEST_MYSQL_CHARSET: utf8mb4
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
run: TEST_CACHE_ENABLE=true make test-mysql
services:
mysql8:
image: mysql:8.0
env:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: xorm_test
ports:
- 3306:3306

View File

@ -0,0 +1,79 @@
name: test postgres
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
lint:
name: test postgres
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test postgres
env:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
run: TEST_CACHE_ENABLE=true make test-postgres
- name: test postgres with schema
env:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_SCHEMA: xorm
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
run: TEST_QUOTE_POLICY=reserved make test-postgres
- name: test pgx
env:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
run: TEST_CACHE_ENABLE=true make test-pgx
- name: test pgx with schema
env:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_SCHEMA: xorm
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
run: TEST_QUOTE_POLICY=reserved make test-pgx
services:
pgsql:
image: postgres:9.5
env:
POSTGRES_DB: xorm_test
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
ports:
- 5432:5432

View File

@ -0,0 +1,49 @@
name: test sqlite
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
test-sqlite:
name: unit test & test sqlite
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: vet
run: make vet
- name: format check
run: make fmt-check
- name: unit test
run: make test
- name: test sqlite3
run: make test-sqlite3
- name: test sqlite3 with cache
run: TEST_CACHE_ENABLE=true make test-sqlite3

View File

@ -0,0 +1,52 @@
name: test tidb
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
test-tidb:
name: test tidb
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test tidb
env:
TEST_TIDB_HOST: "tidb:4000"
TEST_TIDB_DBNAME: xorm_test
TEST_TIDB_USERNAME: root
TEST_TIDB_PASSWORD:
run: make test-tidb
services:
tidb:
image: pingcap/tidb:v3.0.3
ports:
- 4000:4000

View File

@ -24,7 +24,10 @@ func Interface2Interface(userLocation *time.Location, v interface{}) (interface{
return vv.String, nil return vv.String, nil
case *sql.RawBytes: case *sql.RawBytes:
if len([]byte(*vv)) > 0 { if len([]byte(*vv)) > 0 {
return []byte(*vv), nil src := []byte(*vv)
dest := make([]byte, len(src))
copy(dest, src)
return dest, nil
} }
return nil, nil return nil, nil
case *sql.NullInt32: case *sql.NullInt32:

View File

@ -659,7 +659,7 @@ func (db *dameng) DropTableSQL(tableName string) (string, bool) {
// ModifyColumnSQL returns a SQL to modify SQL // ModifyColumnSQL returns a SQL to modify SQL
func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false) s, _ := ColumnString(db.dialect, col, false, false)
return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s)
} }
@ -692,7 +692,7 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
} }
} }
s, _ := ColumnString(db, col, false) s, _ := ColumnString(db, col, false, false)
if _, err := b.WriteString(s); err != nil { if _, err := b.WriteString(s); err != nil {
return "", false, err return "", false, err
} }
@ -709,7 +709,13 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
return "", false, err return "", false, err
} }
} }
if _, err := b.WriteString(fmt.Sprintf("CONSTRAINT PK_%s PRIMARY KEY (", tableName)); err != nil { if _, err := b.WriteString("CONSTRAINT PK_"); err != nil {
return "", false, err
}
if _, err := b.WriteString(tableName); err != nil {
return "", false, err
}
if _, err := b.WriteString(" PRIMARY KEY ("); err != nil {
return "", false, err return "", false, err
} }
if err := quoter.JoinWrite(&b, pkList, ","); err != nil { if err := quoter.JoinWrite(&b, pkList, ","); err != nil {
@ -837,7 +843,11 @@ func addSingleQuote(name string) string {
if name[0] == '\'' && name[len(name)-1] == '\'' { if name[0] == '\'' && name[len(name)-1] == '\'' {
return name return name
} }
return fmt.Sprintf("'%s'", name) var b strings.Builder
b.WriteRune('\'')
b.WriteString(name)
b.WriteRune('\'')
return b.String()
} }
func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {

View File

@ -85,8 +85,6 @@ type Dialect interface {
AddColumnSQL(tableName string, col *schemas.Column) string AddColumnSQL(tableName string, col *schemas.Column) string
ModifyColumnSQL(tableName string, col *schemas.Column) string ModifyColumnSQL(tableName string, col *schemas.Column) string
ForUpdateSQL(query string) string
Filters() []Filter Filters() []Filter
SetParams(params map[string]string) SetParams(params map[string]string)
} }
@ -135,7 +133,7 @@ func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
for i, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, false)
b.WriteString(s) b.WriteString(s)
if i != len(table.ColumnsSeq())-1 { if i != len(table.ColumnsSeq())-1 {
@ -209,7 +207,7 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa
// AddColumnSQL returns a SQL to add a column // AddColumnSQL returns a SQL to add a column
func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, true) s, _ := ColumnString(db.dialect, col, true, false)
return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s)
} }
@ -241,22 +239,15 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
// ModifyColumnSQL returns a SQL to modify SQL // ModifyColumnSQL returns a SQL to modify SQL
func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false) s, _ := ColumnString(db.dialect, col, false, false)
return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s)
} }
// ForUpdateSQL returns for updateSQL
func (db *Base) ForUpdateSQL(query string) string {
return query + " FOR UPDATE"
}
// SetParams set params // SetParams set params
func (db *Base) SetParams(params map[string]string) { func (db *Base) SetParams(params map[string]string) {
} }
var ( var dialects = map[string]func() Dialect{}
dialects = map[string]func() Dialect{}
)
// RegisterDialect register database dialect // RegisterDialect register database dialect
func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) { func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) {
@ -307,7 +298,7 @@ func init() {
} }
// ColumnString generate column description string according dialect // ColumnString generate column description string according dialect
func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) (string, error) { func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey, supportCollation bool) (string, error) {
bd := strings.Builder{} bd := strings.Builder{}
if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil { if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil {
@ -322,6 +313,15 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool)
return "", err return "", err
} }
if supportCollation && col.Collation != "" {
if _, err := bd.WriteString(" COLLATE "); err != nil {
return "", err
}
if _, err := bd.WriteString(col.Collation); err != nil {
return "", err
}
}
if includePrimaryKey && col.IsPrimaryKey { if includePrimaryKey && col.IsPrimaryKey {
if _, err := bd.WriteString(" PRIMARY KEY"); err != nil { if _, err := bd.WriteString(" PRIMARY KEY"); err != nil {
return "", err return "", err

View File

@ -5,13 +5,14 @@
package dialects package dialects
import ( import (
"fmt" "context"
"strconv"
"strings" "strings"
) )
// Filter is an interface to filter SQL // Filter is an interface to filter SQL
type Filter interface { type Filter interface {
Do(sql string) string Do(ctx context.Context, sql string) string
} }
// SeqFilter filter SQL replace ?, ? ... to $1, $2 ... // SeqFilter filter SQL replace ?, ? ... to $1, $2 ...
@ -28,10 +29,11 @@ func convertQuestionMark(sql, prefix string, start int) string {
var isMaybeLineComment bool var isMaybeLineComment bool
var isMaybeComment bool var isMaybeComment bool
var isMaybeCommentEnd bool var isMaybeCommentEnd bool
var index = start index := start
for _, c := range sql { for _, c := range sql {
if !beginSingleQuote && !isLineComment && !isComment && c == '?' { if !beginSingleQuote && !isLineComment && !isComment && c == '?' {
buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) buf.WriteString(prefix)
buf.WriteString(strconv.Itoa(index))
index++ index++
} else { } else {
if isMaybeLineComment { if isMaybeLineComment {
@ -71,6 +73,6 @@ func convertQuestionMark(sql, prefix string, start int) string {
} }
// Do implements Filter // Do implements Filter
func (s *SeqFilter) Do(sql string) string { func (s *SeqFilter) Do(ctx context.Context, sql string) string {
return convertQuestionMark(sql, s.Prefix, s.Start) return convertQuestionMark(sql, s.Prefix, s.Start)
} }

View File

@ -428,7 +428,7 @@ func (db *mssql) DropTableSQL(tableName string) (string, bool) {
} }
func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false) s, _ := ColumnString(db.dialect, col, false, true)
return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s)
} }
@ -454,7 +454,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable,
"default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END),
replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, replace(replace(isnull(c.text,''),'(',''),')','') as vdefault,
ISNULL(p.is_primary_key, 0), a.is_identity as is_identity ISNULL(p.is_primary_key, 0), a.is_identity as is_identity, a.collation_name
from sys.columns a from sys.columns a
left join sys.types b on a.user_type_id=b.user_type_id left join sys.types b on a.user_type_id=b.user_type_id
left join sys.syscomments c on a.default_object_id=c.id left join sys.syscomments c on a.default_object_id=c.id
@ -475,9 +475,10 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
var name, ctype, vdefault string var name, ctype, vdefault string
var collation *string
var maxLen, precision, scale int64 var maxLen, precision, scale int64
var nullable, isPK, defaultIsNull, isIncrement bool var nullable, isPK, defaultIsNull, isIncrement bool
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement) err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement, &collation)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -499,6 +500,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} else { } else {
col.Length = maxLen col.Length = maxLen
} }
if collation != nil {
col.Collation = *collation
}
switch ct { switch ct {
case "DATETIMEOFFSET": case "DATETIMEOFFSET":
col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
@ -646,7 +650,7 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
for i, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true)
b.WriteString(s) b.WriteString(s)
if i != len(table.ColumnsSeq())-1 { if i != len(table.ColumnsSeq())-1 {
@ -665,10 +669,6 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
return b.String(), true, nil return b.String(), true, nil
} }
func (db *mssql) ForUpdateSQL(query string) string {
return query
}
func (db *mssql) Filters() []Filter { func (db *mssql) Filters() []Filter {
return []Filter{} return []Filter{}
} }

View File

@ -380,12 +380,27 @@ func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableNa
func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
quoter := db.dialect.Quoter() quoter := db.dialect.Quoter()
s, _ := ColumnString(db, col, true) s, _ := ColumnString(db, col, true, true)
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), s) var b strings.Builder
b.WriteString("ALTER TABLE ")
quoter.QuoteTo(&b, tableName)
b.WriteString(" ADD ")
b.WriteString(s)
if len(col.Comment) > 0 { if len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'" b.WriteString(" COMMENT '")
b.WriteString(col.Comment)
b.WriteString("'")
} }
return sql return b.String()
}
// ModifyColumnSQL returns a SQL to modify SQL
func (db *mysql) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false, true)
if col.Comment != "" {
s += fmt.Sprintf(" COMMENT '%s'", col.Comment)
}
return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s)
} }
func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
@ -398,7 +413,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
"SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))" "SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))"
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " + " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " +
alreadyQuoted + " AS NEEDS_QUOTE " + alreadyQuoted + " AS NEEDS_QUOTE, `COLLATION_NAME` " +
"FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" +
" ORDER BY `COLUMNS`.ORDINAL_POSITION ASC" " ORDER BY `COLUMNS`.ORDINAL_POSITION ASC"
@ -416,8 +431,8 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
var columnName, nullableStr, colType, colKey, extra, comment string var columnName, nullableStr, colType, colKey, extra, comment string
var alreadyQuoted, isUnsigned bool var alreadyQuoted, isUnsigned bool
var colDefault, maxLength *string var colDefault, maxLength, collation *string
err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted) err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted, &collation)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -433,6 +448,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} else { } else {
col.DefaultIsEmpty = true col.DefaultIsEmpty = true
} }
if collation != nil {
col.Collation = *collation
}
fields := strings.Fields(colType) fields := strings.Fields(colType)
if len(fields) == 2 && fields[1] == "unsigned" { if len(fields) == 2 && fields[1] == "unsigned" {
@ -525,7 +543,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{db.uri.DBName} args := []interface{}{db.uri.DBName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT`, `TABLE_COLLATION` from " +
"`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')"
rows, err := queryer.QueryContext(ctx, s, args...) rows, err := queryer.QueryContext(ctx, s, args...)
@ -537,9 +555,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
tables := make([]*schemas.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
table := schemas.NewEmptyTable() table := schemas.NewEmptyTable()
var name, engine string var name, engine, collation string
var autoIncr, comment *string var autoIncr, comment *string
err = rows.Scan(&name, &engine, &autoIncr, &comment) err = rows.Scan(&name, &engine, &autoIncr, &comment, &collation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -549,6 +567,7 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
table.Comment = *comment table.Comment = *comment
} }
table.StoreEngine = engine table.StoreEngine = engine
table.Collation = collation
tables = append(tables, table) tables = append(tables, table)
} }
if rows.Err() != nil { if rows.Err() != nil {
@ -640,7 +659,7 @@ func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
for i, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true)
b.WriteString(s) b.WriteString(s)
if len(col.Comment) > 0 { if len(col.Comment) > 0 {

View File

@ -609,7 +609,7 @@ func (db *oracle) IsReserved(name string) bool {
} }
func (db *oracle) DropTableSQL(tableName string) (string, bool) { func (db *oracle) DropTableSQL(tableName string) (string, bool) {
return fmt.Sprintf("DROP TABLE `%s`", tableName), false return fmt.Sprintf("DROP TABLE \"%s\"", tableName), false
} }
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
@ -628,7 +628,7 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
/*if col.IsPrimaryKey && len(pkList) == 1 { /*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect) sql += col.String(b.dialect)
} else {*/ } else {*/
s, _ := ColumnString(db, col, false) s, _ := ColumnString(db, col, false, false)
sql += s sql += s
// } // }
sql = strings.TrimSpace(sql) sql = strings.TrimSpace(sql)
@ -645,6 +645,10 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
return sql, false, nil return sql, false, nil
} }
func (db *oracle) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) {
return db.HasRecords(queryer, ctx, `SELECT sequence_name FROM user_sequences WHERE sequence_name = :1`, seqName)
}
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy { switch quotePolicy {
case QuotePolicyNone: case QuotePolicyNone:

View File

@ -992,7 +992,7 @@ func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tabl
} }
func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, true) s, _ := ColumnString(db.dialect, col, true, false)
quoter := db.dialect.Quoter() quoter := db.dialect.Quoter()
addColumnSQL := "" addColumnSQL := ""
@ -1078,7 +1078,7 @@ FROM pg_attribute f
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN pg_class AS g ON p.confrelid = g.oid
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` WHERE n.nspname= s.table_schema AND c.relkind = 'r' AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
schema := db.getSchema() schema := db.getSchema()
if schema != "" { if schema != "" {

View File

@ -193,11 +193,11 @@ func (db *sqlite3) Features() *DialectFeatures {
func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy { switch quotePolicy {
case QuotePolicyNone: case QuotePolicyNone:
var q = sqlite3Quoter q := sqlite3Quoter
q.IsReserved = schemas.AlwaysNoReserve q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q db.quoter = q
case QuotePolicyReserved: case QuotePolicyReserved:
var q = sqlite3Quoter q := sqlite3Quoter
q.IsReserved = db.IsReserved q.IsReserved = db.IsReserved
db.quoter = q db.quoter = q
case QuotePolicyAlways: case QuotePolicyAlways:
@ -291,10 +291,6 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName))
} }
func (db *sqlite3) ForUpdateSQL(query string) string {
return query
}
func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
query := "SELECT * FROM " + tableName + " LIMIT 0" query := "SELECT * FROM " + tableName + " LIMIT 0"
rows, err := queryer.QueryContext(ctx, query) rows, err := queryer.QueryContext(ctx, query)
@ -320,7 +316,7 @@ func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tabl
// splitColStr splits a sqlite col strings as fields // splitColStr splits a sqlite col strings as fields
func splitColStr(colStr string) []string { func splitColStr(colStr string) []string {
colStr = strings.TrimSpace(colStr) colStr = strings.TrimSpace(colStr)
var results = make([]string, 0, 10) results := make([]string, 0, 10)
var lastIdx int var lastIdx int
var hasC, hasQuote bool var hasC, hasQuote bool
for i, c := range colStr { for i, c := range colStr {

View File

@ -1120,21 +1120,6 @@ func (engine *Engine) UnMapType(t reflect.Type) {
engine.tagParser.ClearCacheTable(t) engine.tagParser.ClearCacheTable(t)
} }
// Sync the new struct changes to database, this method will automatically add
// table, column, index, unique. but will not delete or change anything.
// If you change some field, you should change the database manually.
func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
defer session.Close()
return session.Sync(beans...)
}
// Sync2 synchronize structs to database tables
// Depricated
func (engine *Engine) Sync2(beans ...interface{}) error {
return engine.Sync(beans...)
}
// CreateTables create tabls according bean // CreateTables create tabls according bean
func (engine *Engine) CreateTables(beans ...interface{}) error { func (engine *Engine) CreateTables(beans ...interface{}) error {
session := engine.NewSession() session := engine.NewSession()

View File

@ -289,6 +289,48 @@ func TestGetColumnsComment(t *testing.T) {
assert.Zero(t, noComment) assert.Zero(t, noComment)
} }
type TestCommentUpdate struct {
HasComment int `xorm:"bigint comment('this is a comment before update')"`
}
func (m *TestCommentUpdate) TableName() string {
return "test_comment_struct"
}
type TestCommentUpdate2 struct {
HasComment int `xorm:"bigint comment('this is a comment after update')"`
}
func (m *TestCommentUpdate2) TableName() string {
return "test_comment_struct"
}
func TestColumnCommentUpdate(t *testing.T) {
comment := "this is a comment after update"
assertSync(t, new(TestCommentUpdate))
assert.NoError(t, testEngine.Sync2(new(TestCommentUpdate2))) // modify table column comment
switch testEngine.Dialect().URI().DBType {
case schemas.POSTGRES, schemas.MYSQL: // only postgres / mysql dialect implement the feature of modify comment in postgres.ModifyColumnSQL
default:
t.Skip()
return
}
tables, err := testEngine.DBMetas()
assert.NoError(t, err)
tableName := "test_comment_struct"
var hasComment string
for _, table := range tables {
if table.Name == tableName {
col := table.GetColumn(testEngine.GetColumnMapper().Obj2Table("HasComment"))
assert.NotNil(t, col)
hasComment = col.Comment
break
}
}
assert.Equal(t, comment, hasComment)
}
func TestGetColumnsLength(t *testing.T) { func TestGetColumnsLength(t *testing.T) {
var max_length int64 var max_length int64
switch testEngine.Dialect().URI().DBType { switch testEngine.Dialect().URI().DBType {

View File

@ -5,6 +5,7 @@
package integrations package integrations
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
@ -458,7 +459,7 @@ func TestSync2_2(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
var tableNames = make(map[string]bool) tableNames := make(map[string]bool)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
tableName := fmt.Sprintf("test_sync2_index_%d", i) tableName := fmt.Sprintf("test_sync2_index_%d", i)
tableNames[tableName] = true tableNames[tableName] = true
@ -536,3 +537,111 @@ func TestModifyColum(t *testing.T) {
_, err := testEngine.Exec(alterSQL) _, err := testEngine.Exec(alterSQL)
assert.NoError(t, err) assert.NoError(t, err)
} }
type TestCollateColumn struct {
Id int64
UserId int64 `xorm:"unique(s)"`
Name string `xorm:"varchar(20) unique(s)"`
dbtype string `xorm:"-"`
}
func (t TestCollateColumn) TableCollations() []*schemas.Collation {
if t.dbtype == string(schemas.MYSQL) {
return []*schemas.Collation{
{
Name: "utf8mb4_general_ci",
Column: "name",
},
}
} else if t.dbtype == string(schemas.MSSQL) {
return []*schemas.Collation{
{
Name: "Latin1_General_CI_AS",
Column: "name",
},
}
}
return nil
}
func TestCollate(t *testing.T) {
assert.NoError(t, PrepareEngine())
assertSync(t, &TestCollateColumn{
dbtype: string(testEngine.Dialect().URI().DBType),
})
_, err := testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "test",
})
assert.NoError(t, err)
_, err = testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "Test",
})
if testEngine.Dialect().URI().DBType == schemas.MYSQL {
ver, err1 := testEngine.DBVersion()
assert.NoError(t, err1)
tables, err1 := testEngine.DBMetas()
assert.NoError(t, err1)
for _, table := range tables {
if table.Name == "test_collate_column" {
col := table.GetColumn("name")
if col == nil {
assert.Error(t, errors.New("not found column"))
return
}
// tidb doesn't follow utf8mb4_general_ci
if col.Collation == "utf8mb4_general_ci" && ver.Edition != "TiDB" {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
break
}
}
} else if testEngine.Dialect().URI().DBType == schemas.MSSQL {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// Since SQLITE don't support modify column SQL, currrently just ignore
if testEngine.Dialect().URI().DBType != schemas.MYSQL && testEngine.Dialect().URI().DBType != schemas.MSSQL {
return
}
var newCollation string
if testEngine.Dialect().URI().DBType == schemas.MYSQL {
newCollation = "utf8mb4_bin"
} else if testEngine.Dialect().URI().DBType != schemas.MSSQL {
newCollation = "Latin1_General_CS_AS"
} else {
return
}
alterSQL := testEngine.Dialect().ModifyColumnSQL("test_collate_column", &schemas.Column{
Name: "name",
SQLType: schemas.SQLType{
Name: "VARCHAR",
},
Length: 20,
Nullable: true,
DefaultIsEmpty: true,
Collation: newCollation,
})
_, err = testEngine.Exec(alterSQL)
assert.NoError(t, err)
_, err = testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "test1",
})
assert.NoError(t, err)
_, err = testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "Test1",
})
assert.NoError(t, err)
}

View File

@ -89,7 +89,7 @@ func TestCountWithOthers(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
total, err := testEngine.OrderBy("`id` desc").Limit(1).Count(new(CountWithOthers)) total, err := testEngine.OrderBy("count(`id`) desc").Limit(1).Count(new(CountWithOthers))
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, total) assert.EqualValues(t, 2, total)
} }
@ -118,11 +118,11 @@ func TestWithTableName(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
total, err := testEngine.OrderBy("`id` desc").Count(new(CountWithTableName)) total, err := testEngine.OrderBy("count(`id`) desc").Count(new(CountWithTableName))
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, total) assert.EqualValues(t, 2, total)
total, err = testEngine.OrderBy("`id` desc").Count(CountWithTableName{}) total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, total) assert.EqualValues(t, 2, total)
} }

View File

@ -12,6 +12,7 @@ import (
"xorm.io/xorm" "xorm.io/xorm"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/names" "xorm.io/xorm/names"
"xorm.io/xorm/schemas"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -1196,3 +1197,43 @@ func TestUpdateFindDate(t *testing.T) {
assert.EqualValues(t, 1, len(tufs)) assert.EqualValues(t, 1, len(tufs))
assert.EqualValues(t, tuf.Tm.Format("2006-01-02"), tufs[0].Tm.Format("2006-01-02")) assert.EqualValues(t, tuf.Tm.Format("2006-01-02"), tufs[0].Tm.Format("2006-01-02"))
} }
func TestBuilderDialect(t *testing.T) {
assert.NoError(t, PrepareEngine())
type TestBuilderDialect struct {
Id int64
Name string `xorm:"index"`
Age2 int
}
type TestBuilderDialectFoo struct {
Id int64
DialectId int64 `xorm:"index"`
Age int
}
assertSync(t, new(TestBuilderDialect), new(TestBuilderDialectFoo))
session := testEngine.NewSession()
defer session.Close()
var dialect string
switch testEngine.Dialect().URI().DBType {
case schemas.MYSQL:
dialect = builder.MYSQL
case schemas.MSSQL:
dialect = builder.MSSQL
case schemas.POSTGRES:
dialect = builder.POSTGRES
case schemas.SQLITE:
dialect = builder.SQLITE
}
tbName := testEngine.TableName(new(TestBuilderDialectFoo), dialect == builder.POSTGRES)
inner := builder.Dialect(dialect).Select("*").From(tbName).Where(builder.Eq{"age": 20})
result := make([]*TestBuilderDialect, 0, 10)
err := testEngine.Table("test_builder_dialect").Where(builder.Eq{"age2": 2}).Join("INNER", inner, "test_builder_dialect_foo.dialect_id = test_builder_dialect.id").Find(&result)
assert.NoError(t, err)
}

View File

@ -5,11 +5,13 @@
package integrations package integrations
import ( import (
"bytes"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -381,3 +383,68 @@ func TestQueryStringWithLimit(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 0, len(data)) assert.EqualValues(t, 0, len(data))
} }
func TestQueryBLOBInMySQL(t *testing.T) {
assert.NoError(t, PrepareEngine())
var err error
type Avatar struct {
Id int64 `xorm:"autoincr pk"`
Avatar []byte `xorm:"BLOB"`
}
assert.NoError(t, testEngine.Sync(new(Avatar)))
testEngine.Delete(Avatar{})
repeatBytes := func(n int, b byte) []byte {
return bytes.Repeat([]byte{b}, n)
}
const N = 10
var data = []Avatar{}
for i := 0; i < N; i++ {
// allocate a []byte that is as twice big as the last one
// so that the underlying buffer will need to reallocate when querying
bs := repeatBytes(1<<(i+2), 'A'+byte(i))
data = append(data, Avatar{
Avatar: bs,
})
}
_, err = testEngine.Insert(data)
assert.NoError(t, err)
defer func() {
testEngine.Delete(Avatar{})
}()
{
records, err := testEngine.QueryInterface("select avatar from " + testEngine.Quote(testEngine.TableName("avatar", true)))
assert.NoError(t, err)
for i, record := range records {
bs := record["avatar"].([]byte)
assert.EqualValues(t, repeatBytes(1<<(i+2), 'A'+byte(i))[:3], bs[:3])
t.Logf("%d => %p => %02x %02x %02x", i, bs, bs[0], bs[1], bs[2])
}
}
{
arr := make([][]interface{}, 0)
err = testEngine.Table(testEngine.Quote(testEngine.TableName("avatar", true))).Cols("avatar").Find(&arr)
assert.NoError(t, err)
for i, record := range arr {
bs := record[0].([]byte)
assert.EqualValues(t, repeatBytes(1<<(i+2), 'A'+byte(i))[:3], bs[:3])
t.Logf("%d => %p => %02x %02x %02x", i, bs, bs[0], bs[1], bs[2])
}
}
{
arr := make([]map[string]interface{}, 0)
err = testEngine.Table(testEngine.Quote(testEngine.TableName("avatar", true))).Cols("avatar").Find(&arr)
assert.NoError(t, err)
for i, record := range arr {
bs := record["avatar"].([]byte)
assert.EqualValues(t, repeatBytes(1<<(i+2), 'A'+byte(i))[:3], bs[:3])
t.Logf("%d => %p => %02x %02x %02x", i, bs, bs[0], bs[1], bs[2])
}
}
}

View File

@ -185,3 +185,36 @@ func TestMultipleTransaction(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 0, len(ms)) assert.EqualValues(t, 0, len(ms))
} }
func TestInsertMulti2InterfaceTransaction(t *testing.T) {
type Multi2InterfaceTransaction struct {
ID uint64 `xorm:"id pk autoincr"`
Name string
Alias string
CreateTime time.Time `xorm:"created"`
UpdateTime time.Time `xorm:"updated"`
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(Multi2InterfaceTransaction))
session := testEngine.NewSession()
defer session.Close()
err := session.Begin()
assert.NoError(t, err)
users := []interface{}{
&Multi2InterfaceTransaction{Name: "a", Alias: "A"},
&Multi2InterfaceTransaction{Name: "b", Alias: "B"},
&Multi2InterfaceTransaction{Name: "c", Alias: "C"},
&Multi2InterfaceTransaction{Name: "d", Alias: "D"},
}
cnt, err := session.Insert(&users)
assert.NoError(t, err)
assert.EqualValues(t, len(users), cnt)
assert.NotPanics(t, func() {
err = session.Commit()
assert.NoError(t, err)
})
}

View File

@ -6,6 +6,7 @@ package statements
import ( import (
"fmt" "fmt"
"strconv"
"strings" "strings"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
@ -26,14 +27,19 @@ func (statement *Statement) ConvertIDSQL(sqlStr string) string {
return "" return ""
} }
var top string var b strings.Builder
b.WriteString("SELECT ")
pLimitN := statement.LimitN pLimitN := statement.LimitN
if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL { if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN) b.WriteString("TOP ")
b.WriteString(strconv.Itoa(*pLimitN))
b.WriteString(" ")
} }
b.WriteString(colstrs)
b.WriteString(" FROM ")
b.WriteString(sqls[1])
newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) return b.String()
return newsql
} }
return "" return ""
} }
@ -54,7 +60,7 @@ func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) {
return "", "" return "", ""
} }
var whereStr = sqls[1] whereStr := sqls[1]
// TODO: for postgres only, if any other database? // TODO: for postgres only, if any other database?
var paraStr string var paraStr string

View File

@ -15,82 +15,97 @@ import (
) )
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition interface{}, args ...interface{}) *Statement { func (statement *Statement) Join(joinOP string, joinTable interface{}, condition interface{}, args ...interface{}) *Statement {
var buf strings.Builder statement.joins = append(statement.joins, join{
if len(statement.JoinStr) > 0 { op: joinOP,
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) table: joinTable,
} else { condition: condition,
fmt.Fprintf(&buf, "%v JOIN ", joinOP) args: args,
} })
condStr := ""
condArgs := []interface{}{}
switch condTp := condition.(type) {
case string:
condStr = condTp
case builder.Cond:
var err error
condStr, condArgs, err = builder.ToSQL(condTp)
if err != nil {
statement.LastError = err
return statement
}
default:
statement.LastError = fmt.Errorf("unsupported join condition type: %v", condTp)
return statement return statement
} }
switch tp := tablename.(type) { func (statement *Statement) writeJoins(w *builder.BytesWriter) error {
case builder.Builder: for _, join := range statement.joins {
subSQL, subQueryArgs, err := tp.ToSQL() if err := statement.writeJoin(w, join); err != nil {
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr))
statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...)
case *builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr))
statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...)
default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
if !utils.IsSubQuery(tbName) {
var buf strings.Builder
_ = statement.dialect.Quoter().QuoteTo(&buf, tbName)
tbName = buf.String()
} else {
tbName = statement.ReplaceQuote(tbName)
}
fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condStr))
statement.joinArgs = append(statement.joinArgs, condArgs...)
}
statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...)
return statement
}
func (statement *Statement) writeJoin(w builder.Writer) error {
if statement.JoinStr != "" {
if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil {
return err return err
} }
w.Append(statement.joinArgs...)
} }
return nil return nil
} }
func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error {
// write join operator
if _, err := fmt.Fprintf(buf, " %v JOIN", join.op); err != nil {
return err
}
// write join table or subquery
switch tp := join.table.(type) {
case builder.Builder:
if _, err := fmt.Fprintf(buf, " ("); err != nil {
return err
}
if err := tp.WriteTo(statement.QuoteReplacer(buf)); err != nil {
return err
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
if _, err := fmt.Fprintf(buf, ") %s", statement.quote(aliasName)); err != nil {
return err
}
case *builder.Builder:
if _, err := fmt.Fprintf(buf, " ("); err != nil {
return err
}
if err := tp.WriteTo(statement.QuoteReplacer(buf)); err != nil {
return err
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
if _, err := fmt.Fprintf(buf, ") %s", statement.quote(aliasName)); err != nil {
return err
}
default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), join.table, true)
if !utils.IsSubQuery(tbName) {
var sb strings.Builder
if err := statement.dialect.Quoter().QuoteTo(&sb, tbName); err != nil {
return err
}
tbName = sb.String()
} else {
tbName = statement.ReplaceQuote(tbName)
}
if _, err := fmt.Fprint(buf, " ", tbName); err != nil {
return err
}
}
// write on condition
if _, err := fmt.Fprint(buf, " ON "); err != nil {
return err
}
switch condTp := join.condition.(type) {
case string:
if _, err := fmt.Fprint(buf, statement.ReplaceQuote(condTp)); err != nil {
return err
}
case builder.Cond:
if err := condTp.WriteTo(statement.QuoteReplacer(buf)); err != nil {
return err
}
default:
return fmt.Errorf("unsupported join condition type: %v", condTp)
}
buf.Append(join.args...)
return nil
}

View File

@ -7,6 +7,7 @@ package statements
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"reflect" "reflect"
"strings" "strings"
@ -29,37 +30,15 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
if statement.JoinStr == "" {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
if err := statement.ProcessIDParam(); err != nil { if err := statement.ProcessIDParam(); err != nil {
return "", nil, err return "", nil, err
} }
return statement.genSelectSQL(columnStr, true, true) buf := builder.NewWriter()
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil {
return "", nil, err
}
return buf.String(), buf.Args(), nil
} }
// GenSumSQL generates sum SQL // GenSumSQL generates sum SQL
@ -81,13 +60,16 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
} }
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
} }
sumSelect := strings.Join(sumStrs, ", ")
if err := statement.MergeConds(bean); err != nil { if err := statement.MergeConds(bean); err != nil {
return "", nil, err return "", nil, err
} }
return statement.genSelectSQL(sumSelect, true, true) buf := builder.NewWriter()
if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true); err != nil {
return "", nil, err
}
return buf.String(), buf.Args(), nil
} }
// GenGetSQL generates Get SQL // GenGetSQL generates Get SQL
@ -108,7 +90,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
columnStr = statement.SelectStr columnStr = statement.SelectStr
} else { } else {
// TODO: always generate column names, not use * even if join // TODO: always generate column names, not use * even if join
if len(statement.JoinStr) == 0 { if len(statement.joins) == 0 {
if len(columnStr) == 0 { if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 { if len(statement.GroupByStr) > 0 {
columnStr = statement.quoteColumnStr(statement.GroupByStr) columnStr = statement.quoteColumnStr(statement.GroupByStr)
@ -139,7 +121,11 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
} }
} }
return statement.genSelectSQL(columnStr, true, true) buf := builder.NewWriter()
if err := statement.writeSelect(buf, columnStr, true); err != nil {
return "", nil, err
}
return buf.String(), buf.Args(), nil
} }
// GenCountSQL generates the SQL for counting // GenCountSQL generates the SQL for counting
@ -148,8 +134,6 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
return statement.GenRawSQL(), statement.RawParams, nil return statement.GenRawSQL(), statement.RawParams, nil
} }
var condArgs []interface{}
var err error
if len(beans) > 0 { if len(beans) > 0 {
if err := statement.SetRefBean(beans[0]); err != nil { if err := statement.SetRefBean(beans[0]); err != nil {
return "", nil, err return "", nil, err
@ -176,19 +160,27 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
subQuerySelect = selectSQL subQuerySelect = selectSQL
} }
sqlStr, condArgs, err := statement.genSelectSQL(subQuerySelect, false, false) buf := builder.NewWriter()
if err != nil { if statement.GroupByStr != "" {
if _, err := fmt.Fprintf(buf, "SELECT %s FROM (", selectSQL); err != nil {
return "", nil, err
}
}
if err := statement.writeSelect(buf, subQuerySelect, false); err != nil {
return "", nil, err return "", nil, err
} }
if statement.GroupByStr != "" { if statement.GroupByStr != "" {
sqlStr = fmt.Sprintf("SELECT %s FROM (%s) sub", selectSQL, sqlStr) if _, err := fmt.Fprintf(buf, ") sub"); err != nil {
return "", nil, err
}
} }
return sqlStr, condArgs, nil return buf.String(), buf.Args(), nil
} }
func (statement *Statement) writeFrom(w builder.Writer) error { func (statement *Statement) writeFrom(w *builder.BytesWriter) error {
if _, err := fmt.Fprint(w, " FROM "); err != nil { if _, err := fmt.Fprint(w, " FROM "); err != nil {
return err return err
} }
@ -198,7 +190,7 @@ func (statement *Statement) writeFrom(w builder.Writer) error {
if err := statement.writeAlias(w); err != nil { if err := statement.writeAlias(w); err != nil {
return err return err
} }
return statement.writeJoin(w) return statement.writeJoins(w)
} }
func (statement *Statement) writeLimitOffset(w builder.Writer) error { func (statement *Statement) writeLimitOffset(w builder.Writer) error {
@ -218,37 +210,73 @@ func (statement *Statement) writeLimitOffset(w builder.Writer) error {
return nil return nil
} }
func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { func (statement *Statement) writeTop(w builder.Writer) error {
var ( if statement.dialect.URI().DBType != schemas.MSSQL {
distinct string return nil
dialect = statement.dialect }
top, whereStr string if statement.LimitN == nil {
mssqlCondi = builder.NewWriter() return nil
) }
_, err := fmt.Fprintf(w, " TOP %d", *statement.LimitN)
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { return err
distinct = "DISTINCT "
} }
condWriter := builder.NewWriter() func (statement *Statement) writeDistinct(w builder.Writer) error {
if err := statement.cond.WriteTo(statement.QuoteReplacer(condWriter)); err != nil { if statement.IsDistinct && !strings.HasPrefix(statement.SelectStr, "count(") {
return "", nil, err _, err := fmt.Fprint(w, " DISTINCT")
return err
}
return nil
} }
if condWriter.Len() > 0 { func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr string) error {
whereStr = " WHERE " if _, err := fmt.Fprintf(w, "SELECT "); err != nil {
return err
}
if err := statement.writeDistinct(w); err != nil {
return err
}
if err := statement.writeTop(w); err != nil {
return err
}
_, err := fmt.Fprint(w, " ", columnStr)
return err
} }
pLimitN := statement.LimitN func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
if dialect.URI().DBType == schemas.MSSQL { if !statement.cond.IsValid() {
if pLimitN != nil { return statement.writeMssqlPaginationCond(w)
LimitNValue := *pLimitN
top = fmt.Sprintf("TOP %d ", LimitNValue)
} }
if statement.Start > 0 { if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil {
return err
}
return statement.writeMssqlPaginationCond(w)
}
func (statement *Statement) writeForUpdate(w io.Writer) error {
if !statement.IsForUpdate {
return nil
}
if statement.dialect.URI().DBType != schemas.MYSQL {
return errors.New("only support mysql for update")
}
_, err := fmt.Fprint(w, " FOR UPDATE")
return err
}
func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) error {
if statement.dialect.URI().DBType != schemas.MSSQL || statement.Start <= 0 {
return nil
}
if statement.RefTable == nil { if statement.RefTable == nil {
return "", nil, errors.New("Unsupported query limit without reference table") return errors.New("unsupported query limit without reference table")
} }
var column string var column string
if len(statement.RefTable.PKColumns()) == 0 { if len(statement.RefTable.PKColumns()) == 0 {
for _, index := range statement.RefTable.Indexes { for _, index := range statement.RefTable.Indexes {
@ -263,7 +291,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} else { } else {
column = statement.RefTable.PKColumns()[0].Name column = statement.RefTable.PKColumns()[0].Name
} }
if statement.needTableName() { if statement.NeedTableName() {
if len(statement.TableAlias) > 0 { if len(statement.TableAlias) > 0 {
column = fmt.Sprintf("%s.%s", statement.TableAlias, column) column = fmt.Sprintf("%s.%s", statement.TableAlias, column)
} else { } else {
@ -271,100 +299,94 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} }
} }
if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s", subWriter := builder.NewWriter()
if _, err := fmt.Fprintf(subWriter, "(%s NOT IN (SELECT TOP %d %s",
column, statement.Start, column); err != nil { column, statement.Start, column); err != nil {
return "", nil, err return err
} }
if err := statement.writeFrom(mssqlCondi); err != nil { if err := statement.writeFrom(subWriter); err != nil {
return "", nil, err return err
} }
if whereStr != "" { if statement.cond.IsValid() {
if _, err := fmt.Fprint(mssqlCondi, whereStr); err != nil { if _, err := fmt.Fprint(subWriter, " WHERE "); err != nil {
return "", nil, err return err
} }
if err := utils.WriteBuilder(mssqlCondi, statement.QuoteReplacer(condWriter)); err != nil { if err := statement.cond.WriteTo(statement.QuoteReplacer(subWriter)); err != nil {
return "", nil, err return err
} }
} }
if needOrderBy { if err := statement.WriteOrderBy(subWriter); err != nil {
if err := statement.WriteOrderBy(mssqlCondi); err != nil { return err
return "", nil, err
}
}
if err := statement.WriteGroupBy(mssqlCondi); err != nil {
return "", nil, err
}
if _, err := fmt.Fprint(mssqlCondi, "))"); err != nil {
return "", nil, err
} }
if err := statement.writeGroupBy(subWriter); err != nil {
return err
} }
if _, err := fmt.Fprint(subWriter, "))"); err != nil {
return err
} }
buf := builder.NewWriter() if statement.cond.IsValid() {
if _, err := fmt.Fprintf(buf, "SELECT %v%v%v", distinct, top, columnStr); err != nil { if _, err := fmt.Fprint(w, " AND "); err != nil {
return "", nil, err return err
}
if err := statement.writeFrom(buf); err != nil {
return "", nil, err
}
if whereStr != "" {
if _, err := fmt.Fprint(buf, whereStr); err != nil {
return "", nil, err
}
if err := utils.WriteBuilder(buf, statement.QuoteReplacer(condWriter)); err != nil {
return "", nil, err
}
}
if mssqlCondi.Len() > 0 {
if len(whereStr) > 0 {
if _, err := fmt.Fprint(buf, " AND "); err != nil {
return "", nil, err
} }
} else { } else {
if _, err := fmt.Fprint(buf, " WHERE "); err != nil { if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return "", nil, err return err
} }
} }
if err := utils.WriteBuilder(buf, mssqlCondi); err != nil { return utils.WriteBuilder(w, subWriter)
return "", nil, err
}
} }
if err := statement.WriteGroupBy(buf); err != nil { func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr string) error {
return "", nil, err if statement.LimitN == nil {
return nil
} }
if err := statement.writeHaving(buf); err != nil {
return "", nil, err oldString := w.String()
} w.Reset()
if needOrderBy {
if err := statement.WriteOrderBy(buf); err != nil {
return "", nil, err
}
}
if needLimit {
if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
if err := statement.writeLimitOffset(buf); err != nil {
return "", nil, err
}
} else if dialect.URI().DBType == schemas.ORACLE {
if pLimitN != nil {
oldString := buf.String()
buf.Reset()
rawColStr := columnStr rawColStr := columnStr
if rawColStr == "*" { if rawColStr == "*" {
rawColStr = "at.*" rawColStr = "at.*"
} }
fmt.Fprintf(buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", _, err := fmt.Fprintf(w, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start) columnStr, rawColStr, oldString, statement.Start+*statement.LimitN, statement.Start)
} return err
}
}
if statement.IsForUpdate {
return dialect.ForUpdateSQL(buf.String()), buf.Args(), nil
} }
return buf.String(), buf.Args(), nil func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error {
if err := statement.writeSelectColumns(buf, columnStr); err != nil {
return err
}
if err := statement.writeFrom(buf); err != nil {
return err
}
if err := statement.writeWhere(buf); err != nil {
return err
}
if err := statement.writeGroupBy(buf); err != nil {
return err
}
if err := statement.writeHaving(buf); err != nil {
return err
}
if err := statement.WriteOrderBy(buf); err != nil {
return err
}
dialect := statement.dialect
if needLimit {
if dialect.URI().DBType == schemas.ORACLE {
if err := statement.writeOracleLimit(buf, columnStr); err != nil {
return err
}
} else if dialect.URI().DBType != schemas.MSSQL {
if err := statement.writeLimitOffset(buf); err != nil {
return err
}
}
}
return statement.writeForUpdate(buf)
} }
// GenExistSQL generates Exist SQL // GenExistSQL generates Exist SQL
@ -402,7 +424,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil { if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil {
return "", nil, err return "", nil, err
} }
if err := statement.writeJoin(buf); err != nil { if err := statement.writeJoins(buf); err != nil {
return "", nil, err return "", nil, err
} }
if statement.Conds().IsValid() { if statement.Conds().IsValid() {
@ -417,7 +439,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil { if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil {
return "", nil, err return "", nil, err
} }
if err := statement.writeJoin(buf); err != nil { if err := statement.writeJoins(buf); err != nil {
return "", nil, err return "", nil, err
} }
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
@ -438,7 +460,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil { if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil {
return "", nil, err return "", nil, err
} }
if err := statement.writeJoin(buf); err != nil { if err := statement.writeJoins(buf); err != nil {
return "", nil, err return "", nil, err
} }
if statement.Conds().IsValid() { if statement.Conds().IsValid() {
@ -457,6 +479,33 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil
} }
func (statement *Statement) genSelectColumnStr() string {
// manually select columns
if len(statement.SelectStr) > 0 {
return statement.SelectStr
}
columnStr := statement.ColumnStr()
if columnStr != "" {
return columnStr
}
// autodetect columns
if statement.GroupByStr != "" {
return statement.quoteColumnStr(statement.GroupByStr)
}
if len(statement.joins) != 0 {
return "*"
}
columnStr = statement.genColumnStr()
if columnStr == "" {
columnStr = "*"
}
return columnStr
}
// GenFindSQL generates Find SQL // GenFindSQL generates Find SQL
func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) { func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) {
if statement.RawSQL != "" { if statement.RawSQL != "" {
@ -467,33 +516,11 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
if statement.JoinStr == "" {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
statement.cond = statement.cond.And(autoCond) statement.cond = statement.cond.And(autoCond)
return statement.genSelectSQL(columnStr, true, true) buf := builder.NewWriter()
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil {
return "", nil, err
}
return buf.String(), buf.Args(), nil
} }

View File

@ -102,7 +102,7 @@ func (statement *Statement) genColumnStr() string {
buf.WriteString(", ") buf.WriteString(", ")
} }
if statement.JoinStr != "" { if len(statement.joins) > 0 {
if statement.TableAlias != "" { if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias) buf.WriteString(statement.TableAlias)
} else { } else {
@ -119,7 +119,7 @@ func (statement *Statement) genColumnStr() string {
} }
func (statement *Statement) colName(col *schemas.Column, tableName string) string { func (statement *Statement) colName(col *schemas.Column, tableName string) string {
if statement.needTableName() { if statement.NeedTableName() {
nm := tableName nm := tableName
if len(statement.TableAlias) > 0 { if len(statement.TableAlias) > 0 {
nm = statement.TableAlias nm = statement.TableAlias

View File

@ -34,6 +34,13 @@ var (
ErrTableNotFound = errors.New("Table not found") ErrTableNotFound = errors.New("Table not found")
) )
type join struct {
op string
table interface{}
condition interface{}
args []interface{}
}
// Statement save all the sql info for executing SQL // Statement save all the sql info for executing SQL
type Statement struct { type Statement struct {
RefTable *schemas.Table RefTable *schemas.Table
@ -45,8 +52,7 @@ type Statement struct {
idParam schemas.PK idParam schemas.PK
orderStr string orderStr string
orderArgs []interface{} orderArgs []interface{}
JoinStr string joins []join
joinArgs []interface{}
GroupByStr string GroupByStr string
HavingStr string HavingStr string
SelectStr string SelectStr string
@ -123,8 +129,7 @@ func (statement *Statement) Reset() {
statement.LimitN = nil statement.LimitN = nil
statement.ResetOrderBy() statement.ResetOrderBy()
statement.UseCascade = true statement.UseCascade = true
statement.JoinStr = "" statement.joins = nil
statement.joinArgs = make([]interface{}, 0)
statement.GroupByStr = "" statement.GroupByStr = ""
statement.HavingStr = "" statement.HavingStr = ""
statement.ColumnMap = columnMap{} statement.ColumnMap = columnMap{}
@ -205,8 +210,8 @@ func (statement *Statement) SetRefBean(bean interface{}) error {
return nil return nil
} }
func (statement *Statement) needTableName() bool { func (statement *Statement) NeedTableName() bool {
return len(statement.JoinStr) > 0 return len(statement.joins) > 0
} }
// Incr Generate "Update ... Set column = column + arg" statement // Incr Generate "Update ... Set column = column + arg" statement
@ -290,7 +295,7 @@ func (statement *Statement) GroupBy(keys string) *Statement {
return statement return statement
} }
func (statement *Statement) WriteGroupBy(w builder.Writer) error { func (statement *Statement) writeGroupBy(w builder.Writer) error {
if statement.GroupByStr == "" { if statement.GroupByStr == "" {
return nil return nil
} }
@ -605,7 +610,7 @@ func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, i
// MergeConds merge conditions from bean and id // MergeConds merge conditions from bean and id
func (statement *Statement) MergeConds(bean interface{}) error { func (statement *Statement) MergeConds(bean interface{}) error {
if !statement.NoAutoCondition && statement.RefTable != nil { if !statement.NoAutoCondition && statement.RefTable != nil {
addedTableName := (len(statement.JoinStr) > 0) addedTableName := (len(statement.joins) > 0)
autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
if err != nil { if err != nil {
return err return err
@ -673,7 +678,7 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName
// CondDeleted returns the conditions whether a record is soft deleted. // CondDeleted returns the conditions whether a record is soft deleted.
func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond {
colName := statement.quote(col.Name) colName := statement.quote(col.Name)
if statement.JoinStr != "" { if len(statement.joins) > 0 {
var prefix string var prefix string
if statement.TableAlias != "" { if statement.TableAlias != "" {
prefix = statement.TableAlias prefix = statement.TableAlias

View File

@ -19,7 +19,8 @@ import (
) )
func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion, includeUpdated, includeNil, func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion, includeUpdated, includeNil,
includeAutoIncr, update bool) (bool, error) { includeAutoIncr, update bool,
) (bool, error) {
columnMap := statement.ColumnMap columnMap := statement.ColumnMap
omitColumnMap := statement.OmitColumnMap omitColumnMap := statement.OmitColumnMap
unscoped := statement.unscoped unscoped := statement.unscoped
@ -64,15 +65,16 @@ func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion,
// BuildUpdates auto generating update columnes and values according a struct // BuildUpdates auto generating update columnes and values according a struct
func (statement *Statement) BuildUpdates(tableValue reflect.Value, func (statement *Statement) BuildUpdates(tableValue reflect.Value,
includeVersion, includeUpdated, includeNil, includeVersion, includeUpdated, includeNil,
includeAutoIncr, update bool) ([]string, []interface{}, error) { includeAutoIncr, update bool,
) ([]string, []interface{}, error) {
table := statement.RefTable table := statement.RefTable
allUseBool := statement.allUseBool allUseBool := statement.allUseBool
useAllCols := statement.useAllCols useAllCols := statement.useAllCols
mustColumnMap := statement.MustColumnMap mustColumnMap := statement.MustColumnMap
nullableMap := statement.NullableMap nullableMap := statement.NullableMap
var colNames = make([]string, 0) colNames := make([]string, 0)
var args = make([]interface{}, 0) args := make([]interface{}, 0)
for _, col := range table.Columns() { for _, col := range table.Columns() {
ok, err := statement.ifAddColUpdate(col, includeVersion, includeUpdated, includeNil, ok, err := statement.ifAddColUpdate(col, includeVersion, includeUpdated, includeNil,

10
rows.go
View File

@ -46,8 +46,8 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
if rows.session.statement.RawSQL == "" { if rows.session.statement.RawSQL == "" {
var autoCond builder.Cond var autoCond builder.Cond
var addedTableName = (len(session.statement.JoinStr) > 0) addedTableName := session.statement.NeedTableName()
var table = rows.session.statement.RefTable table := rows.session.statement.RefTable
if !session.statement.NoAutoCondition { if !session.statement.NoAutoCondition {
var err error var err error
@ -103,12 +103,12 @@ func (rows *Rows) Scan(beans ...interface{}) error {
return rows.Err() return rows.Err()
} }
var bean = beans[0] bean := beans[0]
var tp = reflect.TypeOf(bean) tp := reflect.TypeOf(bean)
if tp.Kind() == reflect.Ptr { if tp.Kind() == reflect.Ptr {
tp = tp.Elem() tp = tp.Elem()
} }
var beanKind = tp.Kind() beanKind := tp.Kind()
if len(beans) == 1 { if len(beans) == 1 {
if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType {

10
schemas/collation.go Normal file
View File

@ -0,0 +1,10 @@
// Copyright 2023 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
type Collation struct {
Name string
Column string // blank means it's a table collation
}

View File

@ -45,6 +45,7 @@ type Column struct {
DisableTimeZone bool DisableTimeZone bool
TimeZone *time.Location // column specified time zone TimeZone *time.Location // column specified time zone
Comment string Comment string
Collation string
} }
// NewColumn creates a new column // NewColumn creates a new column
@ -89,6 +90,8 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
v.Set(reflect.New(v.Type().Elem())) v.Set(reflect.New(v.Type().Elem()))
} }
v = v.Elem() v = v.Elem()
} else if v.Kind() == reflect.Interface {
v = reflect.Indirect(v.Elem())
} }
v = v.FieldByIndex([]int{i}) v = v.FieldByIndex([]int{i})
} }

View File

@ -27,6 +27,7 @@ type Table struct {
StoreEngine string StoreEngine string
Charset string Charset string
Comment string Comment string
Collation string
} }
// NewEmptyTable creates an empty table // NewEmptyTable creates an empty table
@ -36,7 +37,8 @@ func NewEmptyTable() *Table {
// NewTable creates a new Table object // NewTable creates a new Table object
func NewTable(name string, t reflect.Type) *Table { func NewTable(name string, t reflect.Type) *Table {
return &Table{Name: name, Type: t, return &Table{
Name: name, Type: t,
columnsSeq: make([]string, 0), columnsSeq: make([]string, 0),
columns: make([]*Column, 0), columns: make([]*Column, 0),
columnsMap: make(map[string][]*Column), columnsMap: make(map[string][]*Column),

View File

@ -352,7 +352,7 @@ func (session *Session) DB() *core.DB {
func (session *Session) canCache() bool { func (session *Session) canCache() bool {
if session.statement.RefTable == nil || if session.statement.RefTable == nil ||
session.statement.JoinStr != "" || session.statement.NeedTableName() ||
session.statement.RawSQL != "" || session.statement.RawSQL != "" ||
!session.statement.UseCache || !session.statement.UseCache ||
session.statement.IsForUpdate || session.statement.IsForUpdate ||

View File

@ -30,7 +30,7 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri
} }
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr) sqlStr = filter.Do(session.ctx, sqlStr)
} }
newsql := session.statement.ConvertIDSQL(sqlStr) newsql := session.statement.ConvertIDSQL(sqlStr)

View File

@ -116,7 +116,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
var ( var (
table = session.statement.RefTable table = session.statement.RefTable
addedTableName = (len(session.statement.JoinStr) > 0) addedTableName = session.statement.NeedTableName()
autoCond builder.Cond autoCond builder.Cond
) )
if tp == tpStruct { if tp == tpStruct {
@ -346,7 +346,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} }
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr) sqlStr = filter.Do(session.ctx, sqlStr)
} }
newsql := session.statement.ConvertIDSQL(sqlStr) newsql := session.statement.ConvertIDSQL(sqlStr)

View File

@ -280,7 +280,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
} }
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr) sqlStr = filter.Do(session.ctx, sqlStr)
} }
newsql := session.statement.ConvertIDSQL(sqlStr) newsql := session.statement.ConvertIDSQL(sqlStr)
if newsql == "" { if newsql == "" {

View File

@ -353,6 +353,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
}
if needCommit { if needCommit {
if err := session.Commit(); err != nil { if err := session.Commit(); err != nil {
return 0, err return 0, err
@ -361,7 +362,6 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
if id == 0 { if id == 0 {
return 0, errors.New("insert successfully but not returned id") return 0, errors.New("insert successfully but not returned id")
} }
}
defer handleAfterInsertProcessorFunc(bean) defer handleAfterInsertProcessorFunc(bean)

View File

@ -13,7 +13,7 @@ import (
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
*sqlStr = filter.Do(*sqlStr) *sqlStr = filter.Do(session.ctx, *sqlStr)
} }
session.lastSQL = *sqlStr session.lastSQL = *sqlStr

View File

@ -15,7 +15,6 @@ import (
"xorm.io/xorm/dialects" "xorm.io/xorm/dialects"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
) )
// Ping test if database is ok // Ping test if database is ok
@ -169,7 +168,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
return nil return nil
} }
var seqName = utils.SeqName(tableName) seqName := utils.SeqName(tableName)
exist, err := session.engine.dialect.IsSequenceExist(session.ctx, session.getQueryer(), seqName) exist, err := session.engine.dialect.IsSequenceExist(session.ctx, session.getQueryer(), seqName)
if err != nil { if err != nil {
return err return err
@ -244,228 +243,6 @@ func (session *Session) addUnique(tableName, uqeName string) error {
return err return err
} }
// Sync2 synchronize structs to database tables
// Depricated
func (session *Session) Sync2(beans ...interface{}) error {
return session.Sync(beans...)
}
// Sync synchronize structs to database tables
func (session *Session) Sync(beans ...interface{}) error {
engine := session.engine
if session.isAutoClose {
session.isAutoClose = false
defer session.Close()
}
tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx)
if err != nil {
return err
}
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
for _, bean := range beans {
v := utils.ReflectValue(bean)
table, err := engine.tagParser.ParseWithCache(v)
if err != nil {
return err
}
var tbName string
if len(session.statement.AltTableName) > 0 {
tbName = session.statement.AltTableName
} else {
tbName = engine.TableName(bean)
}
tbNameWithSchema := engine.tbNameWithSchema(tbName)
var oriTable *schemas.Table
for _, tb := range tables {
if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
oriTable = tb
break
}
}
// this is a new table
if oriTable == nil {
err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
if err != nil {
return err
}
err = session.createUniques(bean)
if err != nil {
return err
}
err = session.createIndexes(bean)
if err != nil {
return err
}
continue
}
// this will modify an old table
if err = engine.loadTableInfo(oriTable); err != nil {
return err
}
// check columns
for _, col := range table.Columns() {
var oriCol *schemas.Column
for _, col2 := range oriTable.Columns() {
if strings.EqualFold(col.Name, col2.Name) {
oriCol = col2
break
}
}
// column is not exist on table
if oriCol == nil {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
if err = session.addColumn(col.Name); err != nil {
return err
}
continue
}
err = nil
expectedType := engine.dialect.SQLType(col)
curType := engine.dialect.SQLType(oriCol)
if expectedType != curType {
if expectedType == schemas.Text &&
strings.HasPrefix(curType, schemas.Varchar) {
// currently only support mysql & postgres
if engine.dialect.URI().DBType == schemas.MYSQL ||
engine.dialect.URI().DBType == schemas.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
}
} else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
} else {
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
tbNameWithSchema, col.Name, curType, expectedType)
}
}
}
} else if expectedType == schemas.Varchar {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
} else if col.Comment != oriCol.Comment {
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
if col.Default != oriCol.Default {
switch {
case col.IsAutoIncrement: // For autoincrement column, don't check default
case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) &&
((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") ||
(strings.EqualFold(col.Default, "false") && oriCol.Default == "0")):
default:
engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s",
tbName, col.Name, oriCol.Default, col.Default)
}
}
if col.Nullable != oriCol.Nullable {
engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v",
tbName, col.Name, oriCol.Nullable, col.Nullable)
}
if err != nil {
return err
}
}
var foundIndexNames = make(map[string]bool)
var addedNames = make(map[string]*schemas.Index)
for name, index := range table.Indexes {
var oriIndex *schemas.Index
for name2, index2 := range oriTable.Indexes {
if index.Equal(index2) {
oriIndex = index2
foundIndexNames[name2] = true
break
}
}
if oriIndex != nil {
if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex)
_, err = session.exec(sql)
if err != nil {
return err
}
oriIndex = nil
}
}
if oriIndex == nil {
addedNames[name] = index
}
}
for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2)
_, err = session.exec(sql)
if err != nil {
return err
}
}
}
for name, index := range addedNames {
if index.Type == schemas.UniqueType {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == schemas.IndexType {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
err = session.addIndex(tbNameWithSchema, name)
}
if err != nil {
return err
}
}
// check all the columns which removed from struct fields but left on database tables.
for _, colName := range oriTable.ColumnsSeq() {
if table.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName)
}
}
}
return nil
}
// ImportFile SQL DDL file // ImportFile SQL DDL file
func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) { func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) {
file, err := os.Open(ddlPath) file, err := os.Open(ddlPath)
@ -490,7 +267,7 @@ func (session *Session) Import(r io.Reader) ([]sql.Result, error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
return 0, nil, nil return 0, nil, nil
} }
var oriInSingleQuote = inSingleQuote oriInSingleQuote := inSingleQuote
for i, b := range data { for i, b := range data {
if startComment { if startComment {
if b == '\n' { if b == '\n' {

View File

@ -34,7 +34,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
return ErrCacheFailed return ErrCacheFailed
} }
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
newsql = filter.Do(newsql) newsql = filter.Do(session.ctx, newsql)
} }
session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql) session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql)

276
sync.go Normal file
View File

@ -0,0 +1,276 @@
// Copyright 2023 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"strings"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
type SyncOptions struct {
WarnIfDatabaseColumnMissed bool
}
type SyncResult struct{}
// Sync the new struct changes to database, this method will automatically add
// table, column, index, unique. but will not delete or change anything.
// If you change some field, you should change the database manually.
func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
defer session.Close()
return session.Sync(beans...)
}
// SyncWithOptions sync the database schemas according options and table structs
func (engine *Engine) SyncWithOptions(opts SyncOptions, beans ...interface{}) (*SyncResult, error) {
session := engine.NewSession()
defer session.Close()
return session.SyncWithOptions(opts, beans...)
}
// Sync2 synchronize structs to database tables
// Depricated
func (engine *Engine) Sync2(beans ...interface{}) error {
return engine.Sync(beans...)
}
// Sync2 synchronize structs to database tables
// Depricated
func (session *Session) Sync2(beans ...interface{}) error {
return session.Sync(beans...)
}
// Sync synchronize structs to database tables
func (session *Session) Sync(beans ...interface{}) error {
_, err := session.SyncWithOptions(SyncOptions{
WarnIfDatabaseColumnMissed: false,
}, beans...)
return err
}
func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) (*SyncResult, error) {
engine := session.engine
if session.isAutoClose {
session.isAutoClose = false
defer session.Close()
}
tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx)
if err != nil {
return nil, err
}
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
var syncResult SyncResult
for _, bean := range beans {
v := utils.ReflectValue(bean)
table, err := engine.tagParser.ParseWithCache(v)
if err != nil {
return nil, err
}
var tbName string
if len(session.statement.AltTableName) > 0 {
tbName = session.statement.AltTableName
} else {
tbName = engine.TableName(bean)
}
tbNameWithSchema := engine.tbNameWithSchema(tbName)
var oriTable *schemas.Table
for _, tb := range tables {
if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
oriTable = tb
break
}
}
// this is a new table
if oriTable == nil {
err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
if err != nil {
return nil, err
}
err = session.createUniques(bean)
if err != nil {
return nil, err
}
err = session.createIndexes(bean)
if err != nil {
return nil, err
}
continue
}
// this will modify an old table
if err = engine.loadTableInfo(oriTable); err != nil {
return nil, err
}
// check columns
for _, col := range table.Columns() {
var oriCol *schemas.Column
for _, col2 := range oriTable.Columns() {
if strings.EqualFold(col.Name, col2.Name) {
oriCol = col2
break
}
}
// column is not exist on table
if oriCol == nil {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
if err = session.addColumn(col.Name); err != nil {
return nil, err
}
continue
}
err = nil
expectedType := engine.dialect.SQLType(col)
curType := engine.dialect.SQLType(oriCol)
if expectedType != curType {
if expectedType == schemas.Text &&
strings.HasPrefix(curType, schemas.Varchar) {
// currently only support mysql & postgres
if engine.dialect.URI().DBType == schemas.MYSQL ||
engine.dialect.URI().DBType == schemas.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
}
} else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
} else {
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
tbNameWithSchema, col.Name, curType, expectedType)
}
}
}
} else if expectedType == schemas.Varchar {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
} else if col.Comment != oriCol.Comment {
if engine.dialect.URI().DBType == schemas.POSTGRES ||
engine.dialect.URI().DBType == schemas.MYSQL {
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
if col.Default != oriCol.Default {
switch {
case col.IsAutoIncrement: // For autoincrement column, don't check default
case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) &&
((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") ||
(strings.EqualFold(col.Default, "false") && oriCol.Default == "0")):
default:
engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s",
tbName, col.Name, oriCol.Default, col.Default)
}
}
if col.Nullable != oriCol.Nullable {
engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v",
tbName, col.Name, oriCol.Nullable, col.Nullable)
}
if err != nil {
return nil, err
}
}
foundIndexNames := make(map[string]bool)
addedNames := make(map[string]*schemas.Index)
for name, index := range table.Indexes {
var oriIndex *schemas.Index
for name2, index2 := range oriTable.Indexes {
if index.Equal(index2) {
oriIndex = index2
foundIndexNames[name2] = true
break
}
}
if oriIndex != nil {
if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex)
_, err = session.exec(sql)
if err != nil {
return nil, err
}
oriIndex = nil
}
}
if oriIndex == nil {
addedNames[name] = index
}
}
for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2)
_, err = session.exec(sql)
if err != nil {
return nil, err
}
}
}
for name, index := range addedNames {
if index.Type == schemas.UniqueType {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == schemas.IndexType {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
err = session.addIndex(tbNameWithSchema, name)
}
if err != nil {
return nil, err
}
}
if opts.WarnIfDatabaseColumnMissed {
// check all the columns which removed from struct fields but left on database tables.
for _, colName := range oriTable.ColumnsSeq() {
if table.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName)
}
}
}
}
return &syncResult, nil
}

View File

@ -31,6 +31,12 @@ type TableIndices interface {
var tpTableIndices = reflect.TypeOf((*TableIndices)(nil)).Elem() var tpTableIndices = reflect.TypeOf((*TableIndices)(nil)).Elem()
type TableCollations interface {
TableCollations() []*schemas.Collation
}
var tpTableCollations = reflect.TypeOf((*TableCollations)(nil)).Elem()
// Parser represents a parser for xorm tag // Parser represents a parser for xorm tag
type Parser struct { type Parser struct {
identifier string identifier string
@ -356,6 +362,22 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
} }
} }
collations := tableCollations(v)
for _, collation := range collations {
if collation.Name == "" {
continue
}
if collation.Column == "" {
table.Collation = collation.Name
} else {
col := table.GetColumn(collation.Column)
if col == nil {
return nil, ErrUnsupportedType
}
col.Collation = collation.Name // this may override definition in struct tag
}
}
return table, nil return table, nil
} }
@ -377,3 +399,22 @@ func tableIndices(v reflect.Value) []*schemas.Index {
} }
return nil return nil
} }
func tableCollations(v reflect.Value) []*schemas.Collation {
if v.Type().Implements(tpTableCollations) {
return v.Interface().(TableCollations).TableCollations()
}
if v.Kind() == reflect.Ptr {
v = v.Elem()
if v.Type().Implements(tpTableCollations) {
return v.Interface().(TableCollations).TableCollations()
}
} else if v.CanAddr() {
v1 := v.Addr()
if v1.Type().Implements(tpTableCollations) {
return v1.Interface().(TableCollations).TableCollations()
}
}
return nil
}

View File

@ -123,6 +123,7 @@ var defaultTagHandlers = map[string]Handler{
"COMMENT": CommentTagHandler, "COMMENT": CommentTagHandler,
"EXTENDS": ExtendsTagHandler, "EXTENDS": ExtendsTagHandler,
"UNSIGNED": UnsignedTagHandler, "UNSIGNED": UnsignedTagHandler,
"COLLATE": CollateTagHandler,
} }
func init() { func init() {
@ -282,6 +283,16 @@ func CommentTagHandler(ctx *Context) error {
return nil return nil
} }
func CollateTagHandler(ctx *Context) error {
if len(ctx.params) > 0 {
ctx.col.Collation = ctx.params[0]
} else {
ctx.col.Collation = ctx.nextTag
ctx.ignoreNext = true
}
return nil
}
// SQLTypeTagHandler describes SQL Type tag handler // SQLTypeTagHandler describes SQL Type tag handler
func SQLTypeTagHandler(ctx *Context) error { func SQLTypeTagHandler(ctx *Context) error {
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname}

View File

@ -11,11 +11,12 @@ import (
) )
func TestSplitTag(t *testing.T) { func TestSplitTag(t *testing.T) {
var cases = []struct { cases := []struct {
tag string tag string
tags []tag tags []tag
}{ }{
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ {
"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{
{ {
name: "not", name: "not",
}, },
@ -33,13 +34,15 @@ func TestSplitTag(t *testing.T) {
}, },
}, },
}, },
{"TEXT", []tag{ {
"TEXT", []tag{
{ {
name: "TEXT", name: "TEXT",
}, },
}, },
}, },
{"default('2000-01-01 00:00:00')", []tag{ {
"default('2000-01-01 00:00:00')", []tag{
{ {
name: "default", name: "default",
params: []string{ params: []string{
@ -48,7 +51,8 @@ func TestSplitTag(t *testing.T) {
}, },
}, },
}, },
{"json binary", []tag{ {
"json binary", []tag{
{ {
name: "json", name: "json",
}, },
@ -57,14 +61,16 @@ func TestSplitTag(t *testing.T) {
}, },
}, },
}, },
{"numeric(10, 2)", []tag{ {
"numeric(10, 2)", []tag{
{ {
name: "numeric", name: "numeric",
params: []string{"10", "2"}, params: []string{"10", "2"},
}, },
}, },
}, },
{"numeric(10, 2) notnull", []tag{ {
"numeric(10, 2) notnull", []tag{
{ {
name: "numeric", name: "numeric",
params: []string{"10", "2"}, params: []string{"10", "2"},
@ -74,6 +80,16 @@ func TestSplitTag(t *testing.T) {
}, },
}, },
}, },
{
"collate utf8mb4_bin", []tag{
{
name: "collate",
},
{
name: "utf8mb4_bin",
},
},
},
} }
for _, kase := range cases { for _, kase := range cases {
@ -82,7 +98,7 @@ func TestSplitTag(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, len(tags), len(kase.tags)) assert.EqualValues(t, len(tags), len(kase.tags))
for i := 0; i < len(tags); i++ { for i := 0; i < len(tags); i++ {
assert.Equal(t, tags[i], kase.tags[i]) assert.Equal(t, kase.tags[i], tags[i])
} }
}) })
} }