merge main branch

This commit is contained in:
Lunny Xiao 2023-06-11 17:44:53 +08:00
commit 53fb564447
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
57 changed files with 1862 additions and 1365 deletions

View File

@ -1,437 +0,0 @@
---
kind: pipeline
name: test-mysql
environment:
GO111MODULE: "on"
GOPROXY: "https://goproxy.io"
CGO_ENABLED: 1
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-vet
image: golang:1.15
pull: always
volumes:
- name: cache
path: /go/pkg/mod
commands:
- make vet
- name: test-sqlite3
image: golang:1.15
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-vet
commands:
- make fmt-check
- make test
- make test-sqlite3
- TEST_CACHE_ENABLE=true make test-sqlite3
- name: test-sqlite
image: golang:1.15
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-vet
commands:
- make test-sqlite
- TEST_QUOTE_POLICY=reserved make test-sqlite
- name: test-mysql
image: golang:1.15
pull: never
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-vet
environment:
TEST_MYSQL_HOST: mysql
TEST_MYSQL_CHARSET: utf8
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
commands:
- TEST_CACHE_ENABLE=true make test-mysql
- name: test-mysql-utf8mb4
image: golang:1.15
pull: never
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-mysql
environment:
TEST_MYSQL_HOST: mysql
TEST_MYSQL_CHARSET: utf8mb4
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
commands:
- make test-mysql
- TEST_QUOTE_POLICY=reserved make test-mysql-tls
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: mysql
image: mysql:5.7
environment:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: xorm_test
---
kind: pipeline
name: test-mysql8
depends_on:
- test-mysql
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-mysql8
image: golang:1.15
pull: never
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_MYSQL_HOST: mysql8
TEST_MYSQL_CHARSET: utf8mb4
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
commands:
- make test-mysql
- TEST_CACHE_ENABLE=true make test-mysql
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: mysql8
image: mysql:8.0
environment:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: xorm_test
---
kind: pipeline
name: test-mariadb
depends_on:
- test-mysql8
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-mariadb
image: golang:1.15
pull: never
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_MYSQL_HOST: mariadb
TEST_MYSQL_CHARSET: utf8mb4
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
commands:
- make test-mysql
- TEST_QUOTE_POLICY=reserved make test-mysql
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: mariadb
image: mariadb:10.4
environment:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: xorm_test
---
kind: pipeline
name: test-postgres
depends_on:
- test-mariadb
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-postgres
pull: never
image: golang:1.15
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
commands:
- make test-postgres
- TEST_CACHE_ENABLE=true make test-postgres
- name: test-postgres-schema
pull: never
image: golang:1.15
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-postgres
environment:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_SCHEMA: xorm
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
commands:
- TEST_QUOTE_POLICY=reserved make test-postgres
- name: test-pgx
pull: never
image: golang:1.15
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-postgres-schema
environment:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
commands:
- make test-pgx
- TEST_CACHE_ENABLE=true make test-pgx
- TEST_QUOTE_POLICY=reserved make test-pgx
- name: test-pgx-schema
pull: never
image: golang:1.15
volumes:
- name: cache
path: /go/pkg/mod
depends_on:
- test-pgx
environment:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_SCHEMA: xorm
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
commands:
- make test-pgx
- TEST_CACHE_ENABLE=true make test-pgx
- TEST_QUOTE_POLICY=reserved make test-pgx
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: pgsql
image: postgres:9.5
environment:
POSTGRES_DB: xorm_test
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
---
kind: pipeline
name: test-mssql
depends_on:
- test-postgres
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-mssql
pull: never
image: golang:1.15
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_MSSQL_HOST: mssql
TEST_MSSQL_DBNAME: xorm_test
TEST_MSSQL_USERNAME: sa
TEST_MSSQL_PASSWORD: "yourStrong(!)Password"
commands:
- make test-mssql
- TEST_MSSQL_DEFAULT_VARCHAR=NVARCHAR TEST_MSSQL_DEFAULT_CHAR=NCHAR make test-mssql
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: mssql
pull: always
image: mcr.microsoft.com/mssql/server:latest
environment:
ACCEPT_EULA: Y
SA_PASSWORD: yourStrong(!)Password
MSSQL_PID: Standard
---
kind: pipeline
name: test-tidb
depends_on:
- test-mssql
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-tidb
pull: never
image: golang:1.15
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_TIDB_HOST: "tidb:4000"
TEST_TIDB_DBNAME: xorm_test
TEST_TIDB_USERNAME: root
TEST_TIDB_PASSWORD:
commands:
- make test-tidb
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: tidb
image: pingcap/tidb:v3.0.3
---
kind: pipeline
name: test-cockroach
depends_on:
- test-tidb
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-cockroach
pull: never
image: golang:1.15
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_COCKROACH_HOST: "cockroach:26257"
TEST_COCKROACH_DBNAME: xorm_test
TEST_COCKROACH_USERNAME: root
TEST_COCKROACH_PASSWORD:
commands:
- sleep 10
- make test-cockroach
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: cockroach
image: cockroachdb/cockroach:v19.2.4
commands:
- /cockroach/cockroach start --insecure
---
kind: pipeline
name: test-dameng
depends_on:
- test-cockroach
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: test-dameng
pull: never
image: golang:1.15
volumes:
- name: cache
path: /go/pkg/mod
environment:
TEST_DAMENG_HOST: "dameng:5236"
TEST_DAMENG_USERNAME: SYSDBA
TEST_DAMENG_PASSWORD: SYSDBA
commands:
- sleep 30
- make test-dameng
volumes:
- name: cache
host:
path: /tmp/cache
services:
- name: dameng
image: lunny/dm:v1.0
commands:
- /bin/bash /startDm.sh
---
kind: pipeline
name: merge_coverage
depends_on:
- test-mysql
- test-mysql8
- test-mariadb
- test-postgres
- test-mssql
- test-tidb
- test-cockroach
#- test-dameng
trigger:
ref:
- refs/heads/master
- refs/pull/*/head
steps:
- name: merge_coverage
image: golang:1.15
commands:
- make coverage
---
kind: pipeline
name: release-tag
trigger:
event:
- tag
steps:
- name: release-tag-gitea
pull: always
image: plugins/gitea-release:latest
settings:
base_url: https://gitea.com
title: '${DRONE_TAG} is released'
api_key:
from_secret: gitea_token

View File

@ -0,0 +1,23 @@
name: release
on:
push:
tags:
- '*'
jobs:
release:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: setup go
uses: https://github.com/actions/setup-go@v4
with:
go-version: '>=1.20.1'
- name: Use Go Action
id: use-go-action
uses: actions/release-action@main
with:
api_key: '${{secrets.RELEASE_TOKEN}}'

View File

@ -0,0 +1,55 @@
name: test cockroach
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
test-cockroach:
name: test cockroach
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test cockroach
env:
TEST_COCKROACH_HOST: "cockroach:26257"
TEST_COCKROACH_DBNAME: xorm_test
TEST_COCKROACH_USERNAME: root
TEST_COCKROACH_PASSWORD:
run: sleep 20 && make test-cockroach
services:
cockroach:
image: cockroachdb/cockroach:v19.2.4
ports:
- 26257:26257
cmd:
- 'start'
- '--insecure'

View File

@ -0,0 +1,54 @@
name: test dameng
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
lint:
name: test postgres
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test dameng
env:
TEST_DAMENG_HOST: "dameng:5236"
TEST_DAMENG_USERNAME: SYSDBA
TEST_DAMENG_PASSWORD: SYSDBA
run: make test-dameng
services:
dameng:
image: lunny/dm:v1.0
cmd:
- /bin/bash
- /startDm.sh
ports:
- 5236:5236

View File

@ -0,0 +1,56 @@
name: test mariadb
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
lint:
name: test mariadb
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test mariadb
env:
TEST_MYSQL_HOST: mariadb
TEST_MYSQL_CHARSET: utf8mb4
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
run: TEST_QUOTE_POLICY=reserved make test-mysql
services:
mariadb:
image: mariadb:10.4
env:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: xorm_test
ports:
- 3306:3306

View File

@ -0,0 +1,56 @@
name: test mssql
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
test-mssql:
name: test mssql
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test mssql
env:
TEST_MSSQL_HOST: mssql
TEST_MSSQL_DBNAME: xorm_test
TEST_MSSQL_USERNAME: sa
TEST_MSSQL_PASSWORD: "yourStrong(!)Password"
run: TEST_MSSQL_DEFAULT_VARCHAR=NVARCHAR TEST_MSSQL_DEFAULT_CHAR=NCHAR make test-mssql
services:
mssql:
image: mcr.microsoft.com/mssql/server:latest
env:
ACCEPT_EULA: Y
SA_PASSWORD: yourStrong(!)Password
MSSQL_PID: Standard
ports:
- 1433:1433

View File

@ -0,0 +1,56 @@
name: test mysql
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
test-mysql:
name: test mysql
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test mysql utf8mb4
env:
TEST_MYSQL_HOST: mysql
TEST_MYSQL_CHARSET: utf8mb4
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
run: TEST_QUOTE_POLICY=reserved make test-mysql-tls
services:
mysql:
image: mysql:5.7
env:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: xorm_test
ports:
- 3306:3306

View File

@ -0,0 +1,56 @@
name: test mysql8
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
lint:
name: test mysql8
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test mysql8
env:
TEST_MYSQL_HOST: mysql8
TEST_MYSQL_CHARSET: utf8mb4
TEST_MYSQL_DBNAME: xorm_test
TEST_MYSQL_USERNAME: root
TEST_MYSQL_PASSWORD:
run: TEST_CACHE_ENABLE=true make test-mysql
services:
mysql8:
image: mysql:8.0
env:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: xorm_test
ports:
- 3306:3306

View File

@ -0,0 +1,79 @@
name: test postgres
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
lint:
name: test postgres
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test postgres
env:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
run: TEST_CACHE_ENABLE=true make test-postgres
- name: test postgres with schema
env:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_SCHEMA: xorm
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
run: TEST_QUOTE_POLICY=reserved make test-postgres
- name: test pgx
env:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
run: TEST_CACHE_ENABLE=true make test-pgx
- name: test pgx with schema
env:
TEST_PGSQL_HOST: pgsql
TEST_PGSQL_SCHEMA: xorm
TEST_PGSQL_DBNAME: xorm_test
TEST_PGSQL_USERNAME: postgres
TEST_PGSQL_PASSWORD: postgres
run: TEST_QUOTE_POLICY=reserved make test-pgx
services:
pgsql:
image: postgres:9.5
env:
POSTGRES_DB: xorm_test
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
ports:
- 5432:5432

View File

@ -0,0 +1,49 @@
name: test sqlite
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
test-sqlite:
name: unit test & test sqlite
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: vet
run: make vet
- name: format check
run: make fmt-check
- name: unit test
run: make test
- name: test sqlite3
run: make test-sqlite3
- name: test sqlite3 with cache
run: TEST_CACHE_ENABLE=true make test-sqlite3

View File

@ -0,0 +1,52 @@
name: test tidb
on:
push:
branches:
- master
pull_request:
env:
GOPROXY: https://goproxy.io,direct
GOPATH: /go_path
GOCACHE: /go_cache
jobs:
test-tidb:
name: test tidb
runs-on: ubuntu-latest
steps:
# - name: cache go path
# id: cache-go-path
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_path
# key: go_path-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_path-${{ github.repository }}-
# go_path-
# - name: cache go cache
# id: cache-go-cache
# uses: https://github.com/actions/cache@v3
# with:
# path: /go_cache
# key: go_cache-${{ github.repository }}-${{ github.ref_name }}
# restore-keys: |
# go_cache-${{ github.repository }}-
# go_cache-
- uses: actions/setup-go@v3
with:
go-version: 1.20
- uses: https://github.com/actions/checkout@v3
- name: test tidb
env:
TEST_TIDB_HOST: "tidb:4000"
TEST_TIDB_DBNAME: xorm_test
TEST_TIDB_USERNAME: root
TEST_TIDB_PASSWORD:
run: make test-tidb
services:
tidb:
image: pingcap/tidb:v3.0.3
ports:
- 4000:4000

View File

@ -3,6 +3,29 @@
This changelog goes through all the changes that have been made in each release This changelog goes through all the changes that have been made in each release
without substantial changes to our git log. without substantial changes to our git log.
## [1.3.2](https://gitea.com/xorm/xorm/releases/tag/1.3.2) - 2022-09-03
* BUGFIXES
* Change schemas.Column to use int64 (#2160)
* MISC
* Prevent Sync failure with non-regular indexes on Postgres (#2174)
## [1.3.1](https://gitea.com/xorm/xorm/releases/tag/1.3.1) - 2022-06-03
* BREAKING
* Refactor orderby and support arguments (#2150)
* return a clear error for set TEXT type as compare condition (#2062)
* BUGFIXES
* Fix oid index for postgres (#2154)
* Add ORDER BY SEQ_IN_INDEX to MySQL GetIndexes to Fix IndexTests (#2152)
* some improvement (#2136)
* ENHANCEMENTS
* Add interface to allow structs to provide specific index information (#2137)
* MySQL/MariaDB: return max length for text columns (#2133)
* PostgreSQL: enable comment on column (#2131)
* TESTING
* Add test for find date (#2121)
## [1.3.0](https://gitea.com/xorm/xorm/releases/tag/1.3.0) - 2022-04-14 ## [1.3.0](https://gitea.com/xorm/xorm/releases/tag/1.3.0) - 2022-04-14
* BREAKING * BREAKING

View File

@ -53,6 +53,7 @@ Drivers for Go's sql package which currently support database/sql includes:
* Oracle * Oracle
- [github.com/godror/godror](https://github.com/godror/godror) (experiment) - [github.com/godror/godror](https://github.com/godror/godror) (experiment)
- [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (experiment) - [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (experiment)
- [github.com/sijms/go-ora](https://github.com/sijms/go-ora) (experiment)
## Installation ## Installation

View File

@ -40,6 +40,13 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t
} }
dt = dt.In(convertedLocation) dt = dt.In(convertedLocation)
return &dt, nil return &dt, nil
} else if len(s) >= 21 && s[10] == 'T' && s[19] == '.' {
dt, err := time.Parse(time.RFC3339, s)
if err != nil {
return nil, err
}
dt = dt.In(convertedLocation)
return &dt, nil
} else if len(s) >= 21 && s[19] == '.' { } else if len(s) >= 21 && s[19] == '.' {
var layout = "2006-01-02 15:04:05." + strings.Repeat("0", len(s)-20) var layout = "2006-01-02 15:04:05." + strings.Repeat("0", len(s)-20)
dt, err := time.ParseInLocation(layout, s, originalLocation) dt, err := time.ParseInLocation(layout, s, originalLocation)

View File

@ -16,10 +16,19 @@ func TestString2Time(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
var kases = map[string]time.Time{ var kases = map[string]time.Time{
"2021-08-10": time.Date(2021, 8, 10, 8, 0, 0, 0, expectedLoc), "2021-08-10": time.Date(2021, 8, 10, 8, 0, 0, 0, expectedLoc),
"2021-06-06T22:58:20+08:00": time.Date(2021, 6, 6, 22, 58, 20, 0, expectedLoc), "2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc),
"2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc), "2021-07-11 10:44:00.999": time.Date(2021, 7, 11, 18, 44, 0, 999000000, expectedLoc),
"2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 04, 0, expectedLoc), "2021-07-11 10:44:00.999999": time.Date(2021, 7, 11, 18, 44, 0, 999999000, expectedLoc),
"2021-07-11 10:44:00.999999999": time.Date(2021, 7, 11, 18, 44, 0, 999999999, expectedLoc),
"2021-06-06T22:58:20+08:00": time.Date(2021, 6, 6, 22, 58, 20, 0, expectedLoc),
"2021-06-06T22:58:20.999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999000000, expectedLoc),
"2021-06-06T22:58:20.999999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999999000, expectedLoc),
"2021-06-06T22:58:20.999999999+08:00": time.Date(2021, 6, 6, 22, 58, 20, 999999999, expectedLoc),
"2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 04, 0, expectedLoc),
"2021-08-10T10:33:04.999Z": time.Date(2021, 8, 10, 18, 33, 04, 999000000, expectedLoc),
"2021-08-10T10:33:04.999999Z": time.Date(2021, 8, 10, 18, 33, 04, 999999000, expectedLoc),
"2021-08-10T10:33:04.999999999Z": time.Date(2021, 8, 10, 18, 33, 04, 999999999, expectedLoc),
} }
for layout, tm := range kases { for layout, tm := range kases {
t.Run(layout, func(t *testing.T) { t.Run(layout, func(t *testing.T) {

View File

@ -622,9 +622,9 @@ func (db *dameng) SQLType(c *schemas.Column) string {
hasLen2 := (c.Length2 > 0) hasLen2 := (c.Length2 > 0)
if hasLen2 { if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")"
} else if hasLen1 { } else if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + ")"
} }
return res return res
} }
@ -709,7 +709,13 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
return "", false, err return "", false, err
} }
} }
if _, err := b.WriteString(fmt.Sprintf("CONSTRAINT PK_%s PRIMARY KEY (", tableName)); err != nil { if _, err := b.WriteString("CONSTRAINT PK_"); err != nil {
return "", false, err
}
if _, err := b.WriteString(tableName); err != nil {
return "", false, err
}
if _, err := b.WriteString(" PRIMARY KEY ("); err != nil {
return "", false, err return "", false, err
} }
if err := quoter.JoinWrite(&b, pkList, ","); err != nil { if err := quoter.JoinWrite(&b, pkList, ","); err != nil {
@ -729,11 +735,11 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
func (db *dameng) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *dameng) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy { switch quotePolicy {
case QuotePolicyNone: case QuotePolicyNone:
var q = damengQuoter q := damengQuoter
q.IsReserved = schemas.AlwaysNoReserve q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q db.quoter = q
case QuotePolicyReserved: case QuotePolicyReserved:
var q = damengQuoter q := damengQuoter
q.IsReserved = db.IsReserved q.IsReserved = db.IsReserved
db.quoter = q db.quoter = q
case QuotePolicyAlways: case QuotePolicyAlways:
@ -792,7 +798,7 @@ type dmClobObject interface {
ReadString(int, int) (string, error) ReadString(int, int) (string, error)
} }
//var _ dmClobObject = &dm.DmClob{} // var _ dmClobObject = &dm.DmClob{}
func (d *dmClobScanner) Scan(data interface{}) error { func (d *dmClobScanner) Scan(data interface{}) error {
if data == nil { if data == nil {
@ -837,7 +843,11 @@ func addSingleQuote(name string) string {
if name[0] == '\'' && name[len(name)-1] == '\'' { if name[0] == '\'' && name[len(name)-1] == '\'' {
return name return name
} }
return fmt.Sprintf("'%s'", name) var b strings.Builder
b.WriteRune('\'')
b.WriteString(name)
b.WriteRune('\'')
return b.String()
} }
func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
@ -927,7 +937,7 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
var ( var (
ignore bool ignore bool
dt string dt string
len1, len2 int len1, len2 int64
) )
dts := strings.Split(dataType.String, "(") dts := strings.Split(dataType.String, "(")
@ -935,10 +945,10 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
if len(dts) > 1 { if len(dts) > 1 {
lens := strings.Split(dts[1][:len(dts[1])-1], ",") lens := strings.Split(dts[1][:len(dts[1])-1], ",")
if len(lens) > 1 { if len(lens) > 1 {
len1, _ = strconv.Atoi(lens[0]) len1, _ = strconv.ParseInt(lens[0], 10, 64)
len2, _ = strconv.Atoi(lens[1]) len2, _ = strconv.ParseInt(lens[1], 10, 64)
} else { } else {
len1, _ = strconv.Atoi(lens[0]) len1, _ = strconv.ParseInt(lens[0], 10, 64)
} }
} }
@ -972,9 +982,9 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
} }
if col.SQLType.Name == "TIMESTAMP" { if col.SQLType.Name == "TIMESTAMP" {
col.Length = int(dataScale.Int64) col.Length = dataScale.Int64
} else { } else {
col.Length = int(dataLen.Int64) col.Length = dataLen.Int64
} }
if col.SQLType.IsTime() { if col.SQLType.IsTime() {
@ -1140,8 +1150,8 @@ func (d *damengDriver) GenScanResult(colType string) (interface{}, error) {
} }
func (d *damengDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, vv ...interface{}) error { func (d *damengDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, vv ...interface{}) error {
var scanResults = make([]interface{}, 0, len(types)) scanResults := make([]interface{}, 0, len(types))
var replaces = make([]bool, 0, len(types)) replaces := make([]bool, 0, len(types))
var err error var err error
for i, v := range vv { for i, v := range vv {
var replaced bool var replaced bool

View File

@ -290,6 +290,7 @@ func regDrvsNDialects() bool {
"sqlite": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }}, "sqlite": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
"oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }}, "oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }},
"godror": {"oracle", func() Driver { return &godrorDriver{} }, func() Dialect { return &oracle{} }}, "godror": {"oracle", func() Driver { return &godrorDriver{} }, func() Dialect { return &oracle{} }},
"oracle": {"oracle", func() Driver { return &oracleDriver{} }, func() Dialect { return &oracle{} }},
} }
for driverName, v := range providedDrvsNDialects { for driverName, v := range providedDrvsNDialects {

View File

@ -5,13 +5,14 @@
package dialects package dialects
import ( import (
"fmt" "context"
"strconv"
"strings" "strings"
) )
// Filter is an interface to filter SQL // Filter is an interface to filter SQL
type Filter interface { type Filter interface {
Do(sql string) string Do(ctx context.Context, sql string) string
} }
// SeqFilter filter SQL replace ?, ? ... to $1, $2 ... // SeqFilter filter SQL replace ?, ? ... to $1, $2 ...
@ -28,10 +29,11 @@ func convertQuestionMark(sql, prefix string, start int) string {
var isMaybeLineComment bool var isMaybeLineComment bool
var isMaybeComment bool var isMaybeComment bool
var isMaybeCommentEnd bool var isMaybeCommentEnd bool
var index = start index := start
for _, c := range sql { for _, c := range sql {
if !beginSingleQuote && !isLineComment && !isComment && c == '?' { if !beginSingleQuote && !isLineComment && !isComment && c == '?' {
buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) buf.WriteString(prefix)
buf.WriteString(strconv.Itoa(index))
index++ index++
} else { } else {
if isMaybeLineComment { if isMaybeLineComment {
@ -71,6 +73,6 @@ func convertQuestionMark(sql, prefix string, start int) string {
} }
// Do implements Filter // Do implements Filter
func (s *SeqFilter) Do(sql string) string { func (s *SeqFilter) Do(ctx context.Context, sql string) string {
return convertQuestionMark(sql, s.Prefix, s.Start) return convertQuestionMark(sql, s.Prefix, s.Start)
} }

View File

@ -229,7 +229,7 @@ func (db *mssql) Init(uri *URI) error {
func (db *mssql) SetParams(params map[string]string) { func (db *mssql) SetParams(params map[string]string) {
defaultVarchar, ok := params["DEFAULT_VARCHAR"] defaultVarchar, ok := params["DEFAULT_VARCHAR"]
if ok { if ok {
var t = strings.ToUpper(defaultVarchar) t := strings.ToUpper(defaultVarchar)
switch t { switch t {
case "NVARCHAR", "VARCHAR": case "NVARCHAR", "VARCHAR":
db.defaultVarchar = t db.defaultVarchar = t
@ -242,7 +242,7 @@ func (db *mssql) SetParams(params map[string]string) {
defaultChar, ok := params["DEFAULT_CHAR"] defaultChar, ok := params["DEFAULT_CHAR"]
if ok { if ok {
var t = strings.ToUpper(defaultChar) t := strings.ToUpper(defaultChar)
switch t { switch t {
case "NCHAR", "CHAR": case "NCHAR", "CHAR":
db.defaultChar = t db.defaultChar = t
@ -375,9 +375,9 @@ func (db *mssql) SQLType(c *schemas.Column) string {
hasLen2 := (c.Length2 > 0) hasLen2 := (c.Length2 > 0)
if hasLen2 { if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")"
} else if hasLen1 { } else if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + ")"
} }
return res return res
} }
@ -403,11 +403,11 @@ func (db *mssql) IsReserved(name string) bool {
func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy { switch quotePolicy {
case QuotePolicyNone: case QuotePolicyNone:
var q = mssqlQuoter q := mssqlQuoter
q.IsReserved = schemas.AlwaysNoReserve q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q db.quoter = q
case QuotePolicyReserved: case QuotePolicyReserved:
var q = mssqlQuoter q := mssqlQuoter
q.IsReserved = db.IsReserved q.IsReserved = db.IsReserved
db.quoter = q db.quoter = q
case QuotePolicyAlways: case QuotePolicyAlways:
@ -475,7 +475,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
var name, ctype, vdefault string var name, ctype, vdefault string
var maxLen, precision, scale int var maxLen, precision, scale int64
var nullable, isPK, defaultIsNull, isIncrement bool var nullable, isPK, defaultIsNull, isIncrement bool
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement) err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement)
if err != nil { if err != nil {

View File

@ -330,9 +330,9 @@ func (db *mysql) SQLType(c *schemas.Column) string {
} }
if hasLen2 { if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")"
} else if hasLen1 { } else if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + ")"
} }
if isUnsigned { if isUnsigned {
@ -381,11 +381,17 @@ func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableNa
func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
quoter := db.dialect.Quoter() quoter := db.dialect.Quoter()
s, _ := ColumnString(db, col, true) s, _ := ColumnString(db, col, true)
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), s) var b strings.Builder
b.WriteString("ALTER TABLE ")
quoter.QuoteTo(&b, tableName)
b.WriteString(" ADD ")
b.WriteString(s)
if len(col.Comment) > 0 { if len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'" b.WriteString(" COMMENT '")
b.WriteString(col.Comment)
b.WriteString("'")
} }
return sql return b.String()
} }
func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
@ -400,7 +406,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
" `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " + " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " +
alreadyQuoted + " AS NEEDS_QUOTE " + alreadyQuoted + " AS NEEDS_QUOTE " +
"FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" +
" ORDER BY `COLUMNS`.ORDINAL_POSITION" " ORDER BY `COLUMNS`.ORDINAL_POSITION ASC"
rows, err := queryer.QueryContext(ctx, s, args...) rows, err := queryer.QueryContext(ctx, s, args...)
if err != nil { if err != nil {
@ -444,7 +450,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
// Remove the /* mariadb-5.3 */ suffix from coltypes // Remove the /* mariadb-5.3 */ suffix from coltypes
colName = strings.TrimSuffix(colName, "/* mariadb-5.3 */") colName = strings.TrimSuffix(colName, "/* mariadb-5.3 */")
colType = strings.ToUpper(colName) colType = strings.ToUpper(colName)
var len1, len2 int var len1, len2 int64
if len(cts) == 2 { if len(cts) == 2 {
idx := strings.Index(cts[1], ")") idx := strings.Index(cts[1], ")")
if colType == schemas.Enum && cts[1][0] == '\'' { // enum if colType == schemas.Enum && cts[1][0] == '\'' { // enum
@ -465,12 +471,12 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} }
} else { } else {
lens := strings.Split(cts[1][0:idx], ",") lens := strings.Split(cts[1][0:idx], ",")
len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) len1, err = strconv.ParseInt(strings.TrimSpace(lens[0]), 10, 64)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if len(lens) == 2 { if len(lens) == 2 {
len2, err = strconv.Atoi(lens[1]) len2, err = strconv.ParseInt(lens[1], 10, 64)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -479,7 +485,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} else { } else {
switch colType { switch colType {
case "MEDIUMTEXT", "LONGTEXT", "TEXT": case "MEDIUMTEXT", "LONGTEXT", "TEXT":
len1, err = strconv.Atoi(*maxLength) len1, err = strconv.ParseInt(*maxLength, 10, 64)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -738,8 +744,9 @@ func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
} }
func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) { func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) {
colType = strings.Replace(colType, "UNSIGNED ", "", -1)
switch colType { switch colType {
case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET": case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET", "JSON":
var s sql.NullString var s sql.NullString
return &s, nil return &s, nil
case "BIGINT": case "BIGINT":

View File

@ -548,7 +548,14 @@ func (db *oracle) Features() *DialectFeatures {
func (db *oracle) SQLType(c *schemas.Column) string { func (db *oracle) SQLType(c *schemas.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt, schemas.Bool, schemas.Serial, schemas.BigSerial: case schemas.Bool:
if c.Default == "true" {
c.Default = "1"
} else if c.Default == "false" {
c.Default = "0"
}
res = "NUMBER(1,0)"
case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt, schemas.Serial, schemas.BigSerial:
res = "NUMBER" res = "NUMBER"
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
return schemas.Blob return schemas.Blob
@ -570,9 +577,9 @@ func (db *oracle) SQLType(c *schemas.Column) string {
hasLen2 := (c.Length2 > 0) hasLen2 := (c.Length2 > 0)
if hasLen2 { if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")"
} else if hasLen1 { } else if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + ")"
} }
return res return res
} }
@ -602,11 +609,11 @@ func (db *oracle) IsReserved(name string) bool {
} }
func (db *oracle) DropTableSQL(tableName string) (string, bool) { func (db *oracle) DropTableSQL(tableName string) (string, bool) {
return fmt.Sprintf("DROP TABLE `%s`", tableName), false return fmt.Sprintf("DROP TABLE \"%s\"", tableName), false
} }
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
var sql = "CREATE TABLE " sql := "CREATE TABLE "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
} }
@ -638,14 +645,18 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
return sql, false, nil return sql, false, nil
} }
func (db *oracle) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) {
return db.HasRecords(queryer, ctx, `SELECT sequence_name FROM user_sequences WHERE sequence_name = :1`, seqName)
}
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy { switch quotePolicy {
case QuotePolicyNone: case QuotePolicyNone:
var q = oracleQuoter q := oracleQuoter
q.IsReserved = schemas.AlwaysNoReserve q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q db.quoter = q
case QuotePolicyReserved: case QuotePolicyReserved:
var q = oracleQuoter q := oracleQuoter
q.IsReserved = db.IsReserved q.IsReserved = db.IsReserved
db.quoter = q db.quoter = q
case QuotePolicyAlways: case QuotePolicyAlways:
@ -690,7 +701,7 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string
var dataLen int var dataLen int64
err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision, err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision,
&dataScale, &nullable) &dataScale, &nullable)
@ -713,16 +724,16 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
var ignore bool var ignore bool
var dt string var dt string
var len1, len2 int var len1, len2 int64
dts := strings.Split(*dataType, "(") dts := strings.Split(*dataType, "(")
dt = dts[0] dt = dts[0]
if len(dts) > 1 { if len(dts) > 1 {
lens := strings.Split(dts[1][:len(dts[1])-1], ",") lens := strings.Split(dts[1][:len(dts[1])-1], ",")
if len(lens) > 1 { if len(lens) > 1 {
len1, _ = strconv.Atoi(lens[0]) len1, _ = strconv.ParseInt(lens[0], 10, 64)
len2, _ = strconv.Atoi(lens[1]) len2, _ = strconv.ParseInt(lens[1], 10, 64)
} else { } else {
len1, _ = strconv.Atoi(lens[0]) len1, _ = strconv.ParseInt(lens[0], 10, 64)
} }
} }
@ -932,3 +943,7 @@ func (o *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) {
} }
return db, nil return db, nil
} }
type oracleDriver struct {
godrorDriver
}

View File

@ -862,11 +862,11 @@ func (db *postgres) needQuote(name string) bool {
func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy { switch quotePolicy {
case QuotePolicyNone: case QuotePolicyNone:
var q = postgresQuoter q := postgresQuoter
q.IsReserved = schemas.AlwaysNoReserve q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q db.quoter = q
case QuotePolicyReserved: case QuotePolicyReserved:
var q = postgresQuoter q := postgresQuoter
q.IsReserved = db.needQuote q.IsReserved = db.needQuote
db.quoter = q db.quoter = q
case QuotePolicyAlways: case QuotePolicyAlways:
@ -934,9 +934,9 @@ func (db *postgres) SQLType(c *schemas.Column) string {
hasLen2 := (c.Length2 > 0) hasLen2 := (c.Length2 > 0)
if hasLen2 { if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")"
} else if hasLen1 { } else if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + ")"
} }
return res return res
} }
@ -1030,11 +1030,10 @@ func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string
tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".") tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".")
tableName = tableParts[len(tableParts)-1] tableName = tableParts[len(tableParts)-1]
if !strings.HasPrefix(idxName, "UQE_") && if index.IsRegular {
!strings.HasPrefix(idxName, "IDX_") { if index.Type == schemas.UniqueType && !strings.HasPrefix(idxName, "UQE_") {
if index.Type == schemas.UniqueType {
idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name) idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
} else { } else if index.Type == schemas.IndexType && !strings.HasPrefix(idxName, "IDX_") {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
} }
} }
@ -1110,9 +1109,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
return nil, nil, err return nil, nil, err
} }
var maxLen int var maxLen int64
if maxLenStr != nil { if maxLenStr != nil {
maxLen, err = strconv.Atoi(*maxLenStr) maxLen, err = strconv.ParseInt(*maxLenStr, 10, 64)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -1125,7 +1124,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
col.Name = strings.Trim(colName, `" `) col.Name = strings.Trim(colName, `" `)
if colDefault != nil { if colDefault != nil {
var theDefault = *colDefault theDefault := *colDefault
// cockroach has type with the default value with ::: // cockroach has type with the default value with :::
// and postgres with ::, we should remove them before store them // and postgres with ::, we should remove them before store them
idx := strings.Index(theDefault, ":::") idx := strings.Index(theDefault, ":::")
@ -1186,7 +1185,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
startIdx := strings.Index(strings.ToLower(dataType), "string(") startIdx := strings.Index(strings.ToLower(dataType), "string(")
if startIdx != -1 && strings.HasSuffix(dataType, ")") { if startIdx != -1 && strings.HasSuffix(dataType, ")") {
length := dataType[startIdx+8 : len(dataType)-1] length := dataType[startIdx+8 : len(dataType)-1]
l, _ := strconv.Atoi(length) l, _ := strconv.ParseInt(length, 10, 64)
col.SQLType = schemas.SQLType{Name: "STRING", DefaultLength: l, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: "STRING", DefaultLength: l, DefaultLength2: 0}
} else { } else {
col.SQLType = schemas.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0}
@ -1301,15 +1300,8 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN
} }
colNames = getIndexColName(indexdef) colNames = getIndexColName(indexdef)
isSkip := false // Oid It's a special index. You can't put it in. TODO: This is not perfect.
//Oid It's a special index. You can't put it in if indexName == tableName+"_oid_index" && len(colNames) == 1 && colNames[0] == "oid" {
for _, element := range colNames {
if "oid" == element {
isSkip = true
break
}
}
if isSkip {
continue continue
} }
@ -1351,14 +1343,14 @@ func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, ta
commentSQL := "; " commentSQL := "; "
if table.Comment != "" { if table.Comment != "" {
// support schema.table -> "schema"."table" // support schema.table -> "schema"."table"
commentSQL += fmt.Sprintf("COMMENT ON TABLE %s IS '%s'", quoter.Quote(tableName), table.Comment) commentSQL += fmt.Sprintf("COMMENT ON TABLE %s IS '%s'; ", quoter.Quote(tableName), table.Comment)
} }
for _, colName := range table.ColumnsSeq() { for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
if len(col.Comment) > 0 { if len(col.Comment) > 0 {
commentSQL += fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoter.Quote(tableName), quoter.Quote(col.Name), col.Comment) commentSQL += fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'; ", quoter.Quote(tableName), quoter.Quote(col.Name), col.Comment)
} }
} }

View File

@ -23,7 +23,7 @@ func FormatColumnTime(dialect Dialect, dbLocation *time.Location, col *schemas.C
} }
} }
var tmZone = dbLocation tmZone := dbLocation
if col.TimeZone != nil { if col.TimeZone != nil {
tmZone = col.TimeZone tmZone = col.TimeZone
} }
@ -34,15 +34,17 @@ func FormatColumnTime(dialect Dialect, dbLocation *time.Location, col *schemas.C
case schemas.Date: case schemas.Date:
return t.Format("2006-01-02"), nil return t.Format("2006-01-02"), nil
case schemas.Time: case schemas.Time:
var layout = "15:04:05" layout := "15:04:05"
if col.Length > 0 { if col.Length > 0 {
layout += "." + strings.Repeat("0", col.Length) // we can use int(...) casting here as it's very unlikely to a huge sized field
layout += "." + strings.Repeat("0", int(col.Length))
} }
return t.Format(layout), nil return t.Format(layout), nil
case schemas.DateTime, schemas.TimeStamp: case schemas.DateTime, schemas.TimeStamp:
var layout = "2006-01-02 15:04:05" layout := "2006-01-02 15:04:05"
if col.Length > 0 { if col.Length > 0 {
layout += "." + strings.Repeat("0", col.Length) // we can use int(...) casting here as it's very unlikely to a huge sized field
layout += "." + strings.Repeat("0", int(col.Length))
} }
return t.Format(layout), nil return t.Format(layout), nil
case schemas.Varchar: case schemas.Varchar:

251
doc.go
View File

@ -3,247 +3,246 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
/* /*
Package xorm is a simple and powerful ORM for Go. Package xorm is a simple and powerful ORM for Go.
Installation # Installation
Make sure you have installed Go 1.11+ and then: Make sure you have installed Go 1.11+ and then:
go get xorm.io/xorm go get xorm.io/xorm
Create Engine # Create Engine
Firstly, we should create an engine for a database Firstly, we should create an engine for a database
engine, err := xorm.NewEngine(driverName, dataSourceName) engine, err := xorm.NewEngine(driverName, dataSourceName)
Method NewEngine's parameters are the same as sql.Open which depend drivers' implementation. Method NewEngine's parameters are the same as sql.Open which depend drivers' implementation.
Generally, one engine for an application is enough. You can define it as a package variable. Generally, one engine for an application is enough. You can define it as a package variable.
Raw Methods # Raw Methods
XORM supports raw SQL execution: XORM supports raw SQL execution:
1. query with a SQL string, the returned results is []map[string][]byte 1. query with a SQL string, the returned results is []map[string][]byte
results, err := engine.Query("select * from user") results, err := engine.Query("select * from user")
2. query with a SQL string, the returned results is []map[string]string 2. query with a SQL string, the returned results is []map[string]string
results, err := engine.QueryString("select * from user") results, err := engine.QueryString("select * from user")
3. query with a SQL string, the returned results is []map[string]interface{} 3. query with a SQL string, the returned results is []map[string]interface{}
results, err := engine.QueryInterface("select * from user") results, err := engine.QueryInterface("select * from user")
4. execute with a SQL string, the returned results 4. execute with a SQL string, the returned results
affected, err := engine.Exec("update user set .... where ...") affected, err := engine.Exec("update user set .... where ...")
ORM Methods # ORM Methods
There are 8 major ORM methods and many helpful methods to use to operate database. There are 8 major ORM methods and many helpful methods to use to operate database.
1. Insert one or multiple records to database 1. Insert one or multiple records to database
affected, err := engine.Insert(&struct) affected, err := engine.Insert(&struct)
// INSERT INTO struct () values () // INSERT INTO struct () values ()
affected, err := engine.Insert(&struct1, &struct2) affected, err := engine.Insert(&struct1, &struct2)
// INSERT INTO struct1 () values () // INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values () // INSERT INTO struct2 () values ()
affected, err := engine.Insert(&sliceOfStruct) affected, err := engine.Insert(&sliceOfStruct)
// INSERT INTO struct () values (),(),() // INSERT INTO struct () values (),(),()
affected, err := engine.Insert(&struct1, &sliceOfStruct2) affected, err := engine.Insert(&struct1, &sliceOfStruct2)
// INSERT INTO struct1 () values () // INSERT INTO struct1 () values ()
// INSERT INTO struct2 () values (),(),() // INSERT INTO struct2 () values (),(),()
2. Query one record or one variable from database 2. Query one record or one variable from database
has, err := engine.Get(&user) has, err := engine.Get(&user)
// SELECT * FROM user LIMIT 1 // SELECT * FROM user LIMIT 1
var id int64 var id int64
has, err := engine.Table("user").Where("name = ?", name).Get(&id) has, err := engine.Table("user").Where("name = ?", name).Get(&id)
// SELECT id FROM user WHERE name = ? LIMIT 1 // SELECT id FROM user WHERE name = ? LIMIT 1
var id int64 var id int64
var name string var name string
has, err := engine.Table(&user).Cols("id", "name").Get(&id, &name) has, err := engine.Table(&user).Cols("id", "name").Get(&id, &name)
// SELECT id, name FROM user LIMIT 1 // SELECT id, name FROM user LIMIT 1
3. Query multiple records from database 3. Query multiple records from database
var sliceOfStructs []Struct var sliceOfStructs []Struct
err := engine.Find(&sliceOfStructs) err := engine.Find(&sliceOfStructs)
// SELECT * FROM user // SELECT * FROM user
var mapOfStructs = make(map[int64]Struct) var mapOfStructs = make(map[int64]Struct)
err := engine.Find(&mapOfStructs) err := engine.Find(&mapOfStructs)
// SELECT * FROM user // SELECT * FROM user
var int64s []int64 var int64s []int64
err := engine.Table("user").Cols("id").Find(&int64s) err := engine.Table("user").Cols("id").Find(&int64s)
// SELECT id FROM user // SELECT id FROM user
4. Query multiple records and record by record handle, there two methods, one is Iterate, 4. Query multiple records and record by record handle, there two methods, one is Iterate,
another is Rows another is Rows
err := engine.Iterate(new(User), func(i int, bean interface{}) error { err := engine.Iterate(new(User), func(i int, bean interface{}) error {
// do something // do something
}) })
// SELECT * FROM user // SELECT * FROM user
rows, err := engine.Rows(...) rows, err := engine.Rows(...)
// SELECT * FROM user // SELECT * FROM user
defer rows.Close() defer rows.Close()
bean := new(Struct) bean := new(Struct)
for rows.Next() { for rows.Next() {
err = rows.Scan(bean) err = rows.Scan(bean)
} }
or or
rows, err := engine.Cols("name", "age").Rows(...) rows, err := engine.Cols("name", "age").Rows(...)
// SELECT * FROM user // SELECT * FROM user
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var name string var name string
var age int var age int
err = rows.Scan(&name, &age) err = rows.Scan(&name, &age)
} }
5. Update one or more records 5. Update one or more records
affected, err := engine.ID(...).Update(&user) affected, err := engine.ID(...).Update(&user)
// UPDATE user SET ... // UPDATE user SET ...
6. Delete one or more records, Delete MUST has condition 6. Delete one or more records, Delete MUST has condition
affected, err := engine.Where(...).Delete(&user) affected, err := engine.Where(...).Delete(&user)
// DELETE FROM user Where ... // DELETE FROM user Where ...
7. Count records 7. Count records
counts, err := engine.Count(&user) counts, err := engine.Count(&user)
// SELECT count(*) AS total FROM user // SELECT count(*) AS total FROM user
counts, err := engine.SQL("select count(*) FROM user").Count() counts, err := engine.SQL("select count(*) FROM user").Count()
// select count(*) FROM user // select count(*) FROM user
8. Sum records 8. Sum records
sumFloat64, err := engine.Sum(&user, "id") sumFloat64, err := engine.Sum(&user, "id")
// SELECT sum(id) from user // SELECT sum(id) from user
sumFloat64s, err := engine.Sums(&user, "id1", "id2") sumFloat64s, err := engine.Sums(&user, "id1", "id2")
// SELECT sum(id1), sum(id2) from user // SELECT sum(id1), sum(id2) from user
sumInt64s, err := engine.SumsInt(&user, "id1", "id2") sumInt64s, err := engine.SumsInt(&user, "id1", "id2")
// SELECT sum(id1), sum(id2) from user // SELECT sum(id1), sum(id2) from user
Conditions # Conditions
The above 8 methods could use with condition methods chainable. The above 8 methods could use with condition methods chainable.
Notice: the above 8 methods should be the last chainable method. Notice: the above 8 methods should be the last chainable method.
1. ID, In 1. ID, In
engine.ID(1).Get(&user) // for single primary key engine.ID(1).Get(&user) // for single primary key
// SELECT * FROM user WHERE id = 1 // SELECT * FROM user WHERE id = 1
engine.ID(schemas.PK{1, 2}).Get(&user) // for composite primary keys engine.ID(schemas.PK{1, 2}).Get(&user) // for composite primary keys
// SELECT * FROM user WHERE id1 = 1 AND id2 = 2 // SELECT * FROM user WHERE id1 = 1 AND id2 = 2
engine.In("id", 1, 2, 3).Find(&users) engine.In("id", 1, 2, 3).Find(&users)
// SELECT * FROM user WHERE id IN (1, 2, 3) // SELECT * FROM user WHERE id IN (1, 2, 3)
engine.In("id", []int{1, 2, 3}).Find(&users) engine.In("id", []int{1, 2, 3}).Find(&users)
// SELECT * FROM user WHERE id IN (1, 2, 3) // SELECT * FROM user WHERE id IN (1, 2, 3)
2. Where, And, Or 2. Where, And, Or
engine.Where().And().Or().Find() engine.Where().And().Or().Find()
// SELECT * FROM user WHERE (.. AND ..) OR ... // SELECT * FROM user WHERE (.. AND ..) OR ...
3. OrderBy, Asc, Desc 3. OrderBy, Asc, Desc
engine.Asc().Desc().Find() engine.Asc().Desc().Find()
// SELECT * FROM user ORDER BY .. ASC, .. DESC // SELECT * FROM user ORDER BY .. ASC, .. DESC
engine.OrderBy().Find() engine.OrderBy().Find()
// SELECT * FROM user ORDER BY .. // SELECT * FROM user ORDER BY ..
4. Limit, Top 4. Limit, Top
engine.Limit().Find() engine.Limit().Find()
// SELECT * FROM user LIMIT .. OFFSET .. // SELECT * FROM user LIMIT .. OFFSET ..
engine.Top(5).Find() engine.Top(5).Find()
// SELECT TOP 5 * FROM user // for mssql // SELECT TOP 5 * FROM user // for mssql
// SELECT * FROM user LIMIT .. OFFSET 0 //for other databases // SELECT * FROM user LIMIT .. OFFSET 0 //for other databases
5. SQL, let you custom SQL 5. SQL, let you custom SQL
var users []User var users []User
engine.SQL("select * from user").Find(&users) engine.SQL("select * from user").Find(&users)
6. Cols, Omit, Distinct 6. Cols, Omit, Distinct
var users []*User var users []*User
engine.Cols("col1, col2").Find(&users) engine.Cols("col1, col2").Find(&users)
// SELECT col1, col2 FROM user // SELECT col1, col2 FROM user
engine.Cols("col1", "col2").Where().Update(user) engine.Cols("col1", "col2").Where().Update(user)
// UPDATE user set col1 = ?, col2 = ? Where ... // UPDATE user set col1 = ?, col2 = ? Where ...
engine.Omit("col1").Find(&users) engine.Omit("col1").Find(&users)
// SELECT col2, col3 FROM user // SELECT col2, col3 FROM user
engine.Omit("col1").Insert(&user) engine.Omit("col1").Insert(&user)
// INSERT INTO table (non-col1) VALUES () // INSERT INTO table (non-col1) VALUES ()
engine.Distinct("col1").Find(&users) engine.Distinct("col1").Find(&users)
// SELECT DISTINCT col1 FROM user // SELECT DISTINCT col1 FROM user
7. Join, GroupBy, Having 7. Join, GroupBy, Having
engine.GroupBy("name").Having("name='xlw'").Find(&users) engine.GroupBy("name").Having("name='xlw'").Find(&users)
//SELECT * FROM user GROUP BY name HAVING name='xlw' //SELECT * FROM user GROUP BY name HAVING name='xlw'
engine.Join("LEFT", "userdetail", "user.id=userdetail.id").Find(&users) engine.Join("LEFT", "userdetail", "user.id=userdetail.id").Find(&users)
//SELECT * FROM user LEFT JOIN userdetail ON user.id=userdetail.id //SELECT * FROM user LEFT JOIN userdetail ON user.id=userdetail.id
Builder # Builder
xorm could work with xorm.io/builder directly. xorm could work with xorm.io/builder directly.
1. With Where 1. With Where
var cond = builder.Eq{"a":1, "b":2} var cond = builder.Eq{"a":1, "b":2}
engine.Where(cond).Find(&users) engine.Where(cond).Find(&users)
2. With In 2. With In
var subQuery = builder.Select("name").From("group") var subQuery = builder.Select("name").From("group")
engine.In("group_name", subQuery).Find(&users) engine.In("group_name", subQuery).Find(&users)
3. With Join 3. With Join
var subQuery = builder.Select("name").From("group") var subQuery = builder.Select("name").From("group")
engine.Join("INNER", subQuery, "group.id = user.group_id").Find(&users) engine.Join("INNER", subQuery, "group.id = user.group_id").Find(&users)
4. With SetExprs 4. With SetExprs
var subQuery = builder.Select("name").From("group") var subQuery = builder.Select("name").From("group")
engine.ID(1).SetExprs("name", subQuery).Update(new(User)) engine.ID(1).SetExprs("name", subQuery).Update(new(User))
5. With SQL 5. With SQL
var query = builder.Select("name").From("group") var query = builder.Select("name").From("group")
results, err := engine.SQL(query).Find(&groups) results, err := engine.SQL(query).Find(&groups)
6. With Query 6. With Query
var query = builder.Select("name").From("group") var query = builder.Select("name").From("group")
results, err := engine.Query(query) results, err := engine.Query(query)
results, err := engine.QueryString(query) results, err := engine.QueryString(query)
results, err := engine.QueryInterface(query) results, err := engine.QueryInterface(query)
7. With Exec 7. With Exec
var query = builder.Insert("a, b").Into("table1").Select("b, c").From("table2") var query = builder.Insert("a, b").Into("table1").Select("b, c").From("table2")
results, err := engine.Exec(query) results, err := engine.Exec(query)
More usage, please visit http://xorm.io/docs More usage, please visit http://xorm.io/docs
*/ */

View File

@ -254,6 +254,11 @@ func (engine *Engine) SetConnMaxLifetime(d time.Duration) {
engine.DB().SetConnMaxLifetime(d) engine.DB().SetConnMaxLifetime(d)
} }
// SetConnMaxIdleTime sets the maximum amount of time a connection may be idle.
func (engine *Engine) SetConnMaxIdleTime(d time.Duration) {
engine.DB().SetConnMaxIdleTime(d)
}
// SetMaxOpenConns is only available for go 1.2+ // SetMaxOpenConns is only available for go 1.2+
func (engine *Engine) SetMaxOpenConns(conns int) { func (engine *Engine) SetMaxOpenConns(conns int) {
engine.DB().SetMaxOpenConns(conns) engine.DB().SetMaxOpenConns(conns)
@ -330,7 +335,7 @@ func (engine *Engine) Ping() error {
// SQL method let's you manually write raw SQL and operate // SQL method let's you manually write raw SQL and operate
// For example: // For example:
// //
// engine.SQL("select * from user").Find(&users) // engine.SQL("select * from user").Find(&users)
// //
// This code will execute "select * from user" and set the records to users // This code will execute "select * from user" and set the records to users
func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session { func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session {
@ -380,7 +385,7 @@ func (engine *Engine) loadTableInfo(table *schemas.Table) error {
seq = 0 seq = 0
} }
} }
var colName = strings.Trim(parts[0], `"`) colName := strings.Trim(parts[0], `"`)
if col := table.GetColumn(colName); col != nil { if col := table.GetColumn(colName); col != nil {
col.Indexes[index.Name] = index.Type col.Indexes[index.Name] = index.Type
} else { } else {
@ -502,9 +507,9 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
} }
} }
var dstTableName = dstTable.Name dstTableName := dstTable.Name
var quoter = dstDialect.Quoter().Quote quoter := dstDialect.Quoter().Quote
var quotedDstTableName = quoter(dstTable.Name) quotedDstTableName := quoter(dstTable.Name)
if dstDialect.URI().Schema != "" { if dstDialect.URI().Schema != "" {
dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name) dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name)
quotedDstTableName = fmt.Sprintf("%s.%s", quoter(dstDialect.URI().Schema), quoter(dstTable.Name)) quotedDstTableName = fmt.Sprintf("%s.%s", quoter(dstDialect.URI().Schema), quoter(dstTable.Name))
@ -815,6 +820,9 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
return err return err
} }
} }
// !datbeohbbh! if no error, manually close
rows.Close()
sess.Close()
} }
return nil return nil
} }
@ -996,9 +1004,8 @@ func (engine *Engine) Desc(colNames ...string) *Session {
// Asc will generate "ORDER BY column1,column2 Asc" // Asc will generate "ORDER BY column1,column2 Asc"
// This method can chainable use. // This method can chainable use.
// //
// engine.Desc("name").Asc("age").Find(&users) // engine.Desc("name").Asc("age").Find(&users)
// // SELECT * FROM user ORDER BY name DESC, age ASC // // SELECT * FROM user ORDER BY name DESC, age ASC
//
func (engine *Engine) Asc(colNames ...string) *Session { func (engine *Engine) Asc(colNames ...string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.isAutoClose = true session.isAutoClose = true
@ -1006,10 +1013,10 @@ func (engine *Engine) Asc(colNames ...string) *Session {
} }
// OrderBy will generate "ORDER BY order" // OrderBy will generate "ORDER BY order"
func (engine *Engine) OrderBy(order string) *Session { func (engine *Engine) OrderBy(order interface{}, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.isAutoClose = true session.isAutoClose = true
return session.OrderBy(order) return session.OrderBy(order, args...)
} }
// Prepare enables prepare statement // Prepare enables prepare statement
@ -1020,7 +1027,7 @@ func (engine *Engine) Prepare() *Session {
} }
// Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (engine *Engine) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { func (engine *Engine) Join(joinOperator string, tablename interface{}, condition interface{}, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.isAutoClose = true session.isAutoClose = true
return session.Join(joinOperator, tablename, condition, args...) return session.Join(joinOperator, tablename, condition, args...)
@ -1220,9 +1227,10 @@ func (engine *Engine) InsertOne(bean interface{}) (int64, error) {
// Update records, bean's non-empty fields are updated contents, // Update records, bean's non-empty fields are updated contents,
// condiBean' non-empty filds are conditions // condiBean' non-empty filds are conditions
// CAUTION: // CAUTION:
// 1.bool will defaultly be updated content nor conditions //
// You should call UseBool if you have bool to use. // 1.bool will defaultly be updated content nor conditions
// 2.float32 & float64 may be not inexact as conditions // You should call UseBool if you have bool to use.
// 2.float32 & float64 may be not inexact as conditions
func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) { func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
@ -1230,12 +1238,21 @@ func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64
} }
// Delete records, bean's non-empty fields are conditions // Delete records, bean's non-empty fields are conditions
// At least one condition must be set.
func (engine *Engine) Delete(beans ...interface{}) (int64, error) { func (engine *Engine) Delete(beans ...interface{}) (int64, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
return session.Delete(beans...) return session.Delete(beans...)
} }
// Truncate records, bean's non-empty fields are conditions
// In contrast to Delete this method allows deletes without conditions.
func (engine *Engine) Truncate(beans ...interface{}) (int64, error) {
session := engine.NewSession()
defer session.Close()
return session.Truncate(beans...)
}
// Get retrieve one record from table, bean's non-empty fields // Get retrieve one record from table, bean's non-empty fields
// are conditions // are conditions
func (engine *Engine) Get(beans ...interface{}) (bool, error) { func (engine *Engine) Get(beans ...interface{}) (bool, error) {

2
go.mod
View File

@ -17,5 +17,5 @@ require (
github.com/syndtr/goleveldb v1.0.0 github.com/syndtr/goleveldb v1.0.0
github.com/ziutek/mymysql v1.5.4 github.com/ziutek/mymysql v1.5.4
modernc.org/sqlite v1.14.2 modernc.org/sqlite v1.14.2
xorm.io/builder v0.3.9 xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978
) )

4
go.sum
View File

@ -659,5 +659,5 @@ modernc.org/z v1.2.19 h1:BGyRFWhDVn5LFS5OcX4Yd/MlpRTOc7hOPTdcIpCiUao=
modernc.org/z v1.2.19/go.mod h1:+ZpP0pc4zz97eukOzW3xagV/lS82IpPN9NGG5pNF9vY= modernc.org/z v1.2.19/go.mod h1:+ZpP0pc4zz97eukOzW3xagV/lS82IpPN9NGG5pNF9vY=
sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=
sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU=
xorm.io/builder v0.3.9 h1:Sd65/LdWyO7LR8+Cbd+e7mm3sK/7U9k0jS3999IDHMc= xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978 h1:bvLlAPW1ZMTWA32LuZMBEGHAUOcATZjzHcotf3SWweM=
xorm.io/builder v0.3.9/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE=

View File

@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build dm
// +build dm // +build dm
package integrations package integrations

View File

@ -290,13 +290,11 @@ func TestGetColumnsComment(t *testing.T) {
} }
func TestGetColumnsLength(t *testing.T) { func TestGetColumnsLength(t *testing.T) {
var max_length int var max_length int64
switch testEngine.Dialect().URI().DBType { switch testEngine.Dialect().URI().DBType {
case case schemas.POSTGRES:
schemas.POSTGRES:
max_length = 0 max_length = 0
case case schemas.MYSQL:
schemas.MYSQL:
max_length = 65535 max_length = 65535
default: default:
t.Skip() t.Skip()

View File

@ -208,7 +208,7 @@ func TestUnscopeDelete(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var nowUnix = time.Now().Unix() nowUnix := time.Now().Unix()
var s UnscopeDeleteStruct var s UnscopeDeleteStruct
cnt, err = testEngine.ID(1).Delete(&s) cnt, err = testEngine.ID(1).Delete(&s)
assert.NoError(t, err) assert.NoError(t, err)
@ -266,3 +266,28 @@ func TestDelete2(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, has) assert.False(t, has)
} }
func TestTruncate(t *testing.T) {
assert.NoError(t, PrepareEngine())
type TruncateUser struct {
Uid int64 `xorm:"id pk not null autoincr"`
}
assert.NoError(t, testEngine.Sync(new(TruncateUser)))
cnt, err := testEngine.Insert(&TruncateUser{})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
_, err = testEngine.Delete(&TruncateUser{})
assert.Error(t, err)
_, err = testEngine.Truncate(&TruncateUser{})
assert.NoError(t, err)
user2 := TruncateUser{}
has, err := testEngine.ID(1).Get(&user2)
assert.NoError(t, err)
assert.False(t, has)
}

View File

@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
"xorm.io/builder"
"xorm.io/xorm" "xorm.io/xorm"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/names" "xorm.io/xorm/names"
@ -247,6 +248,10 @@ func TestOrder(t *testing.T) {
users2 := make([]Userinfo, 0) users2 := make([]Userinfo, 0)
err = testEngine.Asc("id", "username").Desc("height").Find(&users2) err = testEngine.Asc("id", "username").Desc("height").Find(&users2)
assert.NoError(t, err) assert.NoError(t, err)
users = make([]Userinfo, 0)
err = testEngine.OrderBy("CASE WHEN username LIKE ? THEN 0 ELSE 1 END DESC", "a").Find(&users)
assert.NoError(t, err)
} }
func TestGroupBy(t *testing.T) { func TestGroupBy(t *testing.T) {
@ -961,6 +966,10 @@ func TestFindJoin(t *testing.T) {
scenes = make([]SceneItem, 0) scenes = make([]SceneItem, 0)
err = testEngine.Join("INNER", "order", "`scene_item`.`device_id`=`order`.`id`").Find(&scenes) err = testEngine.Join("INNER", "order", "`scene_item`.`device_id`=`order`.`id`").Find(&scenes)
assert.NoError(t, err) assert.NoError(t, err)
scenes = make([]SceneItem, 0)
err = testEngine.Join("INNER", "order", builder.Expr("`scene_item`.`device_id`=`order`.`id`")).Find(&scenes)
assert.NoError(t, err)
} }
func TestJoinFindLimit(t *testing.T) { func TestJoinFindLimit(t *testing.T) {

View File

@ -744,7 +744,8 @@ func TestInsertMap(t *testing.T) {
assert.EqualValues(t, "lunny", ims[3].Name) assert.EqualValues(t, "lunny", ims[3].Name)
} }
/*INSERT INTO `issue` (`repo_id`, `poster_id`, ... ,`name`, `content`, ... ,`index`) /*
INSERT INTO `issue` (`repo_id`, `poster_id`, ... ,`name`, `content`, ... ,`index`)
SELECT $1, $2, ..., $14, $15, ..., MAX(`index`) + 1 FROM `issue` WHERE `repo_id` = $1; SELECT $1, $2, ..., $14, $15, ..., MAX(`index`) + 1 FROM `issue` WHERE `repo_id` = $1;
*/ */
func TestInsertWhere(t *testing.T) { func TestInsertWhere(t *testing.T) {

View File

@ -185,3 +185,36 @@ func TestMultipleTransaction(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 0, len(ms)) assert.EqualValues(t, 0, len(ms))
} }
func TestInsertMulti2InterfaceTransaction(t *testing.T) {
type Multi2InterfaceTransaction struct {
ID uint64 `xorm:"id pk autoincr"`
Name string
Alias string
CreateTime time.Time `xorm:"created"`
UpdateTime time.Time `xorm:"updated"`
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(Multi2InterfaceTransaction))
session := testEngine.NewSession()
defer session.Close()
err := session.Begin()
assert.NoError(t, err)
users := []interface{}{
&Multi2InterfaceTransaction{Name: "a", Alias: "A"},
&Multi2InterfaceTransaction{Name: "b", Alias: "B"},
&Multi2InterfaceTransaction{Name: "c", Alias: "C"},
&Multi2InterfaceTransaction{Name: "d", Alias: "D"},
}
cnt, err := session.Insert(&users)
assert.NoError(t, err)
assert.EqualValues(t, len(users), cnt)
assert.NotPanics(t, func() {
err = session.Commit()
assert.NoError(t, err)
})
}

View File

@ -31,6 +31,7 @@ type Interface interface {
Decr(column string, arg ...interface{}) *Session Decr(column string, arg ...interface{}) *Session
Desc(...string) *Session Desc(...string) *Session
Delete(...interface{}) (int64, error) Delete(...interface{}) (int64, error)
Truncate(...interface{}) (int64, error)
Distinct(columns ...string) *Session Distinct(columns ...string) *Session
DropIndexes(bean interface{}) error DropIndexes(bean interface{}) error
Exec(sqlOrArgs ...interface{}) (sql.Result, error) Exec(sqlOrArgs ...interface{}) (sql.Result, error)
@ -52,9 +53,9 @@ type Interface interface {
NoAutoCondition(...bool) *Session NoAutoCondition(...bool) *Session
NotIn(string, ...interface{}) *Session NotIn(string, ...interface{}) *Session
Nullable(...string) *Session Nullable(...string) *Session
Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session Join(joinOperator string, tablename interface{}, condition interface{}, args ...interface{}) *Session
Omit(columns ...string) *Session Omit(columns ...string) *Session
OrderBy(order string) *Session OrderBy(order interface{}, args ...interface{}) *Session
Ping() error Ping() error
Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error) Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error)
QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error)

View File

@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build gojson
// +build gojson // +build gojson
package json package json

View File

@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build jsoniter
// +build jsoniter // +build jsoniter
package json package json

View File

@ -6,6 +6,7 @@ package statements
import ( import (
"fmt" "fmt"
"strconv"
"strings" "strings"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
@ -26,14 +27,19 @@ func (statement *Statement) ConvertIDSQL(sqlStr string) string {
return "" return ""
} }
var top string var b strings.Builder
b.WriteString("SELECT ")
pLimitN := statement.LimitN pLimitN := statement.LimitN
if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL { if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN) b.WriteString("TOP ")
b.WriteString(strconv.Itoa(*pLimitN))
b.WriteString(" ")
} }
b.WriteString(colstrs)
b.WriteString(" FROM ")
b.WriteString(sqls[1])
newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) return b.String()
return newsql
} }
return "" return ""
} }
@ -54,7 +60,7 @@ func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) {
return "", "" return "", ""
} }
var whereStr = sqls[1] whereStr := sqls[1]
// TODO: for postgres only, if any other database? // TODO: for postgres only, if any other database?
var paraStr string var paraStr string

111
internal/statements/cond.go Normal file
View File

@ -0,0 +1,111 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"xorm.io/builder"
"xorm.io/xorm/schemas"
)
type QuoteReplacer struct {
*builder.BytesWriter
quoter schemas.Quoter
}
func (q *QuoteReplacer) Write(p []byte) (n int, err error) {
c := q.quoter.Replace(string(p))
return q.BytesWriter.Builder.WriteString(c)
}
func (statement *Statement) QuoteReplacer(w *builder.BytesWriter) *QuoteReplacer {
return &QuoteReplacer{
BytesWriter: w,
quoter: statement.dialect.Quoter(),
}
}
// Where add Where statement
func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
return statement.And(query, args...)
}
// And add Where & and statement
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
switch qr := query.(type) {
case string:
cond := builder.Expr(qr, args...)
statement.cond = statement.cond.And(cond)
case map[string]interface{}:
cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.And(cond)
case builder.Cond:
statement.cond = statement.cond.And(qr)
for _, v := range args {
if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.And(vv)
}
}
default:
statement.LastError = ErrConditionType
}
return statement
}
// Or add Where & Or statement
func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
switch qr := query.(type) {
case string:
cond := builder.Expr(qr, args...)
statement.cond = statement.cond.Or(cond)
case map[string]interface{}:
cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.Or(cond)
case builder.Cond:
statement.cond = statement.cond.Or(qr)
for _, v := range args {
if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.Or(vv)
}
}
default:
statement.LastError = ErrConditionType
}
return statement
}
// In generate "Where column IN (?) " statement
func (statement *Statement) In(column string, args ...interface{}) *Statement {
in := builder.In(statement.quote(column), args...)
statement.cond = statement.cond.And(in)
return statement
}
// NotIn generate "Where column NOT IN (?) " statement
func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
notIn := builder.NotIn(statement.quote(column), args...)
statement.cond = statement.cond.And(notIn)
return statement
}
// SetNoAutoCondition if you do not want convert bean's field as query condition, then use this function
func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement {
statement.NoAutoCondition = true
if len(no) > 0 {
statement.NoAutoCondition = no[0]
}
return statement
}
// Conds returns condtions
func (statement *Statement) Conds() builder.Cond {
return statement.cond
}

View File

@ -0,0 +1,96 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/builder"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition interface{}, args ...interface{}) *Statement {
var buf strings.Builder
if len(statement.JoinStr) > 0 {
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
} else {
fmt.Fprintf(&buf, "%v JOIN ", joinOP)
}
condStr := ""
condArgs := []interface{}{}
switch condTp := condition.(type) {
case string:
condStr = condTp
case builder.Cond:
var err error
condStr, condArgs, err = builder.ToSQL(condTp)
if err != nil {
statement.LastError = err
return statement
}
default:
statement.LastError = fmt.Errorf("unsupported join condition type: %v", condTp)
return statement
}
switch tp := tablename.(type) {
case builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr))
statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...)
case *builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr))
statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...)
default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
if !utils.IsSubQuery(tbName) {
var buf strings.Builder
_ = statement.dialect.Quoter().QuoteTo(&buf, tbName)
tbName = buf.String()
} else {
tbName = statement.ReplaceQuote(tbName)
}
fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condStr))
statement.joinArgs = append(statement.joinArgs, condArgs...)
}
statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...)
return statement
}
func (statement *Statement) writeJoin(w builder.Writer) error {
if statement.JoinStr != "" {
if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil {
return err
}
w.Append(statement.joinArgs...)
}
return nil
}

View File

@ -0,0 +1,90 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/builder"
)
func (statement *Statement) HasOrderBy() bool {
return statement.orderStr != ""
}
// ResetOrderBy reset ordery conditions
func (statement *Statement) ResetOrderBy() {
statement.orderStr = ""
statement.orderArgs = nil
}
// WriteOrderBy write order by to writer
func (statement *Statement) WriteOrderBy(w builder.Writer) error {
if len(statement.orderStr) > 0 {
if _, err := fmt.Fprintf(w, " ORDER BY %s", statement.orderStr); err != nil {
return err
}
w.Append(statement.orderArgs...)
}
return nil
}
// OrderBy generate "Order By order" statement
func (statement *Statement) OrderBy(order interface{}, args ...interface{}) *Statement {
if len(statement.orderStr) > 0 {
statement.orderStr += ", "
}
var rawOrder string
switch t := order.(type) {
case (*builder.Expression):
rawOrder = t.Content()
args = t.Args()
case string:
rawOrder = t
default:
statement.LastError = ErrUnSupportedSQLType
return statement
}
statement.orderStr += statement.ReplaceQuote(rawOrder)
if len(args) > 0 {
statement.orderArgs = append(statement.orderArgs, args...)
}
return statement
}
// Desc generate `ORDER BY xx DESC`
func (statement *Statement) Desc(colNames ...string) *Statement {
var buf strings.Builder
if len(statement.orderStr) > 0 {
fmt.Fprint(&buf, statement.orderStr, ", ")
}
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
_ = statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " DESC")
}
statement.orderStr = buf.String()
return statement
}
// Asc provide asc order by query condition, the input parameters are columns.
func (statement *Statement) Asc(colNames ...string) *Statement {
var buf strings.Builder
if len(statement.orderStr) > 0 {
fmt.Fprint(&buf, statement.orderStr, ", ")
}
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
_ = statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " ASC")
}
statement.orderStr = buf.String()
return statement
}

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -28,7 +29,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
var columnStr = statement.ColumnStr() columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 { if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr columnStr = statement.SelectStr
} else { } else {
@ -58,19 +59,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
return "", nil, err return "", nil, err
} }
sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) return statement.genSelectSQL(columnStr, true, true)
if err != nil {
return "", nil, err
}
args := append(statement.joinArgs, condArgs...)
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
} }
// GenSumSQL generates sum SQL // GenSumSQL generates sum SQL
@ -83,7 +72,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
return "", nil, err return "", nil, err
} }
var sumStrs = make([]string, 0, len(columns)) sumStrs := make([]string, 0, len(columns))
for _, colName := range columns { for _, colName := range columns {
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
colName = statement.quote(colName) colName = statement.quote(colName)
@ -94,16 +83,11 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
} }
sumSelect := strings.Join(sumStrs, ", ") sumSelect := strings.Join(sumStrs, ", ")
if err := statement.mergeConds(bean); err != nil { if err := statement.MergeConds(bean); err != nil {
return "", nil, err return "", nil, err
} }
sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true) return statement.genSelectSQL(sumSelect, true, true)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
// GenGetSQL generates Get SQL // GenGetSQL generates Get SQL
@ -119,7 +103,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
} }
} }
var columnStr = statement.ColumnStr() columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 { if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr columnStr = statement.SelectStr
} else { } else {
@ -146,7 +130,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
} }
if isStruct { if isStruct {
if err := statement.mergeConds(bean); err != nil { if err := statement.MergeConds(bean); err != nil {
return "", nil, err return "", nil, err
} }
} else { } else {
@ -155,12 +139,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
} }
} }
sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) return statement.genSelectSQL(columnStr, true, true)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
// GenCountSQL generates the SQL for counting // GenCountSQL generates the SQL for counting
@ -175,12 +154,12 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
if err := statement.SetRefBean(beans[0]); err != nil { if err := statement.SetRefBean(beans[0]); err != nil {
return "", nil, err return "", nil, err
} }
if err := statement.mergeConds(beans[0]); err != nil { if err := statement.MergeConds(beans[0]); err != nil {
return "", nil, err return "", nil, err
} }
} }
var selectSQL = statement.SelectStr selectSQL := statement.SelectStr
if len(selectSQL) <= 0 { if len(selectSQL) <= 0 {
if statement.IsDistinct { if statement.IsDistinct {
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr()) selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr())
@ -206,55 +185,58 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
sqlStr = fmt.Sprintf("SELECT %s FROM (%s) sub", selectSQL, sqlStr) sqlStr = fmt.Sprintf("SELECT %s FROM (%s) sub", selectSQL, sqlStr)
} }
return sqlStr, append(statement.joinArgs, condArgs...), nil return sqlStr, condArgs, nil
} }
func (statement *Statement) fromBuilder() *strings.Builder { func (statement *Statement) writeFrom(w builder.Writer) error {
var builder strings.Builder if _, err := fmt.Fprint(w, " FROM "); err != nil {
var quote = statement.quote return err
var dialect = statement.dialect
builder.WriteString(" FROM ")
if dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
builder.WriteString(statement.TableName())
} else {
builder.WriteString(quote(statement.TableName()))
} }
if err := statement.writeTableName(w); err != nil {
return err
}
if err := statement.writeAlias(w); err != nil {
return err
}
return statement.writeJoin(w)
}
if statement.TableAlias != "" { func (statement *Statement) writeLimitOffset(w builder.Writer) error {
if dialect.URI().DBType == schemas.ORACLE { if statement.Start > 0 {
builder.WriteString(" ") if statement.LimitN != nil {
} else { _, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start)
builder.WriteString(" AS ") return err
} }
builder.WriteString(quote(statement.TableAlias)) _, err := fmt.Fprintf(w, " LIMIT 0 OFFSET %v", statement.Start)
return err
} }
if statement.JoinStr != "" { if statement.LimitN != nil {
builder.WriteString(" ") _, err := fmt.Fprint(w, " LIMIT ", *statement.LimitN)
builder.WriteString(statement.JoinStr) return err
} }
return &builder // no limit statement
return nil
} }
func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) {
var ( var (
distinct string distinct string
dialect = statement.dialect dialect = statement.dialect
fromStr = statement.fromBuilder().String() top, whereStr string
top, mssqlCondi, whereStr string mssqlCondi = builder.NewWriter()
) )
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT " distinct = "DISTINCT "
} }
condSQL, condArgs, err := statement.GenCondSQL(statement.cond) condWriter := builder.NewWriter()
if err != nil { if err := statement.cond.WriteTo(statement.QuoteReplacer(condWriter)); err != nil {
return "", nil, err return "", nil, err
} }
if len(condSQL) > 0 {
whereStr = fmt.Sprintf(" WHERE %s", condSQL) if condWriter.Len() > 0 {
whereStr = " WHERE "
} }
pLimitN := statement.LimitN pLimitN := statement.LimitN
@ -289,49 +271,81 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} }
} }
var orderStr string if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s",
if needOrderBy && len(statement.OrderStr) > 0 { column, statement.Start, column); err != nil {
orderStr = fmt.Sprintf(" ORDER BY %s", statement.OrderStr) return "", nil, err
} }
if err := statement.writeFrom(mssqlCondi); err != nil {
var groupStr string return "", nil, err
if len(statement.GroupByStr) > 0 { }
groupStr = fmt.Sprintf(" GROUP BY %s", statement.GroupByStr) if whereStr != "" {
if _, err := fmt.Fprint(mssqlCondi, whereStr); err != nil {
return "", nil, err
}
if err := utils.WriteBuilder(mssqlCondi, statement.QuoteReplacer(condWriter)); err != nil {
return "", nil, err
}
}
if needOrderBy {
if err := statement.WriteOrderBy(mssqlCondi); err != nil {
return "", nil, err
}
}
if err := statement.WriteGroupBy(mssqlCondi); err != nil {
return "", nil, err
}
if _, err := fmt.Fprint(mssqlCondi, "))"); err != nil {
return "", nil, err
} }
mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
} }
} }
var buf strings.Builder buf := builder.NewWriter()
fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) if _, err := fmt.Fprintf(buf, "SELECT %v%v%v", distinct, top, columnStr); err != nil {
if len(mssqlCondi) > 0 { return "", nil, err
}
if err := statement.writeFrom(buf); err != nil {
return "", nil, err
}
if whereStr != "" {
if _, err := fmt.Fprint(buf, whereStr); err != nil {
return "", nil, err
}
if err := utils.WriteBuilder(buf, statement.QuoteReplacer(condWriter)); err != nil {
return "", nil, err
}
}
if mssqlCondi.Len() > 0 {
if len(whereStr) > 0 { if len(whereStr) > 0 {
fmt.Fprint(&buf, " AND ", mssqlCondi) if _, err := fmt.Fprint(buf, " AND "); err != nil {
return "", nil, err
}
} else { } else {
fmt.Fprint(&buf, " WHERE ", mssqlCondi) if _, err := fmt.Fprint(buf, " WHERE "); err != nil {
return "", nil, err
}
}
if err := utils.WriteBuilder(buf, mssqlCondi); err != nil {
return "", nil, err
} }
} }
if statement.GroupByStr != "" { if err := statement.WriteGroupBy(buf); err != nil {
fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) return "", nil, err
} }
if statement.HavingStr != "" { if err := statement.writeHaving(buf); err != nil {
fmt.Fprint(&buf, " ", statement.HavingStr) return "", nil, err
} }
if needOrderBy && statement.OrderStr != "" { if needOrderBy {
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) if err := statement.WriteOrderBy(buf); err != nil {
return "", nil, err
}
} }
if needLimit { if needLimit {
if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
if statement.Start > 0 { if err := statement.writeLimitOffset(buf); err != nil {
if pLimitN != nil { return "", nil, err
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
} else {
fmt.Fprintf(&buf, " LIMIT 0 OFFSET %v", statement.Start)
}
} else if pLimitN != nil {
fmt.Fprint(&buf, " LIMIT ", *pLimitN)
} }
} else if dialect.URI().DBType == schemas.ORACLE { } else if dialect.URI().DBType == schemas.ORACLE {
if pLimitN != nil { if pLimitN != nil {
@ -341,16 +355,16 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
if rawColStr == "*" { if rawColStr == "*" {
rawColStr = "at.*" rawColStr = "at.*"
} }
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", fmt.Fprintf(buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start) columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start)
} }
} }
} }
if statement.IsForUpdate { if statement.IsForUpdate {
return dialect.ForUpdateSQL(buf.String()), condArgs, nil return dialect.ForUpdateSQL(buf.String()), buf.Args(), nil
} }
return buf.String(), condArgs, nil return buf.String(), buf.Args(), nil
} }
// GenExistSQL generates Exist SQL // GenExistSQL generates Exist SQL
@ -359,10 +373,6 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
return statement.GenRawSQL(), statement.RawParams, nil return statement.GenRawSQL(), statement.RawParams, nil
} }
var sqlStr string
var args []interface{}
var joinStr string
var err error
var b interface{} var b interface{}
if len(bean) > 0 { if len(bean) > 0 {
b = bean[0] b = bean[0]
@ -381,45 +391,70 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
if len(tableName) <= 0 { if len(tableName) <= 0 {
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
if statement.RefTable == nil { if statement.RefTable != nil {
tableName = statement.quote(tableName) return statement.Limit(1).GenGetSQL(b)
if len(statement.JoinStr) > 0 { }
joinStr = statement.JoinStr
}
tableName = statement.quote(tableName)
buf := builder.NewWriter()
if statement.dialect.URI().DBType == schemas.MSSQL {
if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil {
return "", nil, err
}
if err := statement.writeJoin(buf); err != nil {
return "", nil, err
}
if statement.Conds().IsValid() { if statement.Conds().IsValid() {
condSQL, condArgs, err := statement.GenCondSQL(statement.Conds()) if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
if err != nil {
return "", nil, err return "", nil, err
} }
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
if statement.dialect.URI().DBType == schemas.MSSQL { return "", nil, err
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL)
} else if statement.dialect.URI().DBType == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL)
} else {
sqlStr = fmt.Sprintf("SELECT 1 FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL)
} }
args = condArgs }
} else { } else if statement.dialect.URI().DBType == schemas.ORACLE {
if statement.dialect.URI().DBType == schemas.MSSQL { if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) return "", nil, err
} else if statement.dialect.URI().DBType == schemas.ORACLE { }
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) if err := statement.writeJoin(buf); err != nil {
} else { return "", nil, err
sqlStr = fmt.Sprintf("SELECT 1 FROM %s %s LIMIT 1", tableName, joinStr) }
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
return "", nil, err
}
if statement.Conds().IsValid() {
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
return "", nil, err
} }
args = []interface{}{} if _, err := fmt.Fprintf(buf, " AND "); err != nil {
return "", nil, err
}
}
if _, err := fmt.Fprintf(buf, "ROWNUM=1"); err != nil {
return "", nil, err
} }
} else { } else {
statement.Limit(1) if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil {
sqlStr, args, err = statement.GenGetSQL(b) return "", nil, err
if err != nil { }
if err := statement.writeJoin(buf); err != nil {
return "", nil, err
}
if statement.Conds().IsValid() {
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
return "", nil, err
}
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
return "", nil, err
}
}
if _, err := fmt.Fprintf(buf, " LIMIT 1"); err != nil {
return "", nil, err return "", nil, err
} }
} }
return sqlStr, args, nil return buf.String(), buf.Args(), nil
} }
// GenFindSQL generates Find SQL // GenFindSQL generates Find SQL
@ -428,15 +463,11 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
return statement.GenRawSQL(), statement.RawParams, nil return statement.GenRawSQL(), statement.RawParams, nil
} }
var sqlStr string
var args []interface{}
var err error
if len(statement.TableName()) <= 0 { if len(statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
var columnStr = statement.ColumnStr() columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 { if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr columnStr = statement.SelectStr
} else { } else {
@ -464,16 +495,5 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
statement.cond = statement.cond.And(autoCond) statement.cond = statement.cond.And(autoCond)
sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) return statement.genSelectSQL(columnStr, true, true)
if err != nil {
return "", nil, err
}
args = append(statement.joinArgs, condArgs...)
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
} }

View File

@ -0,0 +1,137 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/xorm/schemas"
)
// Select replace select
func (statement *Statement) Select(str string) *Statement {
statement.SelectStr = statement.ReplaceQuote(str)
return statement
}
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
// Cols generate "col1, col2" statement
func (statement *Statement) Cols(columns ...string) *Statement {
cols := col2NewCols(columns...)
for _, nc := range cols {
statement.ColumnMap.Add(nc)
}
return statement
}
// ColumnStr returns column string
func (statement *Statement) ColumnStr() string {
return statement.dialect.Quoter().Join(statement.ColumnMap, ", ")
}
// AllCols update use only: update all columns
func (statement *Statement) AllCols() *Statement {
statement.useAllCols = true
return statement
}
// MustCols update use only: must update columns
func (statement *Statement) MustCols(columns ...string) *Statement {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.MustColumnMap[strings.ToLower(nc)] = true
}
return statement
}
// UseBool indicates that use bool fields as update contents and query contiditions
func (statement *Statement) UseBool(columns ...string) *Statement {
if len(columns) > 0 {
statement.MustCols(columns...)
} else {
statement.allUseBool = true
}
return statement
}
// Omit do not use the columns
func (statement *Statement) Omit(columns ...string) {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.OmitColumnMap = append(statement.OmitColumnMap, nc)
}
}
func (statement *Statement) genColumnStr() string {
if statement.RefTable == nil {
return ""
}
var buf strings.Builder
columns := statement.RefTable.Columns()
for _, col := range columns {
if statement.OmitColumnMap.Contain(col.Name) {
continue
}
if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) {
continue
}
if col.MapType == schemas.ONLYTODB {
continue
}
if buf.Len() != 0 {
buf.WriteString(", ")
}
if statement.JoinStr != "" {
if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias)
} else {
buf.WriteString(statement.TableName())
}
buf.WriteString(".")
}
statement.dialect.Quoter().QuoteTo(&buf, col.Name)
}
return buf.String()
}
func (statement *Statement) colName(col *schemas.Column, tableName string) string {
if statement.needTableName() {
nm := tableName
if len(statement.TableAlias) > 0 {
nm = statement.TableAlias
}
return fmt.Sprintf("%s.%s", statement.quote(nm), statement.quote(col.Name))
}
return statement.quote(col.Name)
}
// Distinct generates "DISTINCT col1, col2 " statement
func (statement *Statement) Distinct(columns ...string) *Statement {
statement.IsDistinct = true
statement.Cols(columns...)
return statement
}

View File

@ -43,7 +43,8 @@ type Statement struct {
Start int Start int
LimitN *int LimitN *int
idParam schemas.PK idParam schemas.PK
OrderStr string orderStr string
orderArgs []interface{}
JoinStr string JoinStr string
joinArgs []interface{} joinArgs []interface{}
GroupByStr string GroupByStr string
@ -101,15 +102,6 @@ func (statement *Statement) GenRawSQL() string {
return statement.ReplaceQuote(statement.RawSQL) return statement.ReplaceQuote(statement.RawSQL)
} }
// GenCondSQL generates condition SQL
func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) {
condSQL, condArgs, err := builder.ToSQL(condOrBuilder)
if err != nil {
return "", nil, err
}
return statement.ReplaceQuote(condSQL), condArgs, nil
}
// ReplaceQuote replace sql key words with quote // ReplaceQuote replace sql key words with quote
func (statement *Statement) ReplaceQuote(sql string) string { func (statement *Statement) ReplaceQuote(sql string) string {
if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL ||
@ -129,7 +121,7 @@ func (statement *Statement) Reset() {
statement.RefTable = nil statement.RefTable = nil
statement.Start = 0 statement.Start = 0
statement.LimitN = nil statement.LimitN = nil
statement.OrderStr = "" statement.ResetOrderBy()
statement.UseCascade = true statement.UseCascade = true
statement.JoinStr = "" statement.JoinStr = ""
statement.joinArgs = make([]interface{}, 0) statement.joinArgs = make([]interface{}, 0)
@ -164,21 +156,6 @@ func (statement *Statement) Reset() {
statement.LastError = nil statement.LastError = nil
} }
// SetNoAutoCondition if you do not want convert bean's field as query condition, then use this function
func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement {
statement.NoAutoCondition = true
if len(no) > 0 {
statement.NoAutoCondition = no[0]
}
return statement
}
// Alias set the table alias
func (statement *Statement) Alias(alias string) *Statement {
statement.TableAlias = alias
return statement
}
// SQL adds raw sql statement // SQL adds raw sql statement
func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
switch query.(type) { switch query.(type) {
@ -198,80 +175,10 @@ func (statement *Statement) SQL(query interface{}, args ...interface{}) *Stateme
return statement return statement
} }
// Where add Where statement
func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
return statement.And(query, args...)
}
func (statement *Statement) quote(s string) string { func (statement *Statement) quote(s string) string {
return statement.dialect.Quoter().Quote(s) return statement.dialect.Quoter().Quote(s)
} }
// And add Where & and statement
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
switch qr := query.(type) {
case string:
cond := builder.Expr(qr, args...)
statement.cond = statement.cond.And(cond)
case map[string]interface{}:
cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.And(cond)
case builder.Cond:
statement.cond = statement.cond.And(qr)
for _, v := range args {
if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.And(vv)
}
}
default:
statement.LastError = ErrConditionType
}
return statement
}
// Or add Where & Or statement
func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
switch qr := query.(type) {
case string:
cond := builder.Expr(qr, args...)
statement.cond = statement.cond.Or(cond)
case map[string]interface{}:
cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.Or(cond)
case builder.Cond:
statement.cond = statement.cond.Or(qr)
for _, v := range args {
if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.Or(vv)
}
}
default:
statement.LastError = ErrConditionType
}
return statement
}
// In generate "Where column IN (?) " statement
func (statement *Statement) In(column string, args ...interface{}) *Statement {
in := builder.In(statement.quote(column), args...)
statement.cond = statement.cond.And(in)
return statement
}
// NotIn generate "Where column NOT IN (?) " statement
func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
notIn := builder.NotIn(statement.quote(column), args...)
statement.cond = statement.cond.And(notIn)
return statement
}
// SetRefValue set ref value // SetRefValue set ref value
func (statement *Statement) SetRefValue(v reflect.Value) error { func (statement *Statement) SetRefValue(v reflect.Value) error {
var err error var err error
@ -302,26 +209,6 @@ func (statement *Statement) needTableName() bool {
return len(statement.JoinStr) > 0 return len(statement.JoinStr) > 0
} }
func (statement *Statement) colName(col *schemas.Column, tableName string) string {
if statement.needTableName() {
nm := tableName
if len(statement.TableAlias) > 0 {
nm = statement.TableAlias
}
return fmt.Sprintf("%s.%s", statement.quote(nm), statement.quote(col.Name))
}
return statement.quote(col.Name)
}
// TableName return current tableName
func (statement *Statement) TableName() string {
if statement.AltTableName != "" {
return statement.AltTableName
}
return statement.tableName
}
// Incr Generate "Update ... Set column = column + arg" statement // Incr Generate "Update ... Set column = column + arg" statement
func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
if len(arg) > 0 { if len(arg) > 0 {
@ -352,85 +239,12 @@ func (statement *Statement) SetExpr(column string, expression interface{}) *Stat
return statement return statement
} }
// Distinct generates "DISTINCT col1, col2 " statement
func (statement *Statement) Distinct(columns ...string) *Statement {
statement.IsDistinct = true
statement.Cols(columns...)
return statement
}
// ForUpdate generates "SELECT ... FOR UPDATE" statement // ForUpdate generates "SELECT ... FOR UPDATE" statement
func (statement *Statement) ForUpdate() *Statement { func (statement *Statement) ForUpdate() *Statement {
statement.IsForUpdate = true statement.IsForUpdate = true
return statement return statement
} }
// Select replace select
func (statement *Statement) Select(str string) *Statement {
statement.SelectStr = statement.ReplaceQuote(str)
return statement
}
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
// Cols generate "col1, col2" statement
func (statement *Statement) Cols(columns ...string) *Statement {
cols := col2NewCols(columns...)
for _, nc := range cols {
statement.ColumnMap.Add(nc)
}
return statement
}
// ColumnStr returns column string
func (statement *Statement) ColumnStr() string {
return statement.dialect.Quoter().Join(statement.ColumnMap, ", ")
}
// AllCols update use only: update all columns
func (statement *Statement) AllCols() *Statement {
statement.useAllCols = true
return statement
}
// MustCols update use only: must update columns
func (statement *Statement) MustCols(columns ...string) *Statement {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.MustColumnMap[strings.ToLower(nc)] = true
}
return statement
}
// UseBool indicates that use bool fields as update contents and query contiditions
func (statement *Statement) UseBool(columns ...string) *Statement {
if len(columns) > 0 {
statement.MustCols(columns...)
} else {
statement.allUseBool = true
}
return statement
}
// Omit do not use the columns
func (statement *Statement) Omit(columns ...string) {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.OmitColumnMap = append(statement.OmitColumnMap, nc)
}
}
// Nullable Update use only: update columns to null when value is nullable and zero-value // Nullable Update use only: update columns to null when value is nullable and zero-value
func (statement *Statement) Nullable(columns ...string) { func (statement *Statement) Nullable(columns ...string) {
newColumns := col2NewCols(columns...) newColumns := col2NewCols(columns...)
@ -454,54 +268,6 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement {
return statement return statement
} }
// OrderBy generate "Order By order" statement
func (statement *Statement) OrderBy(order string) *Statement {
if len(statement.OrderStr) > 0 {
statement.OrderStr += ", "
}
statement.OrderStr += statement.ReplaceQuote(order)
return statement
}
// Desc generate `ORDER BY xx DESC`
func (statement *Statement) Desc(colNames ...string) *Statement {
var buf strings.Builder
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, statement.OrderStr, ", ")
}
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
_ = statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " DESC")
}
statement.OrderStr = buf.String()
return statement
}
// Asc provide asc order by query condition, the input parameters are columns.
func (statement *Statement) Asc(colNames ...string) *Statement {
var buf strings.Builder
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, statement.OrderStr, ", ")
}
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
_ = statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " ASC")
}
statement.OrderStr = buf.String()
return statement
}
// Conds returns condtions
func (statement *Statement) Conds() builder.Cond {
return statement.cond
}
// SetTable tempororily set table name, the parameter could be a string or a pointer of struct // SetTable tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) SetTable(tableNameOrBean interface{}) error { func (statement *Statement) SetTable(tableNameOrBean interface{}) error {
v := rValue(tableNameOrBean) v := rValue(tableNameOrBean)
@ -518,71 +284,34 @@ func (statement *Statement) SetTable(tableNameOrBean interface{}) error {
return nil return nil
} }
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf strings.Builder
if len(statement.JoinStr) > 0 {
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
} else {
fmt.Fprintf(&buf, "%v JOIN ", joinOP)
}
switch tp := tablename.(type) {
case builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
if !utils.IsSubQuery(tbName) {
var buf strings.Builder
_ = statement.dialect.Quoter().QuoteTo(&buf, tbName)
tbName = buf.String()
} else {
tbName = statement.ReplaceQuote(tbName)
}
fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition))
}
statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...)
return statement
}
// GroupBy generate "Group By keys" statement // GroupBy generate "Group By keys" statement
func (statement *Statement) GroupBy(keys string) *Statement { func (statement *Statement) GroupBy(keys string) *Statement {
statement.GroupByStr = statement.ReplaceQuote(keys) statement.GroupByStr = statement.ReplaceQuote(keys)
return statement return statement
} }
func (statement *Statement) WriteGroupBy(w builder.Writer) error {
if statement.GroupByStr == "" {
return nil
}
_, err := fmt.Fprintf(w, " GROUP BY %s", statement.GroupByStr)
return err
}
// Having generate "Having conditions" statement // Having generate "Having conditions" statement
func (statement *Statement) Having(conditions string) *Statement { func (statement *Statement) Having(conditions string) *Statement {
statement.HavingStr = fmt.Sprintf("HAVING %v", statement.ReplaceQuote(conditions)) statement.HavingStr = fmt.Sprintf("HAVING %v", statement.ReplaceQuote(conditions))
return statement return statement
} }
func (statement *Statement) writeHaving(w builder.Writer) error {
if statement.HavingStr == "" {
return nil
}
_, err := fmt.Fprint(w, " ", statement.HavingStr)
return err
}
// SetUnscoped always disable struct tag "deleted" // SetUnscoped always disable struct tag "deleted"
func (statement *Statement) SetUnscoped() *Statement { func (statement *Statement) SetUnscoped() *Statement {
statement.unscoped = true statement.unscoped = true
@ -594,47 +323,6 @@ func (statement *Statement) GetUnscoped() bool {
return statement.unscoped return statement.unscoped
} }
func (statement *Statement) genColumnStr() string {
if statement.RefTable == nil {
return ""
}
var buf strings.Builder
columns := statement.RefTable.Columns()
for _, col := range columns {
if statement.OmitColumnMap.Contain(col.Name) {
continue
}
if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) {
continue
}
if col.MapType == schemas.ONLYTODB {
continue
}
if buf.Len() != 0 {
buf.WriteString(", ")
}
if statement.JoinStr != "" {
if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias)
} else {
buf.WriteString(statement.TableName())
}
buf.WriteString(".")
}
statement.dialect.Quoter().QuoteTo(&buf, col.Name)
}
return buf.String()
}
// GenIndexSQL generated create index SQL // GenIndexSQL generated create index SQL
func (statement *Statement) GenIndexSQL() []string { func (statement *Statement) GenIndexSQL() []string {
var sqls []string var sqls []string
@ -914,7 +602,8 @@ func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, i
statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
} }
func (statement *Statement) mergeConds(bean interface{}) error { // MergeConds merge conditions from bean and id
func (statement *Statement) MergeConds(bean interface{}) error {
if !statement.NoAutoCondition && statement.RefTable != nil { if !statement.NoAutoCondition && statement.RefTable != nil {
addedTableName := (len(statement.JoinStr) > 0) addedTableName := (len(statement.JoinStr) > 0)
autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
@ -927,15 +616,6 @@ func (statement *Statement) mergeConds(bean interface{}) error {
return statement.ProcessIDParam() return statement.ProcessIDParam()
} }
// GenConds generates conditions
func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, error) {
if err := statement.mergeConds(bean); err != nil {
return "", nil, err
}
return statement.GenCondSQL(statement.cond)
}
func (statement *Statement) quoteColumnStr(columnStr string) string { func (statement *Statement) quoteColumnStr(columnStr string) string {
columns := strings.Split(columnStr, ",") columns := strings.Split(columnStr, ",")
return statement.dialect.Quoter().Join(columns, ",") return statement.dialect.Quoter().Join(columns, ",")

View File

@ -0,0 +1,56 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/builder"
"xorm.io/xorm/schemas"
)
// TableName return current tableName
func (statement *Statement) TableName() string {
if statement.AltTableName != "" {
return statement.AltTableName
}
return statement.tableName
}
// Alias set the table alias
func (statement *Statement) Alias(alias string) *Statement {
statement.TableAlias = alias
return statement
}
func (statement *Statement) writeAlias(w builder.Writer) error {
if statement.TableAlias != "" {
if statement.dialect.URI().DBType == schemas.ORACLE {
if _, err := fmt.Fprint(w, " ", statement.quote(statement.TableAlias)); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, " AS ", statement.quote(statement.TableAlias)); err != nil {
return err
}
}
}
return nil
}
func (statement *Statement) writeTableName(w builder.Writer) error {
if statement.dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
if _, err := fmt.Fprint(w, statement.TableName()); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, statement.quote(statement.TableName())); err != nil {
return err
}
}
return nil
}

27
internal/utils/builder.go Normal file
View File

@ -0,0 +1,27 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package utils
import (
"fmt"
"xorm.io/builder"
)
type BuildReader interface {
String() string
Args() []interface{}
}
// WriteBuilder writes writers to one
func WriteBuilder(w *builder.BytesWriter, inputs ...BuildReader) error {
for _, input := range inputs {
if _, err := fmt.Fprint(w, input.String()); err != nil {
return err
}
w.Append(input.Args()...)
}
return nil
}

View File

@ -26,8 +26,8 @@ type Column struct {
FieldIndex []int // Available only when parsed from a struct FieldIndex []int // Available only when parsed from a struct
SQLType SQLType SQLType SQLType
IsJSON bool IsJSON bool
Length int Length int64
Length2 int Length2 int64
Nullable bool Nullable bool
Default string Default string
Indexes map[string]int Indexes map[string]int
@ -48,7 +48,7 @@ type Column struct {
} }
// NewColumn creates a new column // NewColumn creates a new column
func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column { func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int64, nullable bool) *Column {
return &Column{ return &Column{
Name: name, Name: name,
IsJSON: sqlType.IsJson(), IsJSON: sqlType.IsJson(),
@ -82,13 +82,15 @@ func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) {
// ValueOfV returns column's filed of struct's value accept reflevt value // ValueOfV returns column's filed of struct's value accept reflevt value
func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
var v = *dataStruct v := *dataStruct
for _, i := range col.FieldIndex { for _, i := range col.FieldIndex {
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
if v.IsNil() { if v.IsNil() {
v.Set(reflect.New(v.Type().Elem())) v.Set(reflect.New(v.Type().Elem()))
} }
v = v.Elem() v = v.Elem()
} else if v.Kind() == reflect.Interface {
v = reflect.Indirect(v.Elem())
} }
v = v.FieldByIndex([]int{i}) v = v.FieldByIndex([]int{i})
} }

View File

@ -163,17 +163,18 @@ func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error {
} }
// QuoteTo quotes the table or column names. i.e. if the quotes are [ and ] // QuoteTo quotes the table or column names. i.e. if the quotes are [ and ]
// name -> [name] //
// `name` -> [name] // name -> [name]
// [name] -> [name] // `name` -> [name]
// schema.name -> [schema].[name] // [name] -> [name]
// `schema`.`name` -> [schema].[name] // schema.name -> [schema].[name]
// `schema`.name -> [schema].[name] // `schema`.`name` -> [schema].[name]
// schema.`name` -> [schema].[name] // `schema`.name -> [schema].[name]
// [schema].name -> [schema].[name] // schema.`name` -> [schema].[name]
// schema.[name] -> [schema].[name] // [schema].name -> [schema].[name]
// name AS a -> [name] AS a // schema.[name] -> [schema].[name]
// schema.name AS a -> [schema].[name] AS a // name AS a -> [name] AS a
// schema.name AS a -> [schema].[name] AS a
func (q Quoter) QuoteTo(buf *strings.Builder, value string) error { func (q Quoter) QuoteTo(buf *strings.Builder, value string) error {
var i int var i int
for i < len(value) { for i < len(value) {

View File

@ -28,8 +28,8 @@ const (
// SQLType represents SQL types // SQLType represents SQL types
type SQLType struct { type SQLType struct {
Name string Name string
DefaultLength int DefaultLength int64
DefaultLength2 int DefaultLength2 int64
} }
// enumerates all columns types // enumerates all columns types

View File

@ -275,8 +275,8 @@ func (session *Session) Limit(limit int, start ...int) *Session {
// OrderBy provide order by query condition, the input parameter is the content // OrderBy provide order by query condition, the input parameter is the content
// after order by on a sql statement. // after order by on a sql statement.
func (session *Session) OrderBy(order string) *Session { func (session *Session) OrderBy(order interface{}, args ...interface{}) *Session {
session.statement.OrderBy(order) session.statement.OrderBy(order, args...)
return session return session
} }
@ -330,7 +330,7 @@ func (session *Session) NoCache() *Session {
} }
// Join join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (session *Session) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { func (session *Session) Join(joinOperator string, tablename interface{}, condition interface{}, args ...interface{}) *Session {
session.statement.Join(joinOperator, tablename, condition, args...) session.statement.Join(joinOperator, tablename, condition, args...)
return session return session
} }
@ -794,3 +794,9 @@ func (session *Session) PingContext(ctx context.Context) error {
session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
return session.DB().PingContext(ctx) return session.DB().PingContext(ctx)
} }
// disable version check
func (session *Session) NoVersionCheck() *Session {
session.statement.CheckVersion = false
return session
}

View File

@ -9,7 +9,9 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"xorm.io/builder"
"xorm.io/xorm/caches" "xorm.io/xorm/caches"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -28,7 +30,7 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri
} }
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr) sqlStr = filter.Do(session.ctx, sqlStr)
} }
newsql := session.statement.ConvertIDSQL(sqlStr) newsql := session.statement.ConvertIDSQL(sqlStr)
@ -89,7 +91,18 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri
} }
// Delete records, bean's non-empty fields are conditions // Delete records, bean's non-empty fields are conditions
// At least one condition must be set.
func (session *Session) Delete(beans ...interface{}) (int64, error) { func (session *Session) Delete(beans ...interface{}) (int64, error) {
return session.delete(beans, true)
}
// Truncate records, bean's non-empty fields are conditions
// In contrast to Delete this method allows deletes without conditions.
func (session *Session) Truncate(beans ...interface{}) (int64, error) {
return session.delete(beans, false)
}
func (session *Session) delete(beans []interface{}, mustHaveConditions bool) (int64, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
@ -99,10 +112,9 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
} }
var ( var (
condSQL string condWriter = builder.NewWriter()
condArgs []interface{} err error
err error bean interface{}
bean interface{}
) )
if len(beans) > 0 { if len(beans) > 0 {
bean = beans[0] bean = beans[0]
@ -116,115 +128,97 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
processor.BeforeDelete() processor.BeforeDelete()
} }
condSQL, condArgs, err = session.statement.GenConds(bean) if err = session.statement.MergeConds(bean); err != nil {
} else { return 0, err
condSQL, condArgs, err = session.statement.GenCondSQL(session.statement.Conds()) }
} }
if err != nil {
if err = session.statement.Conds().WriteTo(session.statement.QuoteReplacer(condWriter)); err != nil {
return 0, err return 0, err
} }
pLimitN := session.statement.LimitN pLimitN := session.statement.LimitN
if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { if mustHaveConditions && condWriter.Len() == 0 && (pLimitN == nil || *pLimitN == 0) {
return 0, ErrNeedDeletedCond return 0, ErrNeedDeletedCond
} }
var tableNameNoQuote = session.statement.TableName() tableNameNoQuote := session.statement.TableName()
var tableName = session.engine.Quote(tableNameNoQuote) tableName := session.engine.Quote(tableNameNoQuote)
var table = session.statement.RefTable table := session.statement.RefTable
var deleteSQL string deleteSQLWriter := builder.NewWriter()
if len(condSQL) > 0 { fmt.Fprintf(deleteSQLWriter, "DELETE FROM %v", tableName)
deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) if condWriter.Len() > 0 {
} else { fmt.Fprintf(deleteSQLWriter, " WHERE %v", condWriter.String())
deleteSQL = fmt.Sprintf("DELETE FROM %v", tableName) deleteSQLWriter.Append(condWriter.Args()...)
} }
var orderSQL string orderSQLWriter := builder.NewWriter()
if len(session.statement.OrderStr) > 0 { if err := session.statement.WriteOrderBy(orderSQLWriter); err != nil {
orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr) return 0, err
} }
if pLimitN != nil && *pLimitN > 0 { if pLimitN != nil && *pLimitN > 0 {
limitNValue := *pLimitN limitNValue := *pLimitN
orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue) if _, err := fmt.Fprintf(orderSQLWriter, " LIMIT %d", limitNValue); err != nil {
return 0, err
}
} }
if len(orderSQL) > 0 { orderCondWriter := builder.NewWriter()
if orderSQLWriter.Len() > 0 {
switch session.engine.dialect.URI().DBType { switch session.engine.dialect.URI().DBType {
case schemas.POSTGRES: case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) if condWriter.Len() > 0 {
if len(condSQL) > 0 { fmt.Fprintf(orderCondWriter, " AND ")
deleteSQL += " AND " + inSQL
} else { } else {
deleteSQL += " WHERE " + inSQL fmt.Fprintf(orderCondWriter, " WHERE ")
} }
fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String())
orderCondWriter.Append(orderSQLWriter.Args()...)
case schemas.SQLITE: case schemas.SQLITE:
inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) if condWriter.Len() > 0 {
if len(condSQL) > 0 { fmt.Fprintf(orderCondWriter, " AND ")
deleteSQL += " AND " + inSQL
} else { } else {
deleteSQL += " WHERE " + inSQL fmt.Fprintf(orderCondWriter, " WHERE ")
} }
fmt.Fprintf(orderCondWriter, "rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQLWriter.String())
// TODO: how to handle delete limit on mssql? // TODO: how to handle delete limit on mssql?
case schemas.MSSQL: case schemas.MSSQL:
return 0, ErrNotImplemented return 0, ErrNotImplemented
default: default:
deleteSQL += orderSQL fmt.Fprint(orderCondWriter, orderSQLWriter.String())
orderCondWriter.Append(orderSQLWriter.Args()...)
} }
} }
var realSQL string realSQLWriter := builder.NewWriter()
argsForCache := make([]interface{}, 0, len(condArgs)*2) argsForCache := make([]interface{}, 0, len(deleteSQLWriter.Args())*2)
copy(argsForCache, deleteSQLWriter.Args())
argsForCache = append(deleteSQLWriter.Args(), argsForCache...)
if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled
realSQL = deleteSQL if err := utils.WriteBuilder(realSQLWriter, deleteSQLWriter, orderCondWriter); err != nil {
copy(argsForCache, condArgs) return 0, err
argsForCache = append(condArgs, argsForCache...) }
} else { } else {
// !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches.
copy(argsForCache, condArgs)
argsForCache = append(condArgs, argsForCache...)
deletedColumn := table.DeletedColumn() deletedColumn := table.DeletedColumn()
realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", if _, err := fmt.Fprintf(realSQLWriter, "UPDATE %v SET %v = ? WHERE %v",
session.engine.Quote(session.statement.TableName()), session.engine.Quote(session.statement.TableName()),
session.engine.Quote(deletedColumn.Name), session.engine.Quote(deletedColumn.Name),
condSQL) condWriter.String()); err != nil {
return 0, err
if len(orderSQL) > 0 {
switch session.engine.dialect.URI().DBType {
case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
realSQL += " AND " + inSQL
} else {
realSQL += " WHERE " + inSQL
}
case schemas.SQLITE:
inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
realSQL += " AND " + inSQL
} else {
realSQL += " WHERE " + inSQL
}
// TODO: how to handle delete limit on mssql?
case schemas.MSSQL:
return 0, ErrNotImplemented
default:
realSQL += orderSQL
}
} }
// !oinume! Insert nowTime to the head of session.statement.Params
condArgs = append(condArgs, "")
paramsLen := len(condArgs)
copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1])
val, t, err := session.engine.nowTime(deletedColumn) val, t, err := session.engine.nowTime(deletedColumn)
if err != nil { if err != nil {
return 0, err return 0, err
} }
condArgs[0] = val realSQLWriter.Append(val)
realSQLWriter.Append(condWriter.Args()...)
var colName = deletedColumn.Name if err := utils.WriteBuilder(realSQLWriter, orderCondWriter); err != nil {
return 0, err
}
colName := deletedColumn.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)
@ -232,11 +226,11 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
} }
if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache {
_ = session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) _ = session.cacheDelete(table, tableNameNoQuote, deleteSQLWriter.String(), argsForCache...)
} }
session.statement.RefTable = table session.statement.RefTable = table
res, err := session.exec(realSQL, condArgs...) res, err := session.exec(realSQLWriter.String(), realSQLWriter.Args()...)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -60,9 +60,7 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte
if len(session.statement.ColumnMap) > 0 && !session.statement.IsDistinct { if len(session.statement.ColumnMap) > 0 && !session.statement.IsDistinct {
session.statement.ColumnMap = []string{} session.statement.ColumnMap = []string{}
} }
if session.statement.OrderStr != "" { session.statement.ResetOrderBy()
session.statement.OrderStr = ""
}
if session.statement.LimitN != nil { if session.statement.LimitN != nil {
session.statement.LimitN = nil session.statement.LimitN = nil
} }
@ -85,15 +83,15 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
var isSlice = sliceValue.Kind() == reflect.Slice isSlice := sliceValue.Kind() == reflect.Slice
var isMap = sliceValue.Kind() == reflect.Map isMap := sliceValue.Kind() == reflect.Map
if !isSlice && !isMap { if !isSlice && !isMap {
return errors.New("needs a pointer to a slice or a map") return errors.New("needs a pointer to a slice or a map")
} }
sliceElementType := sliceValue.Type().Elem() sliceElementType := sliceValue.Type().Elem()
var tp = tpStruct tp := tpStruct
if session.statement.RefTable == nil { if session.statement.RefTable == nil {
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct { if sliceElementType.Elem().Kind() == reflect.Struct {
@ -190,7 +188,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
return err return err
} }
var newElemFunc = func(fields []string) reflect.Value { newElemFunc := func(fields []string) reflect.Value {
return utils.New(elemType, len(fields), len(fields)) return utils.New(elemType, len(fields), len(fields))
} }
@ -235,7 +233,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
} }
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
var newValue = newElemFunc(fields) newValue := newElemFunc(fields)
tb, err := session.engine.tagParser.ParseWithCache(newValue) tb, err := session.engine.tagParser.ParseWithCache(newValue)
if err != nil { if err != nil {
return err return err
@ -249,7 +247,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
} }
for rows.Next() { for rows.Next() {
var newValue = newElemFunc(fields) newValue := newElemFunc(fields)
bean := newValue.Interface() bean := newValue.Interface()
switch elemType.Kind() { switch elemType.Kind() {
@ -285,7 +283,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} }
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr) sqlStr = filter.Do(session.ctx, sqlStr)
} }
newsql := session.statement.ConvertIDSQL(sqlStr) newsql := session.statement.ConvertIDSQL(sqlStr)
@ -310,7 +308,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache")
return ErrCacheFailed return ErrCacheFailed
} }
var res = make([]string, len(table.PrimaryKeys)) res := make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res) err = rows.ScanSlice(&res)
if err != nil { if err != nil {
return err return err
@ -342,7 +340,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ididxes := make(map[string]int) ididxes := make(map[string]int)
var ides []schemas.PK var ides []schemas.PK
var temps = make([]interface{}, len(ids)) temps := make([]interface{}, len(ids))
for idx, id := range ids { for idx, id := range ids {
sid, err := id.ToString() sid, err := id.ToString()
@ -457,7 +455,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean))))
} }
} else if sliceValue.Kind() == reflect.Map { } else if sliceValue.Kind() == reflect.Map {
var key = ids[j] key := ids[j]
keyType := sliceValue.Type().Key() keyType := sliceValue.Type().Key()
keyValue := reflect.New(keyType) keyValue := reflect.New(keyType)
var ikey interface{} var ikey interface{}

View File

@ -278,7 +278,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
} }
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr) sqlStr = filter.Do(session.ctx, sqlStr)
} }
newsql := session.statement.ConvertIDSQL(sqlStr) newsql := session.statement.ConvertIDSQL(sqlStr)
if newsql == "" { if newsql == "" {

View File

@ -353,15 +353,15 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
if needCommit { }
if err := session.Commit(); err != nil { if needCommit {
return 0, err if err := session.Commit(); err != nil {
} return 0, err
}
if id == 0 {
return 0, errors.New("insert successfully but not returned id")
} }
} }
if id == 0 {
return 0, errors.New("insert successfully but not returned id")
}
defer handleAfterInsertProcessorFunc(bean) defer handleAfterInsertProcessorFunc(bean)

View File

@ -13,7 +13,7 @@ import (
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
*sqlStr = filter.Do(*sqlStr) *sqlStr = filter.Do(session.ctx, *sqlStr)
} }
session.lastSQL = *sqlStr session.lastSQL = *sqlStr

View File

@ -34,7 +34,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
return ErrCacheFailed return ErrCacheFailed
} }
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
newsql = filter.Do(newsql) newsql = filter.Do(session.ctx, newsql)
} }
session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql) session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql)
@ -60,7 +60,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
ids = make([]schemas.PK, 0) ids = make([]schemas.PK, 0)
for rows.Next() { for rows.Next() {
var res = make([]string, len(table.PrimaryKeys)) res := make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res) err = rows.ScanSlice(&res)
if err != nil { if err != nil {
return err return err
@ -145,9 +145,10 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
// Update records, bean's non-empty fields are updated contents, // Update records, bean's non-empty fields are updated contents,
// condiBean' non-empty filds are conditions // condiBean' non-empty filds are conditions
// CAUTION: // CAUTION:
// 1.bool will defaultly be updated content nor conditions //
// You should call UseBool if you have bool to use. // 1.bool will defaultly be updated content nor conditions
// 2.float32 & float64 may be not inexact as conditions // You should call UseBool if you have bool to use.
// 2.float32 & float64 may be not inexact as conditions
func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
@ -176,8 +177,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
// -- // --
var err error var err error
var isMap = t.Kind() == reflect.Map isMap := t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct isStruct := t.Kind() == reflect.Struct
if isStruct { if isStruct {
if err := session.statement.SetRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return 0, err return 0, err
@ -226,7 +227,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
args = append(args, val) args = append(args, val)
} }
var colName = col.Name colName := col.Name
if isStruct { if isStruct {
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
@ -258,10 +259,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp) colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp)
case *builder.Builder: case *builder.Builder:
subQuery, subArgs, err := session.statement.GenCondSQL(tp) subQuery, subArgs, err := builder.ToSQL(tp)
if err != nil { if err != nil {
return 0, err return 0, err
} }
subQuery = session.statement.ReplaceQuote(subQuery)
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")") colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")")
args = append(args, subArgs...) args = append(args, subArgs...)
default: default:
@ -279,7 +281,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condBeanIsStruct := false condBeanIsStruct := false
if len(condiBean) > 0 { if len(condiBean) > 0 {
if c, ok := condiBean[0].(map[string]interface{}); ok { if c, ok := condiBean[0].(map[string]interface{}); ok {
var eq = make(builder.Eq) eq := make(builder.Eq)
for k, v := range c { for k, v := range c {
eq[session.engine.Quote(k)] = v eq[session.engine.Quote(k)] = v
} }
@ -323,11 +325,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
st := session.statement st := session.statement
var ( var (
sqlStr string
condArgs []interface{}
condSQL string
cond = session.statement.Conds().And(autoCond) cond = session.statement.Conds().And(autoCond)
doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion) doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion)
verValue *reflect.Value verValue *reflect.Value
) )
@ -347,70 +345,65 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, ErrNoColumnsTobeUpdated return 0, ErrNoColumnsTobeUpdated
} }
condSQL, condArgs, err = session.statement.GenCondSQL(cond) whereWriter := builder.NewWriter()
if err != nil { if cond.IsValid() {
fmt.Fprint(whereWriter, "WHERE ")
}
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
return 0, err
}
if err := st.WriteOrderBy(whereWriter); err != nil {
return 0, err return 0, err
} }
if len(condSQL) > 0 { tableName := session.statement.TableName()
condSQL = "WHERE " + condSQL
}
if st.OrderStr != "" {
condSQL += fmt.Sprintf(" ORDER BY %v", st.OrderStr)
}
var tableName = session.statement.TableName()
// TODO: Oracle support needed // TODO: Oracle support needed
var top string var top string
if st.LimitN != nil { if st.LimitN != nil {
limitValue := *st.LimitN limitValue := *st.LimitN
switch session.engine.dialect.URI().DBType { switch session.engine.dialect.URI().DBType {
case schemas.MYSQL: case schemas.MYSQL:
condSQL += fmt.Sprintf(" LIMIT %d", limitValue) fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
case schemas.SQLITE: case schemas.SQLITE:
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil { whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
return 0, err return 0, err
} }
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
case schemas.POSTGRES: case schemas.POSTGRES:
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil { whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
return 0, err return 0, err
} }
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
case schemas.MSSQL: case schemas.MSSQL:
if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 { if st.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
session.engine.Quote(tableName), condSQL), condArgs...) session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)
condSQL, condArgs, err = session.statement.GenCondSQL(cond) whereWriter = builder.NewWriter()
if err != nil { fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
return 0, err return 0, err
} }
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else { } else {
top = fmt.Sprintf("TOP (%d) ", limitValue) top = fmt.Sprintf("TOP (%d) ", limitValue)
} }
} }
} }
var tableAlias = session.engine.Quote(tableName) tableAlias := session.engine.Quote(tableName)
var fromSQL string var fromSQL string
if session.statement.TableAlias != "" { if session.statement.TableAlias != "" {
switch session.engine.dialect.URI().DBType { switch session.engine.dialect.URI().DBType {
@ -422,14 +415,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v%v", updateWriter := builder.NewWriter()
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v",
top, top,
tableAlias, tableAlias,
strings.Join(colNames, ", "), strings.Join(colNames, ", "),
fromSQL, fromSQL); err != nil {
condSQL) return 0, err
}
if err := utils.WriteBuilder(updateWriter, whereWriter); err != nil {
return 0, err
}
res, err := session.exec(sqlStr, append(args, condArgs...)...) res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...)
if err != nil { if err != nil {
return 0, err return 0, err
} else if doIncVer { } else if doIncVer {
@ -535,7 +533,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
} }
args = append(args, val) args = append(args, val)
var colName = col.Name colName := col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)

View File

@ -557,6 +557,29 @@ func TestParseWithJSON(t *testing.T) {
assert.True(t, table.Columns()[0].IsJSON) assert.True(t, table.Columns()[0].IsJSON)
} }
func TestParseWithJSONB(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("postgres"),
names.GonicMapper{
"JSONB": true,
},
names.SnakeMapper{},
caches.NewManager(),
)
type StructWithJSONB struct {
Default1 []string `db:"jsonb"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithJSONB)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_jsonb", table.Name)
assert.EqualValues(t, 1, len(table.Columns()))
assert.EqualValues(t, "default1", table.Columns()[0].Name)
assert.True(t, table.Columns()[0].IsJSON)
}
func TestParseWithSQLType(t *testing.T) { func TestParseWithSQLType(t *testing.T) {
parser := NewParser( parser := NewParser(
"db", "db",

View File

@ -99,33 +99,31 @@ type Context struct {
// Handler describes tag handler for XORM // Handler describes tag handler for XORM
type Handler func(ctx *Context) error type Handler func(ctx *Context) error
var ( // defaultTagHandlers enumerates all the default tag handler
// defaultTagHandlers enumerates all the default tag handler var defaultTagHandlers = map[string]Handler{
defaultTagHandlers = map[string]Handler{ "-": IgnoreHandler,
"-": IgnoreHandler, "<-": OnlyFromDBTagHandler,
"<-": OnlyFromDBTagHandler, "->": OnlyToDBTagHandler,
"->": OnlyToDBTagHandler, "PK": PKTagHandler,
"PK": PKTagHandler, "NULL": NULLTagHandler,
"NULL": NULLTagHandler, "NOT": NotTagHandler,
"NOT": NotTagHandler, "AUTOINCR": AutoIncrTagHandler,
"AUTOINCR": AutoIncrTagHandler, "DEFAULT": DefaultTagHandler,
"DEFAULT": DefaultTagHandler, "CREATED": CreatedTagHandler,
"CREATED": CreatedTagHandler, "UPDATED": UpdatedTagHandler,
"UPDATED": UpdatedTagHandler, "DELETED": DeletedTagHandler,
"DELETED": DeletedTagHandler, "VERSION": VersionTagHandler,
"VERSION": VersionTagHandler, "UTC": UTCTagHandler,
"UTC": UTCTagHandler, "LOCAL": LocalTagHandler,
"LOCAL": LocalTagHandler, "NOTNULL": NotNullTagHandler,
"NOTNULL": NotNullTagHandler, "INDEX": IndexTagHandler,
"INDEX": IndexTagHandler, "UNIQUE": UniqueTagHandler,
"UNIQUE": UniqueTagHandler, "CACHE": CacheTagHandler,
"CACHE": CacheTagHandler, "NOCACHE": NoCacheTagHandler,
"NOCACHE": NoCacheTagHandler, "COMMENT": CommentTagHandler,
"COMMENT": CommentTagHandler, "EXTENDS": ExtendsTagHandler,
"EXTENDS": ExtendsTagHandler, "UNSIGNED": UnsignedTagHandler,
"UNSIGNED": UnsignedTagHandler, }
}
)
func init() { func init() {
for k := range schemas.SqlTypes { for k := range schemas.SqlTypes {
@ -287,7 +285,7 @@ func CommentTagHandler(ctx *Context) error {
// SQLTypeTagHandler describes SQL Type tag handler // SQLTypeTagHandler describes SQL Type tag handler
func SQLTypeTagHandler(ctx *Context) error { func SQLTypeTagHandler(ctx *Context) error {
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname}
if ctx.tagUname == "JSON" { if ctx.tagUname == "JSON" || ctx.tagUname == "JSONB" {
ctx.col.IsJSON = true ctx.col.IsJSON = true
} }
if len(ctx.params) == 0 { if len(ctx.params) == 0 {
@ -312,16 +310,16 @@ func SQLTypeTagHandler(ctx *Context) error {
default: default:
var err error var err error
if len(ctx.params) == 2 { if len(ctx.params) == 2 {
ctx.col.Length, err = strconv.Atoi(ctx.params[0]) ctx.col.Length, err = strconv.ParseInt(ctx.params[0], 10, 64)
if err != nil { if err != nil {
return err return err
} }
ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) ctx.col.Length2, err = strconv.ParseInt(ctx.params[1], 10, 64)
if err != nil { if err != nil {
return err return err
} }
} else if len(ctx.params) == 1 { } else if len(ctx.params) == 1 {
ctx.col.Length, err = strconv.Atoi(ctx.params[0]) ctx.col.Length, err = strconv.ParseInt(ctx.params[0], 10, 64)
if err != nil { if err != nil {
return err return err
} }
@ -332,8 +330,8 @@ func SQLTypeTagHandler(ctx *Context) error {
// ExtendsTagHandler describes extends tag handler // ExtendsTagHandler describes extends tag handler
func ExtendsTagHandler(ctx *Context) error { func ExtendsTagHandler(ctx *Context) error {
var fieldValue = ctx.fieldValue fieldValue := ctx.fieldValue
var isPtr = false isPtr := false
switch fieldValue.Kind() { switch fieldValue.Kind() {
case reflect.Ptr: case reflect.Ptr:
f := fieldValue.Type().Elem() f := fieldValue.Type().Elem()
@ -355,7 +353,7 @@ func ExtendsTagHandler(ctx *Context) error {
col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName)
col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...) col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...)
var tagPrefix = ctx.col.FieldName tagPrefix := ctx.col.FieldName
if len(ctx.params) > 0 { if len(ctx.params) > 0 {
col.Nullable = isPtr col.Nullable = isPtr
tagPrefix = strings.Trim(ctx.params[0], "'") tagPrefix = strings.Trim(ctx.params[0], "'")
@ -378,7 +376,7 @@ func ExtendsTagHandler(ctx *Context) error {
} }
} }
default: default:
//TODO: warning // TODO: warning
} }
return ErrIgnoreField return ErrIgnoreField
} }