diff --git a/.drone.yml b/.drone.yml deleted file mode 100644 index 2bad4b5a..00000000 --- a/.drone.yml +++ /dev/null @@ -1,437 +0,0 @@ ---- -kind: pipeline -name: test-mysql -environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.io" - CGO_ENABLED: 1 -trigger: - ref: - - refs/heads/master - - refs/pull/*/head -steps: -- name: test-vet - image: golang:1.17 - pull: always - volumes: - - name: cache - path: /go/pkg/mod - commands: - - make vet -- name: test-sqlite3 - image: golang:1.17 - volumes: - - name: cache - path: /go/pkg/mod - depends_on: - - test-vet - commands: - - make fmt-check - - make test - - make test-sqlite3 - - TEST_CACHE_ENABLE=true make test-sqlite3 -- name: test-sqlite - image: golang:1.17 - volumes: - - name: cache - path: /go/pkg/mod - depends_on: - - test-vet - commands: - - make test-sqlite - - TEST_QUOTE_POLICY=reserved make test-sqlite -- name: test-mysql - image: golang:1.17 - pull: never - volumes: - - name: cache - path: /go/pkg/mod - depends_on: - - test-vet - environment: - TEST_MYSQL_HOST: mysql - TEST_MYSQL_CHARSET: utf8 - TEST_MYSQL_DBNAME: xorm_test - TEST_MYSQL_USERNAME: root - TEST_MYSQL_PASSWORD: - commands: - - TEST_CACHE_ENABLE=true make test-mysql - -- name: test-mysql-utf8mb4 - image: golang:1.17 - pull: never - volumes: - - name: cache - path: /go/pkg/mod - depends_on: - - test-mysql - environment: - TEST_MYSQL_HOST: mysql - TEST_MYSQL_CHARSET: utf8mb4 - TEST_MYSQL_DBNAME: xorm_test - TEST_MYSQL_USERNAME: root - TEST_MYSQL_PASSWORD: - commands: - - make test-mysql - - TEST_QUOTE_POLICY=reserved make test-mysql-tls - -volumes: -- name: cache - host: - path: /tmp/cache - -services: -- name: mysql - image: mysql:5.7 - environment: - MYSQL_ALLOW_EMPTY_PASSWORD: yes - MYSQL_DATABASE: xorm_test - ---- -kind: pipeline -name: test-mysql8 -depends_on: - - test-mysql -trigger: - ref: - - refs/heads/master - - refs/pull/*/head -steps: -- name: test-mysql8 - image: golang:1.17 - pull: never - volumes: - - name: cache - path: /go/pkg/mod - environment: - TEST_MYSQL_HOST: mysql8 - TEST_MYSQL_CHARSET: utf8mb4 - TEST_MYSQL_DBNAME: xorm_test - TEST_MYSQL_USERNAME: root - TEST_MYSQL_PASSWORD: - commands: - - make test-mysql - - TEST_CACHE_ENABLE=true make test-mysql - -volumes: -- name: cache - host: - path: /tmp/cache - -services: -- name: mysql8 - image: mysql:8.0 - environment: - MYSQL_ALLOW_EMPTY_PASSWORD: yes - MYSQL_DATABASE: xorm_test - ---- -kind: pipeline -name: test-mariadb -depends_on: - - test-mysql8 -trigger: - ref: - - refs/heads/master - - refs/pull/*/head -steps: -- name: test-mariadb - image: golang:1.17 - pull: never - volumes: - - name: cache - path: /go/pkg/mod - environment: - TEST_MYSQL_HOST: mariadb - TEST_MYSQL_CHARSET: utf8mb4 - TEST_MYSQL_DBNAME: xorm_test - TEST_MYSQL_USERNAME: root - TEST_MYSQL_PASSWORD: - commands: - - make test-mysql - - TEST_QUOTE_POLICY=reserved make test-mysql - -volumes: -- name: cache - host: - path: /tmp/cache - -services: -- name: mariadb - image: mariadb:10.4 - environment: - MYSQL_ALLOW_EMPTY_PASSWORD: yes - MYSQL_DATABASE: xorm_test - ---- -kind: pipeline -name: test-postgres -depends_on: - - test-mariadb -trigger: - ref: - - refs/heads/master - - refs/pull/*/head -steps: -- name: test-postgres - pull: never - image: golang:1.17 - volumes: - - name: cache - path: /go/pkg/mod - environment: - TEST_PGSQL_HOST: pgsql - TEST_PGSQL_DBNAME: xorm_test - TEST_PGSQL_USERNAME: postgres - TEST_PGSQL_PASSWORD: postgres - commands: - - make test-postgres - - TEST_CACHE_ENABLE=true make test-postgres - -- name: test-postgres-schema - pull: never - image: golang:1.17 - volumes: - - name: cache - path: /go/pkg/mod - depends_on: - - test-postgres - environment: - TEST_PGSQL_HOST: pgsql - TEST_PGSQL_SCHEMA: xorm - TEST_PGSQL_DBNAME: xorm_test - TEST_PGSQL_USERNAME: postgres - TEST_PGSQL_PASSWORD: postgres - commands: - - TEST_QUOTE_POLICY=reserved make test-postgres - -- name: test-pgx - pull: never - image: golang:1.17 - volumes: - - name: cache - path: /go/pkg/mod - depends_on: - - test-postgres-schema - environment: - TEST_PGSQL_HOST: pgsql - TEST_PGSQL_DBNAME: xorm_test - TEST_PGSQL_USERNAME: postgres - TEST_PGSQL_PASSWORD: postgres - commands: - - make test-pgx - - TEST_CACHE_ENABLE=true make test-pgx - - TEST_QUOTE_POLICY=reserved make test-pgx - -- name: test-pgx-schema - pull: never - image: golang:1.17 - volumes: - - name: cache - path: /go/pkg/mod - depends_on: - - test-pgx - environment: - TEST_PGSQL_HOST: pgsql - TEST_PGSQL_SCHEMA: xorm - TEST_PGSQL_DBNAME: xorm_test - TEST_PGSQL_USERNAME: postgres - TEST_PGSQL_PASSWORD: postgres - commands: - - make test-pgx - - TEST_CACHE_ENABLE=true make test-pgx - - TEST_QUOTE_POLICY=reserved make test-pgx - -volumes: -- name: cache - host: - path: /tmp/cache - -services: -- name: pgsql - image: postgres:9.5 - environment: - POSTGRES_DB: xorm_test - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - ---- -kind: pipeline -name: test-mssql -depends_on: - - test-postgres -trigger: - ref: - - refs/heads/master - - refs/pull/*/head -steps: -- name: test-mssql - pull: never - image: golang:1.17 - volumes: - - name: cache - path: /go/pkg/mod - environment: - TEST_MSSQL_HOST: mssql - TEST_MSSQL_DBNAME: xorm_test - TEST_MSSQL_USERNAME: sa - TEST_MSSQL_PASSWORD: "yourStrong(!)Password" - commands: - - make test-mssql - - TEST_MSSQL_DEFAULT_VARCHAR=NVARCHAR TEST_MSSQL_DEFAULT_CHAR=NCHAR make test-mssql - -volumes: -- name: cache - host: - path: /tmp/cache - -services: -- name: mssql - pull: always - image: mcr.microsoft.com/mssql/server:latest - environment: - ACCEPT_EULA: Y - SA_PASSWORD: yourStrong(!)Password - MSSQL_PID: Standard - ---- -kind: pipeline -name: test-tidb -depends_on: - - test-mssql -trigger: - ref: - - refs/heads/master - - refs/pull/*/head -steps: -- name: test-tidb - pull: never - image: golang:1.17 - volumes: - - name: cache - path: /go/pkg/mod - environment: - TEST_TIDB_HOST: "tidb:4000" - TEST_TIDB_DBNAME: xorm_test - TEST_TIDB_USERNAME: root - TEST_TIDB_PASSWORD: - commands: - - make test-tidb - -volumes: -- name: cache - host: - path: /tmp/cache - -services: -- name: tidb - image: pingcap/tidb:v3.0.3 - ---- -kind: pipeline -name: test-cockroach -depends_on: - - test-tidb -trigger: - ref: - - refs/heads/master - - refs/pull/*/head -steps: -- name: test-cockroach - pull: never - image: golang:1.17 - volumes: - - name: cache - path: /go/pkg/mod - environment: - TEST_COCKROACH_HOST: "cockroach:26257" - TEST_COCKROACH_DBNAME: xorm_test - TEST_COCKROACH_USERNAME: root - TEST_COCKROACH_PASSWORD: - commands: - - sleep 10 - - make test-cockroach - -volumes: -- name: cache - host: - path: /tmp/cache - -services: -- name: cockroach - image: cockroachdb/cockroach:v19.2.4 - commands: - - /cockroach/cockroach start --insecure - -# --- -# kind: pipeline -# name: test-dameng -# depends_on: -# - test-cockroach -# trigger: -# ref: -# - refs/heads/master -# - refs/pull/*/head -# steps: -# - name: test-dameng -# pull: never -# image: golang:1.17 -# volumes: -# - name: cache -# path: /go/pkg/mod -# environment: -# TEST_DAMENG_HOST: "dameng:5236" -# TEST_DAMENG_USERNAME: SYSDBA -# TEST_DAMENG_PASSWORD: SYSDBA -# commands: -# - sleep 30 -# - make test-dameng - -# volumes: -# - name: cache -# host: -# path: /tmp/cache - -# services: -# - name: dameng -# image: lunny/dm:v1.0 -# commands: -# - /bin/bash /startDm.sh - ---- -kind: pipeline -name: merge_coverage -depends_on: - - test-mysql - - test-mysql8 - - test-mariadb - - test-postgres - - test-mssql - - test-tidb - - test-cockroach - #- test-dameng -trigger: - ref: - - refs/heads/master - - refs/pull/*/head -steps: -- name: merge_coverage - image: golang:1.17 - commands: - - make coverage - ---- -kind: pipeline -name: release-tag -trigger: - event: - - tag -steps: -- name: release-tag-gitea - pull: always - image: plugins/gitea-release:latest - settings: - base_url: https://gitea.com - title: '${DRONE_TAG} is released' - api_key: - from_secret: gitea_token \ No newline at end of file diff --git a/.gitea/workflows/release-tag.yml b/.gitea/workflows/release-tag.yml new file mode 100644 index 00000000..10ed831e --- /dev/null +++ b/.gitea/workflows/release-tag.yml @@ -0,0 +1,23 @@ +name: release + +on: + push: + tags: + - '*' + +jobs: + release: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: setup go + uses: https://github.com/actions/setup-go@v4 + with: + go-version: '>=1.20.1' + - name: Use Go Action + id: use-go-action + uses: actions/release-action@main + with: + api_key: '${{secrets.RELEASE_TOKEN}}' \ No newline at end of file diff --git a/.gitea/workflows/test-cockroach.yml b/.gitea/workflows/test-cockroach.yml new file mode 100644 index 00000000..0ca18861 --- /dev/null +++ b/.gitea/workflows/test-cockroach.yml @@ -0,0 +1,55 @@ +name: test cockroach +on: + push: + branches: + - master + pull_request: + +env: + GOPROXY: https://goproxy.io,direct + GOPATH: /go_path + GOCACHE: /go_cache + +jobs: + test-cockroach: + name: test cockroach + runs-on: ubuntu-latest + steps: + # - name: cache go path + # id: cache-go-path + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_path + # key: go_path-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_path-${{ github.repository }}- + # go_path- + # - name: cache go cache + # id: cache-go-cache + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_cache + # key: go_cache-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_cache-${{ github.repository }}- + # go_cache- + - uses: actions/setup-go@v3 + with: + go-version: 1.20 + - uses: https://github.com/actions/checkout@v3 + - name: test cockroach + env: + TEST_COCKROACH_HOST: "cockroach:26257" + TEST_COCKROACH_DBNAME: xorm_test + TEST_COCKROACH_USERNAME: root + TEST_COCKROACH_PASSWORD: + run: sleep 20 && make test-cockroach + + services: + cockroach: + image: cockroachdb/cockroach:v19.2.4 + ports: + - 26257:26257 + cmd: + - 'start' + - '--insecure' diff --git a/.gitea/workflows/test-mariadb.yml b/.gitea/workflows/test-mariadb.yml new file mode 100644 index 00000000..466f3858 --- /dev/null +++ b/.gitea/workflows/test-mariadb.yml @@ -0,0 +1,56 @@ +name: test mariadb +on: + push: + branches: + - master + pull_request: + +env: + GOPROXY: https://goproxy.io,direct + GOPATH: /go_path + GOCACHE: /go_cache + +jobs: + lint: + name: test mariadb + runs-on: ubuntu-latest + steps: + # - name: cache go path + # id: cache-go-path + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_path + # key: go_path-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_path-${{ github.repository }}- + # go_path- + # - name: cache go cache + # id: cache-go-cache + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_cache + # key: go_cache-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_cache-${{ github.repository }}- + # go_cache- + - uses: actions/setup-go@v3 + with: + go-version: 1.20 + - uses: https://github.com/actions/checkout@v3 + - name: test mariadb + env: + TEST_MYSQL_HOST: mariadb + TEST_MYSQL_CHARSET: utf8mb4 + TEST_MYSQL_DBNAME: xorm_test + TEST_MYSQL_USERNAME: root + TEST_MYSQL_PASSWORD: + run: TEST_QUOTE_POLICY=reserved make test-mysql + + services: + mariadb: + image: mariadb:10.4 + env: + MYSQL_ALLOW_EMPTY_PASSWORD: yes + MYSQL_DATABASE: xorm_test + ports: + - 3306:3306 \ No newline at end of file diff --git a/.gitea/workflows/test-mssql.yml b/.gitea/workflows/test-mssql.yml new file mode 100644 index 00000000..d02e6956 --- /dev/null +++ b/.gitea/workflows/test-mssql.yml @@ -0,0 +1,56 @@ +name: test mssql +on: + push: + branches: + - master + pull_request: + +env: + GOPROXY: https://goproxy.io,direct + GOPATH: /go_path + GOCACHE: /go_cache + +jobs: + test-mssql: + name: test mssql + runs-on: ubuntu-latest + steps: + # - name: cache go path + # id: cache-go-path + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_path + # key: go_path-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_path-${{ github.repository }}- + # go_path- + # - name: cache go cache + # id: cache-go-cache + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_cache + # key: go_cache-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_cache-${{ github.repository }}- + # go_cache- + - uses: actions/setup-go@v3 + with: + go-version: 1.20 + - uses: https://github.com/actions/checkout@v3 + - name: test mssql + env: + TEST_MSSQL_HOST: mssql + TEST_MSSQL_DBNAME: xorm_test + TEST_MSSQL_USERNAME: sa + TEST_MSSQL_PASSWORD: "yourStrong(!)Password" + run: TEST_MSSQL_DEFAULT_VARCHAR=NVARCHAR TEST_MSSQL_DEFAULT_CHAR=NCHAR make test-mssql + + services: + mssql: + image: mcr.microsoft.com/mssql/server:latest + env: + ACCEPT_EULA: Y + SA_PASSWORD: yourStrong(!)Password + MSSQL_PID: Standard + ports: + - 1433:1433 \ No newline at end of file diff --git a/.gitea/workflows/test-mysql.yml b/.gitea/workflows/test-mysql.yml new file mode 100644 index 00000000..03ee2725 --- /dev/null +++ b/.gitea/workflows/test-mysql.yml @@ -0,0 +1,56 @@ +name: test mysql +on: + push: + branches: + - master + pull_request: + +env: + GOPROXY: https://goproxy.io,direct + GOPATH: /go_path + GOCACHE: /go_cache + +jobs: + test-mysql: + name: test mysql + runs-on: ubuntu-latest + steps: + # - name: cache go path + # id: cache-go-path + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_path + # key: go_path-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_path-${{ github.repository }}- + # go_path- + # - name: cache go cache + # id: cache-go-cache + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_cache + # key: go_cache-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_cache-${{ github.repository }}- + # go_cache- + - uses: actions/setup-go@v3 + with: + go-version: 1.20 + - uses: https://github.com/actions/checkout@v3 + - name: test mysql utf8mb4 + env: + TEST_MYSQL_HOST: mysql + TEST_MYSQL_CHARSET: utf8mb4 + TEST_MYSQL_DBNAME: xorm_test + TEST_MYSQL_USERNAME: root + TEST_MYSQL_PASSWORD: + run: TEST_QUOTE_POLICY=reserved make test-mysql-tls + + services: + mysql: + image: mysql:5.7 + env: + MYSQL_ALLOW_EMPTY_PASSWORD: yes + MYSQL_DATABASE: xorm_test + ports: + - 3306:3306 \ No newline at end of file diff --git a/.gitea/workflows/test-mysql8.yml b/.gitea/workflows/test-mysql8.yml new file mode 100644 index 00000000..3fbd7c30 --- /dev/null +++ b/.gitea/workflows/test-mysql8.yml @@ -0,0 +1,56 @@ +name: test mysql8 +on: + push: + branches: + - master + pull_request: + +env: + GOPROXY: https://goproxy.io,direct + GOPATH: /go_path + GOCACHE: /go_cache + +jobs: + lint: + name: test mysql8 + runs-on: ubuntu-latest + steps: + # - name: cache go path + # id: cache-go-path + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_path + # key: go_path-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_path-${{ github.repository }}- + # go_path- + # - name: cache go cache + # id: cache-go-cache + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_cache + # key: go_cache-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_cache-${{ github.repository }}- + # go_cache- + - uses: actions/setup-go@v3 + with: + go-version: 1.20 + - uses: https://github.com/actions/checkout@v3 + - name: test mysql8 + env: + TEST_MYSQL_HOST: mysql8 + TEST_MYSQL_CHARSET: utf8mb4 + TEST_MYSQL_DBNAME: xorm_test + TEST_MYSQL_USERNAME: root + TEST_MYSQL_PASSWORD: + run: TEST_CACHE_ENABLE=true make test-mysql + + services: + mysql8: + image: mysql:8.0 + env: + MYSQL_ALLOW_EMPTY_PASSWORD: yes + MYSQL_DATABASE: xorm_test + ports: + - 3306:3306 \ No newline at end of file diff --git a/.gitea/workflows/test-postgres.yml b/.gitea/workflows/test-postgres.yml new file mode 100644 index 00000000..89aa72c3 --- /dev/null +++ b/.gitea/workflows/test-postgres.yml @@ -0,0 +1,79 @@ +name: test postgres +on: + push: + branches: + - master + pull_request: + +env: + GOPROXY: https://goproxy.io,direct + GOPATH: /go_path + GOCACHE: /go_cache + +jobs: + lint: + name: test postgres + runs-on: ubuntu-latest + steps: + # - name: cache go path + # id: cache-go-path + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_path + # key: go_path-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_path-${{ github.repository }}- + # go_path- + # - name: cache go cache + # id: cache-go-cache + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_cache + # key: go_cache-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_cache-${{ github.repository }}- + # go_cache- + - uses: actions/setup-go@v3 + with: + go-version: 1.20 + - uses: https://github.com/actions/checkout@v3 + - name: test postgres + env: + TEST_PGSQL_HOST: pgsql + TEST_PGSQL_DBNAME: xorm_test + TEST_PGSQL_USERNAME: postgres + TEST_PGSQL_PASSWORD: postgres + run: TEST_CACHE_ENABLE=true make test-postgres + - name: test postgres with schema + env: + TEST_PGSQL_HOST: pgsql + TEST_PGSQL_SCHEMA: xorm + TEST_PGSQL_DBNAME: xorm_test + TEST_PGSQL_USERNAME: postgres + TEST_PGSQL_PASSWORD: postgres + run: TEST_QUOTE_POLICY=reserved make test-postgres + - name: test pgx + env: + TEST_PGSQL_HOST: pgsql + TEST_PGSQL_DBNAME: xorm_test + TEST_PGSQL_USERNAME: postgres + TEST_PGSQL_PASSWORD: postgres + run: TEST_CACHE_ENABLE=true make test-pgx + - name: test pgx with schema + env: + TEST_PGSQL_HOST: pgsql + TEST_PGSQL_SCHEMA: xorm + TEST_PGSQL_DBNAME: xorm_test + TEST_PGSQL_USERNAME: postgres + TEST_PGSQL_PASSWORD: postgres + run: TEST_QUOTE_POLICY=reserved make test-pgx + + services: + pgsql: + image: postgres:9.5 + env: + POSTGRES_DB: xorm_test + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + ports: + - 5432:5432 \ No newline at end of file diff --git a/.gitea/workflows/test-sqlite.yml b/.gitea/workflows/test-sqlite.yml new file mode 100644 index 00000000..cca2e786 --- /dev/null +++ b/.gitea/workflows/test-sqlite.yml @@ -0,0 +1,49 @@ +name: test sqlite +on: + push: + branches: + - master + pull_request: + +env: + GOPROXY: https://goproxy.io,direct + GOPATH: /go_path + GOCACHE: /go_cache + +jobs: + test-sqlite: + name: unit test & test sqlite + runs-on: ubuntu-latest + steps: + # - name: cache go path + # id: cache-go-path + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_path + # key: go_path-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_path-${{ github.repository }}- + # go_path- + # - name: cache go cache + # id: cache-go-cache + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_cache + # key: go_cache-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_cache-${{ github.repository }}- + # go_cache- + - uses: actions/setup-go@v3 + with: + go-version: 1.20 + - uses: https://github.com/actions/checkout@v3 + - name: vet + run: make vet + - name: format check + run: make fmt-check + - name: unit test + run: make test + - name: test sqlite3 + run: make test-sqlite3 + - name: test sqlite3 with cache + run: TEST_CACHE_ENABLE=true make test-sqlite3 \ No newline at end of file diff --git a/.gitea/workflows/test-tidb.yml b/.gitea/workflows/test-tidb.yml new file mode 100644 index 00000000..fa6e27ad --- /dev/null +++ b/.gitea/workflows/test-tidb.yml @@ -0,0 +1,52 @@ +name: test tidb +on: + push: + branches: + - master + pull_request: + +env: + GOPROXY: https://goproxy.io,direct + GOPATH: /go_path + GOCACHE: /go_cache + +jobs: + test-tidb: + name: test tidb + runs-on: ubuntu-latest + steps: + # - name: cache go path + # id: cache-go-path + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_path + # key: go_path-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_path-${{ github.repository }}- + # go_path- + # - name: cache go cache + # id: cache-go-cache + # uses: https://github.com/actions/cache@v3 + # with: + # path: /go_cache + # key: go_cache-${{ github.repository }}-${{ github.ref_name }} + # restore-keys: | + # go_cache-${{ github.repository }}- + # go_cache- + - uses: actions/setup-go@v3 + with: + go-version: 1.20 + - uses: https://github.com/actions/checkout@v3 + - name: test tidb + env: + TEST_TIDB_HOST: "tidb:4000" + TEST_TIDB_DBNAME: xorm_test + TEST_TIDB_USERNAME: root + TEST_TIDB_PASSWORD: + run: make test-tidb + + services: + tidb: + image: pingcap/tidb:v3.0.3 + ports: + - 4000:4000 \ No newline at end of file diff --git a/convert/interface.go b/convert/interface.go index b0f28c81..2cc8d9f4 100644 --- a/convert/interface.go +++ b/convert/interface.go @@ -24,7 +24,10 @@ func Interface2Interface(userLocation *time.Location, v interface{}) (interface{ return vv.String, nil case *sql.RawBytes: if len([]byte(*vv)) > 0 { - return []byte(*vv), nil + src := []byte(*vv) + dest := make([]byte, len(src)) + copy(dest, src) + return dest, nil } return nil, nil case *sql.NullInt32: diff --git a/dialects/dameng.go b/dialects/dameng.go index 5e92ec2f..23d1836a 100644 --- a/dialects/dameng.go +++ b/dialects/dameng.go @@ -659,7 +659,7 @@ func (db *dameng) DropTableSQL(tableName string) (string, bool) { // ModifyColumnSQL returns a SQL to modify SQL func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, false) + s, _ := ColumnString(db.dialect, col, false, false) return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s) } @@ -692,7 +692,7 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl } } - s, _ := ColumnString(db, col, false) + s, _ := ColumnString(db, col, false, false) if _, err := b.WriteString(s); err != nil { return "", false, err } @@ -709,7 +709,13 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl return "", false, err } } - if _, err := b.WriteString(fmt.Sprintf("CONSTRAINT PK_%s PRIMARY KEY (", tableName)); err != nil { + if _, err := b.WriteString("CONSTRAINT PK_"); err != nil { + return "", false, err + } + if _, err := b.WriteString(tableName); err != nil { + return "", false, err + } + if _, err := b.WriteString(" PRIMARY KEY ("); err != nil { return "", false, err } if err := quoter.JoinWrite(&b, pkList, ","); err != nil { @@ -837,7 +843,11 @@ func addSingleQuote(name string) string { if name[0] == '\'' && name[len(name)-1] == '\'' { return name } - return fmt.Sprintf("'%s'", name) + var b strings.Builder + b.WriteRune('\'') + b.WriteString(name) + b.WriteRune('\'') + return b.String() } func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { diff --git a/dialects/dialect.go b/dialects/dialect.go index 70d599e6..8e512c4f 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -85,8 +85,6 @@ type Dialect interface { AddColumnSQL(tableName string, col *schemas.Column) string ModifyColumnSQL(tableName string, col *schemas.Column) string - ForUpdateSQL(query string) string - Filters() []Filter SetParams(params map[string]string) } @@ -135,7 +133,7 @@ func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, false) b.WriteString(s) if i != len(table.ColumnsSeq())-1 { @@ -209,7 +207,7 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa // AddColumnSQL returns a SQL to add a column func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, true) + s, _ := ColumnString(db.dialect, col, true, false) return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s) } @@ -241,22 +239,15 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { // ModifyColumnSQL returns a SQL to modify SQL func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, false) + s, _ := ColumnString(db.dialect, col, false, false) return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s) } -// ForUpdateSQL returns for updateSQL -func (db *Base) ForUpdateSQL(query string) string { - return query + " FOR UPDATE" -} - // SetParams set params func (db *Base) SetParams(params map[string]string) { } -var ( - dialects = map[string]func() Dialect{} -) +var dialects = map[string]func() Dialect{} // RegisterDialect register database dialect func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) { @@ -307,7 +298,7 @@ func init() { } // ColumnString generate column description string according dialect -func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) (string, error) { +func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey, supportCollation bool) (string, error) { bd := strings.Builder{} if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil { @@ -322,6 +313,15 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) return "", err } + if supportCollation && col.Collation != "" { + if _, err := bd.WriteString(" COLLATE "); err != nil { + return "", err + } + if _, err := bd.WriteString(col.Collation); err != nil { + return "", err + } + } + if includePrimaryKey && col.IsPrimaryKey { if _, err := bd.WriteString(" PRIMARY KEY"); err != nil { return "", err diff --git a/dialects/filter.go b/dialects/filter.go index bfe2e93e..add67c1b 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -5,13 +5,14 @@ package dialects import ( - "fmt" + "context" + "strconv" "strings" ) // Filter is an interface to filter SQL type Filter interface { - Do(sql string) string + Do(ctx context.Context, sql string) string } // SeqFilter filter SQL replace ?, ? ... to $1, $2 ... @@ -28,10 +29,11 @@ func convertQuestionMark(sql, prefix string, start int) string { var isMaybeLineComment bool var isMaybeComment bool var isMaybeCommentEnd bool - var index = start + index := start for _, c := range sql { if !beginSingleQuote && !isLineComment && !isComment && c == '?' { - buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) + buf.WriteString(prefix) + buf.WriteString(strconv.Itoa(index)) index++ } else { if isMaybeLineComment { @@ -71,6 +73,6 @@ func convertQuestionMark(sql, prefix string, start int) string { } // Do implements Filter -func (s *SeqFilter) Do(sql string) string { +func (s *SeqFilter) Do(ctx context.Context, sql string) string { return convertQuestionMark(sql, s.Prefix, s.Start) } diff --git a/dialects/mssql.go b/dialects/mssql.go index 1b6fe692..dcac9c3f 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -428,7 +428,7 @@ func (db *mssql) DropTableSQL(tableName string) (string, bool) { } func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, false) + s, _ := ColumnString(db.dialect, col, false, true) return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s) } @@ -454,7 +454,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, - ISNULL(p.is_primary_key, 0), a.is_identity as is_identity + ISNULL(p.is_primary_key, 0), a.is_identity as is_identity, a.collation_name from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id left join sys.syscomments c on a.default_object_id=c.id @@ -475,9 +475,10 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName colSeq := make([]string, 0) for rows.Next() { var name, ctype, vdefault string + var collation *string var maxLen, precision, scale int64 var nullable, isPK, defaultIsNull, isIncrement bool - err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement) + err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement, &collation) if err != nil { return nil, nil, err } @@ -499,6 +500,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName } else { col.Length = maxLen } + if collation != nil { + col.Collation = *collation + } switch ct { case "DATETIMEOFFSET": col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} @@ -646,7 +650,7 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true) b.WriteString(s) if i != len(table.ColumnsSeq())-1 { @@ -665,10 +669,6 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table return b.String(), true, nil } -func (db *mssql) ForUpdateSQL(query string) string { - return query -} - func (db *mssql) Filters() []Filter { return []Filter{} } diff --git a/dialects/mysql.go b/dialects/mysql.go index 195e1f23..82505707 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -380,12 +380,27 @@ func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableNa func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { quoter := db.dialect.Quoter() - s, _ := ColumnString(db, col, true) - sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), s) + s, _ := ColumnString(db, col, true, true) + var b strings.Builder + b.WriteString("ALTER TABLE ") + quoter.QuoteTo(&b, tableName) + b.WriteString(" ADD ") + b.WriteString(s) if len(col.Comment) > 0 { - sql += " COMMENT '" + col.Comment + "'" + b.WriteString(" COMMENT '") + b.WriteString(col.Comment) + b.WriteString("'") } - return sql + return b.String() +} + +// ModifyColumnSQL returns a SQL to modify SQL +func (db *mysql) ModifyColumnSQL(tableName string, col *schemas.Column) string { + s, _ := ColumnString(db.dialect, col, false, true) + if col.Comment != "" { + s += fmt.Sprintf(" COMMENT '%s'", col.Comment) + } + return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s) } func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { @@ -398,7 +413,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName "SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))" s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " + - alreadyQuoted + " AS NEEDS_QUOTE " + + alreadyQuoted + " AS NEEDS_QUOTE, `COLLATION_NAME` " + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + " ORDER BY `COLUMNS`.ORDINAL_POSITION ASC" @@ -416,8 +431,8 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName var columnName, nullableStr, colType, colKey, extra, comment string var alreadyQuoted, isUnsigned bool - var colDefault, maxLength *string - err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted) + var colDefault, maxLength, collation *string + err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted, &collation) if err != nil { return nil, nil, err } @@ -433,6 +448,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName } else { col.DefaultIsEmpty = true } + if collation != nil { + col.Collation = *collation + } fields := strings.Fields(colType) if len(fields) == 2 && fields[1] == "unsigned" { @@ -525,7 +543,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{db.uri.DBName} - s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + + s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT`, `TABLE_COLLATION` from " + "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" rows, err := queryer.QueryContext(ctx, s, args...) @@ -537,9 +555,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { table := schemas.NewEmptyTable() - var name, engine string + var name, engine, collation string var autoIncr, comment *string - err = rows.Scan(&name, &engine, &autoIncr, &comment) + err = rows.Scan(&name, &engine, &autoIncr, &comment, &collation) if err != nil { return nil, err } @@ -549,6 +567,7 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.Comment = *comment } table.StoreEngine = engine + table.Collation = collation tables = append(tables, table) } if rows.Err() != nil { @@ -640,7 +659,7 @@ func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true) b.WriteString(s) if len(col.Comment) > 0 { diff --git a/dialects/oracle.go b/dialects/oracle.go index 72c26ce2..a5f8a5b2 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -609,7 +609,7 @@ func (db *oracle) IsReserved(name string) bool { } func (db *oracle) DropTableSQL(tableName string) (string, bool) { - return fmt.Sprintf("DROP TABLE `%s`", tableName), false + return fmt.Sprintf("DROP TABLE \"%s\"", tableName), false } func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { @@ -628,7 +628,7 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl /*if col.IsPrimaryKey && len(pkList) == 1 { sql += col.String(b.dialect) } else {*/ - s, _ := ColumnString(db, col, false) + s, _ := ColumnString(db, col, false, false) sql += s // } sql = strings.TrimSpace(sql) @@ -645,6 +645,10 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl return sql, false, nil } +func (db *oracle) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) { + return db.HasRecords(queryer, ctx, `SELECT sequence_name FROM user_sequences WHERE sequence_name = :1`, seqName) +} + func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: diff --git a/dialects/postgres.go b/dialects/postgres.go index 5efe54f4..28196891 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -992,7 +992,7 @@ func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tabl } func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, true) + s, _ := ColumnString(db.dialect, col, true, false) quoter := db.dialect.Quoter() addColumnSQL := "" @@ -1078,7 +1078,7 @@ FROM pg_attribute f LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name -WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` +WHERE n.nspname= s.table_schema AND c.relkind = 'r' AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` schema := db.getSchema() if schema != "" { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 4ff9a39e..62519b6b 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -193,11 +193,11 @@ func (db *sqlite3) Features() *DialectFeatures { func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: - var q = sqlite3Quoter + q := sqlite3Quoter q.IsReserved = schemas.AlwaysNoReserve db.quoter = q case QuotePolicyReserved: - var q = sqlite3Quoter + q := sqlite3Quoter q.IsReserved = db.IsReserved db.quoter = q case QuotePolicyAlways: @@ -291,10 +291,6 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *sqlite3) ForUpdateSQL(query string) string { - return query -} - func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { query := "SELECT * FROM " + tableName + " LIMIT 0" rows, err := queryer.QueryContext(ctx, query) @@ -320,7 +316,7 @@ func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tabl // splitColStr splits a sqlite col strings as fields func splitColStr(colStr string) []string { colStr = strings.TrimSpace(colStr) - var results = make([]string, 0, 10) + results := make([]string, 0, 10) var lastIdx int var hasC, hasQuote bool for i, c := range colStr { diff --git a/engine.go b/engine.go index 389819e7..aa9d8050 100644 --- a/engine.go +++ b/engine.go @@ -1120,21 +1120,6 @@ func (engine *Engine) UnMapType(t reflect.Type) { engine.tagParser.ClearCacheTable(t) } -// Sync the new struct changes to database, this method will automatically add -// table, column, index, unique. but will not delete or change anything. -// If you change some field, you should change the database manually. -func (engine *Engine) Sync(beans ...interface{}) error { - session := engine.NewSession() - defer session.Close() - return session.Sync(beans...) -} - -// Sync2 synchronize structs to database tables -// Depricated -func (engine *Engine) Sync2(beans ...interface{}) error { - return engine.Sync(beans...) -} - // CreateTables create tabls according bean func (engine *Engine) CreateTables(beans ...interface{}) error { session := engine.NewSession() diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 730a424e..86ed7344 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -289,6 +289,48 @@ func TestGetColumnsComment(t *testing.T) { assert.Zero(t, noComment) } +type TestCommentUpdate struct { + HasComment int `xorm:"bigint comment('this is a comment before update')"` +} + +func (m *TestCommentUpdate) TableName() string { + return "test_comment_struct" +} + +type TestCommentUpdate2 struct { + HasComment int `xorm:"bigint comment('this is a comment after update')"` +} + +func (m *TestCommentUpdate2) TableName() string { + return "test_comment_struct" +} + +func TestColumnCommentUpdate(t *testing.T) { + comment := "this is a comment after update" + assertSync(t, new(TestCommentUpdate)) + assert.NoError(t, testEngine.Sync2(new(TestCommentUpdate2))) // modify table column comment + + switch testEngine.Dialect().URI().DBType { + case schemas.POSTGRES, schemas.MYSQL: // only postgres / mysql dialect implement the feature of modify comment in postgres.ModifyColumnSQL + default: + t.Skip() + return + } + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + tableName := "test_comment_struct" + var hasComment string + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn(testEngine.GetColumnMapper().Obj2Table("HasComment")) + assert.NotNil(t, col) + hasComment = col.Comment + break + } + } + assert.Equal(t, comment, hasComment) +} + func TestGetColumnsLength(t *testing.T) { var max_length int64 switch testEngine.Dialect().URI().DBType { diff --git a/integrations/session_schema_test.go b/integrations/schema_test.go similarity index 85% rename from integrations/session_schema_test.go rename to integrations/schema_test.go index 3212d027..149c6394 100644 --- a/integrations/session_schema_test.go +++ b/integrations/schema_test.go @@ -5,6 +5,7 @@ package integrations import ( + "errors" "fmt" "strings" "testing" @@ -325,14 +326,14 @@ func TestIsTableEmpty(t *testing.T) { type PictureEmpty struct { Id int64 - Url string `xorm:"unique"` //image's url + Url string `xorm:"unique"` // image's url Title string Description string Created time.Time `xorm:"created"` ILike int PageView int From_url string // nolint - Pre_url string `xorm:"unique"` //pre view image's url + Pre_url string `xorm:"unique"` // pre view image's url Uid int64 } @@ -458,7 +459,7 @@ func TestSync2_2(t *testing.T) { assert.NoError(t, PrepareEngine()) - var tableNames = make(map[string]bool) + tableNames := make(map[string]bool) for i := 0; i < 10; i++ { tableName := fmt.Sprintf("test_sync2_index_%d", i) tableNames[tableName] = true @@ -536,3 +537,111 @@ func TestModifyColum(t *testing.T) { _, err := testEngine.Exec(alterSQL) assert.NoError(t, err) } + +type TestCollateColumn struct { + Id int64 + UserId int64 `xorm:"unique(s)"` + Name string `xorm:"varchar(20) unique(s)"` + dbtype string `xorm:"-"` +} + +func (t TestCollateColumn) TableCollations() []*schemas.Collation { + if t.dbtype == string(schemas.MYSQL) { + return []*schemas.Collation{ + { + Name: "utf8mb4_general_ci", + Column: "name", + }, + } + } else if t.dbtype == string(schemas.MSSQL) { + return []*schemas.Collation{ + { + Name: "Latin1_General_CI_AS", + Column: "name", + }, + } + } + return nil +} + +func TestCollate(t *testing.T) { + assert.NoError(t, PrepareEngine()) + assertSync(t, &TestCollateColumn{ + dbtype: string(testEngine.Dialect().URI().DBType), + }) + + _, err := testEngine.Insert(&TestCollateColumn{ + UserId: 1, + Name: "test", + }) + assert.NoError(t, err) + _, err = testEngine.Insert(&TestCollateColumn{ + UserId: 1, + Name: "Test", + }) + if testEngine.Dialect().URI().DBType == schemas.MYSQL { + ver, err1 := testEngine.DBVersion() + assert.NoError(t, err1) + + tables, err1 := testEngine.DBMetas() + assert.NoError(t, err1) + for _, table := range tables { + if table.Name == "test_collate_column" { + col := table.GetColumn("name") + if col == nil { + assert.Error(t, errors.New("not found column")) + return + } + // tidb doesn't follow utf8mb4_general_ci + if col.Collation == "utf8mb4_general_ci" && ver.Edition != "TiDB" { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + break + } + } + } else if testEngine.Dialect().URI().DBType == schemas.MSSQL { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + // Since SQLITE don't support modify column SQL, currrently just ignore + if testEngine.Dialect().URI().DBType != schemas.MYSQL && testEngine.Dialect().URI().DBType != schemas.MSSQL { + return + } + + var newCollation string + if testEngine.Dialect().URI().DBType == schemas.MYSQL { + newCollation = "utf8mb4_bin" + } else if testEngine.Dialect().URI().DBType != schemas.MSSQL { + newCollation = "Latin1_General_CS_AS" + } else { + return + } + + alterSQL := testEngine.Dialect().ModifyColumnSQL("test_collate_column", &schemas.Column{ + Name: "name", + SQLType: schemas.SQLType{ + Name: "VARCHAR", + }, + Length: 20, + Nullable: true, + DefaultIsEmpty: true, + Collation: newCollation, + }) + _, err = testEngine.Exec(alterSQL) + assert.NoError(t, err) + + _, err = testEngine.Insert(&TestCollateColumn{ + UserId: 1, + Name: "test1", + }) + assert.NoError(t, err) + _, err = testEngine.Insert(&TestCollateColumn{ + UserId: 1, + Name: "Test1", + }) + assert.NoError(t, err) +} diff --git a/integrations/session_count_test.go b/integrations/session_count_test.go index 13d84edb..079602c3 100644 --- a/integrations/session_count_test.go +++ b/integrations/session_count_test.go @@ -89,7 +89,7 @@ func TestCountWithOthers(t *testing.T) { }) assert.NoError(t, err) - total, err := testEngine.OrderBy("`id` desc").Limit(1).Count(new(CountWithOthers)) + total, err := testEngine.OrderBy("count(`id`) desc").Limit(1).Count(new(CountWithOthers)) assert.NoError(t, err) assert.EqualValues(t, 2, total) } @@ -118,11 +118,11 @@ func TestWithTableName(t *testing.T) { }) assert.NoError(t, err) - total, err := testEngine.OrderBy("`id` desc").Count(new(CountWithTableName)) + total, err := testEngine.OrderBy("count(`id`) desc").Count(new(CountWithTableName)) assert.NoError(t, err) assert.EqualValues(t, 2, total) - total, err = testEngine.OrderBy("`id` desc").Count(CountWithTableName{}) + total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{}) assert.NoError(t, err) assert.EqualValues(t, 2, total) } diff --git a/integrations/session_find_test.go b/integrations/session_find_test.go index 5c2a4c68..65df5aee 100644 --- a/integrations/session_find_test.go +++ b/integrations/session_find_test.go @@ -12,6 +12,7 @@ import ( "xorm.io/xorm" "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" + "xorm.io/xorm/schemas" "github.com/stretchr/testify/assert" ) @@ -1196,3 +1197,43 @@ func TestUpdateFindDate(t *testing.T) { assert.EqualValues(t, 1, len(tufs)) assert.EqualValues(t, tuf.Tm.Format("2006-01-02"), tufs[0].Tm.Format("2006-01-02")) } + +func TestBuilderDialect(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestBuilderDialect struct { + Id int64 + Name string `xorm:"index"` + Age2 int + } + + type TestBuilderDialectFoo struct { + Id int64 + DialectId int64 `xorm:"index"` + Age int + } + + assertSync(t, new(TestBuilderDialect), new(TestBuilderDialectFoo)) + + session := testEngine.NewSession() + defer session.Close() + + var dialect string + switch testEngine.Dialect().URI().DBType { + case schemas.MYSQL: + dialect = builder.MYSQL + case schemas.MSSQL: + dialect = builder.MSSQL + case schemas.POSTGRES: + dialect = builder.POSTGRES + case schemas.SQLITE: + dialect = builder.SQLITE + } + + tbName := testEngine.TableName(new(TestBuilderDialectFoo), dialect == builder.POSTGRES) + + inner := builder.Dialect(dialect).Select("*").From(tbName).Where(builder.Eq{"age": 20}) + result := make([]*TestBuilderDialect, 0, 10) + err := testEngine.Table("test_builder_dialect").Where(builder.Eq{"age2": 2}).Join("INNER", inner, "test_builder_dialect_foo.dialect_id = test_builder_dialect.id").Find(&result) + assert.NoError(t, err) +} diff --git a/integrations/session_query_test.go b/integrations/session_query_test.go index b72f7ef2..00b7d7a6 100644 --- a/integrations/session_query_test.go +++ b/integrations/session_query_test.go @@ -5,11 +5,13 @@ package integrations import ( + "bytes" "strconv" "testing" "time" "xorm.io/builder" + "xorm.io/xorm/schemas" "github.com/stretchr/testify/assert" @@ -381,3 +383,68 @@ func TestQueryStringWithLimit(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 0, len(data)) } + +func TestQueryBLOBInMySQL(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + var err error + type Avatar struct { + Id int64 `xorm:"autoincr pk"` + Avatar []byte `xorm:"BLOB"` + } + + assert.NoError(t, testEngine.Sync(new(Avatar))) + testEngine.Delete(Avatar{}) + + repeatBytes := func(n int, b byte) []byte { + return bytes.Repeat([]byte{b}, n) + } + + const N = 10 + var data = []Avatar{} + for i := 0; i < N; i++ { + // allocate a []byte that is as twice big as the last one + // so that the underlying buffer will need to reallocate when querying + bs := repeatBytes(1<<(i+2), 'A'+byte(i)) + data = append(data, Avatar{ + Avatar: bs, + }) + } + _, err = testEngine.Insert(data) + assert.NoError(t, err) + defer func() { + testEngine.Delete(Avatar{}) + }() + + { + records, err := testEngine.QueryInterface("select avatar from " + testEngine.Quote(testEngine.TableName("avatar", true))) + assert.NoError(t, err) + for i, record := range records { + bs := record["avatar"].([]byte) + assert.EqualValues(t, repeatBytes(1<<(i+2), 'A'+byte(i))[:3], bs[:3]) + t.Logf("%d => %p => %02x %02x %02x", i, bs, bs[0], bs[1], bs[2]) + } + } + + { + arr := make([][]interface{}, 0) + err = testEngine.Table(testEngine.Quote(testEngine.TableName("avatar", true))).Cols("avatar").Find(&arr) + assert.NoError(t, err) + for i, record := range arr { + bs := record[0].([]byte) + assert.EqualValues(t, repeatBytes(1<<(i+2), 'A'+byte(i))[:3], bs[:3]) + t.Logf("%d => %p => %02x %02x %02x", i, bs, bs[0], bs[1], bs[2]) + } + } + + { + arr := make([]map[string]interface{}, 0) + err = testEngine.Table(testEngine.Quote(testEngine.TableName("avatar", true))).Cols("avatar").Find(&arr) + assert.NoError(t, err) + for i, record := range arr { + bs := record["avatar"].([]byte) + assert.EqualValues(t, repeatBytes(1<<(i+2), 'A'+byte(i))[:3], bs[:3]) + t.Logf("%d => %p => %02x %02x %02x", i, bs, bs[0], bs[1], bs[2]) + } + } +} diff --git a/integrations/session_tx_test.go b/integrations/session_tx_test.go index 8d6519d0..890e755d 100644 --- a/integrations/session_tx_test.go +++ b/integrations/session_tx_test.go @@ -185,3 +185,36 @@ func TestMultipleTransaction(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 0, len(ms)) } + +func TestInsertMulti2InterfaceTransaction(t *testing.T) { + + type Multi2InterfaceTransaction struct { + ID uint64 `xorm:"id pk autoincr"` + Name string + Alias string + CreateTime time.Time `xorm:"created"` + UpdateTime time.Time `xorm:"updated"` + } + assert.NoError(t, PrepareEngine()) + assertSync(t, new(Multi2InterfaceTransaction)) + session := testEngine.NewSession() + defer session.Close() + err := session.Begin() + assert.NoError(t, err) + + users := []interface{}{ + &Multi2InterfaceTransaction{Name: "a", Alias: "A"}, + &Multi2InterfaceTransaction{Name: "b", Alias: "B"}, + &Multi2InterfaceTransaction{Name: "c", Alias: "C"}, + &Multi2InterfaceTransaction{Name: "d", Alias: "D"}, + } + cnt, err := session.Insert(&users) + + assert.NoError(t, err) + assert.EqualValues(t, len(users), cnt) + + assert.NotPanics(t, func() { + err = session.Commit() + assert.NoError(t, err) + }) +} diff --git a/internal/statements/cache.go b/internal/statements/cache.go index 669cd018..9dd76754 100644 --- a/internal/statements/cache.go +++ b/internal/statements/cache.go @@ -6,6 +6,7 @@ package statements import ( "fmt" + "strconv" "strings" "xorm.io/xorm/internal/utils" @@ -26,14 +27,19 @@ func (statement *Statement) ConvertIDSQL(sqlStr string) string { return "" } - var top string + var b strings.Builder + b.WriteString("SELECT ") pLimitN := statement.LimitN if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL { - top = fmt.Sprintf("TOP %d ", *pLimitN) + b.WriteString("TOP ") + b.WriteString(strconv.Itoa(*pLimitN)) + b.WriteString(" ") } + b.WriteString(colstrs) + b.WriteString(" FROM ") + b.WriteString(sqls[1]) - newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) - return newsql + return b.String() } return "" } @@ -54,7 +60,7 @@ func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) { return "", "" } - var whereStr = sqls[1] + whereStr := sqls[1] // TODO: for postgres only, if any other database? var paraStr string diff --git a/internal/statements/join.go b/internal/statements/join.go index adf349e7..61a1b4de 100644 --- a/internal/statements/join.go +++ b/internal/statements/join.go @@ -15,82 +15,97 @@ import ( ) // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (statement *Statement) Join(joinOP string, tablename interface{}, condition interface{}, args ...interface{}) *Statement { - var buf strings.Builder - if len(statement.JoinStr) > 0 { - fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) - } else { - fmt.Fprintf(&buf, "%v JOIN ", joinOP) - } - - condStr := "" - condArgs := []interface{}{} - switch condTp := condition.(type) { - case string: - condStr = condTp - case builder.Cond: - var err error - condStr, condArgs, err = builder.ToSQL(condTp) - if err != nil { - statement.LastError = err - return statement - } - default: - statement.LastError = fmt.Errorf("unsupported join condition type: %v", condTp) - return statement - } - - switch tp := tablename.(type) { - case builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.LastError = err - return statement - } - - fields := strings.Split(tp.TableName(), ".") - aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) - aliasName = schemas.CommonQuoter.Trim(aliasName) - - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr)) - statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...) - case *builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.LastError = err - return statement - } - - fields := strings.Split(tp.TableName(), ".") - aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) - aliasName = schemas.CommonQuoter.Trim(aliasName) - - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr)) - statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...) - default: - tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) - if !utils.IsSubQuery(tbName) { - var buf strings.Builder - _ = statement.dialect.Quoter().QuoteTo(&buf, tbName) - tbName = buf.String() - } else { - tbName = statement.ReplaceQuote(tbName) - } - fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condStr)) - statement.joinArgs = append(statement.joinArgs, condArgs...) - } - - statement.JoinStr = buf.String() - statement.joinArgs = append(statement.joinArgs, args...) +func (statement *Statement) Join(joinOP string, joinTable interface{}, condition interface{}, args ...interface{}) *Statement { + statement.joins = append(statement.joins, join{ + op: joinOP, + table: joinTable, + condition: condition, + args: args, + }) return statement } -func (statement *Statement) writeJoin(w builder.Writer) error { - if statement.JoinStr != "" { - if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil { +func (statement *Statement) writeJoins(w *builder.BytesWriter) error { + for _, join := range statement.joins { + if err := statement.writeJoin(w, join); err != nil { return err } - w.Append(statement.joinArgs...) } return nil } + +func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error { + // write join operator + if _, err := fmt.Fprintf(buf, " %v JOIN", join.op); err != nil { + return err + } + + // write join table or subquery + switch tp := join.table.(type) { + case builder.Builder: + if _, err := fmt.Fprintf(buf, " ("); err != nil { + return err + } + if err := tp.WriteTo(statement.QuoteReplacer(buf)); err != nil { + return err + } + + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + + if _, err := fmt.Fprintf(buf, ") %s", statement.quote(aliasName)); err != nil { + return err + } + case *builder.Builder: + if _, err := fmt.Fprintf(buf, " ("); err != nil { + return err + } + if err := tp.WriteTo(statement.QuoteReplacer(buf)); err != nil { + return err + } + + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + + if _, err := fmt.Fprintf(buf, ") %s", statement.quote(aliasName)); err != nil { + return err + } + default: + tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), join.table, true) + if !utils.IsSubQuery(tbName) { + var sb strings.Builder + if err := statement.dialect.Quoter().QuoteTo(&sb, tbName); err != nil { + return err + } + tbName = sb.String() + } else { + tbName = statement.ReplaceQuote(tbName) + } + if _, err := fmt.Fprint(buf, " ", tbName); err != nil { + return err + } + } + + // write on condition + if _, err := fmt.Fprint(buf, " ON "); err != nil { + return err + } + + switch condTp := join.condition.(type) { + case string: + if _, err := fmt.Fprint(buf, statement.ReplaceQuote(condTp)); err != nil { + return err + } + case builder.Cond: + if err := condTp.WriteTo(statement.QuoteReplacer(buf)); err != nil { + return err + } + default: + return fmt.Errorf("unsupported join condition type: %v", condTp) + } + buf.Append(join.args...) + + return nil +} diff --git a/internal/statements/query.go b/internal/statements/query.go index f72c8602..cea8be6d 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -7,6 +7,7 @@ package statements import ( "errors" "fmt" + "io" "reflect" "strings" @@ -29,37 +30,15 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int return "", nil, ErrTableNotFound } - columnStr := statement.ColumnStr() - if len(statement.SelectStr) > 0 { - columnStr = statement.SelectStr - } else { - if statement.JoinStr == "" { - if columnStr == "" { - if statement.GroupByStr != "" { - columnStr = statement.quoteColumnStr(statement.GroupByStr) - } else { - columnStr = statement.genColumnStr() - } - } - } else { - if columnStr == "" { - if statement.GroupByStr != "" { - columnStr = statement.quoteColumnStr(statement.GroupByStr) - } else { - columnStr = "*" - } - } - } - if columnStr == "" { - columnStr = "*" - } - } - if err := statement.ProcessIDParam(); err != nil { return "", nil, err } - return statement.genSelectSQL(columnStr, true, true) + buf := builder.NewWriter() + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { + return "", nil, err + } + return buf.String(), buf.Args(), nil } // GenSumSQL generates sum SQL @@ -81,13 +60,16 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri } sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) } - sumSelect := strings.Join(sumStrs, ", ") if err := statement.MergeConds(bean); err != nil { return "", nil, err } - return statement.genSelectSQL(sumSelect, true, true) + buf := builder.NewWriter() + if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true); err != nil { + return "", nil, err + } + return buf.String(), buf.Args(), nil } // GenGetSQL generates Get SQL @@ -108,7 +90,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, columnStr = statement.SelectStr } else { // TODO: always generate column names, not use * even if join - if len(statement.JoinStr) == 0 { + if len(statement.joins) == 0 { if len(columnStr) == 0 { if len(statement.GroupByStr) > 0 { columnStr = statement.quoteColumnStr(statement.GroupByStr) @@ -139,7 +121,11 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, } } - return statement.genSelectSQL(columnStr, true, true) + buf := builder.NewWriter() + if err := statement.writeSelect(buf, columnStr, true); err != nil { + return "", nil, err + } + return buf.String(), buf.Args(), nil } // GenCountSQL generates the SQL for counting @@ -148,8 +134,6 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa return statement.GenRawSQL(), statement.RawParams, nil } - var condArgs []interface{} - var err error if len(beans) > 0 { if err := statement.SetRefBean(beans[0]); err != nil { return "", nil, err @@ -176,19 +160,27 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa subQuerySelect = selectSQL } - sqlStr, condArgs, err := statement.genSelectSQL(subQuerySelect, false, false) - if err != nil { + buf := builder.NewWriter() + if statement.GroupByStr != "" { + if _, err := fmt.Fprintf(buf, "SELECT %s FROM (", selectSQL); err != nil { + return "", nil, err + } + } + + if err := statement.writeSelect(buf, subQuerySelect, false); err != nil { return "", nil, err } if statement.GroupByStr != "" { - sqlStr = fmt.Sprintf("SELECT %s FROM (%s) sub", selectSQL, sqlStr) + if _, err := fmt.Fprintf(buf, ") sub"); err != nil { + return "", nil, err + } } - return sqlStr, condArgs, nil + return buf.String(), buf.Args(), nil } -func (statement *Statement) writeFrom(w builder.Writer) error { +func (statement *Statement) writeFrom(w *builder.BytesWriter) error { if _, err := fmt.Fprint(w, " FROM "); err != nil { return err } @@ -198,7 +190,7 @@ func (statement *Statement) writeFrom(w builder.Writer) error { if err := statement.writeAlias(w); err != nil { return err } - return statement.writeJoin(w) + return statement.writeJoins(w) } func (statement *Statement) writeLimitOffset(w builder.Writer) error { @@ -218,153 +210,183 @@ func (statement *Statement) writeLimitOffset(w builder.Writer) error { return nil } -func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { - var ( - distinct string - dialect = statement.dialect - top, whereStr string - mssqlCondi = builder.NewWriter() - ) +func (statement *Statement) writeTop(w builder.Writer) error { + if statement.dialect.URI().DBType != schemas.MSSQL { + return nil + } + if statement.LimitN == nil { + return nil + } + _, err := fmt.Fprintf(w, " TOP %d", *statement.LimitN) + return err +} - if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { - distinct = "DISTINCT " +func (statement *Statement) writeDistinct(w builder.Writer) error { + if statement.IsDistinct && !strings.HasPrefix(statement.SelectStr, "count(") { + _, err := fmt.Fprint(w, " DISTINCT") + return err + } + return nil +} + +func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr string) error { + if _, err := fmt.Fprintf(w, "SELECT "); err != nil { + return err + } + if err := statement.writeDistinct(w); err != nil { + return err + } + if err := statement.writeTop(w); err != nil { + return err + } + _, err := fmt.Fprint(w, " ", columnStr) + return err +} + +func (statement *Statement) writeWhere(w *builder.BytesWriter) error { + if !statement.cond.IsValid() { + return statement.writeMssqlPaginationCond(w) + } + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil { + return err + } + return statement.writeMssqlPaginationCond(w) +} + +func (statement *Statement) writeForUpdate(w io.Writer) error { + if !statement.IsForUpdate { + return nil } - condWriter := builder.NewWriter() - if err := statement.cond.WriteTo(statement.QuoteReplacer(condWriter)); err != nil { - return "", nil, err + if statement.dialect.URI().DBType != schemas.MYSQL { + return errors.New("only support mysql for update") + } + _, err := fmt.Fprint(w, " FOR UPDATE") + return err +} + +func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) error { + if statement.dialect.URI().DBType != schemas.MSSQL || statement.Start <= 0 { + return nil } - if condWriter.Len() > 0 { - whereStr = " WHERE " + if statement.RefTable == nil { + return errors.New("unsupported query limit without reference table") } - pLimitN := statement.LimitN - if dialect.URI().DBType == schemas.MSSQL { - if pLimitN != nil { - LimitNValue := *pLimitN - top = fmt.Sprintf("TOP %d ", LimitNValue) + var column string + if len(statement.RefTable.PKColumns()) == 0 { + for _, index := range statement.RefTable.Indexes { + if len(index.Cols) == 1 { + column = index.Cols[0] + break + } } - if statement.Start > 0 { - if statement.RefTable == nil { - return "", nil, errors.New("Unsupported query limit without reference table") - } - var column string - if len(statement.RefTable.PKColumns()) == 0 { - for _, index := range statement.RefTable.Indexes { - if len(index.Cols) == 1 { - column = index.Cols[0] - break - } - } - if len(column) == 0 { - column = statement.RefTable.ColumnsSeq()[0] - } - } else { - column = statement.RefTable.PKColumns()[0].Name - } - if statement.needTableName() { - if len(statement.TableAlias) > 0 { - column = fmt.Sprintf("%s.%s", statement.TableAlias, column) - } else { - column = fmt.Sprintf("%s.%s", statement.TableName(), column) - } - } - - if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s", - column, statement.Start, column); err != nil { - return "", nil, err - } - if err := statement.writeFrom(mssqlCondi); err != nil { - return "", nil, err - } - if whereStr != "" { - if _, err := fmt.Fprint(mssqlCondi, whereStr); err != nil { - return "", nil, err - } - if err := utils.WriteBuilder(mssqlCondi, statement.QuoteReplacer(condWriter)); err != nil { - return "", nil, err - } - } - if needOrderBy { - if err := statement.WriteOrderBy(mssqlCondi); err != nil { - return "", nil, err - } - } - if err := statement.WriteGroupBy(mssqlCondi); err != nil { - return "", nil, err - } - if _, err := fmt.Fprint(mssqlCondi, "))"); err != nil { - return "", nil, err - } + if len(column) == 0 { + column = statement.RefTable.ColumnsSeq()[0] + } + } else { + column = statement.RefTable.PKColumns()[0].Name + } + if statement.NeedTableName() { + if len(statement.TableAlias) > 0 { + column = fmt.Sprintf("%s.%s", statement.TableAlias, column) + } else { + column = fmt.Sprintf("%s.%s", statement.TableName(), column) } } - buf := builder.NewWriter() - if _, err := fmt.Fprintf(buf, "SELECT %v%v%v", distinct, top, columnStr); err != nil { - return "", nil, err + subWriter := builder.NewWriter() + if _, err := fmt.Fprintf(subWriter, "(%s NOT IN (SELECT TOP %d %s", + column, statement.Start, column); err != nil { + return err + } + if err := statement.writeFrom(subWriter); err != nil { + return err + } + if statement.cond.IsValid() { + if _, err := fmt.Fprint(subWriter, " WHERE "); err != nil { + return err + } + if err := statement.cond.WriteTo(statement.QuoteReplacer(subWriter)); err != nil { + return err + } + } + if err := statement.WriteOrderBy(subWriter); err != nil { + return err + } + if err := statement.writeGroupBy(subWriter); err != nil { + return err + } + if _, err := fmt.Fprint(subWriter, "))"); err != nil { + return err + } + + if statement.cond.IsValid() { + if _, err := fmt.Fprint(w, " AND "); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + } + + return utils.WriteBuilder(w, subWriter) +} + +func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr string) error { + if statement.LimitN == nil { + return nil + } + + oldString := w.String() + w.Reset() + rawColStr := columnStr + if rawColStr == "*" { + rawColStr = "at.*" + } + _, err := fmt.Fprintf(w, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", + columnStr, rawColStr, oldString, statement.Start+*statement.LimitN, statement.Start) + return err +} + +func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error { + if err := statement.writeSelectColumns(buf, columnStr); err != nil { + return err } if err := statement.writeFrom(buf); err != nil { - return "", nil, err + return err } - if whereStr != "" { - if _, err := fmt.Fprint(buf, whereStr); err != nil { - return "", nil, err - } - if err := utils.WriteBuilder(buf, statement.QuoteReplacer(condWriter)); err != nil { - return "", nil, err - } + if err := statement.writeWhere(buf); err != nil { + return err } - if mssqlCondi.Len() > 0 { - if len(whereStr) > 0 { - if _, err := fmt.Fprint(buf, " AND "); err != nil { - return "", nil, err - } - } else { - if _, err := fmt.Fprint(buf, " WHERE "); err != nil { - return "", nil, err - } - } - - if err := utils.WriteBuilder(buf, mssqlCondi); err != nil { - return "", nil, err - } - } - - if err := statement.WriteGroupBy(buf); err != nil { - return "", nil, err + if err := statement.writeGroupBy(buf); err != nil { + return err } if err := statement.writeHaving(buf); err != nil { - return "", nil, err + return err } - if needOrderBy { - if err := statement.WriteOrderBy(buf); err != nil { - return "", nil, err - } - } - if needLimit { - if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { - if err := statement.writeLimitOffset(buf); err != nil { - return "", nil, err - } - } else if dialect.URI().DBType == schemas.ORACLE { - if pLimitN != nil { - oldString := buf.String() - buf.Reset() - rawColStr := columnStr - if rawColStr == "*" { - rawColStr = "at.*" - } - fmt.Fprintf(buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", - columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start) - } - } - } - if statement.IsForUpdate { - return dialect.ForUpdateSQL(buf.String()), buf.Args(), nil + if err := statement.WriteOrderBy(buf); err != nil { + return err } - return buf.String(), buf.Args(), nil + dialect := statement.dialect + if needLimit { + if dialect.URI().DBType == schemas.ORACLE { + if err := statement.writeOracleLimit(buf, columnStr); err != nil { + return err + } + } else if dialect.URI().DBType != schemas.MSSQL { + if err := statement.writeLimitOffset(buf); err != nil { + return err + } + } + } + return statement.writeForUpdate(buf) } // GenExistSQL generates Exist SQL @@ -402,7 +424,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil { return "", nil, err } - if err := statement.writeJoin(buf); err != nil { + if err := statement.writeJoins(buf); err != nil { return "", nil, err } if statement.Conds().IsValid() { @@ -417,7 +439,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil { return "", nil, err } - if err := statement.writeJoin(buf); err != nil { + if err := statement.writeJoins(buf); err != nil { return "", nil, err } if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { @@ -438,7 +460,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil { return "", nil, err } - if err := statement.writeJoin(buf); err != nil { + if err := statement.writeJoins(buf); err != nil { return "", nil, err } if statement.Conds().IsValid() { @@ -457,6 +479,33 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac return buf.String(), buf.Args(), nil } +func (statement *Statement) genSelectColumnStr() string { + // manually select columns + if len(statement.SelectStr) > 0 { + return statement.SelectStr + } + + columnStr := statement.ColumnStr() + if columnStr != "" { + return columnStr + } + + // autodetect columns + if statement.GroupByStr != "" { + return statement.quoteColumnStr(statement.GroupByStr) + } + + if len(statement.joins) != 0 { + return "*" + } + + columnStr = statement.genColumnStr() + if columnStr == "" { + columnStr = "*" + } + return columnStr +} + // GenFindSQL generates Find SQL func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) { if statement.RawSQL != "" { @@ -467,33 +516,11 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa return "", nil, ErrTableNotFound } - columnStr := statement.ColumnStr() - if len(statement.SelectStr) > 0 { - columnStr = statement.SelectStr - } else { - if statement.JoinStr == "" { - if columnStr == "" { - if statement.GroupByStr != "" { - columnStr = statement.quoteColumnStr(statement.GroupByStr) - } else { - columnStr = statement.genColumnStr() - } - } - } else { - if columnStr == "" { - if statement.GroupByStr != "" { - columnStr = statement.quoteColumnStr(statement.GroupByStr) - } else { - columnStr = "*" - } - } - } - if columnStr == "" { - columnStr = "*" - } - } - statement.cond = statement.cond.And(autoCond) - return statement.genSelectSQL(columnStr, true, true) + buf := builder.NewWriter() + if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { + return "", nil, err + } + return buf.String(), buf.Args(), nil } diff --git a/internal/statements/select.go b/internal/statements/select.go index 2bd2e94d..59161d76 100644 --- a/internal/statements/select.go +++ b/internal/statements/select.go @@ -102,7 +102,7 @@ func (statement *Statement) genColumnStr() string { buf.WriteString(", ") } - if statement.JoinStr != "" { + if len(statement.joins) > 0 { if statement.TableAlias != "" { buf.WriteString(statement.TableAlias) } else { @@ -119,7 +119,7 @@ func (statement *Statement) genColumnStr() string { } func (statement *Statement) colName(col *schemas.Column, tableName string) string { - if statement.needTableName() { + if statement.NeedTableName() { nm := tableName if len(statement.TableAlias) > 0 { nm = statement.TableAlias diff --git a/internal/statements/statement.go b/internal/statements/statement.go index a8fe34fa..ae38ca27 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -34,6 +34,13 @@ var ( ErrTableNotFound = errors.New("Table not found") ) +type join struct { + op string + table interface{} + condition interface{} + args []interface{} +} + // Statement save all the sql info for executing SQL type Statement struct { RefTable *schemas.Table @@ -45,8 +52,7 @@ type Statement struct { idParam schemas.PK orderStr string orderArgs []interface{} - JoinStr string - joinArgs []interface{} + joins []join GroupByStr string HavingStr string SelectStr string @@ -123,8 +129,7 @@ func (statement *Statement) Reset() { statement.LimitN = nil statement.ResetOrderBy() statement.UseCascade = true - statement.JoinStr = "" - statement.joinArgs = make([]interface{}, 0) + statement.joins = nil statement.GroupByStr = "" statement.HavingStr = "" statement.ColumnMap = columnMap{} @@ -205,8 +210,8 @@ func (statement *Statement) SetRefBean(bean interface{}) error { return nil } -func (statement *Statement) needTableName() bool { - return len(statement.JoinStr) > 0 +func (statement *Statement) NeedTableName() bool { + return len(statement.joins) > 0 } // Incr Generate "Update ... Set column = column + arg" statement @@ -290,7 +295,7 @@ func (statement *Statement) GroupBy(keys string) *Statement { return statement } -func (statement *Statement) WriteGroupBy(w builder.Writer) error { +func (statement *Statement) writeGroupBy(w builder.Writer) error { if statement.GroupByStr == "" { return nil } @@ -605,7 +610,7 @@ func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, i // MergeConds merge conditions from bean and id func (statement *Statement) MergeConds(bean interface{}) error { if !statement.NoAutoCondition && statement.RefTable != nil { - addedTableName := (len(statement.JoinStr) > 0) + addedTableName := (len(statement.joins) > 0) autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) if err != nil { return err @@ -673,7 +678,7 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName // CondDeleted returns the conditions whether a record is soft deleted. func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { colName := statement.quote(col.Name) - if statement.JoinStr != "" { + if len(statement.joins) > 0 { var prefix string if statement.TableAlias != "" { prefix = statement.TableAlias diff --git a/internal/statements/update.go b/internal/statements/update.go index 40159e0c..4dc54780 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -19,7 +19,8 @@ import ( ) func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion, includeUpdated, includeNil, - includeAutoIncr, update bool) (bool, error) { + includeAutoIncr, update bool, +) (bool, error) { columnMap := statement.ColumnMap omitColumnMap := statement.OmitColumnMap unscoped := statement.unscoped @@ -64,15 +65,16 @@ func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion, // BuildUpdates auto generating update columnes and values according a struct func (statement *Statement) BuildUpdates(tableValue reflect.Value, includeVersion, includeUpdated, includeNil, - includeAutoIncr, update bool) ([]string, []interface{}, error) { + includeAutoIncr, update bool, +) ([]string, []interface{}, error) { table := statement.RefTable allUseBool := statement.allUseBool useAllCols := statement.useAllCols mustColumnMap := statement.MustColumnMap nullableMap := statement.NullableMap - var colNames = make([]string, 0) - var args = make([]interface{}, 0) + colNames := make([]string, 0) + args := make([]interface{}, 0) for _, col := range table.Columns() { ok, err := statement.ifAddColUpdate(col, includeVersion, includeUpdated, includeNil, diff --git a/rows.go b/rows.go index 90a851a3..a42eedb9 100644 --- a/rows.go +++ b/rows.go @@ -46,8 +46,8 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { if rows.session.statement.RawSQL == "" { var autoCond builder.Cond - var addedTableName = (len(session.statement.JoinStr) > 0) - var table = rows.session.statement.RefTable + addedTableName := session.statement.NeedTableName() + table := rows.session.statement.RefTable if !session.statement.NoAutoCondition { var err error @@ -103,12 +103,12 @@ func (rows *Rows) Scan(beans ...interface{}) error { return rows.Err() } - var bean = beans[0] - var tp = reflect.TypeOf(bean) + bean := beans[0] + tp := reflect.TypeOf(bean) if tp.Kind() == reflect.Ptr { tp = tp.Elem() } - var beanKind = tp.Kind() + beanKind := tp.Kind() if len(beans) == 1 { if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { diff --git a/schemas/collation.go b/schemas/collation.go new file mode 100644 index 00000000..acec5268 --- /dev/null +++ b/schemas/collation.go @@ -0,0 +1,10 @@ +// Copyright 2023 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +type Collation struct { + Name string + Column string // blank means it's a table collation +} diff --git a/schemas/column.go b/schemas/column.go index 001769cd..08d34b91 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -45,6 +45,7 @@ type Column struct { DisableTimeZone bool TimeZone *time.Location // column specified time zone Comment string + Collation string } // NewColumn creates a new column @@ -89,6 +90,8 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { v.Set(reflect.New(v.Type().Elem())) } v = v.Elem() + } else if v.Kind() == reflect.Interface { + v = reflect.Indirect(v.Elem()) } v = v.FieldByIndex([]int{i}) } diff --git a/schemas/table.go b/schemas/table.go index 91b33e06..5c38cc70 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -27,6 +27,7 @@ type Table struct { StoreEngine string Charset string Comment string + Collation string } // NewEmptyTable creates an empty table @@ -36,7 +37,8 @@ func NewEmptyTable() *Table { // NewTable creates a new Table object func NewTable(name string, t reflect.Type) *Table { - return &Table{Name: name, Type: t, + return &Table{ + Name: name, Type: t, columnsSeq: make([]string, 0), columns: make([]*Column, 0), columnsMap: make(map[string][]*Column), diff --git a/session.go b/session.go index 5e3d0ac9..14d0781e 100644 --- a/session.go +++ b/session.go @@ -352,7 +352,7 @@ func (session *Session) DB() *core.DB { func (session *Session) canCache() bool { if session.statement.RefTable == nil || - session.statement.JoinStr != "" || + session.statement.NeedTableName() || session.statement.RawSQL != "" || !session.statement.UseCache || session.statement.IsForUpdate || diff --git a/session_delete.go b/session_delete.go index d36b9e52..6c63d9b0 100644 --- a/session_delete.go +++ b/session_delete.go @@ -30,7 +30,7 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr) + sqlStr = filter.Do(session.ctx, sqlStr) } newsql := session.statement.ConvertIDSQL(sqlStr) diff --git a/session_find.go b/session_find.go index 548d9127..1026910c 100644 --- a/session_find.go +++ b/session_find.go @@ -116,7 +116,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) var ( table = session.statement.RefTable - addedTableName = (len(session.statement.JoinStr) > 0) + addedTableName = session.statement.NeedTableName() autoCond builder.Cond ) if tp == tpStruct { @@ -346,7 +346,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr) + sqlStr = filter.Do(session.ctx, sqlStr) } newsql := session.statement.ConvertIDSQL(sqlStr) diff --git a/session_get.go b/session_get.go index 607bc5f0..0d590330 100644 --- a/session_get.go +++ b/session_get.go @@ -280,7 +280,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr) + sqlStr = filter.Do(session.ctx, sqlStr) } newsql := session.statement.ConvertIDSQL(sqlStr) if newsql == "" { diff --git a/session_insert.go b/session_insert.go index fc025613..cfa26d39 100644 --- a/session_insert.go +++ b/session_insert.go @@ -353,15 +353,15 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { if err != nil { return 0, err } - if needCommit { - if err := session.Commit(); err != nil { - return 0, err - } - } - if id == 0 { - return 0, errors.New("insert successfully but not returned id") + } + if needCommit { + if err := session.Commit(); err != nil { + return 0, err } } + if id == 0 { + return 0, errors.New("insert successfully but not returned id") + } defer handleAfterInsertProcessorFunc(bean) diff --git a/session_raw.go b/session_raw.go index add584d0..99f6be99 100644 --- a/session_raw.go +++ b/session_raw.go @@ -13,7 +13,7 @@ import ( func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { for _, filter := range session.engine.dialect.Filters() { - *sqlStr = filter.Do(*sqlStr) + *sqlStr = filter.Do(session.ctx, *sqlStr) } session.lastSQL = *sqlStr diff --git a/session_schema.go b/session_schema.go index e66c3b42..830ba08a 100644 --- a/session_schema.go +++ b/session_schema.go @@ -15,7 +15,6 @@ import ( "xorm.io/xorm/dialects" "xorm.io/xorm/internal/utils" - "xorm.io/xorm/schemas" ) // Ping test if database is ok @@ -169,7 +168,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { return nil } - var seqName = utils.SeqName(tableName) + seqName := utils.SeqName(tableName) exist, err := session.engine.dialect.IsSequenceExist(session.ctx, session.getQueryer(), seqName) if err != nil { return err @@ -244,228 +243,6 @@ func (session *Session) addUnique(tableName, uqeName string) error { return err } -// Sync2 synchronize structs to database tables -// Depricated -func (session *Session) Sync2(beans ...interface{}) error { - return session.Sync(beans...) -} - -// Sync synchronize structs to database tables -func (session *Session) Sync(beans ...interface{}) error { - engine := session.engine - - if session.isAutoClose { - session.isAutoClose = false - defer session.Close() - } - - tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx) - if err != nil { - return err - } - - session.autoResetStatement = false - defer func() { - session.autoResetStatement = true - session.resetStatement() - }() - - for _, bean := range beans { - v := utils.ReflectValue(bean) - table, err := engine.tagParser.ParseWithCache(v) - if err != nil { - return err - } - var tbName string - if len(session.statement.AltTableName) > 0 { - tbName = session.statement.AltTableName - } else { - tbName = engine.TableName(bean) - } - tbNameWithSchema := engine.tbNameWithSchema(tbName) - - var oriTable *schemas.Table - for _, tb := range tables { - if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) { - oriTable = tb - break - } - } - - // this is a new table - if oriTable == nil { - err = session.StoreEngine(session.statement.StoreEngine).createTable(bean) - if err != nil { - return err - } - - err = session.createUniques(bean) - if err != nil { - return err - } - - err = session.createIndexes(bean) - if err != nil { - return err - } - continue - } - - // this will modify an old table - if err = engine.loadTableInfo(oriTable); err != nil { - return err - } - - // check columns - for _, col := range table.Columns() { - var oriCol *schemas.Column - for _, col2 := range oriTable.Columns() { - if strings.EqualFold(col.Name, col2.Name) { - oriCol = col2 - break - } - } - - // column is not exist on table - if oriCol == nil { - session.statement.RefTable = table - session.statement.SetTableName(tbNameWithSchema) - if err = session.addColumn(col.Name); err != nil { - return err - } - continue - } - - err = nil - expectedType := engine.dialect.SQLType(col) - curType := engine.dialect.SQLType(oriCol) - if expectedType != curType { - if expectedType == schemas.Text && - strings.HasPrefix(curType, schemas.Varchar) { - // currently only support mysql & postgres - if engine.dialect.URI().DBType == schemas.MYSQL || - engine.dialect.URI().DBType == schemas.POSTGRES { - engine.logger.Infof("Table %s column %s change type from %s to %s\n", - tbNameWithSchema, col.Name, curType, expectedType) - _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) - } else { - engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", - tbNameWithSchema, col.Name, curType, expectedType) - } - } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) { - if engine.dialect.URI().DBType == schemas.MYSQL { - if oriCol.Length < col.Length { - engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", - tbNameWithSchema, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) - } - } - } else { - if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { - if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) { - engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", - tbNameWithSchema, col.Name, curType, expectedType) - } - } - } - } else if expectedType == schemas.Varchar { - if engine.dialect.URI().DBType == schemas.MYSQL { - if oriCol.Length < col.Length { - engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", - tbNameWithSchema, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) - } - } - } else if col.Comment != oriCol.Comment { - _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) - } - - if col.Default != oriCol.Default { - switch { - case col.IsAutoIncrement: // For autoincrement column, don't check default - case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) && - ((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") || - (strings.EqualFold(col.Default, "false") && oriCol.Default == "0")): - default: - engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s", - tbName, col.Name, oriCol.Default, col.Default) - } - } - if col.Nullable != oriCol.Nullable { - engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v", - tbName, col.Name, oriCol.Nullable, col.Nullable) - } - - if err != nil { - return err - } - } - - var foundIndexNames = make(map[string]bool) - var addedNames = make(map[string]*schemas.Index) - - for name, index := range table.Indexes { - var oriIndex *schemas.Index - for name2, index2 := range oriTable.Indexes { - if index.Equal(index2) { - oriIndex = index2 - foundIndexNames[name2] = true - break - } - } - - if oriIndex != nil { - if oriIndex.Type != index.Type { - sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex) - _, err = session.exec(sql) - if err != nil { - return err - } - oriIndex = nil - } - } - - if oriIndex == nil { - addedNames[name] = index - } - } - - for name2, index2 := range oriTable.Indexes { - if _, ok := foundIndexNames[name2]; !ok { - sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2) - _, err = session.exec(sql) - if err != nil { - return err - } - } - } - - for name, index := range addedNames { - if index.Type == schemas.UniqueType { - session.statement.RefTable = table - session.statement.SetTableName(tbNameWithSchema) - err = session.addUnique(tbNameWithSchema, name) - } else if index.Type == schemas.IndexType { - session.statement.RefTable = table - session.statement.SetTableName(tbNameWithSchema) - err = session.addIndex(tbNameWithSchema, name) - } - if err != nil { - return err - } - } - - // check all the columns which removed from struct fields but left on database tables. - for _, colName := range oriTable.ColumnsSeq() { - if table.GetColumn(colName) == nil { - engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName) - } - } - } - - return nil -} - // ImportFile SQL DDL file func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) { file, err := os.Open(ddlPath) @@ -490,7 +267,7 @@ func (session *Session) Import(r io.Reader) ([]sql.Result, error) { if atEOF && len(data) == 0 { return 0, nil, nil } - var oriInSingleQuote = inSingleQuote + oriInSingleQuote := inSingleQuote for i, b := range data { if startComment { if b == '\n' { diff --git a/session_update.go b/session_update.go index e7104710..1f80e70f 100644 --- a/session_update.go +++ b/session_update.go @@ -34,7 +34,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri return ErrCacheFailed } for _, filter := range session.engine.dialect.Filters() { - newsql = filter.Do(newsql) + newsql = filter.Do(session.ctx, newsql) } session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql) diff --git a/sync.go b/sync.go new file mode 100644 index 00000000..11e75404 --- /dev/null +++ b/sync.go @@ -0,0 +1,276 @@ +// Copyright 2023 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "strings" + + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +type SyncOptions struct { + WarnIfDatabaseColumnMissed bool +} + +type SyncResult struct{} + +// Sync the new struct changes to database, this method will automatically add +// table, column, index, unique. but will not delete or change anything. +// If you change some field, you should change the database manually. +func (engine *Engine) Sync(beans ...interface{}) error { + session := engine.NewSession() + defer session.Close() + return session.Sync(beans...) +} + +// SyncWithOptions sync the database schemas according options and table structs +func (engine *Engine) SyncWithOptions(opts SyncOptions, beans ...interface{}) (*SyncResult, error) { + session := engine.NewSession() + defer session.Close() + return session.SyncWithOptions(opts, beans...) +} + +// Sync2 synchronize structs to database tables +// Depricated +func (engine *Engine) Sync2(beans ...interface{}) error { + return engine.Sync(beans...) +} + +// Sync2 synchronize structs to database tables +// Depricated +func (session *Session) Sync2(beans ...interface{}) error { + return session.Sync(beans...) +} + +// Sync synchronize structs to database tables +func (session *Session) Sync(beans ...interface{}) error { + _, err := session.SyncWithOptions(SyncOptions{ + WarnIfDatabaseColumnMissed: false, + }, beans...) + return err +} + +func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) (*SyncResult, error) { + engine := session.engine + + if session.isAutoClose { + session.isAutoClose = false + defer session.Close() + } + + tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx) + if err != nil { + return nil, err + } + + session.autoResetStatement = false + defer func() { + session.autoResetStatement = true + session.resetStatement() + }() + + var syncResult SyncResult + + for _, bean := range beans { + v := utils.ReflectValue(bean) + table, err := engine.tagParser.ParseWithCache(v) + if err != nil { + return nil, err + } + var tbName string + if len(session.statement.AltTableName) > 0 { + tbName = session.statement.AltTableName + } else { + tbName = engine.TableName(bean) + } + tbNameWithSchema := engine.tbNameWithSchema(tbName) + + var oriTable *schemas.Table + for _, tb := range tables { + if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) { + oriTable = tb + break + } + } + + // this is a new table + if oriTable == nil { + err = session.StoreEngine(session.statement.StoreEngine).createTable(bean) + if err != nil { + return nil, err + } + + err = session.createUniques(bean) + if err != nil { + return nil, err + } + + err = session.createIndexes(bean) + if err != nil { + return nil, err + } + continue + } + + // this will modify an old table + if err = engine.loadTableInfo(oriTable); err != nil { + return nil, err + } + + // check columns + for _, col := range table.Columns() { + var oriCol *schemas.Column + for _, col2 := range oriTable.Columns() { + if strings.EqualFold(col.Name, col2.Name) { + oriCol = col2 + break + } + } + + // column is not exist on table + if oriCol == nil { + session.statement.RefTable = table + session.statement.SetTableName(tbNameWithSchema) + if err = session.addColumn(col.Name); err != nil { + return nil, err + } + continue + } + + err = nil + expectedType := engine.dialect.SQLType(col) + curType := engine.dialect.SQLType(oriCol) + if expectedType != curType { + if expectedType == schemas.Text && + strings.HasPrefix(curType, schemas.Varchar) { + // currently only support mysql & postgres + if engine.dialect.URI().DBType == schemas.MYSQL || + engine.dialect.URI().DBType == schemas.POSTGRES { + engine.logger.Infof("Table %s column %s change type from %s to %s\n", + tbNameWithSchema, col.Name, curType, expectedType) + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) + } else { + engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", + tbNameWithSchema, col.Name, curType, expectedType) + } + } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) { + if engine.dialect.URI().DBType == schemas.MYSQL { + if oriCol.Length < col.Length { + engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", + tbNameWithSchema, col.Name, oriCol.Length, col.Length) + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) + } + } + } else { + if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { + if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) { + engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", + tbNameWithSchema, col.Name, curType, expectedType) + } + } + } + } else if expectedType == schemas.Varchar { + if engine.dialect.URI().DBType == schemas.MYSQL { + if oriCol.Length < col.Length { + engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", + tbNameWithSchema, col.Name, oriCol.Length, col.Length) + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) + } + } + } else if col.Comment != oriCol.Comment { + if engine.dialect.URI().DBType == schemas.POSTGRES || + engine.dialect.URI().DBType == schemas.MYSQL { + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) + } + } + + if col.Default != oriCol.Default { + switch { + case col.IsAutoIncrement: // For autoincrement column, don't check default + case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) && + ((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") || + (strings.EqualFold(col.Default, "false") && oriCol.Default == "0")): + default: + engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s", + tbName, col.Name, oriCol.Default, col.Default) + } + } + if col.Nullable != oriCol.Nullable { + engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v", + tbName, col.Name, oriCol.Nullable, col.Nullable) + } + + if err != nil { + return nil, err + } + } + + foundIndexNames := make(map[string]bool) + addedNames := make(map[string]*schemas.Index) + + for name, index := range table.Indexes { + var oriIndex *schemas.Index + for name2, index2 := range oriTable.Indexes { + if index.Equal(index2) { + oriIndex = index2 + foundIndexNames[name2] = true + break + } + } + + if oriIndex != nil { + if oriIndex.Type != index.Type { + sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex) + _, err = session.exec(sql) + if err != nil { + return nil, err + } + oriIndex = nil + } + } + + if oriIndex == nil { + addedNames[name] = index + } + } + + for name2, index2 := range oriTable.Indexes { + if _, ok := foundIndexNames[name2]; !ok { + sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2) + _, err = session.exec(sql) + if err != nil { + return nil, err + } + } + } + + for name, index := range addedNames { + if index.Type == schemas.UniqueType { + session.statement.RefTable = table + session.statement.SetTableName(tbNameWithSchema) + err = session.addUnique(tbNameWithSchema, name) + } else if index.Type == schemas.IndexType { + session.statement.RefTable = table + session.statement.SetTableName(tbNameWithSchema) + err = session.addIndex(tbNameWithSchema, name) + } + if err != nil { + return nil, err + } + } + + if opts.WarnIfDatabaseColumnMissed { + // check all the columns which removed from struct fields but left on database tables. + for _, colName := range oriTable.ColumnsSeq() { + if table.GetColumn(colName) == nil { + engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName) + } + } + } + } + + return &syncResult, nil +} diff --git a/tags/parser.go b/tags/parser.go index 028f8d0b..53ef0c10 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -31,6 +31,12 @@ type TableIndices interface { var tpTableIndices = reflect.TypeOf((*TableIndices)(nil)).Elem() +type TableCollations interface { + TableCollations() []*schemas.Collation +} + +var tpTableCollations = reflect.TypeOf((*TableCollations)(nil)).Elem() + // Parser represents a parser for xorm tag type Parser struct { identifier string @@ -356,6 +362,22 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { } } + collations := tableCollations(v) + for _, collation := range collations { + if collation.Name == "" { + continue + } + if collation.Column == "" { + table.Collation = collation.Name + } else { + col := table.GetColumn(collation.Column) + if col == nil { + return nil, ErrUnsupportedType + } + col.Collation = collation.Name // this may override definition in struct tag + } + } + return table, nil } @@ -377,3 +399,22 @@ func tableIndices(v reflect.Value) []*schemas.Index { } return nil } + +func tableCollations(v reflect.Value) []*schemas.Collation { + if v.Type().Implements(tpTableCollations) { + return v.Interface().(TableCollations).TableCollations() + } + + if v.Kind() == reflect.Ptr { + v = v.Elem() + if v.Type().Implements(tpTableCollations) { + return v.Interface().(TableCollations).TableCollations() + } + } else if v.CanAddr() { + v1 := v.Addr() + if v1.Type().Implements(tpTableCollations) { + return v1.Interface().(TableCollations).TableCollations() + } + } + return nil +} diff --git a/tags/tag.go b/tags/tag.go index 41d525e1..024c9c18 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -123,6 +123,7 @@ var defaultTagHandlers = map[string]Handler{ "COMMENT": CommentTagHandler, "EXTENDS": ExtendsTagHandler, "UNSIGNED": UnsignedTagHandler, + "COLLATE": CollateTagHandler, } func init() { @@ -282,6 +283,16 @@ func CommentTagHandler(ctx *Context) error { return nil } +func CollateTagHandler(ctx *Context) error { + if len(ctx.params) > 0 { + ctx.col.Collation = ctx.params[0] + } else { + ctx.col.Collation = ctx.nextTag + ctx.ignoreNext = true + } + return nil +} + // SQLTypeTagHandler describes SQL Type tag handler func SQLTypeTagHandler(ctx *Context) error { ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} diff --git a/tags/tag_test.go b/tags/tag_test.go index 3ceeefd1..6c456f2a 100644 --- a/tags/tag_test.go +++ b/tags/tag_test.go @@ -11,68 +11,84 @@ import ( ) func TestSplitTag(t *testing.T) { - var cases = []struct { + cases := []struct { tag string tags []tag }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ - { - name: "not", - }, - { - name: "null", - }, - { - name: "default", - }, - { - name: "'2000-01-01 00:00:00'", - }, - { - name: "TIMESTAMP", - }, - }, - }, - {"TEXT", []tag{ - { - name: "TEXT", - }, - }, - }, - {"default('2000-01-01 00:00:00')", []tag{ - { - name: "default", - params: []string{ - "'2000-01-01 00:00:00'", + { + "not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ + { + name: "not", + }, + { + name: "null", + }, + { + name: "default", + }, + { + name: "'2000-01-01 00:00:00'", + }, + { + name: "TIMESTAMP", }, }, }, - }, - {"json binary", []tag{ - { - name: "json", - }, - { - name: "binary", + { + "TEXT", []tag{ + { + name: "TEXT", + }, }, }, - }, - {"numeric(10, 2)", []tag{ - { - name: "numeric", - params: []string{"10", "2"}, + { + "default('2000-01-01 00:00:00')", []tag{ + { + name: "default", + params: []string{ + "'2000-01-01 00:00:00'", + }, + }, }, }, - }, - {"numeric(10, 2) notnull", []tag{ - { - name: "numeric", - params: []string{"10", "2"}, - }, - { - name: "notnull", + { + "json binary", []tag{ + { + name: "json", + }, + { + name: "binary", + }, }, }, + { + "numeric(10, 2)", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + }, + }, + { + "numeric(10, 2) notnull", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + { + name: "notnull", + }, + }, + }, + { + "collate utf8mb4_bin", []tag{ + { + name: "collate", + }, + { + name: "utf8mb4_bin", + }, + }, }, } @@ -82,7 +98,7 @@ func TestSplitTag(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, len(tags), len(kase.tags)) for i := 0; i < len(tags); i++ { - assert.Equal(t, tags[i], kase.tags[i]) + assert.Equal(t, kase.tags[i], tags[i]) } }) }