Merge branch 'master' into lunny/add_alias_table
This commit is contained in:
commit
3e0887c5c2
20
.drone.yml
20
.drone.yml
|
@ -249,11 +249,11 @@ volumes:
|
||||||
services:
|
services:
|
||||||
- name: mssql
|
- name: mssql
|
||||||
pull: always
|
pull: always
|
||||||
image: microsoft/mssql-server-linux:latest
|
image: mcr.microsoft.com/mssql/server:latest
|
||||||
environment:
|
environment:
|
||||||
ACCEPT_EULA: Y
|
ACCEPT_EULA: Y
|
||||||
SA_PASSWORD: yourStrong(!)Password
|
SA_PASSWORD: yourStrong(!)Password
|
||||||
MSSQL_PID: Developer
|
MSSQL_PID: Standard
|
||||||
|
|
||||||
---
|
---
|
||||||
kind: pipeline
|
kind: pipeline
|
||||||
|
@ -347,3 +347,19 @@ steps:
|
||||||
image: golang:1.15
|
image: golang:1.15
|
||||||
commands:
|
commands:
|
||||||
- make coverage
|
- make coverage
|
||||||
|
|
||||||
|
---
|
||||||
|
kind: pipeline
|
||||||
|
name: release-tag
|
||||||
|
trigger:
|
||||||
|
event:
|
||||||
|
- tag
|
||||||
|
steps:
|
||||||
|
- name: release-tag-gitea
|
||||||
|
pull: always
|
||||||
|
image: plugins/gitea-release:latest
|
||||||
|
settings:
|
||||||
|
base_url: https://gitea.com
|
||||||
|
title: '${DRONE_TAG} is released'
|
||||||
|
api_key:
|
||||||
|
from_secret: gitea_token
|
|
@ -37,3 +37,4 @@ test.db.sql
|
||||||
test.db
|
test.db
|
||||||
integrations/*.sql
|
integrations/*.sql
|
||||||
integrations/test_sqlite*
|
integrations/test_sqlite*
|
||||||
|
cover.out
|
14
.revive.toml
14
.revive.toml
|
@ -8,20 +8,22 @@ warningCode = 1
|
||||||
[rule.context-as-argument]
|
[rule.context-as-argument]
|
||||||
[rule.context-keys-type]
|
[rule.context-keys-type]
|
||||||
[rule.dot-imports]
|
[rule.dot-imports]
|
||||||
|
[rule.empty-lines]
|
||||||
|
[rule.errorf]
|
||||||
[rule.error-return]
|
[rule.error-return]
|
||||||
[rule.error-strings]
|
[rule.error-strings]
|
||||||
[rule.error-naming]
|
[rule.error-naming]
|
||||||
[rule.exported]
|
[rule.exported]
|
||||||
[rule.if-return]
|
[rule.if-return]
|
||||||
[rule.increment-decrement]
|
[rule.increment-decrement]
|
||||||
[rule.var-naming]
|
[rule.indent-error-flow]
|
||||||
arguments = [["ID", "UID", "UUID", "URL", "JSON"], []]
|
|
||||||
[rule.var-declaration]
|
|
||||||
[rule.package-comments]
|
[rule.package-comments]
|
||||||
[rule.range]
|
[rule.range]
|
||||||
[rule.receiver-naming]
|
[rule.receiver-naming]
|
||||||
|
[rule.struct-tag]
|
||||||
[rule.time-naming]
|
[rule.time-naming]
|
||||||
[rule.unexported-return]
|
[rule.unexported-return]
|
||||||
[rule.indent-error-flow]
|
[rule.unnecessary-stmt]
|
||||||
[rule.errorf]
|
[rule.var-declaration]
|
||||||
[rule.struct-tag]
|
[rule.var-naming]
|
||||||
|
arguments = [["ID", "UID", "UUID", "URL", "JSON"], []]
|
35
CHANGELOG.md
35
CHANGELOG.md
|
@ -3,6 +3,41 @@
|
||||||
This changelog goes through all the changes that have been made in each release
|
This changelog goes through all the changes that have been made in each release
|
||||||
without substantial changes to our git log.
|
without substantial changes to our git log.
|
||||||
|
|
||||||
|
## [1.1.2](https://gitea.com/xorm/xorm/releases/tag/1.1.2) - 2021-07-04
|
||||||
|
|
||||||
|
* BUILD
|
||||||
|
* Add release tag (#1966)
|
||||||
|
|
||||||
|
## [1.1.1](https://gitea.com/xorm/xorm/releases/tag/1.1.1) - 2021-07-03
|
||||||
|
|
||||||
|
* BUGFIXES
|
||||||
|
* Ignore comments when deciding when to replace question marks. #1954 (#1955)
|
||||||
|
* Fix bug didn't reset statement on update (#1939)
|
||||||
|
* Fix create table with struct missing columns (#1938)
|
||||||
|
* Fix #929 (#1936)
|
||||||
|
* Fix exist (#1921)
|
||||||
|
* ENHANCEMENTS
|
||||||
|
* Improve get field value of bean (#1961)
|
||||||
|
* refactor splitTag function (#1960)
|
||||||
|
* Fix #1663 (#1952)
|
||||||
|
* fix pg GetColumns missing comment (#1949)
|
||||||
|
* Support build flag jsoniter to replace default json (#1916)
|
||||||
|
* refactor exprParam (#1825)
|
||||||
|
* Add DBVersion (#1723)
|
||||||
|
* TESTING
|
||||||
|
* Add test to confirm #1247 resolved (#1951)
|
||||||
|
* Add test for dump table with default value (#1950)
|
||||||
|
* Test for #1486 (#1942)
|
||||||
|
* Add sync tests to confirm #539 is gone (#1937)
|
||||||
|
* test for unsigned int32 (#1923)
|
||||||
|
* Add tests for array store (#1922)
|
||||||
|
* BUILD
|
||||||
|
* Remove mymysql from ci (#1928)
|
||||||
|
* MISC
|
||||||
|
* fix lint (#1953)
|
||||||
|
* Compitable with cockroach (#1930)
|
||||||
|
* Replace goracle with godror (#1914)
|
||||||
|
|
||||||
## [1.1.0](https://gitea.com/xorm/xorm/releases/tag/1.1.0) - 2021-05-14
|
## [1.1.0](https://gitea.com/xorm/xorm/releases/tag/1.1.0) - 2021-05-14
|
||||||
|
|
||||||
* FEATURES
|
* FEATURES
|
||||||
|
|
458
convert.go
458
convert.go
|
@ -5,12 +5,15 @@
|
||||||
package xorm
|
package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"xorm.io/xorm/convert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
|
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
|
||||||
|
@ -37,6 +40,12 @@ func asString(src interface{}) string {
|
||||||
return v
|
return v
|
||||||
case []byte:
|
case []byte:
|
||||||
return string(v)
|
return string(v)
|
||||||
|
case *sql.NullString:
|
||||||
|
return v.String
|
||||||
|
case *sql.NullInt32:
|
||||||
|
return fmt.Sprintf("%d", v.Int32)
|
||||||
|
case *sql.NullInt64:
|
||||||
|
return fmt.Sprintf("%d", v.Int64)
|
||||||
}
|
}
|
||||||
rv := reflect.ValueOf(src)
|
rv := reflect.ValueOf(src)
|
||||||
switch rv.Kind() {
|
switch rv.Kind() {
|
||||||
|
@ -54,6 +63,156 @@ func asString(src interface{}) string {
|
||||||
return fmt.Sprintf("%v", src)
|
return fmt.Sprintf("%v", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func asInt64(src interface{}) (int64, error) {
|
||||||
|
switch v := src.(type) {
|
||||||
|
case int:
|
||||||
|
return int64(v), nil
|
||||||
|
case int16:
|
||||||
|
return int64(v), nil
|
||||||
|
case int32:
|
||||||
|
return int64(v), nil
|
||||||
|
case int8:
|
||||||
|
return int64(v), nil
|
||||||
|
case int64:
|
||||||
|
return v, nil
|
||||||
|
case uint:
|
||||||
|
return int64(v), nil
|
||||||
|
case uint8:
|
||||||
|
return int64(v), nil
|
||||||
|
case uint16:
|
||||||
|
return int64(v), nil
|
||||||
|
case uint32:
|
||||||
|
return int64(v), nil
|
||||||
|
case uint64:
|
||||||
|
return int64(v), nil
|
||||||
|
case []byte:
|
||||||
|
return strconv.ParseInt(string(v), 10, 64)
|
||||||
|
case string:
|
||||||
|
return strconv.ParseInt(v, 10, 64)
|
||||||
|
case *sql.NullString:
|
||||||
|
return strconv.ParseInt(v.String, 10, 64)
|
||||||
|
case *sql.NullInt32:
|
||||||
|
return int64(v.Int32), nil
|
||||||
|
case *sql.NullInt64:
|
||||||
|
return int64(v.Int64), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rv := reflect.ValueOf(src)
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
return rv.Int(), nil
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return int64(rv.Uint()), nil
|
||||||
|
case reflect.Float64:
|
||||||
|
return int64(rv.Float()), nil
|
||||||
|
case reflect.Float32:
|
||||||
|
return int64(rv.Float()), nil
|
||||||
|
case reflect.String:
|
||||||
|
return strconv.ParseInt(rv.String(), 10, 64)
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("unsupported value %T as int64", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func asUint64(src interface{}) (uint64, error) {
|
||||||
|
switch v := src.(type) {
|
||||||
|
case int:
|
||||||
|
return uint64(v), nil
|
||||||
|
case int16:
|
||||||
|
return uint64(v), nil
|
||||||
|
case int32:
|
||||||
|
return uint64(v), nil
|
||||||
|
case int8:
|
||||||
|
return uint64(v), nil
|
||||||
|
case int64:
|
||||||
|
return uint64(v), nil
|
||||||
|
case uint:
|
||||||
|
return uint64(v), nil
|
||||||
|
case uint8:
|
||||||
|
return uint64(v), nil
|
||||||
|
case uint16:
|
||||||
|
return uint64(v), nil
|
||||||
|
case uint32:
|
||||||
|
return uint64(v), nil
|
||||||
|
case uint64:
|
||||||
|
return v, nil
|
||||||
|
case []byte:
|
||||||
|
return strconv.ParseUint(string(v), 10, 64)
|
||||||
|
case string:
|
||||||
|
return strconv.ParseUint(v, 10, 64)
|
||||||
|
case *sql.NullString:
|
||||||
|
return strconv.ParseUint(v.String, 10, 64)
|
||||||
|
case *sql.NullInt32:
|
||||||
|
return uint64(v.Int32), nil
|
||||||
|
case *sql.NullInt64:
|
||||||
|
return uint64(v.Int64), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rv := reflect.ValueOf(src)
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
return uint64(rv.Int()), nil
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return uint64(rv.Uint()), nil
|
||||||
|
case reflect.Float64:
|
||||||
|
return uint64(rv.Float()), nil
|
||||||
|
case reflect.Float32:
|
||||||
|
return uint64(rv.Float()), nil
|
||||||
|
case reflect.String:
|
||||||
|
return strconv.ParseUint(rv.String(), 10, 64)
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("unsupported value %T as uint64", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func asFloat64(src interface{}) (float64, error) {
|
||||||
|
switch v := src.(type) {
|
||||||
|
case int:
|
||||||
|
return float64(v), nil
|
||||||
|
case int16:
|
||||||
|
return float64(v), nil
|
||||||
|
case int32:
|
||||||
|
return float64(v), nil
|
||||||
|
case int8:
|
||||||
|
return float64(v), nil
|
||||||
|
case int64:
|
||||||
|
return float64(v), nil
|
||||||
|
case uint:
|
||||||
|
return float64(v), nil
|
||||||
|
case uint8:
|
||||||
|
return float64(v), nil
|
||||||
|
case uint16:
|
||||||
|
return float64(v), nil
|
||||||
|
case uint32:
|
||||||
|
return float64(v), nil
|
||||||
|
case uint64:
|
||||||
|
return float64(v), nil
|
||||||
|
case []byte:
|
||||||
|
return strconv.ParseFloat(string(v), 64)
|
||||||
|
case string:
|
||||||
|
return strconv.ParseFloat(v, 64)
|
||||||
|
case *sql.NullString:
|
||||||
|
return strconv.ParseFloat(v.String, 64)
|
||||||
|
case *sql.NullInt32:
|
||||||
|
return float64(v.Int32), nil
|
||||||
|
case *sql.NullInt64:
|
||||||
|
return float64(v.Int64), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rv := reflect.ValueOf(src)
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
return float64(rv.Int()), nil
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return float64(rv.Uint()), nil
|
||||||
|
case reflect.Float64:
|
||||||
|
return float64(rv.Float()), nil
|
||||||
|
case reflect.Float32:
|
||||||
|
return float64(rv.Float()), nil
|
||||||
|
case reflect.String:
|
||||||
|
return strconv.ParseFloat(rv.String(), 64)
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("unsupported value %T as int64", src)
|
||||||
|
}
|
||||||
|
|
||||||
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
|
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
|
||||||
switch rv.Kind() {
|
switch rv.Kind() {
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
@ -76,7 +235,7 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
|
||||||
// convertAssign copies to dest the value in src, converting it if possible.
|
// convertAssign copies to dest the value in src, converting it if possible.
|
||||||
// An error is returned if the copy would result in loss of information.
|
// An error is returned if the copy would result in loss of information.
|
||||||
// dest should be a pointer type.
|
// dest should be a pointer type.
|
||||||
func convertAssign(dest, src interface{}) error {
|
func convertAssign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error {
|
||||||
// Common cases, without reflect.
|
// Common cases, without reflect.
|
||||||
switch s := src.(type) {
|
switch s := src.(type) {
|
||||||
case string:
|
case string:
|
||||||
|
@ -143,6 +302,163 @@ func convertAssign(dest, src interface{}) error {
|
||||||
*d = nil
|
*d = nil
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
case *sql.NullString:
|
||||||
|
switch d := dest.(type) {
|
||||||
|
case *int:
|
||||||
|
if s.Valid {
|
||||||
|
*d, _ = strconv.Atoi(s.String)
|
||||||
|
}
|
||||||
|
case *int64:
|
||||||
|
if s.Valid {
|
||||||
|
*d, _ = strconv.ParseInt(s.String, 10, 64)
|
||||||
|
}
|
||||||
|
case *string:
|
||||||
|
if s.Valid {
|
||||||
|
*d = s.String
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *time.Time:
|
||||||
|
if s.Valid {
|
||||||
|
var err error
|
||||||
|
dt, err := convert.String2Time(s.String, originalLocation, convertedLocation)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*d = *dt
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *sql.NullTime:
|
||||||
|
if s.Valid {
|
||||||
|
var err error
|
||||||
|
dt, err := convert.String2Time(s.String, originalLocation, convertedLocation)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
d.Valid = true
|
||||||
|
d.Time = *dt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case *sql.NullInt32:
|
||||||
|
switch d := dest.(type) {
|
||||||
|
case *int:
|
||||||
|
if s.Valid {
|
||||||
|
*d = int(s.Int32)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *int8:
|
||||||
|
if s.Valid {
|
||||||
|
*d = int8(s.Int32)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *int16:
|
||||||
|
if s.Valid {
|
||||||
|
*d = int16(s.Int32)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *int32:
|
||||||
|
if s.Valid {
|
||||||
|
*d = s.Int32
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *int64:
|
||||||
|
if s.Valid {
|
||||||
|
*d = int64(s.Int32)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case *sql.NullInt64:
|
||||||
|
switch d := dest.(type) {
|
||||||
|
case *int:
|
||||||
|
if s.Valid {
|
||||||
|
*d = int(s.Int64)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *int8:
|
||||||
|
if s.Valid {
|
||||||
|
*d = int8(s.Int64)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *int16:
|
||||||
|
if s.Valid {
|
||||||
|
*d = int16(s.Int64)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *int32:
|
||||||
|
if s.Valid {
|
||||||
|
*d = int32(s.Int64)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *int64:
|
||||||
|
if s.Valid {
|
||||||
|
*d = s.Int64
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case *sql.NullFloat64:
|
||||||
|
switch d := dest.(type) {
|
||||||
|
case *int:
|
||||||
|
if s.Valid {
|
||||||
|
*d = int(s.Float64)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *float64:
|
||||||
|
if s.Valid {
|
||||||
|
*d = s.Float64
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case *sql.NullBool:
|
||||||
|
switch d := dest.(type) {
|
||||||
|
case *bool:
|
||||||
|
if s.Valid {
|
||||||
|
*d = s.Bool
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case *sql.NullTime:
|
||||||
|
switch d := dest.(type) {
|
||||||
|
case *time.Time:
|
||||||
|
if s.Valid {
|
||||||
|
*d = s.Time
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *string:
|
||||||
|
if s.Valid {
|
||||||
|
*d = s.Time.In(convertedLocation).Format("2006-01-02 15:04:05")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case *NullUint32:
|
||||||
|
switch d := dest.(type) {
|
||||||
|
case *uint8:
|
||||||
|
if s.Valid {
|
||||||
|
*d = uint8(s.Uint32)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *uint16:
|
||||||
|
if s.Valid {
|
||||||
|
*d = uint16(s.Uint32)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case *uint:
|
||||||
|
if s.Valid {
|
||||||
|
*d = uint(s.Uint32)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case *NullUint64:
|
||||||
|
switch d := dest.(type) {
|
||||||
|
case *uint64:
|
||||||
|
if s.Valid {
|
||||||
|
*d = s.Uint64
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case *sql.RawBytes:
|
||||||
|
switch d := dest.(type) {
|
||||||
|
case convert.Conversion:
|
||||||
|
return d.FromDB(*s)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var sv reflect.Value
|
var sv reflect.Value
|
||||||
|
@ -175,7 +491,10 @@ func convertAssign(dest, src interface{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
dpv := reflect.ValueOf(dest)
|
return convertAssignV(reflect.ValueOf(dest), src, originalLocation, convertedLocation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, convertedLocation *time.Location) error {
|
||||||
if dpv.Kind() != reflect.Ptr {
|
if dpv.Kind() != reflect.Ptr {
|
||||||
return errors.New("destination not a pointer")
|
return errors.New("destination not a pointer")
|
||||||
}
|
}
|
||||||
|
@ -183,9 +502,7 @@ func convertAssign(dest, src interface{}) error {
|
||||||
return errNilPtr
|
return errNilPtr
|
||||||
}
|
}
|
||||||
|
|
||||||
if !sv.IsValid() {
|
var sv = reflect.ValueOf(src)
|
||||||
sv = reflect.ValueOf(src)
|
|
||||||
}
|
|
||||||
|
|
||||||
dv := reflect.Indirect(dpv)
|
dv := reflect.Indirect(dpv)
|
||||||
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
|
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
|
||||||
|
@ -211,31 +528,28 @@ func convertAssign(dest, src interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
dv.Set(reflect.New(dv.Type().Elem()))
|
dv.Set(reflect.New(dv.Type().Elem()))
|
||||||
return convertAssign(dv.Interface(), src)
|
return convertAssign(dv.Interface(), src, originalLocation, convertedLocation)
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
s := asString(src)
|
i64, err := asInt64(src)
|
||||||
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = strconvErr(err)
|
err = strconvErr(err)
|
||||||
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
|
return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err)
|
||||||
}
|
}
|
||||||
dv.SetInt(i64)
|
dv.SetInt(i64)
|
||||||
return nil
|
return nil
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
s := asString(src)
|
u64, err := asUint64(src)
|
||||||
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = strconvErr(err)
|
err = strconvErr(err)
|
||||||
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
|
return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err)
|
||||||
}
|
}
|
||||||
dv.SetUint(u64)
|
dv.SetUint(u64)
|
||||||
return nil
|
return nil
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
s := asString(src)
|
f64, err := asFloat64(src)
|
||||||
f64, err := strconv.ParseFloat(s, dv.Type().Bits())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = strconvErr(err)
|
err = strconvErr(err)
|
||||||
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
|
return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err)
|
||||||
}
|
}
|
||||||
dv.SetFloat(f64)
|
dv.SetFloat(f64)
|
||||||
return nil
|
return nil
|
||||||
|
@ -244,7 +558,7 @@ func convertAssign(dest, src interface{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
|
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dpv.Interface())
|
||||||
}
|
}
|
||||||
|
|
||||||
func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
|
func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
|
||||||
|
@ -376,47 +690,79 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) {
|
||||||
return v.Interface(), nil
|
return v.Interface(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func int64ToIntValue(id int64, tp reflect.Type) reflect.Value {
|
var (
|
||||||
var v interface{}
|
_ sql.Scanner = &NullUint64{}
|
||||||
kind := tp.Kind()
|
)
|
||||||
|
|
||||||
if kind == reflect.Ptr {
|
// NullUint64 represents an uint64 that may be null.
|
||||||
kind = tp.Elem().Kind()
|
// NullUint64 implements the Scanner interface so
|
||||||
}
|
// it can be used as a scan destination, similar to NullString.
|
||||||
|
type NullUint64 struct {
|
||||||
switch kind {
|
Uint64 uint64
|
||||||
case reflect.Int16:
|
Valid bool
|
||||||
temp := int16(id)
|
|
||||||
v = &temp
|
|
||||||
case reflect.Int32:
|
|
||||||
temp := int32(id)
|
|
||||||
v = &temp
|
|
||||||
case reflect.Int:
|
|
||||||
temp := int(id)
|
|
||||||
v = &temp
|
|
||||||
case reflect.Int64:
|
|
||||||
temp := id
|
|
||||||
v = &temp
|
|
||||||
case reflect.Uint16:
|
|
||||||
temp := uint16(id)
|
|
||||||
v = &temp
|
|
||||||
case reflect.Uint32:
|
|
||||||
temp := uint32(id)
|
|
||||||
v = &temp
|
|
||||||
case reflect.Uint64:
|
|
||||||
temp := uint64(id)
|
|
||||||
v = &temp
|
|
||||||
case reflect.Uint:
|
|
||||||
temp := uint(id)
|
|
||||||
v = &temp
|
|
||||||
}
|
|
||||||
|
|
||||||
if tp.Kind() == reflect.Ptr {
|
|
||||||
return reflect.ValueOf(v).Convert(tp)
|
|
||||||
}
|
|
||||||
return reflect.ValueOf(v).Elem().Convert(tp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func int64ToInt(id int64, tp reflect.Type) interface{} {
|
// Scan implements the Scanner interface.
|
||||||
return int64ToIntValue(id, tp).Interface()
|
func (n *NullUint64) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
n.Uint64, n.Valid = 0, false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
n.Valid = true
|
||||||
|
var err error
|
||||||
|
n.Uint64, err = asUint64(value)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the driver Valuer interface.
|
||||||
|
func (n NullUint64) Value() (driver.Value, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return n.Uint64, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ sql.Scanner = &NullUint32{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// NullUint32 represents an uint32 that may be null.
|
||||||
|
// NullUint32 implements the Scanner interface so
|
||||||
|
// it can be used as a scan destination, similar to NullString.
|
||||||
|
type NullUint32 struct {
|
||||||
|
Uint32 uint32
|
||||||
|
Valid bool // Valid is true if Uint32 is not NULL
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the Scanner interface.
|
||||||
|
func (n *NullUint32) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
n.Uint32, n.Valid = 0, false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
n.Valid = true
|
||||||
|
i64, err := asUint64(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n.Uint32 = uint32(i64)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the driver Valuer interface.
|
||||||
|
func (n NullUint32) Value() (driver.Value, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return int64(n.Uint32), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ sql.Scanner = &EmptyScanner{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type EmptyScanner struct{}
|
||||||
|
|
||||||
|
func (EmptyScanner) Scan(value interface{}) error {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
// Copyright 2021 The Xorm Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Interface2Interface(userLocation *time.Location, v interface{}) (interface{}, error) {
|
||||||
|
if v == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
switch vv := v.(type) {
|
||||||
|
case *int64:
|
||||||
|
return *vv, nil
|
||||||
|
case *int8:
|
||||||
|
return *vv, nil
|
||||||
|
case *sql.NullString:
|
||||||
|
return vv.String, nil
|
||||||
|
case *sql.RawBytes:
|
||||||
|
if len([]byte(*vv)) > 0 {
|
||||||
|
return []byte(*vv), nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
case *sql.NullInt32:
|
||||||
|
return vv.Int32, nil
|
||||||
|
case *sql.NullInt64:
|
||||||
|
return vv.Int64, nil
|
||||||
|
case *sql.NullFloat64:
|
||||||
|
return vv.Float64, nil
|
||||||
|
case *sql.NullBool:
|
||||||
|
if vv.Valid {
|
||||||
|
return vv.Bool, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
case *sql.NullTime:
|
||||||
|
if vv.Valid {
|
||||||
|
return vv.Time.In(userLocation).Format("2006-01-02 15:04:05"), nil
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("convert assign string unsupported type: %#v", vv)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
// Copyright 2021 The Xorm Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// String2Time converts a string to time with original location
|
||||||
|
func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) {
|
||||||
|
if len(s) == 19 {
|
||||||
|
dt, err := time.ParseInLocation("2006-01-02 15:04:05", s, originalLocation)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dt = dt.In(convertedLocation)
|
||||||
|
return &dt, nil
|
||||||
|
} else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' {
|
||||||
|
dt, err := time.ParseInLocation("2006-01-02T15:04:05Z", s, originalLocation)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dt = dt.In(convertedLocation)
|
||||||
|
return &dt, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unsupported convertion from %s to time", s)
|
||||||
|
}
|
|
@ -5,12 +5,29 @@
|
||||||
package dialects
|
package dialects
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"xorm.io/xorm/core"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ScanContext represents a context when Scan
|
||||||
|
type ScanContext struct {
|
||||||
|
DBLocation *time.Location
|
||||||
|
UserLocation *time.Location
|
||||||
|
}
|
||||||
|
|
||||||
|
type DriverFeatures struct {
|
||||||
|
SupportNullable bool
|
||||||
|
}
|
||||||
|
|
||||||
// Driver represents a database driver
|
// Driver represents a database driver
|
||||||
type Driver interface {
|
type Driver interface {
|
||||||
Parse(string, string) (*URI, error)
|
Parse(string, string) (*URI, error)
|
||||||
|
Features() DriverFeatures
|
||||||
|
GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface
|
||||||
|
Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -59,3 +76,15 @@ func OpenDialect(driverName, connstr string) (Dialect, error) {
|
||||||
|
|
||||||
return dialect, nil
|
return dialect, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type baseDriver struct{}
|
||||||
|
|
||||||
|
func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error {
|
||||||
|
return rows.Scan(v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *baseDriver) Features() DriverFeatures {
|
||||||
|
return DriverFeatures{
|
||||||
|
SupportNullable: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -23,13 +23,45 @@ type SeqFilter struct {
|
||||||
func convertQuestionMark(sql, prefix string, start int) string {
|
func convertQuestionMark(sql, prefix string, start int) string {
|
||||||
var buf strings.Builder
|
var buf strings.Builder
|
||||||
var beginSingleQuote bool
|
var beginSingleQuote bool
|
||||||
|
var isLineComment bool
|
||||||
|
var isComment bool
|
||||||
|
var isMaybeLineComment bool
|
||||||
|
var isMaybeComment bool
|
||||||
|
var isMaybeCommentEnd bool
|
||||||
var index = start
|
var index = start
|
||||||
for _, c := range sql {
|
for _, c := range sql {
|
||||||
if !beginSingleQuote && c == '?' {
|
if !beginSingleQuote && !isLineComment && !isComment && c == '?' {
|
||||||
buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
|
buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
|
||||||
index++
|
index++
|
||||||
} else {
|
} else {
|
||||||
if c == '\'' {
|
if isMaybeLineComment {
|
||||||
|
if c == '-' {
|
||||||
|
isLineComment = true
|
||||||
|
}
|
||||||
|
isMaybeLineComment = false
|
||||||
|
} else if isMaybeComment {
|
||||||
|
if c == '*' {
|
||||||
|
isComment = true
|
||||||
|
}
|
||||||
|
isMaybeComment = false
|
||||||
|
} else if isMaybeCommentEnd {
|
||||||
|
if c == '/' {
|
||||||
|
isComment = false
|
||||||
|
}
|
||||||
|
isMaybeCommentEnd = false
|
||||||
|
} else if isLineComment {
|
||||||
|
if c == '\n' {
|
||||||
|
isLineComment = false
|
||||||
|
}
|
||||||
|
} else if isComment {
|
||||||
|
if c == '*' {
|
||||||
|
isMaybeCommentEnd = true
|
||||||
|
}
|
||||||
|
} else if !beginSingleQuote && c == '-' {
|
||||||
|
isMaybeLineComment = true
|
||||||
|
} else if !beginSingleQuote && c == '/' {
|
||||||
|
isMaybeComment = true
|
||||||
|
} else if c == '\'' {
|
||||||
beginSingleQuote = !beginSingleQuote
|
beginSingleQuote = !beginSingleQuote
|
||||||
}
|
}
|
||||||
buf.WriteRune(c)
|
buf.WriteRune(c)
|
||||||
|
|
|
@ -19,3 +19,60 @@ func TestSeqFilter(t *testing.T) {
|
||||||
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
|
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSeqFilterLineComment(t *testing.T) {
|
||||||
|
var kases = map[string]string{
|
||||||
|
`SELECT *
|
||||||
|
FROM TABLE1
|
||||||
|
WHERE foo='bar'
|
||||||
|
AND a=? -- it's a comment
|
||||||
|
AND b=?`: `SELECT *
|
||||||
|
FROM TABLE1
|
||||||
|
WHERE foo='bar'
|
||||||
|
AND a=$1 -- it's a comment
|
||||||
|
AND b=$2`,
|
||||||
|
`SELECT *
|
||||||
|
FROM TABLE1
|
||||||
|
WHERE foo='bar'
|
||||||
|
AND a=? -- it's a comment?
|
||||||
|
AND b=?`: `SELECT *
|
||||||
|
FROM TABLE1
|
||||||
|
WHERE foo='bar'
|
||||||
|
AND a=$1 -- it's a comment?
|
||||||
|
AND b=$2`,
|
||||||
|
`SELECT *
|
||||||
|
FROM TABLE1
|
||||||
|
WHERE a=? -- it's a comment? and that's okay?
|
||||||
|
AND b=?`: `SELECT *
|
||||||
|
FROM TABLE1
|
||||||
|
WHERE a=$1 -- it's a comment? and that's okay?
|
||||||
|
AND b=$2`,
|
||||||
|
}
|
||||||
|
for sql, result := range kases {
|
||||||
|
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSeqFilterComment(t *testing.T) {
|
||||||
|
var kases = map[string]string{
|
||||||
|
`SELECT *
|
||||||
|
FROM TABLE1
|
||||||
|
WHERE a=? /* it's a comment */
|
||||||
|
AND b=?`: `SELECT *
|
||||||
|
FROM TABLE1
|
||||||
|
WHERE a=$1 /* it's a comment */
|
||||||
|
AND b=$2`,
|
||||||
|
`SELECT /* it's a comment * ?
|
||||||
|
More comment on the next line! */ *
|
||||||
|
FROM TABLE1
|
||||||
|
WHERE a=? /**/
|
||||||
|
AND b=?`: `SELECT /* it's a comment * ?
|
||||||
|
More comment on the next line! */ *
|
||||||
|
FROM TABLE1
|
||||||
|
WHERE a=$1 /**/
|
||||||
|
AND b=$2`,
|
||||||
|
}
|
||||||
|
for sql, result := range kases {
|
||||||
|
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ package dialects
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -624,6 +625,7 @@ func (db *mssql) Filters() []Filter {
|
||||||
}
|
}
|
||||||
|
|
||||||
type odbcDriver struct {
|
type odbcDriver struct {
|
||||||
|
baseDriver
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
|
@ -652,3 +654,26 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
}
|
}
|
||||||
return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil
|
return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *odbcDriver) GenScanResult(colType string) (interface{}, error) {
|
||||||
|
switch colType {
|
||||||
|
case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT":
|
||||||
|
fallthrough
|
||||||
|
case "DATE", "DATETIME", "DATETIME2", "TIME":
|
||||||
|
var s sql.NullString
|
||||||
|
return &s, nil
|
||||||
|
case "FLOAT", "REAL":
|
||||||
|
var s sql.NullFloat64
|
||||||
|
return &s, nil
|
||||||
|
case "BIGINT", "DATETIMEOFFSET":
|
||||||
|
var s sql.NullInt64
|
||||||
|
return &s, nil
|
||||||
|
case "TINYINT", "SMALLINT", "INT":
|
||||||
|
var s sql.NullInt32
|
||||||
|
return &s, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
var r sql.RawBytes
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ package dialects
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
@ -14,6 +15,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"xorm.io/xorm/convert"
|
||||||
"xorm.io/xorm/core"
|
"xorm.io/xorm/core"
|
||||||
"xorm.io/xorm/schemas"
|
"xorm.io/xorm/schemas"
|
||||||
)
|
)
|
||||||
|
@ -645,7 +647,125 @@ func (db *mysql) Filters() []Filter {
|
||||||
return []Filter{}
|
return []Filter{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mysqlDriver struct {
|
||||||
|
baseDriver
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
|
dsnPattern := regexp.MustCompile(
|
||||||
|
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
|
||||||
|
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
|
||||||
|
`\/(?P<dbname>.*?)` + // /dbname
|
||||||
|
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1¶mN=valueN]
|
||||||
|
matches := dsnPattern.FindStringSubmatch(dataSourceName)
|
||||||
|
// tlsConfigRegister := make(map[string]*tls.Config)
|
||||||
|
names := dsnPattern.SubexpNames()
|
||||||
|
|
||||||
|
uri := &URI{DBType: schemas.MYSQL}
|
||||||
|
|
||||||
|
for i, match := range matches {
|
||||||
|
switch names[i] {
|
||||||
|
case "dbname":
|
||||||
|
uri.DBName = match
|
||||||
|
case "params":
|
||||||
|
if len(match) > 0 {
|
||||||
|
kvs := strings.Split(match, "&")
|
||||||
|
for _, kv := range kvs {
|
||||||
|
splits := strings.Split(kv, "=")
|
||||||
|
if len(splits) == 2 {
|
||||||
|
switch splits[0] {
|
||||||
|
case "charset":
|
||||||
|
uri.Charset = splits[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return uri, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) {
|
||||||
|
switch colType {
|
||||||
|
case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET":
|
||||||
|
var s sql.NullString
|
||||||
|
return &s, nil
|
||||||
|
case "BIGINT":
|
||||||
|
var s sql.NullInt64
|
||||||
|
return &s, nil
|
||||||
|
case "TINYINT", "SMALLINT", "MEDIUMINT", "INT":
|
||||||
|
var s sql.NullInt32
|
||||||
|
return &s, nil
|
||||||
|
case "FLOAT", "REAL", "DOUBLE PRECISION":
|
||||||
|
var s sql.NullFloat64
|
||||||
|
return &s, nil
|
||||||
|
case "DECIMAL", "NUMERIC":
|
||||||
|
var s sql.NullString
|
||||||
|
return &s, nil
|
||||||
|
case "DATETIME":
|
||||||
|
var s sql.NullTime
|
||||||
|
return &s, nil
|
||||||
|
case "BIT":
|
||||||
|
var s sql.RawBytes
|
||||||
|
return &s, nil
|
||||||
|
case "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB":
|
||||||
|
var r sql.RawBytes
|
||||||
|
return &r, nil
|
||||||
|
default:
|
||||||
|
var r sql.RawBytes
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mysqlDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, scanResults ...interface{}) error {
|
||||||
|
var v2 = make([]interface{}, 0, len(scanResults))
|
||||||
|
var turnBackIdxes = make([]int, 0, 5)
|
||||||
|
for i, vv := range scanResults {
|
||||||
|
switch vv.(type) {
|
||||||
|
case *time.Time:
|
||||||
|
v2 = append(v2, &sql.NullString{})
|
||||||
|
turnBackIdxes = append(turnBackIdxes, i)
|
||||||
|
case *sql.NullTime:
|
||||||
|
v2 = append(v2, &sql.NullString{})
|
||||||
|
turnBackIdxes = append(turnBackIdxes, i)
|
||||||
|
default:
|
||||||
|
v2 = append(v2, scanResults[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := rows.Scan(v2...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, i := range turnBackIdxes {
|
||||||
|
switch t := scanResults[i].(type) {
|
||||||
|
case *time.Time:
|
||||||
|
var s = *(v2[i].(*sql.NullString))
|
||||||
|
if !s.Valid {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*t = *dt
|
||||||
|
case *sql.NullTime:
|
||||||
|
var s = *(v2[i].(*sql.NullString))
|
||||||
|
if !s.Valid {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.Time = *dt
|
||||||
|
t.Valid = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type mymysqlDriver struct {
|
type mymysqlDriver struct {
|
||||||
|
mysqlDriver
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
|
@ -696,41 +816,3 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
|
|
||||||
return uri, nil
|
return uri, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type mysqlDriver struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
|
||||||
dsnPattern := regexp.MustCompile(
|
|
||||||
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
|
|
||||||
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
|
|
||||||
`\/(?P<dbname>.*?)` + // /dbname
|
|
||||||
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1¶mN=valueN]
|
|
||||||
matches := dsnPattern.FindStringSubmatch(dataSourceName)
|
|
||||||
// tlsConfigRegister := make(map[string]*tls.Config)
|
|
||||||
names := dsnPattern.SubexpNames()
|
|
||||||
|
|
||||||
uri := &URI{DBType: schemas.MYSQL}
|
|
||||||
|
|
||||||
for i, match := range matches {
|
|
||||||
switch names[i] {
|
|
||||||
case "dbname":
|
|
||||||
uri.DBName = match
|
|
||||||
case "params":
|
|
||||||
if len(match) > 0 {
|
|
||||||
kvs := strings.Split(match, "&")
|
|
||||||
for _, kv := range kvs {
|
|
||||||
splits := strings.Split(kv, "=")
|
|
||||||
if len(splits) == 2 {
|
|
||||||
switch splits[0] {
|
|
||||||
case "charset":
|
|
||||||
uri.Charset = splits[1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return uri, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ package dialects
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
@ -823,6 +824,7 @@ func (db *oracle) Filters() []Filter {
|
||||||
}
|
}
|
||||||
|
|
||||||
type godrorDriver struct {
|
type godrorDriver struct {
|
||||||
|
baseDriver
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
|
@ -848,7 +850,28 @@ func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error)
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *godrorDriver) GenScanResult(colType string) (interface{}, error) {
|
||||||
|
switch colType {
|
||||||
|
case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB":
|
||||||
|
var s sql.NullString
|
||||||
|
return &s, nil
|
||||||
|
case "NUMBER":
|
||||||
|
var s sql.NullString
|
||||||
|
return &s, nil
|
||||||
|
case "DATE":
|
||||||
|
var s sql.NullTime
|
||||||
|
return &s, nil
|
||||||
|
case "BLOB":
|
||||||
|
var r sql.RawBytes
|
||||||
|
return &r, nil
|
||||||
|
default:
|
||||||
|
var r sql.RawBytes
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type oci8Driver struct {
|
type oci8Driver struct {
|
||||||
|
godrorDriver
|
||||||
}
|
}
|
||||||
|
|
||||||
// dataSourceName=user/password@ipv4:port/dbname
|
// dataSourceName=user/password@ipv4:port/dbname
|
||||||
|
|
|
@ -6,6 +6,7 @@ package dialects
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -1056,12 +1057,13 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab
|
||||||
|
|
||||||
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
|
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, description,
|
||||||
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
|
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
|
||||||
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
|
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
|
||||||
FROM pg_attribute f
|
FROM pg_attribute f
|
||||||
JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid
|
JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid
|
||||||
LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum
|
LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum
|
||||||
|
LEFT JOIN pg_description de ON f.attrelid=de.objoid AND f.attnum=de.objsubid
|
||||||
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
|
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||||
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
|
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
|
||||||
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
|
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
|
||||||
|
@ -1090,9 +1092,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
|
||||||
col.Indexes = make(map[string]int)
|
col.Indexes = make(map[string]int)
|
||||||
|
|
||||||
var colName, isNullable, dataType string
|
var colName, isNullable, dataType string
|
||||||
var maxLenStr, colDefault *string
|
var maxLenStr, colDefault, description *string
|
||||||
var isPK, isUnique bool
|
var isPK, isUnique bool
|
||||||
err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &isPK, &isUnique)
|
err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &description, &isPK, &isUnique)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -1138,6 +1140,10 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
|
||||||
col.DefaultIsEmpty = true
|
col.DefaultIsEmpty = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if description != nil {
|
||||||
|
col.Comment = *description
|
||||||
|
}
|
||||||
|
|
||||||
if isPK {
|
if isPK {
|
||||||
col.IsPrimaryKey = true
|
col.IsPrimaryKey = true
|
||||||
}
|
}
|
||||||
|
@ -1305,6 +1311,13 @@ func (db *postgres) Filters() []Filter {
|
||||||
}
|
}
|
||||||
|
|
||||||
type pqDriver struct {
|
type pqDriver struct {
|
||||||
|
baseDriver
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *pqDriver) Features() DriverFeatures {
|
||||||
|
return DriverFeatures{
|
||||||
|
SupportNullable: false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type values map[string]string
|
type values map[string]string
|
||||||
|
@ -1381,6 +1394,36 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *pqDriver) GenScanResult(colType string) (interface{}, error) {
|
||||||
|
switch colType {
|
||||||
|
case "VARCHAR", "TEXT":
|
||||||
|
var s sql.NullString
|
||||||
|
return &s, nil
|
||||||
|
case "BIGINT":
|
||||||
|
var s sql.NullInt64
|
||||||
|
return &s, nil
|
||||||
|
case "TINYINT", "INT", "INT8", "INT4":
|
||||||
|
var s sql.NullInt32
|
||||||
|
return &s, nil
|
||||||
|
case "FLOAT", "FLOAT4":
|
||||||
|
var s sql.NullFloat64
|
||||||
|
return &s, nil
|
||||||
|
case "DATETIME", "TIMESTAMP":
|
||||||
|
var s sql.NullTime
|
||||||
|
return &s, nil
|
||||||
|
case "BIT":
|
||||||
|
var s sql.RawBytes
|
||||||
|
return &s, nil
|
||||||
|
case "BOOL":
|
||||||
|
var s sql.NullBool
|
||||||
|
return &s, nil
|
||||||
|
default:
|
||||||
|
fmt.Printf("unknow postgres database type: %v\n", colType)
|
||||||
|
var r sql.RawBytes
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type pqDriverPgx struct {
|
type pqDriverPgx struct {
|
||||||
pqDriver
|
pqDriver
|
||||||
}
|
}
|
||||||
|
|
|
@ -540,6 +540,7 @@ func (db *sqlite3) Filters() []Filter {
|
||||||
}
|
}
|
||||||
|
|
||||||
type sqlite3Driver struct {
|
type sqlite3Driver struct {
|
||||||
|
baseDriver
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
|
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
|
@ -549,3 +550,35 @@ func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
|
|
||||||
return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil
|
return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) {
|
||||||
|
switch colType {
|
||||||
|
case "TEXT":
|
||||||
|
var s sql.NullString
|
||||||
|
return &s, nil
|
||||||
|
case "INTEGER":
|
||||||
|
var s sql.NullInt64
|
||||||
|
return &s, nil
|
||||||
|
case "DATETIME":
|
||||||
|
var s sql.NullTime
|
||||||
|
return &s, nil
|
||||||
|
case "REAL":
|
||||||
|
var s sql.NullFloat64
|
||||||
|
return &s, nil
|
||||||
|
case "NUMERIC":
|
||||||
|
var s sql.NullString
|
||||||
|
return &s, nil
|
||||||
|
case "BLOB":
|
||||||
|
var s sql.RawBytes
|
||||||
|
return &s, nil
|
||||||
|
default:
|
||||||
|
var r sql.NullString
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *sqlite3Driver) Features() DriverFeatures {
|
||||||
|
return DriverFeatures{
|
||||||
|
SupportNullable: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
18
engine.go
18
engine.go
|
@ -35,6 +35,7 @@ type Engine struct {
|
||||||
cacherMgr *caches.Manager
|
cacherMgr *caches.Manager
|
||||||
defaultContext context.Context
|
defaultContext context.Context
|
||||||
dialect dialects.Dialect
|
dialect dialects.Dialect
|
||||||
|
driver dialects.Driver
|
||||||
engineGroup *EngineGroup
|
engineGroup *EngineGroup
|
||||||
logger log.ContextLogger
|
logger log.ContextLogger
|
||||||
tagParser *tags.Parser
|
tagParser *tags.Parser
|
||||||
|
@ -72,6 +73,7 @@ func newEngine(driverName, dataSourceName string, dialect dialects.Dialect, db *
|
||||||
|
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
dialect: dialect,
|
dialect: dialect,
|
||||||
|
driver: dialects.QueryDriver(driverName),
|
||||||
TZLocation: time.Local,
|
TZLocation: time.Local,
|
||||||
defaultContext: context.Background(),
|
defaultContext: context.Background(),
|
||||||
cacherMgr: cacherMgr,
|
cacherMgr: cacherMgr,
|
||||||
|
@ -444,7 +446,7 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
||||||
return engine.dumpTables(tables, w, tp...)
|
return engine.dumpTables(tables, w, tp...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string {
|
func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string {
|
||||||
if d == nil {
|
if d == nil {
|
||||||
return "NULL"
|
return "NULL"
|
||||||
}
|
}
|
||||||
|
@ -473,10 +475,8 @@ func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas.
|
||||||
|
|
||||||
return "'" + strings.Replace(v, "'", "''", -1) + "'"
|
return "'" + strings.Replace(v, "'", "''", -1) + "'"
|
||||||
} else if col.SQLType.IsTime() {
|
} else if col.SQLType.IsTime() {
|
||||||
if dstDialect.URI().DBType == schemas.MSSQL && col.SQLType.Name == schemas.DateTime {
|
|
||||||
if t, ok := d.(time.Time); ok {
|
if t, ok := d.(time.Time); ok {
|
||||||
return "'" + t.UTC().Format("2006-01-02 15:04:05") + "'"
|
return "'" + t.In(dbLocation).Format("2006-01-02 15:04:05") + "'"
|
||||||
}
|
|
||||||
}
|
}
|
||||||
var v = fmt.Sprintf("%s", d)
|
var v = fmt.Sprintf("%s", d)
|
||||||
if strings.HasSuffix(v, " +0000 UTC") {
|
if strings.HasSuffix(v, " +0000 UTC") {
|
||||||
|
@ -652,12 +652,8 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
||||||
return errors.New("unknown column error")
|
return errors.New("unknown column error")
|
||||||
}
|
}
|
||||||
|
|
||||||
fields := strings.Split(col.FieldName, ".")
|
field := dataStruct.FieldByIndex(col.FieldIndex)
|
||||||
field := dataStruct
|
temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, field.Interface(), col)
|
||||||
for _, fieldName := range fields {
|
|
||||||
field = field.FieldByName(fieldName)
|
|
||||||
}
|
|
||||||
temp += "," + formatColumnValue(dstDialect, field.Interface(), col)
|
|
||||||
}
|
}
|
||||||
_, err = io.WriteString(w, temp[1:]+");\n")
|
_, err = io.WriteString(w, temp[1:]+");\n")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -684,7 +680,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
||||||
return errors.New("unknow column error")
|
return errors.New("unknow column error")
|
||||||
}
|
}
|
||||||
|
|
||||||
temp += "," + formatColumnValue(dstDialect, d, col)
|
temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col)
|
||||||
}
|
}
|
||||||
_, err = io.WriteString(w, temp[1:]+");\n")
|
_, err = io.WriteString(w, temp[1:]+");\n")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -176,6 +176,23 @@ func TestDumpTables(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDumpTables2(t *testing.T) {
|
||||||
|
assert.NoError(t, PrepareEngine())
|
||||||
|
|
||||||
|
type TestDumpTableStruct2 struct {
|
||||||
|
Id int64
|
||||||
|
Created time.Time `xorm:"Default CURRENT_TIMESTAMP"`
|
||||||
|
}
|
||||||
|
|
||||||
|
assertSync(t, new(TestDumpTableStruct2))
|
||||||
|
|
||||||
|
fp := fmt.Sprintf("./dump2-%v-table.sql", testEngine.Dialect().URI().DBType)
|
||||||
|
os.Remove(fp)
|
||||||
|
tb, err := testEngine.TableInfo(new(TestDumpTableStruct2))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, fp))
|
||||||
|
}
|
||||||
|
|
||||||
func TestSetSchema(t *testing.T) {
|
func TestSetSchema(t *testing.T) {
|
||||||
assert.NoError(t, PrepareEngine())
|
assert.NoError(t, PrepareEngine())
|
||||||
|
|
||||||
|
@ -209,3 +226,39 @@ func TestDBVersion(t *testing.T) {
|
||||||
|
|
||||||
fmt.Println(testEngine.Dialect().URI().DBType, "version is", version)
|
fmt.Println(testEngine.Dialect().URI().DBType, "version is", version)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetColumns(t *testing.T) {
|
||||||
|
if testEngine.Dialect().URI().DBType != schemas.POSTGRES {
|
||||||
|
t.Skip()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
type TestCommentStruct struct {
|
||||||
|
HasComment int
|
||||||
|
NoComment int
|
||||||
|
}
|
||||||
|
|
||||||
|
assertSync(t, new(TestCommentStruct))
|
||||||
|
|
||||||
|
comment := "this is a comment"
|
||||||
|
sql := fmt.Sprintf("comment on column %s.%s is '%s'", testEngine.TableName(new(TestCommentStruct), true), "has_comment", comment)
|
||||||
|
_, err := testEngine.Exec(sql)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
tables, err := testEngine.DBMetas()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
tableName := testEngine.GetColumnMapper().Obj2Table("TestCommentStruct")
|
||||||
|
var hasComment, noComment string
|
||||||
|
for _, table := range tables {
|
||||||
|
if table.Name == tableName {
|
||||||
|
col := table.GetColumn("has_comment")
|
||||||
|
assert.NotNil(t, col)
|
||||||
|
hasComment = col.Comment
|
||||||
|
col2 := table.GetColumn("no_comment")
|
||||||
|
assert.NotNil(t, col2)
|
||||||
|
noComment = col2.Comment
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, comment, hasComment)
|
||||||
|
assert.Zero(t, noComment)
|
||||||
|
}
|
||||||
|
|
|
@ -406,16 +406,16 @@ func TestFindMapPtrString(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindBit(t *testing.T) {
|
func TestFindBool(t *testing.T) {
|
||||||
type FindBitStruct struct {
|
type FindBoolStruct struct {
|
||||||
Id int64
|
Id int64
|
||||||
Msg bool `xorm:"bit"`
|
Msg bool
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NoError(t, PrepareEngine())
|
assert.NoError(t, PrepareEngine())
|
||||||
assertSync(t, new(FindBitStruct))
|
assertSync(t, new(FindBoolStruct))
|
||||||
|
|
||||||
cnt, err := testEngine.Insert([]FindBitStruct{
|
cnt, err := testEngine.Insert([]FindBoolStruct{
|
||||||
{
|
{
|
||||||
Msg: false,
|
Msg: false,
|
||||||
},
|
},
|
||||||
|
@ -426,14 +426,13 @@ func TestFindBit(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, 2, cnt)
|
assert.EqualValues(t, 2, cnt)
|
||||||
|
|
||||||
var results = make([]FindBitStruct, 0, 2)
|
var results = make([]FindBoolStruct, 0, 2)
|
||||||
err = testEngine.Find(&results)
|
err = testEngine.Find(&results)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, 2, len(results))
|
assert.EqualValues(t, 2, len(results))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindMark(t *testing.T) {
|
func TestFindMark(t *testing.T) {
|
||||||
|
|
||||||
type Mark struct {
|
type Mark struct {
|
||||||
Mark1 string `xorm:"VARCHAR(1)"`
|
Mark1 string `xorm:"VARCHAR(1)"`
|
||||||
Mark2 string `xorm:"VARCHAR(1)"`
|
Mark2 string `xorm:"VARCHAR(1)"`
|
||||||
|
@ -468,7 +467,7 @@ func TestFindAndCountOneFunc(t *testing.T) {
|
||||||
type FindAndCountStruct struct {
|
type FindAndCountStruct struct {
|
||||||
Id int64
|
Id int64
|
||||||
Content string
|
Content string
|
||||||
Msg bool `xorm:"bit"`
|
Msg bool
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NoError(t, PrepareEngine())
|
assert.NoError(t, PrepareEngine())
|
||||||
|
|
|
@ -32,7 +32,6 @@ func TestInsertOne(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInsertMulti(t *testing.T) {
|
func TestInsertMulti(t *testing.T) {
|
||||||
|
|
||||||
assert.NoError(t, PrepareEngine())
|
assert.NoError(t, PrepareEngine())
|
||||||
type TestMulti struct {
|
type TestMulti struct {
|
||||||
Id int64 `xorm:"int(11) pk"`
|
Id int64 `xorm:"int(11) pk"`
|
||||||
|
@ -78,7 +77,6 @@ func insertMultiDatas(step int, datas interface{}) (num int64, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) error) (err error) {
|
func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) error) (err error) {
|
||||||
|
|
||||||
sliceValue := reflect.Indirect(reflect.ValueOf(datas))
|
sliceValue := reflect.Indirect(reflect.ValueOf(datas))
|
||||||
if sliceValue.Kind() != reflect.Slice {
|
if sliceValue.Kind() != reflect.Slice {
|
||||||
return fmt.Errorf("not slice")
|
return fmt.Errorf("not slice")
|
||||||
|
@ -170,16 +168,16 @@ func TestInsertAutoIncr(t *testing.T) {
|
||||||
assert.Greater(t, user.Uid, int64(0))
|
assert.Greater(t, user.Uid, int64(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultInsert struct {
|
func TestInsertDefault(t *testing.T) {
|
||||||
|
assert.NoError(t, PrepareEngine())
|
||||||
|
|
||||||
|
type DefaultInsert struct {
|
||||||
Id int64
|
Id int64
|
||||||
Status int `xorm:"default -1"`
|
Status int `xorm:"default -1"`
|
||||||
Name string
|
Name string
|
||||||
Created time.Time `xorm:"created"`
|
Created time.Time `xorm:"created"`
|
||||||
Updated time.Time `xorm:"updated"`
|
Updated time.Time `xorm:"updated"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInsertDefault(t *testing.T) {
|
|
||||||
assert.NoError(t, PrepareEngine())
|
|
||||||
|
|
||||||
di := new(DefaultInsert)
|
di := new(DefaultInsert)
|
||||||
err := testEngine.Sync2(di)
|
err := testEngine.Sync2(di)
|
||||||
|
@ -197,15 +195,15 @@ func TestInsertDefault(t *testing.T) {
|
||||||
assert.EqualValues(t, di2.Created.Unix(), di.Created.Unix())
|
assert.EqualValues(t, di2.Created.Unix(), di.Created.Unix())
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultInsert2 struct {
|
func TestInsertDefault2(t *testing.T) {
|
||||||
|
assert.NoError(t, PrepareEngine())
|
||||||
|
|
||||||
|
type DefaultInsert2 struct {
|
||||||
Id int64
|
Id int64
|
||||||
Name string
|
Name string
|
||||||
Url string `xorm:"text"`
|
Url string `xorm:"text"`
|
||||||
CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00' TIMESTAMP"`
|
CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00' TIMESTAMP"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInsertDefault2(t *testing.T) {
|
|
||||||
assert.NoError(t, PrepareEngine())
|
|
||||||
|
|
||||||
di := new(DefaultInsert2)
|
di := new(DefaultInsert2)
|
||||||
err := testEngine.Sync2(di)
|
err := testEngine.Sync2(di)
|
||||||
|
@ -1026,3 +1024,44 @@ func TestInsertIntSlice(t *testing.T) {
|
||||||
assert.True(t, has)
|
assert.True(t, has)
|
||||||
assert.EqualValues(t, v3, v4)
|
assert.EqualValues(t, v3, v4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInsertDeleted(t *testing.T) {
|
||||||
|
assert.NoError(t, PrepareEngine())
|
||||||
|
|
||||||
|
type InsertDeletedStructNotRight struct {
|
||||||
|
ID uint64 `xorm:"'ID' pk autoincr"`
|
||||||
|
DeletedAt time.Time `xorm:"'DELETED_AT' deleted notnull"`
|
||||||
|
}
|
||||||
|
// notnull tag will be ignored
|
||||||
|
err := testEngine.Sync2(new(InsertDeletedStructNotRight))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
type InsertDeletedStruct struct {
|
||||||
|
ID uint64 `xorm:"'ID' pk autoincr"`
|
||||||
|
DeletedAt time.Time `xorm:"'DELETED_AT' deleted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, testEngine.Sync2(new(InsertDeletedStruct)))
|
||||||
|
|
||||||
|
var v InsertDeletedStruct
|
||||||
|
_, err = testEngine.Insert(&v)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var v2 InsertDeletedStruct
|
||||||
|
has, err := testEngine.Get(&v2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, has)
|
||||||
|
|
||||||
|
_, err = testEngine.ID(v.ID).Delete(new(InsertDeletedStruct))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var v3 InsertDeletedStruct
|
||||||
|
has, err = testEngine.Get(&v3)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, has)
|
||||||
|
|
||||||
|
var v4 InsertDeletedStruct
|
||||||
|
has, err = testEngine.Unscoped().Get(&v4)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, has)
|
||||||
|
}
|
||||||
|
|
|
@ -52,7 +52,7 @@ func TestQueryString2(t *testing.T) {
|
||||||
|
|
||||||
type GetVar3 struct {
|
type GetVar3 struct {
|
||||||
Id int64 `xorm:"autoincr pk"`
|
Id int64 `xorm:"autoincr pk"`
|
||||||
Msg bool `xorm:"bit"`
|
Msg bool
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync2(new(GetVar3)))
|
assert.NoError(t, testEngine.Sync2(new(GetVar3)))
|
||||||
|
@ -107,6 +107,16 @@ func toFloat64(i interface{}) float64 {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toBool(i interface{}) bool {
|
||||||
|
switch t := i.(type) {
|
||||||
|
case int32:
|
||||||
|
return t > 0
|
||||||
|
case bool:
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func TestQueryInterface(t *testing.T) {
|
func TestQueryInterface(t *testing.T) {
|
||||||
assert.NoError(t, PrepareEngine())
|
assert.NoError(t, PrepareEngine())
|
||||||
|
|
||||||
|
@ -132,10 +142,10 @@ func TestQueryInterface(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(records))
|
assert.Equal(t, 1, len(records))
|
||||||
assert.Equal(t, 5, len(records[0]))
|
assert.Equal(t, 5, len(records[0]))
|
||||||
assert.EqualValues(t, 1, toInt64(records[0]["id"]))
|
assert.EqualValues(t, int64(1), records[0]["id"])
|
||||||
assert.Equal(t, "hi", toString(records[0]["msg"]))
|
assert.Equal(t, "hi", records[0]["msg"])
|
||||||
assert.EqualValues(t, 28, toInt64(records[0]["age"]))
|
assert.EqualValues(t, 28, records[0]["age"])
|
||||||
assert.EqualValues(t, 1.5, toFloat64(records[0]["money"]))
|
assert.EqualValues(t, 1.5, records[0]["money"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryNoParams(t *testing.T) {
|
func TestQueryNoParams(t *testing.T) {
|
||||||
|
@ -192,7 +202,7 @@ func TestQueryStringNoParam(t *testing.T) {
|
||||||
|
|
||||||
type GetVar4 struct {
|
type GetVar4 struct {
|
||||||
Id int64 `xorm:"autoincr pk"`
|
Id int64 `xorm:"autoincr pk"`
|
||||||
Msg bool `xorm:"bit"`
|
Msg bool
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync2(new(GetVar4)))
|
assert.NoError(t, testEngine.Sync2(new(GetVar4)))
|
||||||
|
@ -229,7 +239,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
|
||||||
|
|
||||||
type GetVar6 struct {
|
type GetVar6 struct {
|
||||||
Id int64 `xorm:"autoincr pk"`
|
Id int64 `xorm:"autoincr pk"`
|
||||||
Msg bool `xorm:"bit"`
|
Msg bool
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync2(new(GetVar6)))
|
assert.NoError(t, testEngine.Sync2(new(GetVar6)))
|
||||||
|
@ -266,7 +276,7 @@ func TestQueryInterfaceNoParam(t *testing.T) {
|
||||||
|
|
||||||
type GetVar5 struct {
|
type GetVar5 struct {
|
||||||
Id int64 `xorm:"autoincr pk"`
|
Id int64 `xorm:"autoincr pk"`
|
||||||
Msg bool `xorm:"bit"`
|
Msg bool
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync2(new(GetVar5)))
|
assert.NoError(t, testEngine.Sync2(new(GetVar5)))
|
||||||
|
@ -280,14 +290,14 @@ func TestQueryInterfaceNoParam(t *testing.T) {
|
||||||
records, err := testEngine.Table("get_var5").Limit(1).QueryInterface()
|
records, err := testEngine.Table("get_var5").Limit(1).QueryInterface()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, 1, len(records))
|
assert.EqualValues(t, 1, len(records))
|
||||||
assert.EqualValues(t, 1, toInt64(records[0]["id"]))
|
assert.EqualValues(t, 1, records[0]["id"])
|
||||||
assert.EqualValues(t, 0, toInt64(records[0]["msg"]))
|
assert.False(t, toBool(records[0]["msg"]))
|
||||||
|
|
||||||
records, err = testEngine.Table("get_var5").Where(builder.Eq{"id": 1}).QueryInterface()
|
records, err = testEngine.Table("get_var5").Where(builder.Eq{"id": 1}).QueryInterface()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, 1, len(records))
|
assert.EqualValues(t, 1, len(records))
|
||||||
assert.EqualValues(t, 1, toInt64(records[0]["id"]))
|
assert.EqualValues(t, 1, records[0]["id"])
|
||||||
assert.EqualValues(t, 0, toInt64(records[0]["msg"]))
|
assert.False(t, toBool(records[0]["msg"]))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryWithBuilder(t *testing.T) {
|
func TestQueryWithBuilder(t *testing.T) {
|
||||||
|
|
|
@ -472,6 +472,11 @@ func TestUpdateIncrDecr(t *testing.T) {
|
||||||
cnt, err = testEngine.ID(col1.Id).Cols(colName).Incr(colName).Update(col1)
|
cnt, err = testEngine.ID(col1.Id).Cols(colName).Incr(colName).Update(col1)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, 1, cnt)
|
assert.EqualValues(t, 1, cnt)
|
||||||
|
|
||||||
|
testEngine.SetColumnMapper(testEngine.GetColumnMapper())
|
||||||
|
cnt, err = testEngine.Cols(colName).Decr(colName, 2).ID(col1.Id).Update(new(UpdateIncr))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, cnt)
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdatedUpdate struct {
|
type UpdatedUpdate struct {
|
||||||
|
|
|
@ -27,6 +27,7 @@ type Expr struct {
|
||||||
Arg interface{}
|
Arg interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteArgs writes args to the writer
|
||||||
func (expr *Expr) WriteArgs(w *builder.BytesWriter) error {
|
func (expr *Expr) WriteArgs(w *builder.BytesWriter) error {
|
||||||
switch arg := expr.Arg.(type) {
|
switch arg := expr.Arg.(type) {
|
||||||
case *builder.Builder:
|
case *builder.Builder:
|
||||||
|
|
|
@ -17,7 +17,7 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem
|
||||||
if _, err := buf.WriteString(" OUTPUT Inserted."); err != nil {
|
if _, err := buf.WriteString(" OUTPUT Inserted."); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := buf.WriteString(table.AutoIncrement); err != nil {
|
if err := statement.dialect.Quoter().QuoteTo(buf, table.AutoIncrement); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -343,7 +343,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
var joinStr string
|
var joinStr string
|
||||||
var err error
|
var err error
|
||||||
var b interface{} = nil
|
var b interface{}
|
||||||
if len(bean) > 0 {
|
if len(bean) > 0 {
|
||||||
b = bean[0]
|
b = bean[0]
|
||||||
beanValue := reflect.ValueOf(bean[0])
|
beanValue := reflect.ValueOf(bean[0])
|
||||||
|
|
|
@ -208,20 +208,18 @@ func (statement *Statement) quote(s string) string {
|
||||||
|
|
||||||
// And add Where & and statement
|
// And add Where & and statement
|
||||||
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
|
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
|
||||||
switch query.(type) {
|
switch qr := query.(type) {
|
||||||
case string:
|
case string:
|
||||||
cond := builder.Expr(query.(string), args...)
|
cond := builder.Expr(qr, args...)
|
||||||
statement.cond = statement.cond.And(cond)
|
statement.cond = statement.cond.And(cond)
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
queryMap := query.(map[string]interface{})
|
cond := make(builder.Eq)
|
||||||
newMap := make(map[string]interface{})
|
for k, v := range qr {
|
||||||
for k, v := range queryMap {
|
cond[statement.quote(k)] = v
|
||||||
newMap[statement.quote(k)] = v
|
|
||||||
}
|
}
|
||||||
statement.cond = statement.cond.And(builder.Eq(newMap))
|
|
||||||
case builder.Cond:
|
|
||||||
cond := query.(builder.Cond)
|
|
||||||
statement.cond = statement.cond.And(cond)
|
statement.cond = statement.cond.And(cond)
|
||||||
|
case builder.Cond:
|
||||||
|
statement.cond = statement.cond.And(qr)
|
||||||
for _, v := range args {
|
for _, v := range args {
|
||||||
if vv, ok := v.(builder.Cond); ok {
|
if vv, ok := v.(builder.Cond); ok {
|
||||||
statement.cond = statement.cond.And(vv)
|
statement.cond = statement.cond.And(vv)
|
||||||
|
@ -236,23 +234,25 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
|
||||||
|
|
||||||
// Or add Where & Or statement
|
// Or add Where & Or statement
|
||||||
func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
|
func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
|
||||||
switch query.(type) {
|
switch qr := query.(type) {
|
||||||
case string:
|
case string:
|
||||||
cond := builder.Expr(query.(string), args...)
|
cond := builder.Expr(qr, args...)
|
||||||
statement.cond = statement.cond.Or(cond)
|
statement.cond = statement.cond.Or(cond)
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
cond := builder.Eq(query.(map[string]interface{}))
|
cond := make(builder.Eq)
|
||||||
|
for k, v := range qr {
|
||||||
|
cond[statement.quote(k)] = v
|
||||||
|
}
|
||||||
statement.cond = statement.cond.Or(cond)
|
statement.cond = statement.cond.Or(cond)
|
||||||
case builder.Cond:
|
case builder.Cond:
|
||||||
cond := query.(builder.Cond)
|
statement.cond = statement.cond.Or(qr)
|
||||||
statement.cond = statement.cond.Or(cond)
|
|
||||||
for _, v := range args {
|
for _, v := range args {
|
||||||
if vv, ok := v.(builder.Cond); ok {
|
if vv, ok := v.(builder.Cond); ok {
|
||||||
statement.cond = statement.cond.Or(vv)
|
statement.cond = statement.cond.Or(vv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// TODO: not support condition type
|
statement.LastError = ErrConditionType
|
||||||
}
|
}
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
@ -734,6 +734,8 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
|
||||||
//engine.logger.Warn(err)
|
//engine.logger.Warn(err)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
} else if fieldValuePtr == nil {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if col.IsDeleted && !unscoped { // tag "deleted" is enabled
|
if col.IsDeleted && !unscoped { // tag "deleted" is enabled
|
||||||
|
@ -976,7 +978,7 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName
|
||||||
|
|
||||||
// CondDeleted returns the conditions whether a record is soft deleted.
|
// CondDeleted returns the conditions whether a record is soft deleted.
|
||||||
func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond {
|
func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond {
|
||||||
var colName = col.Name
|
var colName = statement.quote(col.Name)
|
||||||
if statement.JoinStr != "" {
|
if statement.JoinStr != "" {
|
||||||
var prefix string
|
var prefix string
|
||||||
if statement.TableAlias != "" {
|
if statement.TableAlias != "" {
|
||||||
|
|
|
@ -78,7 +78,6 @@ func TestColumnsStringGeneration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
|
func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
|
||||||
|
|
||||||
b.StopTimer()
|
b.StopTimer()
|
||||||
|
|
||||||
mapCols := make(map[string]bool)
|
mapCols := make(map[string]bool)
|
||||||
|
@ -101,9 +100,7 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
|
||||||
b.StartTimer()
|
b.StartTimer()
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
|
||||||
for _, col := range cols {
|
for _, col := range cols {
|
||||||
|
|
||||||
if _, ok := getFlagForColumn(mapCols, col); !ok {
|
if _, ok := getFlagForColumn(mapCols, col); !ok {
|
||||||
b.Fatal("Unexpected result")
|
b.Fatal("Unexpected result")
|
||||||
}
|
}
|
||||||
|
@ -112,7 +109,6 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) {
|
func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) {
|
||||||
|
|
||||||
b.StopTimer()
|
b.StopTimer()
|
||||||
|
|
||||||
mapCols := make(map[string]bool)
|
mapCols := make(map[string]bool)
|
||||||
|
@ -131,9 +127,7 @@ func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) {
|
||||||
b.StartTimer()
|
b.StartTimer()
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
|
||||||
for _, col := range cols {
|
for _, col := range cols {
|
||||||
|
|
||||||
if _, ok := getFlagForColumn(mapCols, col); ok {
|
if _, ok := getFlagForColumn(mapCols, col); ok {
|
||||||
b.Fatal("Unexpected result")
|
b.Fatal("Unexpected result")
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,6 +88,9 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
if fieldValuePtr == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
fieldValue := *fieldValuePtr
|
fieldValue := *fieldValuePtr
|
||||||
fieldType := reflect.TypeOf(fieldValue.Interface())
|
fieldType := reflect.TypeOf(fieldValue.Interface())
|
||||||
|
|
|
@ -132,7 +132,6 @@ func (s *SimpleLogger) Error(v ...interface{}) {
|
||||||
if s.level <= LOG_ERR {
|
if s.level <= LOG_ERR {
|
||||||
s.ERR.Output(2, fmt.Sprintln(v...))
|
s.ERR.Output(2, fmt.Sprintln(v...))
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Errorf implement ILogger
|
// Errorf implement ILogger
|
||||||
|
@ -140,7 +139,6 @@ func (s *SimpleLogger) Errorf(format string, v ...interface{}) {
|
||||||
if s.level <= LOG_ERR {
|
if s.level <= LOG_ERR {
|
||||||
s.ERR.Output(2, fmt.Sprintf(format, v...))
|
s.ERR.Output(2, fmt.Sprintf(format, v...))
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debug implement ILogger
|
// Debug implement ILogger
|
||||||
|
@ -148,7 +146,6 @@ func (s *SimpleLogger) Debug(v ...interface{}) {
|
||||||
if s.level <= LOG_DEBUG {
|
if s.level <= LOG_DEBUG {
|
||||||
s.DEBUG.Output(2, fmt.Sprintln(v...))
|
s.DEBUG.Output(2, fmt.Sprintln(v...))
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debugf implement ILogger
|
// Debugf implement ILogger
|
||||||
|
@ -156,7 +153,6 @@ func (s *SimpleLogger) Debugf(format string, v ...interface{}) {
|
||||||
if s.level <= LOG_DEBUG {
|
if s.level <= LOG_DEBUG {
|
||||||
s.DEBUG.Output(2, fmt.Sprintf(format, v...))
|
s.DEBUG.Output(2, fmt.Sprintf(format, v...))
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Info implement ILogger
|
// Info implement ILogger
|
||||||
|
@ -164,7 +160,6 @@ func (s *SimpleLogger) Info(v ...interface{}) {
|
||||||
if s.level <= LOG_INFO {
|
if s.level <= LOG_INFO {
|
||||||
s.INFO.Output(2, fmt.Sprintln(v...))
|
s.INFO.Output(2, fmt.Sprintln(v...))
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Infof implement ILogger
|
// Infof implement ILogger
|
||||||
|
@ -172,7 +167,6 @@ func (s *SimpleLogger) Infof(format string, v ...interface{}) {
|
||||||
if s.level <= LOG_INFO {
|
if s.level <= LOG_INFO {
|
||||||
s.INFO.Output(2, fmt.Sprintf(format, v...))
|
s.INFO.Output(2, fmt.Sprintf(format, v...))
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warn implement ILogger
|
// Warn implement ILogger
|
||||||
|
@ -180,7 +174,6 @@ func (s *SimpleLogger) Warn(v ...interface{}) {
|
||||||
if s.level <= LOG_WARNING {
|
if s.level <= LOG_WARNING {
|
||||||
s.WARN.Output(2, fmt.Sprintln(v...))
|
s.WARN.Output(2, fmt.Sprintln(v...))
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warnf implement ILogger
|
// Warnf implement ILogger
|
||||||
|
@ -188,7 +181,6 @@ func (s *SimpleLogger) Warnf(format string, v ...interface{}) {
|
||||||
if s.level <= LOG_WARNING {
|
if s.level <= LOG_WARNING {
|
||||||
s.WARN.Output(2, fmt.Sprintf(format, v...))
|
s.WARN.Output(2, fmt.Sprintf(format, v...))
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Level implement ILogger
|
// Level implement ILogger
|
||||||
|
@ -199,7 +191,6 @@ func (s *SimpleLogger) Level() LogLevel {
|
||||||
// SetLevel implement ILogger
|
// SetLevel implement ILogger
|
||||||
func (s *SimpleLogger) SetLevel(l LogLevel) {
|
func (s *SimpleLogger) SetLevel(l LogLevel) {
|
||||||
s.level = l
|
s.level = l
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShowSQL implement ILogger
|
// ShowSQL implement ILogger
|
||||||
|
|
|
@ -0,0 +1,303 @@
|
||||||
|
// Copyright 2021 The Xorm Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package xorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"xorm.io/xorm/convert"
|
||||||
|
"xorm.io/xorm/core"
|
||||||
|
"xorm.io/xorm/dialects"
|
||||||
|
)
|
||||||
|
|
||||||
|
// genScanResultsByBeanNullabale generates scan result
|
||||||
|
func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) {
|
||||||
|
switch t := bean.(type) {
|
||||||
|
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes:
|
||||||
|
return t, false, nil
|
||||||
|
case *time.Time:
|
||||||
|
return &sql.NullTime{}, true, nil
|
||||||
|
case *string:
|
||||||
|
return &sql.NullString{}, true, nil
|
||||||
|
case *int, *int8, *int16, *int32:
|
||||||
|
return &sql.NullInt32{}, true, nil
|
||||||
|
case *int64:
|
||||||
|
return &sql.NullInt64{}, true, nil
|
||||||
|
case *uint, *uint8, *uint16, *uint32:
|
||||||
|
return &NullUint32{}, true, nil
|
||||||
|
case *uint64:
|
||||||
|
return &NullUint64{}, true, nil
|
||||||
|
case *float32, *float64:
|
||||||
|
return &sql.NullFloat64{}, true, nil
|
||||||
|
case *bool:
|
||||||
|
return &sql.NullBool{}, true, nil
|
||||||
|
case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString,
|
||||||
|
time.Time,
|
||||||
|
string,
|
||||||
|
int, int8, int16, int32, int64,
|
||||||
|
uint, uint8, uint16, uint32, uint64,
|
||||||
|
float32, float64,
|
||||||
|
bool:
|
||||||
|
return nil, false, fmt.Errorf("unsupported scan type: %t", t)
|
||||||
|
case convert.Conversion:
|
||||||
|
return &sql.RawBytes{}, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tp := reflect.TypeOf(bean).Elem()
|
||||||
|
switch tp.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
return &sql.NullString{}, true, nil
|
||||||
|
case reflect.Int64:
|
||||||
|
return &sql.NullInt64{}, true, nil
|
||||||
|
case reflect.Int32, reflect.Int, reflect.Int16, reflect.Int8:
|
||||||
|
return &sql.NullInt32{}, true, nil
|
||||||
|
case reflect.Uint64:
|
||||||
|
return &NullUint64{}, true, nil
|
||||||
|
case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8:
|
||||||
|
return &NullUint32{}, true, nil
|
||||||
|
default:
|
||||||
|
return nil, false, fmt.Errorf("unsupported type: %#v", bean)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func genScanResultsByBean(bean interface{}) (interface{}, bool, error) {
|
||||||
|
switch t := bean.(type) {
|
||||||
|
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString,
|
||||||
|
*string,
|
||||||
|
*int, *int8, *int16, *int32, *int64,
|
||||||
|
*uint, *uint8, *uint16, *uint32, *uint64,
|
||||||
|
*float32, *float64,
|
||||||
|
*bool:
|
||||||
|
return t, false, nil
|
||||||
|
case *time.Time:
|
||||||
|
return &sql.NullTime{}, true, nil
|
||||||
|
case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString,
|
||||||
|
time.Time,
|
||||||
|
string,
|
||||||
|
int, int8, int16, int32, int64,
|
||||||
|
uint, uint8, uint16, uint32, uint64,
|
||||||
|
bool:
|
||||||
|
return nil, false, fmt.Errorf("unsupported scan type: %t", t)
|
||||||
|
case convert.Conversion:
|
||||||
|
return &sql.RawBytes{}, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tp := reflect.TypeOf(bean).Elem()
|
||||||
|
switch tp.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
return new(string), true, nil
|
||||||
|
case reflect.Int64:
|
||||||
|
return new(int64), true, nil
|
||||||
|
case reflect.Int32:
|
||||||
|
return new(int32), true, nil
|
||||||
|
case reflect.Int:
|
||||||
|
return new(int32), true, nil
|
||||||
|
case reflect.Int16:
|
||||||
|
return new(int32), true, nil
|
||||||
|
case reflect.Int8:
|
||||||
|
return new(int32), true, nil
|
||||||
|
case reflect.Uint64:
|
||||||
|
return new(uint64), true, nil
|
||||||
|
case reflect.Uint32:
|
||||||
|
return new(uint32), true, nil
|
||||||
|
case reflect.Uint:
|
||||||
|
return new(uint), true, nil
|
||||||
|
case reflect.Uint16:
|
||||||
|
return new(uint16), true, nil
|
||||||
|
case reflect.Uint8:
|
||||||
|
return new(uint8), true, nil
|
||||||
|
case reflect.Float32:
|
||||||
|
return new(float32), true, nil
|
||||||
|
case reflect.Float64:
|
||||||
|
return new(float64), true, nil
|
||||||
|
default:
|
||||||
|
return nil, false, fmt.Errorf("unsupported type: %#v", bean)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) {
|
||||||
|
var scanResults = make([]interface{}, len(fields))
|
||||||
|
for i := 0; i < len(fields); i++ {
|
||||||
|
var s sql.NullString
|
||||||
|
scanResults[i] = &s
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Scan(scanResults...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(map[string]string, len(fields))
|
||||||
|
for ii, key := range fields {
|
||||||
|
s := scanResults[ii].(*sql.NullString)
|
||||||
|
result[key] = s.String
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string][]byte, error) {
|
||||||
|
var scanResults = make([]interface{}, len(fields))
|
||||||
|
for i := 0; i < len(fields); i++ {
|
||||||
|
var s sql.NullString
|
||||||
|
scanResults[i] = &s
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Scan(scanResults...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(map[string][]byte, len(fields))
|
||||||
|
for ii, key := range fields {
|
||||||
|
s := scanResults[ii].(*sql.NullString)
|
||||||
|
result[key] = []byte(s.String)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) {
|
||||||
|
var scanResults = make([]interface{}, len(types))
|
||||||
|
for i := 0; i < len(types); i++ {
|
||||||
|
var s sql.NullString
|
||||||
|
scanResults[i] = &s
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := engine.driver.Scan(&dialects.ScanContext{
|
||||||
|
DBLocation: engine.DatabaseTZ,
|
||||||
|
UserLocation: engine.TZLocation,
|
||||||
|
}, rows, types, scanResults...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return scanResults, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scan is a wrap of driver.Scan but will automatically change the input values according requirements
|
||||||
|
func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error {
|
||||||
|
var scanResults = make([]interface{}, 0, len(types))
|
||||||
|
var replaces = make([]bool, 0, len(types))
|
||||||
|
var err error
|
||||||
|
for _, v := range vv {
|
||||||
|
var replaced bool
|
||||||
|
var scanResult interface{}
|
||||||
|
if _, ok := v.(sql.Scanner); !ok {
|
||||||
|
var useNullable = true
|
||||||
|
if engine.driver.Features().SupportNullable {
|
||||||
|
nullable, ok := types[0].Nullable()
|
||||||
|
useNullable = ok && nullable
|
||||||
|
}
|
||||||
|
|
||||||
|
if useNullable {
|
||||||
|
scanResult, replaced, err = genScanResultsByBeanNullable(v)
|
||||||
|
} else {
|
||||||
|
scanResult, replaced, err = genScanResultsByBean(v)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scanResult = v
|
||||||
|
}
|
||||||
|
scanResults = append(scanResults, scanResult)
|
||||||
|
replaces = append(replaces, replaced)
|
||||||
|
}
|
||||||
|
|
||||||
|
var scanCtx = dialects.ScanContext{
|
||||||
|
DBLocation: engine.DatabaseTZ,
|
||||||
|
UserLocation: engine.TZLocation,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, replaced := range replaces {
|
||||||
|
if replaced {
|
||||||
|
if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) {
|
||||||
|
var scanResultContainers = make([]interface{}, len(types))
|
||||||
|
for i := 0; i < len(types); i++ {
|
||||||
|
scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
scanResultContainers[i] = scanResult
|
||||||
|
}
|
||||||
|
if err := engine.driver.Scan(&dialects.ScanContext{
|
||||||
|
DBLocation: engine.DatabaseTZ,
|
||||||
|
UserLocation: engine.TZLocation,
|
||||||
|
}, rows, types, scanResultContainers...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return scanResultContainers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) {
|
||||||
|
scanResults, err := engine.scanStringInterface(rows, types)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var results = make([]string, 0, len(fields))
|
||||||
|
for i := 0; i < len(fields); i++ {
|
||||||
|
results = append(results, scanResults[i].(*sql.NullString).String)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
|
||||||
|
fields, err := rows.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
types, err := rows.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for rows.Next() {
|
||||||
|
result, err := row2mapBytes(rows, types, fields)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resultsSlice = append(resultsSlice, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultsSlice, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]interface{}, error) {
|
||||||
|
var resultsMap = make(map[string]interface{}, len(fields))
|
||||||
|
var scanResultContainers = make([]interface{}, len(fields))
|
||||||
|
for i := 0; i < len(fields); i++ {
|
||||||
|
scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
scanResultContainers[i] = scanResult
|
||||||
|
}
|
||||||
|
if err := engine.driver.Scan(&dialects.ScanContext{
|
||||||
|
DBLocation: engine.DatabaseTZ,
|
||||||
|
UserLocation: engine.TZLocation,
|
||||||
|
}, rows, types, scanResultContainers...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for ii, key := range fields {
|
||||||
|
res, err := convert.Interface2Interface(engine.TZLocation, scanResultContainers[ii])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resultsMap[key] = res
|
||||||
|
}
|
||||||
|
return resultsMap, nil
|
||||||
|
}
|
|
@ -6,10 +6,8 @@ package schemas
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -25,6 +23,7 @@ type Column struct {
|
||||||
Name string
|
Name string
|
||||||
TableName string
|
TableName string
|
||||||
FieldName string // Available only when parsed from a struct
|
FieldName string // Available only when parsed from a struct
|
||||||
|
FieldIndex []int // Available only when parsed from a struct
|
||||||
SQLType SQLType
|
SQLType SQLType
|
||||||
IsJSON bool
|
IsJSON bool
|
||||||
Length int
|
Length int
|
||||||
|
@ -83,41 +82,17 @@ func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) {
|
||||||
|
|
||||||
// ValueOfV returns column's filed of struct's value accept reflevt value
|
// ValueOfV returns column's filed of struct's value accept reflevt value
|
||||||
func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
|
func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
|
||||||
var fieldValue reflect.Value
|
var v = *dataStruct
|
||||||
fieldPath := strings.Split(col.FieldName, ".")
|
for _, i := range col.FieldIndex {
|
||||||
|
if v.Kind() == reflect.Ptr {
|
||||||
if dataStruct.Type().Kind() == reflect.Map {
|
if v.IsNil() {
|
||||||
keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1])
|
v.Set(reflect.New(v.Type().Elem()))
|
||||||
fieldValue = dataStruct.MapIndex(keyValue)
|
|
||||||
return &fieldValue, nil
|
|
||||||
} else if dataStruct.Type().Kind() == reflect.Interface {
|
|
||||||
structValue := reflect.ValueOf(dataStruct.Interface())
|
|
||||||
dataStruct = &structValue
|
|
||||||
}
|
}
|
||||||
|
v = v.Elem()
|
||||||
level := len(fieldPath)
|
|
||||||
fieldValue = dataStruct.FieldByName(fieldPath[0])
|
|
||||||
for i := 0; i < level-1; i++ {
|
|
||||||
if !fieldValue.IsValid() {
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
if fieldValue.Kind() == reflect.Struct {
|
v = v.FieldByIndex([]int{i})
|
||||||
fieldValue = fieldValue.FieldByName(fieldPath[i+1])
|
|
||||||
} else if fieldValue.Kind() == reflect.Ptr {
|
|
||||||
if fieldValue.IsNil() {
|
|
||||||
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
|
|
||||||
}
|
}
|
||||||
fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1])
|
return &v, nil
|
||||||
} else {
|
|
||||||
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !fieldValue.IsValid() {
|
|
||||||
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &fieldValue, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertID converts id content to suitable type according column type
|
// ConvertID converts id content to suitable type according column type
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
package schemas
|
package schemas
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -159,24 +158,8 @@ func (table *Table) IDOfV(rv reflect.Value) (PK, error) {
|
||||||
for i, col := range table.PKColumns() {
|
for i, col := range table.PKColumns() {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
fieldName := col.FieldName
|
pkField := v.FieldByIndex(col.FieldIndex)
|
||||||
for {
|
|
||||||
parts := strings.SplitN(fieldName, ".", 2)
|
|
||||||
if len(parts) == 1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
v = v.FieldByName(parts[0])
|
|
||||||
if v.Kind() == reflect.Ptr {
|
|
||||||
v = v.Elem()
|
|
||||||
}
|
|
||||||
if v.Kind() != reflect.Struct {
|
|
||||||
return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName)
|
|
||||||
}
|
|
||||||
fieldName = parts[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
pkField := v.FieldByName(fieldName)
|
|
||||||
switch pkField.Kind() {
|
switch pkField.Kind() {
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
pk[i], err = col.ConvertID(pkField.String())
|
pk[i], err = col.ConvertID(pkField.String())
|
||||||
|
|
|
@ -27,7 +27,6 @@ var testsGetColumn = []struct {
|
||||||
var table *Table
|
var table *Table
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
||||||
table = NewEmptyTable()
|
table = NewEmptyTable()
|
||||||
|
|
||||||
var name string
|
var name string
|
||||||
|
@ -41,7 +40,6 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetColumn(t *testing.T) {
|
func TestGetColumn(t *testing.T) {
|
||||||
|
|
||||||
for _, test := range testsGetColumn {
|
for _, test := range testsGetColumn {
|
||||||
if table.GetColumn(test.name) == nil {
|
if table.GetColumn(test.name) == nil {
|
||||||
t.Error("Column not found!")
|
t.Error("Column not found!")
|
||||||
|
@ -50,7 +48,6 @@ func TestGetColumn(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetColumnIdx(t *testing.T) {
|
func TestGetColumnIdx(t *testing.T) {
|
||||||
|
|
||||||
for _, test := range testsGetColumn {
|
for _, test := range testsGetColumn {
|
||||||
if table.GetColumnIdx(test.name, test.idx) == nil {
|
if table.GetColumnIdx(test.name, test.idx) == nil {
|
||||||
t.Errorf("Column %s with idx %d not found!", test.name, test.idx)
|
t.Errorf("Column %s with idx %d not found!", test.name, test.idx)
|
||||||
|
@ -59,7 +56,6 @@ func TestGetColumnIdx(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkGetColumnWithToLower(b *testing.B) {
|
func BenchmarkGetColumnWithToLower(b *testing.B) {
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, test := range testsGetColumn {
|
for _, test := range testsGetColumn {
|
||||||
|
|
||||||
|
@ -71,7 +67,6 @@ func BenchmarkGetColumnWithToLower(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkGetColumnIdxWithToLower(b *testing.B) {
|
func BenchmarkGetColumnIdxWithToLower(b *testing.B) {
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, test := range testsGetColumn {
|
for _, test := range testsGetColumn {
|
||||||
|
|
||||||
|
@ -89,7 +84,6 @@ func BenchmarkGetColumnIdxWithToLower(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkGetColumn(b *testing.B) {
|
func BenchmarkGetColumn(b *testing.B) {
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, test := range testsGetColumn {
|
for _, test := range testsGetColumn {
|
||||||
if table.GetColumn(test.name) == nil {
|
if table.GetColumn(test.name) == nil {
|
||||||
|
@ -100,7 +94,6 @@ func BenchmarkGetColumn(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkGetColumnIdx(b *testing.B) {
|
func BenchmarkGetColumnIdx(b *testing.B) {
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, test := range testsGetColumn {
|
for _, test := range testsGetColumn {
|
||||||
if table.GetColumnIdx(test.name, test.idx) == nil {
|
if table.GetColumnIdx(test.name, test.idx) == nil {
|
||||||
|
|
|
@ -375,6 +375,9 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if fieldValue == nil {
|
||||||
|
return nil, ErrFieldIsNotValid{key, table.Name}
|
||||||
|
}
|
||||||
|
|
||||||
if !fieldValue.IsValid() || !fieldValue.CanSet() {
|
if !fieldValue.IsValid() || !fieldValue.CanSet() {
|
||||||
return nil, ErrFieldIsNotValid{key, table.Name}
|
return nil, ErrFieldIsNotValid{key, table.Name}
|
||||||
|
|
|
@ -35,27 +35,20 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time
|
||||||
sd, err := strconv.ParseInt(sdata, 10, 64)
|
sd, err := strconv.ParseInt(sdata, 10, 64)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
x = time.Unix(sd, 0)
|
x = time.Unix(sd, 0)
|
||||||
//session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
|
||||||
} else {
|
|
||||||
//session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
|
||||||
}
|
}
|
||||||
} else if len(sdata) > 19 && strings.Contains(sdata, "-") {
|
} else if len(sdata) > 19 && strings.Contains(sdata, "-") {
|
||||||
x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc)
|
x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc)
|
||||||
session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.Name, x, sdata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc)
|
x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc)
|
||||||
//session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc)
|
x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc)
|
||||||
//session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
|
||||||
}
|
}
|
||||||
} else if len(sdata) == 19 && strings.Contains(sdata, "-") {
|
} else if len(sdata) == 19 && strings.Contains(sdata, "-") {
|
||||||
x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc)
|
x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc)
|
||||||
//session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
|
||||||
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
|
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
|
||||||
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
|
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
|
||||||
//session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
|
||||||
} else if col.SQLType.Name == schemas.Time {
|
} else if col.SQLType.Name == schemas.Time {
|
||||||
if strings.Contains(sdata, " ") {
|
if strings.Contains(sdata, " ") {
|
||||||
ssd := strings.Split(sdata, " ")
|
ssd := strings.Split(sdata, " ")
|
||||||
|
@ -69,7 +62,6 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time
|
||||||
|
|
||||||
st := fmt.Sprintf("2006-01-02 %v", sdata)
|
st := fmt.Sprintf("2006-01-02 %v", sdata)
|
||||||
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc)
|
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc)
|
||||||
//session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
|
||||||
} else {
|
} else {
|
||||||
outErr = fmt.Errorf("unsupported time format %v", sdata)
|
outErr = fmt.Errorf("unsupported time format %v", sdata)
|
||||||
return
|
return
|
||||||
|
|
|
@ -276,7 +276,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
|
||||||
func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error {
|
func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error {
|
||||||
cols := table.PKColumns()
|
cols := table.PKColumns()
|
||||||
if len(cols) == 1 {
|
if len(cols) == 1 {
|
||||||
return convertAssign(dst, pk[0])
|
return convertAssign(dst, pk[0], nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
dst = pk
|
dst = pk
|
||||||
|
|
267
session_get.go
267
session_get.go
|
@ -6,12 +6,16 @@ package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"xorm.io/xorm/caches"
|
"xorm.io/xorm/caches"
|
||||||
|
"xorm.io/xorm/convert"
|
||||||
|
"xorm.io/xorm/core"
|
||||||
"xorm.io/xorm/internal/utils"
|
"xorm.io/xorm/internal/utils"
|
||||||
"xorm.io/xorm/schemas"
|
"xorm.io/xorm/schemas"
|
||||||
)
|
)
|
||||||
|
@ -108,6 +112,17 @@ func (session *Session) get(bean interface{}) (bool, error) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
valuerTypePlaceHolder driver.Valuer
|
||||||
|
valuerType = reflect.TypeOf(&valuerTypePlaceHolder).Elem()
|
||||||
|
|
||||||
|
scannerTypePlaceHolder sql.Scanner
|
||||||
|
scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem()
|
||||||
|
|
||||||
|
conversionTypePlaceHolder convert.Conversion
|
||||||
|
conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem()
|
||||||
|
)
|
||||||
|
|
||||||
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
|
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
|
||||||
rows, err := session.queryRows(sqlStr, args...)
|
rows, err := session.queryRows(sqlStr, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -122,123 +137,141 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table,
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch bean.(type) {
|
// WARN: Alougth rows return true, but we may also return error.
|
||||||
case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString:
|
types, err := rows.ColumnTypes()
|
||||||
return true, rows.Scan(&bean)
|
if err != nil {
|
||||||
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString:
|
|
||||||
return true, rows.Scan(bean)
|
|
||||||
case *string:
|
|
||||||
var res sql.NullString
|
|
||||||
if err := rows.Scan(&res); err != nil {
|
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
if res.Valid {
|
fields, err := rows.Columns()
|
||||||
*(bean.(*string)) = res.String
|
if err != nil {
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case *int:
|
|
||||||
var res sql.NullInt64
|
|
||||||
if err := rows.Scan(&res); err != nil {
|
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
if res.Valid {
|
|
||||||
*(bean.(*int)) = int(res.Int64)
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case *int8:
|
|
||||||
var res sql.NullInt64
|
|
||||||
if err := rows.Scan(&res); err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
if res.Valid {
|
|
||||||
*(bean.(*int8)) = int8(res.Int64)
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case *int16:
|
|
||||||
var res sql.NullInt64
|
|
||||||
if err := rows.Scan(&res); err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
if res.Valid {
|
|
||||||
*(bean.(*int16)) = int16(res.Int64)
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case *int32:
|
|
||||||
var res sql.NullInt64
|
|
||||||
if err := rows.Scan(&res); err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
if res.Valid {
|
|
||||||
*(bean.(*int32)) = int32(res.Int64)
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case *int64:
|
|
||||||
var res sql.NullInt64
|
|
||||||
if err := rows.Scan(&res); err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
if res.Valid {
|
|
||||||
*(bean.(*int64)) = int64(res.Int64)
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case *uint:
|
|
||||||
var res sql.NullInt64
|
|
||||||
if err := rows.Scan(&res); err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
if res.Valid {
|
|
||||||
*(bean.(*uint)) = uint(res.Int64)
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case *uint8:
|
|
||||||
var res sql.NullInt64
|
|
||||||
if err := rows.Scan(&res); err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
if res.Valid {
|
|
||||||
*(bean.(*uint8)) = uint8(res.Int64)
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case *uint16:
|
|
||||||
var res sql.NullInt64
|
|
||||||
if err := rows.Scan(&res); err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
if res.Valid {
|
|
||||||
*(bean.(*uint16)) = uint16(res.Int64)
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case *uint32:
|
|
||||||
var res sql.NullInt64
|
|
||||||
if err := rows.Scan(&res); err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
if res.Valid {
|
|
||||||
*(bean.(*uint32)) = uint32(res.Int64)
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case *uint64:
|
|
||||||
var res sql.NullInt64
|
|
||||||
if err := rows.Scan(&res); err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
if res.Valid {
|
|
||||||
*(bean.(*uint64)) = uint64(res.Int64)
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case *bool:
|
|
||||||
var res sql.NullBool
|
|
||||||
if err := rows.Scan(&res); err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
if res.Valid {
|
|
||||||
*(bean.(*bool)) = res.Bool
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch beanKind {
|
switch beanKind {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
|
if _, ok := bean.(*time.Time); ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if _, ok := bean.(sql.Scanner); ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if _, ok := bean.(convert.Conversion); len(types) == 1 && ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return session.getStruct(rows, types, fields, table, bean)
|
||||||
|
case reflect.Slice:
|
||||||
|
return session.getSlice(rows, types, fields, bean)
|
||||||
|
case reflect.Map:
|
||||||
|
return session.getMap(rows, types, fields, bean)
|
||||||
|
}
|
||||||
|
|
||||||
|
return session.getVars(rows, types, fields, bean)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) {
|
||||||
|
switch t := bean.(type) {
|
||||||
|
case *[]string:
|
||||||
|
res, err := session.engine.scanStringInterface(rows, types)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var needAppend = len(*t) == 0 // both support slice is empty or has been initlized
|
||||||
|
for i, r := range res {
|
||||||
|
if needAppend {
|
||||||
|
*t = append(*t, r.(*sql.NullString).String)
|
||||||
|
} else {
|
||||||
|
(*t)[i] = r.(*sql.NullString).String
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
case *[]interface{}:
|
||||||
|
scanResults, err := session.engine.scanInterfaces(rows, types)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
var needAppend = len(*t) == 0
|
||||||
|
for ii := range fields {
|
||||||
|
s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
if needAppend {
|
||||||
|
*t = append(*t, s)
|
||||||
|
} else {
|
||||||
|
(*t)[ii] = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
return true, fmt.Errorf("unspoorted slice type: %t", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) {
|
||||||
|
switch t := bean.(type) {
|
||||||
|
case *map[string]string:
|
||||||
|
scanResults, err := session.engine.scanStringInterface(rows, types)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
for ii, key := range fields {
|
||||||
|
(*t)[key] = scanResults[ii].(*sql.NullString).String
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
case *map[string]interface{}:
|
||||||
|
scanResults, err := session.engine.scanInterfaces(rows, types)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
for ii, key := range fields {
|
||||||
|
s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
(*t)[key] = s
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
return true, fmt.Errorf("unspoorted map type: %t", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields []string, beans ...interface{}) (bool, error) {
|
||||||
|
if len(beans) != len(types) {
|
||||||
|
return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans))
|
||||||
|
}
|
||||||
|
var scanResults = make([]interface{}, 0, len(types))
|
||||||
|
var replaceds = make([]bool, 0, len(types))
|
||||||
|
for _, bean := range beans {
|
||||||
|
switch t := bean.(type) {
|
||||||
|
case sql.Scanner:
|
||||||
|
scanResults = append(scanResults, t)
|
||||||
|
replaceds = append(replaceds, false)
|
||||||
|
case convert.Conversion:
|
||||||
|
scanResults = append(scanResults, &sql.RawBytes{})
|
||||||
|
replaceds = append(replaceds, true)
|
||||||
|
default:
|
||||||
|
scanResults = append(scanResults, bean)
|
||||||
|
replaceds = append(replaceds, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := session.engine.scan(rows, fields, types, scanResults...)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
for i, replaced := range replaceds {
|
||||||
|
if replaced {
|
||||||
|
err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation)
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {
|
||||||
fields, err := rows.Columns()
|
fields, err := rows.Columns()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// WARN: Alougth rows return true, but get fields failed
|
// WARN: Alougth rows return true, but get fields failed
|
||||||
|
@ -259,18 +292,6 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table,
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, session.executeProcessors()
|
return true, session.executeProcessors()
|
||||||
case reflect.Slice:
|
|
||||||
err = rows.ScanSlice(bean)
|
|
||||||
case reflect.Map:
|
|
||||||
err = rows.ScanMap(bean)
|
|
||||||
case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
|
||||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
||||||
err = rows.Scan(bean)
|
|
||||||
default:
|
|
||||||
err = rows.Scan(bean)
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) {
|
func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) {
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"xorm.io/xorm/internal/utils"
|
"xorm.io/xorm/internal/utils"
|
||||||
"xorm.io/xorm/schemas"
|
"xorm.io/xorm/schemas"
|
||||||
|
@ -374,9 +375,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
||||||
return 1, nil
|
return 1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
|
return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation)
|
||||||
|
|
||||||
return 1, nil
|
|
||||||
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
|
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
|
||||||
session.engine.dialect.URI().DBType == schemas.MSSQL) {
|
session.engine.dialect.URI().DBType == schemas.MSSQL) {
|
||||||
res, err := session.queryBytes(sqlStr, args...)
|
res, err := session.queryBytes(sqlStr, args...)
|
||||||
|
@ -416,9 +415,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
||||||
return 1, nil
|
return 1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
|
return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation)
|
||||||
|
|
||||||
return 1, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := session.exec(sqlStr, args...)
|
res, err := session.exec(sqlStr, args...)
|
||||||
|
@ -458,7 +455,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
}
|
}
|
||||||
|
|
||||||
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
|
if err := convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
}
|
}
|
||||||
|
@ -499,6 +498,16 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
|
||||||
}
|
}
|
||||||
|
|
||||||
if col.IsDeleted {
|
if col.IsDeleted {
|
||||||
|
colNames = append(colNames, col.Name)
|
||||||
|
if !col.Nullable {
|
||||||
|
if col.SQLType.IsNumeric() {
|
||||||
|
args = append(args, 0)
|
||||||
|
} else {
|
||||||
|
args = append(args, time.Time{}.Format("2006-01-02 15:04:05"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
args = append(args, nil)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
158
session_query.go
158
session_query.go
|
@ -5,13 +5,7 @@
|
||||||
package xorm
|
package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"xorm.io/xorm/core"
|
"xorm.io/xorm/core"
|
||||||
"xorm.io/xorm/schemas"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Query runs a raw sql and return records as []map[string][]byte
|
// Query runs a raw sql and return records as []map[string][]byte
|
||||||
|
@ -28,116 +22,18 @@ func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, er
|
||||||
return session.queryBytes(sqlStr, args...)
|
return session.queryBytes(sqlStr, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func value2String(rawValue *reflect.Value) (str string, err error) {
|
func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
|
||||||
aa := reflect.TypeOf((*rawValue).Interface())
|
|
||||||
vv := reflect.ValueOf((*rawValue).Interface())
|
|
||||||
switch aa.Kind() {
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
||||||
str = strconv.FormatInt(vv.Int(), 10)
|
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
||||||
str = strconv.FormatUint(vv.Uint(), 10)
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
|
|
||||||
case reflect.String:
|
|
||||||
str = vv.String()
|
|
||||||
case reflect.Array, reflect.Slice:
|
|
||||||
switch aa.Elem().Kind() {
|
|
||||||
case reflect.Uint8:
|
|
||||||
data := rawValue.Interface().([]byte)
|
|
||||||
str = string(data)
|
|
||||||
if str == "\x00" {
|
|
||||||
str = "0"
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
|
|
||||||
}
|
|
||||||
// time type
|
|
||||||
case reflect.Struct:
|
|
||||||
if aa.ConvertibleTo(schemas.TimeType) {
|
|
||||||
str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
|
|
||||||
} else {
|
|
||||||
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
|
|
||||||
}
|
|
||||||
case reflect.Bool:
|
|
||||||
str = strconv.FormatBool(vv.Bool())
|
|
||||||
case reflect.Complex128, reflect.Complex64:
|
|
||||||
str = fmt.Sprintf("%v", vv.Complex())
|
|
||||||
/* TODO: unsupported types below
|
|
||||||
case reflect.Map:
|
|
||||||
case reflect.Ptr:
|
|
||||||
case reflect.Uintptr:
|
|
||||||
case reflect.UnsafePointer:
|
|
||||||
case reflect.Chan, reflect.Func, reflect.Interface:
|
|
||||||
*/
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) {
|
|
||||||
result := make(map[string]string)
|
|
||||||
scanResultContainers := make([]interface{}, len(fields))
|
|
||||||
for i := 0; i < len(fields); i++ {
|
|
||||||
var scanResultContainer interface{}
|
|
||||||
scanResultContainers[i] = &scanResultContainer
|
|
||||||
}
|
|
||||||
if err := rows.Scan(scanResultContainers...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for ii, key := range fields {
|
|
||||||
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
|
|
||||||
// if row is null then as empty string
|
|
||||||
if rawValue.Interface() == nil {
|
|
||||||
result[key] = ""
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if data, err := value2String(&rawValue); err == nil {
|
|
||||||
result[key] = data
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func row2sliceStr(rows *core.Rows, fields []string) (results []string, err error) {
|
|
||||||
result := make([]string, 0, len(fields))
|
|
||||||
scanResultContainers := make([]interface{}, len(fields))
|
|
||||||
for i := 0; i < len(fields); i++ {
|
|
||||||
var scanResultContainer interface{}
|
|
||||||
scanResultContainers[i] = &scanResultContainer
|
|
||||||
}
|
|
||||||
if err := rows.Scan(scanResultContainers...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(fields); i++ {
|
|
||||||
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[i]))
|
|
||||||
// if row is null then as empty string
|
|
||||||
if rawValue.Interface() == nil {
|
|
||||||
result = append(result, "")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if data, err := value2String(&rawValue); err == nil {
|
|
||||||
result = append(result, data)
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
|
|
||||||
fields, err := rows.Columns()
|
fields, err := rows.Columns()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
types, err := rows.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
result, err := row2mapStr(rows, fields)
|
result, err := row2mapStr(rows, types, fields)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -147,13 +43,18 @@ func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error)
|
||||||
return resultsSlice, nil
|
return resultsSlice, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) {
|
func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) {
|
||||||
fields, err := rows.Columns()
|
fields, err := rows.Columns()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
types, err := rows.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
record, err := row2sliceStr(rows, fields)
|
record, err := session.engine.row2sliceStr(rows, types, fields)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -180,7 +81,7 @@ func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]stri
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
return rows2Strings(rows)
|
return session.rows2Strings(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
// QuerySliceString runs a raw sql and return records as [][]string
|
// QuerySliceString runs a raw sql and return records as [][]string
|
||||||
|
@ -200,33 +101,20 @@ func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string,
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
return rows2SliceString(rows)
|
return session.rows2SliceString(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) {
|
func (session *Session) rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) {
|
||||||
resultsMap = make(map[string]interface{}, len(fields))
|
|
||||||
scanResultContainers := make([]interface{}, len(fields))
|
|
||||||
for i := 0; i < len(fields); i++ {
|
|
||||||
var scanResultContainer interface{}
|
|
||||||
scanResultContainers[i] = &scanResultContainer
|
|
||||||
}
|
|
||||||
if err := rows.Scan(scanResultContainers...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for ii, key := range fields {
|
|
||||||
resultsMap[key] = reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])).Interface()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) {
|
|
||||||
fields, err := rows.Columns()
|
fields, err := rows.Columns()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
types, err := rows.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
result, err := row2mapInterface(rows, fields)
|
result, err := session.engine.row2mapInterface(rows, types, fields)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -253,5 +141,5 @@ func (session *Session) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]i
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
return rows2Interfaces(rows)
|
return session.rows2Interfaces(rows)
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,9 +6,13 @@ package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"xorm.io/xorm/core"
|
"xorm.io/xorm/core"
|
||||||
|
"xorm.io/xorm/schemas"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
|
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
|
||||||
|
@ -71,6 +75,53 @@ func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row {
|
||||||
return core.NewRow(session.queryRows(sqlStr, args...))
|
return core.NewRow(session.queryRows(sqlStr, args...))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func value2String(rawValue *reflect.Value) (str string, err error) {
|
||||||
|
aa := reflect.TypeOf((*rawValue).Interface())
|
||||||
|
vv := reflect.ValueOf((*rawValue).Interface())
|
||||||
|
switch aa.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
str = strconv.FormatInt(vv.Int(), 10)
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
str = strconv.FormatUint(vv.Uint(), 10)
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
|
||||||
|
case reflect.String:
|
||||||
|
str = vv.String()
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
switch aa.Elem().Kind() {
|
||||||
|
case reflect.Uint8:
|
||||||
|
data := rawValue.Interface().([]byte)
|
||||||
|
str = string(data)
|
||||||
|
if str == "\x00" {
|
||||||
|
str = "0"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
|
||||||
|
}
|
||||||
|
// time type
|
||||||
|
case reflect.Struct:
|
||||||
|
if aa.ConvertibleTo(schemas.TimeType) {
|
||||||
|
str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
|
||||||
|
}
|
||||||
|
case reflect.Bool:
|
||||||
|
str = strconv.FormatBool(vv.Bool())
|
||||||
|
case reflect.Complex128, reflect.Complex64:
|
||||||
|
str = fmt.Sprintf("%v", vv.Complex())
|
||||||
|
/* TODO: unsupported types below
|
||||||
|
case reflect.Map:
|
||||||
|
case reflect.Ptr:
|
||||||
|
case reflect.Uintptr:
|
||||||
|
case reflect.UnsafePointer:
|
||||||
|
case reflect.Chan, reflect.Func, reflect.Interface:
|
||||||
|
*/
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func value2Bytes(rawValue *reflect.Value) ([]byte, error) {
|
func value2Bytes(rawValue *reflect.Value) ([]byte, error) {
|
||||||
str, err := value2String(rawValue)
|
str, err := value2String(rawValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -79,50 +130,6 @@ func value2Bytes(rawValue *reflect.Value) ([]byte, error) {
|
||||||
return []byte(str), nil
|
return []byte(str), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) {
|
|
||||||
result := make(map[string][]byte)
|
|
||||||
scanResultContainers := make([]interface{}, len(fields))
|
|
||||||
for i := 0; i < len(fields); i++ {
|
|
||||||
var scanResultContainer interface{}
|
|
||||||
scanResultContainers[i] = &scanResultContainer
|
|
||||||
}
|
|
||||||
if err := rows.Scan(scanResultContainers...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for ii, key := range fields {
|
|
||||||
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
|
|
||||||
//if row is null then ignore
|
|
||||||
if rawValue.Interface() == nil {
|
|
||||||
result[key] = []byte{}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if data, err := value2Bytes(&rawValue); err == nil {
|
|
||||||
result[key] = data
|
|
||||||
} else {
|
|
||||||
return nil, err // !nashtsai! REVIEW, should return err or just error log?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
|
|
||||||
fields, err := rows.Columns()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for rows.Next() {
|
|
||||||
result, err := row2map(rows, fields)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resultsSlice = append(resultsSlice, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
return resultsSlice, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) {
|
func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) {
|
||||||
rows, err := session.queryRows(sqlStr, args...)
|
rows, err := session.queryRows(sqlStr, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -280,15 +280,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
k = ct.Elem().Kind()
|
k = ct.Elem().Kind()
|
||||||
}
|
}
|
||||||
if k == reflect.Struct {
|
if k == reflect.Struct {
|
||||||
var refTable = session.statement.RefTable
|
condTable, err := session.engine.TableInfo(condiBean[0])
|
||||||
if refTable == nil {
|
|
||||||
refTable, err = session.engine.TableInfo(condiBean[0])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
var err error
|
autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false)
|
||||||
autoCond, err = session.statement.BuildConds(refTable, condiBean[0], true, true, false, true, false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -457,7 +454,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
// FIXME: if bean is a map type, it will panic because map cannot be as map key
|
// FIXME: if bean is a map type, it will panic because map cannot be as map key
|
||||||
session.afterUpdateBeans[bean] = &afterClosures
|
session.afterUpdateBeans[bean] = &afterClosures
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if _, ok := interface{}(bean).(AfterUpdateProcessor); ok {
|
if _, ok := interface{}(bean).(AfterUpdateProcessor); ok {
|
||||||
session.afterUpdateBeans[bean] = nil
|
session.afterUpdateBeans[bean] = nil
|
||||||
|
|
328
tags/parser.go
328
tags/parser.go
|
@ -7,7 +7,6 @@ package tags
|
||||||
import (
|
import (
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -23,7 +22,7 @@ import (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrUnsupportedType represents an unsupported type error
|
// ErrUnsupportedType represents an unsupported type error
|
||||||
ErrUnsupportedType = errors.New("Unsupported type")
|
ErrUnsupportedType = errors.New("unsupported type")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Parser represents a parser for xorm tag
|
// Parser represents a parser for xorm tag
|
||||||
|
@ -125,6 +124,147 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrIgnoreField = errors.New("field will be ignored")
|
||||||
|
|
||||||
|
func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
|
||||||
|
var sqlType schemas.SQLType
|
||||||
|
if fieldValue.CanAddr() {
|
||||||
|
if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
|
||||||
|
sqlType = schemas.SQLType{Name: schemas.Text}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
|
||||||
|
sqlType = schemas.SQLType{Name: schemas.Text}
|
||||||
|
} else {
|
||||||
|
sqlType = schemas.Type2SQLType(field.Type)
|
||||||
|
}
|
||||||
|
col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name),
|
||||||
|
field.Name, sqlType, sqlType.DefaultLength,
|
||||||
|
sqlType.DefaultLength2, true)
|
||||||
|
col.FieldIndex = []int{fieldIndex}
|
||||||
|
|
||||||
|
if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
|
||||||
|
col.IsAutoIncrement = true
|
||||||
|
col.IsPrimaryKey = true
|
||||||
|
col.Nullable = false
|
||||||
|
}
|
||||||
|
return col, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) {
|
||||||
|
var col = &schemas.Column{
|
||||||
|
FieldName: field.Name,
|
||||||
|
FieldIndex: []int{fieldIndex},
|
||||||
|
Nullable: true,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
IsAutoIncrement: false,
|
||||||
|
MapType: schemas.TWOSIDES,
|
||||||
|
Indexes: make(map[string]int),
|
||||||
|
DefaultIsEmpty: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var ctx = Context{
|
||||||
|
table: table,
|
||||||
|
col: col,
|
||||||
|
fieldValue: fieldValue,
|
||||||
|
indexNames: make(map[string]int),
|
||||||
|
parser: parser,
|
||||||
|
}
|
||||||
|
|
||||||
|
for j, tag := range tags {
|
||||||
|
if ctx.ignoreNext {
|
||||||
|
ctx.ignoreNext = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.tag = tag
|
||||||
|
ctx.tagUname = strings.ToUpper(tag.name)
|
||||||
|
|
||||||
|
if j > 0 {
|
||||||
|
ctx.preTag = strings.ToUpper(tags[j-1].name)
|
||||||
|
}
|
||||||
|
if j < len(tags)-1 {
|
||||||
|
ctx.nextTag = tags[j+1].name
|
||||||
|
} else {
|
||||||
|
ctx.nextTag = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if h, ok := parser.handlers[ctx.tagUname]; ok {
|
||||||
|
if err := h(&ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if strings.HasPrefix(ctx.tag.name, "'") && strings.HasSuffix(ctx.tag.name, "'") {
|
||||||
|
col.Name = ctx.tag.name[1 : len(ctx.tag.name)-1]
|
||||||
|
} else {
|
||||||
|
col.Name = ctx.tag.name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.hasCacheTag {
|
||||||
|
if parser.cacherMgr.GetDefaultCacher() != nil {
|
||||||
|
parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher())
|
||||||
|
} else {
|
||||||
|
parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ctx.hasNoCacheTag {
|
||||||
|
parser.cacherMgr.SetCacher(table.Name, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if col.SQLType.Name == "" {
|
||||||
|
col.SQLType = schemas.Type2SQLType(field.Type)
|
||||||
|
}
|
||||||
|
parser.dialect.SQLType(col)
|
||||||
|
if col.Length == 0 {
|
||||||
|
col.Length = col.SQLType.DefaultLength
|
||||||
|
}
|
||||||
|
if col.Length2 == 0 {
|
||||||
|
col.Length2 = col.SQLType.DefaultLength2
|
||||||
|
}
|
||||||
|
if col.Name == "" {
|
||||||
|
col.Name = parser.columnMapper.Obj2Table(field.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.isUnique {
|
||||||
|
ctx.indexNames[col.Name] = schemas.UniqueType
|
||||||
|
} else if ctx.isIndex {
|
||||||
|
ctx.indexNames[col.Name] = schemas.IndexType
|
||||||
|
}
|
||||||
|
|
||||||
|
for indexName, indexType := range ctx.indexNames {
|
||||||
|
addIndex(indexName, table, col, indexType)
|
||||||
|
}
|
||||||
|
|
||||||
|
return col, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
|
||||||
|
var (
|
||||||
|
tag = field.Tag
|
||||||
|
ormTagStr = strings.TrimSpace(tag.Get(parser.identifier))
|
||||||
|
)
|
||||||
|
if ormTagStr == "-" {
|
||||||
|
return nil, ErrIgnoreField
|
||||||
|
}
|
||||||
|
if ormTagStr == "" {
|
||||||
|
return parser.parseFieldWithNoTag(fieldIndex, field, fieldValue)
|
||||||
|
}
|
||||||
|
tags, err := splitTag(ormTagStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return parser.parseFieldWithTags(table, fieldIndex, field, fieldValue, tags)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNotTitle(n string) bool {
|
||||||
|
for _, c := range n {
|
||||||
|
return unicode.IsLower(c)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// Parse parses a struct as a table information
|
// Parse parses a struct as a table information
|
||||||
func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
|
func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
|
||||||
t := v.Type()
|
t := v.Type()
|
||||||
|
@ -140,192 +280,26 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
|
||||||
table.Type = t
|
table.Type = t
|
||||||
table.Name = names.GetTableName(parser.tableMapper, v)
|
table.Name = names.GetTableName(parser.tableMapper, v)
|
||||||
|
|
||||||
var idFieldColName string
|
|
||||||
var hasCacheTag, hasNoCacheTag bool
|
|
||||||
|
|
||||||
for i := 0; i < t.NumField(); i++ {
|
for i := 0; i < t.NumField(); i++ {
|
||||||
var isUnexportField bool
|
var field = t.Field(i)
|
||||||
for _, c := range t.Field(i).Name {
|
if isNotTitle(field.Name) {
|
||||||
if unicode.IsLower(c) {
|
|
||||||
isUnexportField = true
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if isUnexportField {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
tag := t.Field(i).Tag
|
col, err := parser.parseField(table, i, field, v.Field(i))
|
||||||
ormTagStr := tag.Get(parser.identifier)
|
if err == ErrIgnoreField {
|
||||||
var col *schemas.Column
|
|
||||||
fieldValue := v.Field(i)
|
|
||||||
fieldType := fieldValue.Type()
|
|
||||||
|
|
||||||
if ormTagStr != "" {
|
|
||||||
col = &schemas.Column{
|
|
||||||
FieldName: t.Field(i).Name,
|
|
||||||
Nullable: true,
|
|
||||||
IsPrimaryKey: false,
|
|
||||||
IsAutoIncrement: false,
|
|
||||||
MapType: schemas.TWOSIDES,
|
|
||||||
Indexes: make(map[string]int),
|
|
||||||
DefaultIsEmpty: true,
|
|
||||||
}
|
|
||||||
tags := splitTag(ormTagStr)
|
|
||||||
|
|
||||||
if len(tags) > 0 {
|
|
||||||
if tags[0] == "-" {
|
|
||||||
continue
|
continue
|
||||||
}
|
} else if err != nil {
|
||||||
|
|
||||||
var ctx = Context{
|
|
||||||
table: table,
|
|
||||||
col: col,
|
|
||||||
fieldValue: fieldValue,
|
|
||||||
indexNames: make(map[string]int),
|
|
||||||
parser: parser,
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") {
|
|
||||||
pStart := strings.Index(tags[0], "(")
|
|
||||||
if pStart > -1 && strings.HasSuffix(tags[0], ")") {
|
|
||||||
var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool {
|
|
||||||
return r == '\'' || r == '"'
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx.params = []string{tagPrefix}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ExtendsTagHandler(&ctx); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for j, key := range tags {
|
|
||||||
if ctx.ignoreNext {
|
|
||||||
ctx.ignoreNext = false
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
k := strings.ToUpper(key)
|
|
||||||
ctx.tagName = k
|
|
||||||
ctx.params = []string{}
|
|
||||||
|
|
||||||
pStart := strings.Index(k, "(")
|
|
||||||
if pStart == 0 {
|
|
||||||
return nil, errors.New("( could not be the first character")
|
|
||||||
}
|
|
||||||
if pStart > -1 {
|
|
||||||
if !strings.HasSuffix(k, ")") {
|
|
||||||
return nil, fmt.Errorf("field %s tag %s cannot match ) character", col.FieldName, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.tagName = k[:pStart]
|
|
||||||
ctx.params = strings.Split(key[pStart+1:len(k)-1], ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
if j > 0 {
|
|
||||||
ctx.preTag = strings.ToUpper(tags[j-1])
|
|
||||||
}
|
|
||||||
if j < len(tags)-1 {
|
|
||||||
ctx.nextTag = tags[j+1]
|
|
||||||
} else {
|
|
||||||
ctx.nextTag = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
if h, ok := parser.handlers[ctx.tagName]; ok {
|
|
||||||
if err := h(&ctx); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") {
|
|
||||||
col.Name = key[1 : len(key)-1]
|
|
||||||
} else {
|
|
||||||
col.Name = key
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctx.hasCacheTag {
|
|
||||||
hasCacheTag = true
|
|
||||||
}
|
|
||||||
if ctx.hasNoCacheTag {
|
|
||||||
hasNoCacheTag = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if col.SQLType.Name == "" {
|
|
||||||
col.SQLType = schemas.Type2SQLType(fieldType)
|
|
||||||
}
|
|
||||||
parser.dialect.SQLType(col)
|
|
||||||
if col.Length == 0 {
|
|
||||||
col.Length = col.SQLType.DefaultLength
|
|
||||||
}
|
|
||||||
if col.Length2 == 0 {
|
|
||||||
col.Length2 = col.SQLType.DefaultLength2
|
|
||||||
}
|
|
||||||
if col.Name == "" {
|
|
||||||
col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctx.isUnique {
|
|
||||||
ctx.indexNames[col.Name] = schemas.UniqueType
|
|
||||||
} else if ctx.isIndex {
|
|
||||||
ctx.indexNames[col.Name] = schemas.IndexType
|
|
||||||
}
|
|
||||||
|
|
||||||
for indexName, indexType := range ctx.indexNames {
|
|
||||||
addIndex(indexName, table, col, indexType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
var sqlType schemas.SQLType
|
|
||||||
if fieldValue.CanAddr() {
|
|
||||||
if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
|
|
||||||
sqlType = schemas.SQLType{Name: schemas.Text}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
|
|
||||||
sqlType = schemas.SQLType{Name: schemas.Text}
|
|
||||||
} else {
|
|
||||||
sqlType = schemas.Type2SQLType(fieldType)
|
|
||||||
}
|
|
||||||
col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name),
|
|
||||||
t.Field(i).Name, sqlType, sqlType.DefaultLength,
|
|
||||||
sqlType.DefaultLength2, true)
|
|
||||||
|
|
||||||
if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
|
|
||||||
idFieldColName = col.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if col.IsAutoIncrement {
|
|
||||||
col.Nullable = false
|
|
||||||
}
|
|
||||||
|
|
||||||
table.AddColumn(col)
|
table.AddColumn(col)
|
||||||
} // end for
|
} // end for
|
||||||
|
|
||||||
if idFieldColName != "" && len(table.PrimaryKeys) == 0 {
|
deletedColumn := table.DeletedColumn()
|
||||||
col := table.GetColumn(idFieldColName)
|
// check columns
|
||||||
col.IsPrimaryKey = true
|
if deletedColumn != nil {
|
||||||
col.IsAutoIncrement = true
|
deletedColumn.Nullable = true
|
||||||
col.Nullable = false
|
|
||||||
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
|
|
||||||
table.AutoIncrement = col.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasCacheTag {
|
|
||||||
if parser.cacherMgr.GetDefaultCacher() != nil { // !nash! use engine's cacher if provided
|
|
||||||
//engine.logger.Info("enable cache on table:", table.Name)
|
|
||||||
parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher())
|
|
||||||
} else {
|
|
||||||
//engine.logger.Info("enable LRU cache on table:", table.Name)
|
|
||||||
parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if hasNoCacheTag {
|
|
||||||
//engine.logger.Info("disable cache on table:", table.Name)
|
|
||||||
parser.cacherMgr.SetCacher(table.Name, nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return table, nil
|
return table, nil
|
||||||
|
|
|
@ -6,12 +6,16 @@ package tags
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"xorm.io/xorm/caches"
|
"xorm.io/xorm/caches"
|
||||||
"xorm.io/xorm/dialects"
|
"xorm.io/xorm/dialects"
|
||||||
"xorm.io/xorm/names"
|
"xorm.io/xorm/names"
|
||||||
|
"xorm.io/xorm/schemas"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ParseTableName1 struct{}
|
type ParseTableName1 struct{}
|
||||||
|
@ -80,7 +84,7 @@ func TestParseWithOtherIdentifier(t *testing.T) {
|
||||||
parser := NewParser(
|
parser := NewParser(
|
||||||
"xorm",
|
"xorm",
|
||||||
dialects.QueryDialect("mysql"),
|
dialects.QueryDialect("mysql"),
|
||||||
names.GonicMapper{},
|
names.SameMapper{},
|
||||||
names.SnakeMapper{},
|
names.SnakeMapper{},
|
||||||
caches.NewManager(),
|
caches.NewManager(),
|
||||||
)
|
)
|
||||||
|
@ -88,13 +92,461 @@ func TestParseWithOtherIdentifier(t *testing.T) {
|
||||||
type StructWithDBTag struct {
|
type StructWithDBTag struct {
|
||||||
FieldFoo string `db:"foo"`
|
FieldFoo string `db:"foo"`
|
||||||
}
|
}
|
||||||
|
|
||||||
parser.SetIdentifier("db")
|
parser.SetIdentifier("db")
|
||||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag)))
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag)))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, "struct_with_db_tag", table.Name)
|
assert.EqualValues(t, "StructWithDBTag", table.Name)
|
||||||
assert.EqualValues(t, 1, len(table.Columns()))
|
assert.EqualValues(t, 1, len(table.Columns()))
|
||||||
|
|
||||||
for _, col := range table.Columns() {
|
for _, col := range table.Columns() {
|
||||||
assert.EqualValues(t, "foo", col.Name)
|
assert.EqualValues(t, "foo", col.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseWithIgnore(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SameMapper{},
|
||||||
|
names.SnakeMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithIgnoreTag struct {
|
||||||
|
FieldFoo string `db:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithIgnoreTag)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "StructWithIgnoreTag", table.Name)
|
||||||
|
assert.EqualValues(t, 0, len(table.Columns()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithAutoincrement(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithAutoIncrement struct {
|
||||||
|
ID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_auto_increment", table.Name)
|
||||||
|
assert.EqualValues(t, 1, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "id", table.Columns()[0].Name)
|
||||||
|
assert.True(t, table.Columns()[0].IsAutoIncrement)
|
||||||
|
assert.True(t, table.Columns()[0].IsPrimaryKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithAutoincrement2(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithAutoIncrement2 struct {
|
||||||
|
ID int64 `db:"pk autoincr"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement2)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_auto_increment2", table.Name)
|
||||||
|
assert.EqualValues(t, 1, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "id", table.Columns()[0].Name)
|
||||||
|
assert.True(t, table.Columns()[0].IsAutoIncrement)
|
||||||
|
assert.True(t, table.Columns()[0].IsPrimaryKey)
|
||||||
|
assert.False(t, table.Columns()[0].Nullable)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithNullable(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithNullable struct {
|
||||||
|
Name string `db:"notnull"`
|
||||||
|
FullName string `db:"null comment('column comment,字段注释')"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithNullable)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_nullable", table.Name)
|
||||||
|
assert.EqualValues(t, 2, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||||
|
assert.EqualValues(t, "full_name", table.Columns()[1].Name)
|
||||||
|
assert.False(t, table.Columns()[0].Nullable)
|
||||||
|
assert.True(t, table.Columns()[1].Nullable)
|
||||||
|
assert.EqualValues(t, "column comment,字段注释", table.Columns()[1].Comment)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithTimes(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithTimes struct {
|
||||||
|
Name string `db:"notnull"`
|
||||||
|
CreatedAt time.Time `db:"created"`
|
||||||
|
UpdatedAt time.Time `db:"updated"`
|
||||||
|
DeletedAt time.Time `db:"deleted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithTimes)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_times", table.Name)
|
||||||
|
assert.EqualValues(t, 4, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||||
|
assert.EqualValues(t, "created_at", table.Columns()[1].Name)
|
||||||
|
assert.EqualValues(t, "updated_at", table.Columns()[2].Name)
|
||||||
|
assert.EqualValues(t, "deleted_at", table.Columns()[3].Name)
|
||||||
|
assert.False(t, table.Columns()[0].Nullable)
|
||||||
|
assert.True(t, table.Columns()[1].Nullable)
|
||||||
|
assert.True(t, table.Columns()[1].IsCreated)
|
||||||
|
assert.True(t, table.Columns()[2].Nullable)
|
||||||
|
assert.True(t, table.Columns()[2].IsUpdated)
|
||||||
|
assert.True(t, table.Columns()[3].Nullable)
|
||||||
|
assert.True(t, table.Columns()[3].IsDeleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithExtends(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithEmbed struct {
|
||||||
|
Name string
|
||||||
|
CreatedAt time.Time `db:"created"`
|
||||||
|
UpdatedAt time.Time `db:"updated"`
|
||||||
|
DeletedAt time.Time `db:"deleted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StructWithExtends struct {
|
||||||
|
SW StructWithEmbed `db:"extends"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithExtends)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_extends", table.Name)
|
||||||
|
assert.EqualValues(t, 4, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||||
|
assert.EqualValues(t, "created_at", table.Columns()[1].Name)
|
||||||
|
assert.EqualValues(t, "updated_at", table.Columns()[2].Name)
|
||||||
|
assert.EqualValues(t, "deleted_at", table.Columns()[3].Name)
|
||||||
|
assert.True(t, table.Columns()[0].Nullable)
|
||||||
|
assert.True(t, table.Columns()[1].Nullable)
|
||||||
|
assert.True(t, table.Columns()[1].IsCreated)
|
||||||
|
assert.True(t, table.Columns()[2].Nullable)
|
||||||
|
assert.True(t, table.Columns()[2].IsUpdated)
|
||||||
|
assert.True(t, table.Columns()[3].Nullable)
|
||||||
|
assert.True(t, table.Columns()[3].IsDeleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithCache(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithCache struct {
|
||||||
|
Name string `db:"cache"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithCache)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_cache", table.Name)
|
||||||
|
assert.EqualValues(t, 1, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||||
|
assert.True(t, table.Columns()[0].Nullable)
|
||||||
|
cacher := parser.cacherMgr.GetCacher(table.Name)
|
||||||
|
assert.NotNil(t, cacher)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithNoCache(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithNoCache struct {
|
||||||
|
Name string `db:"nocache"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithNoCache)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_no_cache", table.Name)
|
||||||
|
assert.EqualValues(t, 1, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||||
|
assert.True(t, table.Columns()[0].Nullable)
|
||||||
|
cacher := parser.cacherMgr.GetCacher(table.Name)
|
||||||
|
assert.Nil(t, cacher)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithEnum(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithEnum struct {
|
||||||
|
Name string `db:"enum('alice', 'bob')"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithEnum)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_enum", table.Name)
|
||||||
|
assert.EqualValues(t, 1, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||||
|
assert.True(t, table.Columns()[0].Nullable)
|
||||||
|
assert.EqualValues(t, schemas.Enum, strings.ToUpper(table.Columns()[0].SQLType.Name))
|
||||||
|
assert.EqualValues(t, map[string]int{
|
||||||
|
"alice": 0,
|
||||||
|
"bob": 1,
|
||||||
|
}, table.Columns()[0].EnumOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithSet(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithSet struct {
|
||||||
|
Name string `db:"set('alice', 'bob')"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithSet)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_set", table.Name)
|
||||||
|
assert.EqualValues(t, 1, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||||
|
assert.True(t, table.Columns()[0].Nullable)
|
||||||
|
assert.EqualValues(t, schemas.Set, strings.ToUpper(table.Columns()[0].SQLType.Name))
|
||||||
|
assert.EqualValues(t, map[string]int{
|
||||||
|
"alice": 0,
|
||||||
|
"bob": 1,
|
||||||
|
}, table.Columns()[0].SetOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithIndex(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithIndex struct {
|
||||||
|
Name string `db:"index"`
|
||||||
|
Name2 string `db:"index(s)"`
|
||||||
|
Name3 string `db:"unique"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithIndex)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_index", table.Name)
|
||||||
|
assert.EqualValues(t, 3, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||||
|
assert.EqualValues(t, "name2", table.Columns()[1].Name)
|
||||||
|
assert.EqualValues(t, "name3", table.Columns()[2].Name)
|
||||||
|
assert.True(t, table.Columns()[0].Nullable)
|
||||||
|
assert.True(t, table.Columns()[1].Nullable)
|
||||||
|
assert.True(t, table.Columns()[2].Nullable)
|
||||||
|
assert.EqualValues(t, 1, len(table.Columns()[0].Indexes))
|
||||||
|
assert.EqualValues(t, 1, len(table.Columns()[1].Indexes))
|
||||||
|
assert.EqualValues(t, 1, len(table.Columns()[2].Indexes))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithVersion(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithVersion struct {
|
||||||
|
Name string
|
||||||
|
Version int `db:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithVersion)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_version", table.Name)
|
||||||
|
assert.EqualValues(t, 2, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||||
|
assert.EqualValues(t, "version", table.Columns()[1].Name)
|
||||||
|
assert.True(t, table.Columns()[0].Nullable)
|
||||||
|
assert.True(t, table.Columns()[1].Nullable)
|
||||||
|
assert.True(t, table.Columns()[1].IsVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithLocale(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithLocale struct {
|
||||||
|
UTCLocale time.Time `db:"utc"`
|
||||||
|
LocalLocale time.Time `db:"local"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithLocale)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_locale", table.Name)
|
||||||
|
assert.EqualValues(t, 2, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "utc_locale", table.Columns()[0].Name)
|
||||||
|
assert.EqualValues(t, "local_locale", table.Columns()[1].Name)
|
||||||
|
assert.EqualValues(t, time.UTC, table.Columns()[0].TimeZone)
|
||||||
|
assert.EqualValues(t, time.Local, table.Columns()[1].TimeZone)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithDefault(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.SnakeMapper{},
|
||||||
|
names.GonicMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithDefault struct {
|
||||||
|
Default1 time.Time `db:"default '1970-01-01 00:00:00'"`
|
||||||
|
Default2 time.Time `db:"default(CURRENT_TIMESTAMP)"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithDefault)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_default", table.Name)
|
||||||
|
assert.EqualValues(t, 2, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "default1", table.Columns()[0].Name)
|
||||||
|
assert.EqualValues(t, "default2", table.Columns()[1].Name)
|
||||||
|
assert.EqualValues(t, "'1970-01-01 00:00:00'", table.Columns()[0].Default)
|
||||||
|
assert.EqualValues(t, "CURRENT_TIMESTAMP", table.Columns()[1].Default)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithOnlyToDB(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.GonicMapper{
|
||||||
|
"DB": true,
|
||||||
|
},
|
||||||
|
names.SnakeMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithOnlyToDB struct {
|
||||||
|
Default1 time.Time `db:"->"`
|
||||||
|
Default2 time.Time `db:"<-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithOnlyToDB)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_only_to_db", table.Name)
|
||||||
|
assert.EqualValues(t, 2, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "default1", table.Columns()[0].Name)
|
||||||
|
assert.EqualValues(t, "default2", table.Columns()[1].Name)
|
||||||
|
assert.EqualValues(t, schemas.ONLYTODB, table.Columns()[0].MapType)
|
||||||
|
assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithJSON(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.GonicMapper{
|
||||||
|
"JSON": true,
|
||||||
|
},
|
||||||
|
names.SnakeMapper{},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithJSON struct {
|
||||||
|
Default1 []string `db:"json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithJSON)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_json", table.Name)
|
||||||
|
assert.EqualValues(t, 1, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "default1", table.Columns()[0].Name)
|
||||||
|
assert.True(t, table.Columns()[0].IsJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithSQLType(t *testing.T) {
|
||||||
|
parser := NewParser(
|
||||||
|
"db",
|
||||||
|
dialects.QueryDialect("mysql"),
|
||||||
|
names.GonicMapper{
|
||||||
|
"SQL": true,
|
||||||
|
},
|
||||||
|
names.GonicMapper{
|
||||||
|
"UUID": true,
|
||||||
|
},
|
||||||
|
caches.NewManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructWithSQLType struct {
|
||||||
|
Col1 string `db:"varchar(32)"`
|
||||||
|
Col2 string `db:"char(32)"`
|
||||||
|
Int int64 `db:"bigint"`
|
||||||
|
DateTime time.Time `db:"datetime"`
|
||||||
|
UUID string `db:"uuid"`
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := parser.Parse(reflect.ValueOf(new(StructWithSQLType)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "struct_with_sql_type", table.Name)
|
||||||
|
assert.EqualValues(t, 5, len(table.Columns()))
|
||||||
|
assert.EqualValues(t, "col1", table.Columns()[0].Name)
|
||||||
|
assert.EqualValues(t, "col2", table.Columns()[1].Name)
|
||||||
|
assert.EqualValues(t, "int", table.Columns()[2].Name)
|
||||||
|
assert.EqualValues(t, "date_time", table.Columns()[3].Name)
|
||||||
|
assert.EqualValues(t, "uuid", table.Columns()[4].Name)
|
||||||
|
|
||||||
|
assert.EqualValues(t, "VARCHAR", table.Columns()[0].SQLType.Name)
|
||||||
|
assert.EqualValues(t, "CHAR", table.Columns()[1].SQLType.Name)
|
||||||
|
assert.EqualValues(t, "BIGINT", table.Columns()[2].SQLType.Name)
|
||||||
|
assert.EqualValues(t, "DATETIME", table.Columns()[3].SQLType.Name)
|
||||||
|
assert.EqualValues(t, "UUID", table.Columns()[4].SQLType.Name)
|
||||||
|
}
|
||||||
|
|
99
tags/tag.go
99
tags/tag.go
|
@ -14,30 +14,74 @@ import (
|
||||||
"xorm.io/xorm/schemas"
|
"xorm.io/xorm/schemas"
|
||||||
)
|
)
|
||||||
|
|
||||||
func splitTag(tag string) (tags []string) {
|
type tag struct {
|
||||||
tag = strings.TrimSpace(tag)
|
name string
|
||||||
var hasQuote = false
|
params []string
|
||||||
var lastIdx = 0
|
}
|
||||||
for i, t := range tag {
|
|
||||||
if t == '\'' {
|
func splitTag(tagStr string) ([]tag, error) {
|
||||||
hasQuote = !hasQuote
|
tagStr = strings.TrimSpace(tagStr)
|
||||||
} else if t == ' ' {
|
var (
|
||||||
if lastIdx < i && !hasQuote {
|
inQuote bool
|
||||||
tags = append(tags, strings.TrimSpace(tag[lastIdx:i]))
|
inBigQuote bool
|
||||||
|
lastIdx int
|
||||||
|
curTag tag
|
||||||
|
paramStart int
|
||||||
|
tags []tag
|
||||||
|
)
|
||||||
|
for i, t := range tagStr {
|
||||||
|
switch t {
|
||||||
|
case '\'':
|
||||||
|
inQuote = !inQuote
|
||||||
|
case ' ':
|
||||||
|
if !inQuote && !inBigQuote {
|
||||||
|
if lastIdx < i {
|
||||||
|
if curTag.name == "" {
|
||||||
|
curTag.name = tagStr[lastIdx:i]
|
||||||
|
}
|
||||||
|
tags = append(tags, curTag)
|
||||||
|
lastIdx = i + 1
|
||||||
|
curTag = tag{}
|
||||||
|
} else if lastIdx == i {
|
||||||
lastIdx = i + 1
|
lastIdx = i + 1
|
||||||
}
|
}
|
||||||
|
} else if inBigQuote && !inQuote {
|
||||||
|
paramStart = i + 1
|
||||||
|
}
|
||||||
|
case ',':
|
||||||
|
if !inQuote && !inBigQuote {
|
||||||
|
return nil, fmt.Errorf("comma[%d] of %s should be in quote or big quote", i, tagStr)
|
||||||
|
}
|
||||||
|
if !inQuote && inBigQuote {
|
||||||
|
curTag.params = append(curTag.params, strings.TrimSpace(tagStr[paramStart:i]))
|
||||||
|
paramStart = i + 1
|
||||||
|
}
|
||||||
|
case '(':
|
||||||
|
inBigQuote = true
|
||||||
|
if !inQuote {
|
||||||
|
curTag.name = tagStr[lastIdx:i]
|
||||||
|
paramStart = i + 1
|
||||||
|
}
|
||||||
|
case ')':
|
||||||
|
inBigQuote = false
|
||||||
|
if !inQuote {
|
||||||
|
curTag.params = append(curTag.params, tagStr[paramStart:i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if lastIdx < len(tag) {
|
|
||||||
tags = append(tags, strings.TrimSpace(tag[lastIdx:]))
|
|
||||||
}
|
}
|
||||||
return
|
if lastIdx < len(tagStr) {
|
||||||
|
if curTag.name == "" {
|
||||||
|
curTag.name = tagStr[lastIdx:]
|
||||||
|
}
|
||||||
|
tags = append(tags, curTag)
|
||||||
|
}
|
||||||
|
return tags, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Context represents a context for xorm tag parse.
|
// Context represents a context for xorm tag parse.
|
||||||
type Context struct {
|
type Context struct {
|
||||||
tagName string
|
tag
|
||||||
params []string
|
tagUname string
|
||||||
preTag, nextTag string
|
preTag, nextTag string
|
||||||
table *schemas.Table
|
table *schemas.Table
|
||||||
col *schemas.Column
|
col *schemas.Column
|
||||||
|
@ -76,6 +120,7 @@ var (
|
||||||
"CACHE": CacheTagHandler,
|
"CACHE": CacheTagHandler,
|
||||||
"NOCACHE": NoCacheTagHandler,
|
"NOCACHE": NoCacheTagHandler,
|
||||||
"COMMENT": CommentTagHandler,
|
"COMMENT": CommentTagHandler,
|
||||||
|
"EXTENDS": ExtendsTagHandler,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -124,6 +169,7 @@ func NotNullTagHandler(ctx *Context) error {
|
||||||
// AutoIncrTagHandler describes autoincr tag handler
|
// AutoIncrTagHandler describes autoincr tag handler
|
||||||
func AutoIncrTagHandler(ctx *Context) error {
|
func AutoIncrTagHandler(ctx *Context) error {
|
||||||
ctx.col.IsAutoIncrement = true
|
ctx.col.IsAutoIncrement = true
|
||||||
|
ctx.col.Nullable = false
|
||||||
/*
|
/*
|
||||||
if len(ctx.params) > 0 {
|
if len(ctx.params) > 0 {
|
||||||
autoStartInt, err := strconv.Atoi(ctx.params[0])
|
autoStartInt, err := strconv.Atoi(ctx.params[0])
|
||||||
|
@ -192,6 +238,7 @@ func UpdatedTagHandler(ctx *Context) error {
|
||||||
// DeletedTagHandler describes deleted tag handler
|
// DeletedTagHandler describes deleted tag handler
|
||||||
func DeletedTagHandler(ctx *Context) error {
|
func DeletedTagHandler(ctx *Context) error {
|
||||||
ctx.col.IsDeleted = true
|
ctx.col.IsDeleted = true
|
||||||
|
ctx.col.Nullable = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,26 +272,30 @@ func CommentTagHandler(ctx *Context) error {
|
||||||
|
|
||||||
// SQLTypeTagHandler describes SQL Type tag handler
|
// SQLTypeTagHandler describes SQL Type tag handler
|
||||||
func SQLTypeTagHandler(ctx *Context) error {
|
func SQLTypeTagHandler(ctx *Context) error {
|
||||||
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName}
|
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname}
|
||||||
if strings.EqualFold(ctx.tagName, "JSON") {
|
if ctx.tagUname == "JSON" {
|
||||||
ctx.col.IsJSON = true
|
ctx.col.IsJSON = true
|
||||||
}
|
}
|
||||||
if len(ctx.params) > 0 {
|
if len(ctx.params) == 0 {
|
||||||
if ctx.tagName == schemas.Enum {
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ctx.tagUname {
|
||||||
|
case schemas.Enum:
|
||||||
ctx.col.EnumOptions = make(map[string]int)
|
ctx.col.EnumOptions = make(map[string]int)
|
||||||
for k, v := range ctx.params {
|
for k, v := range ctx.params {
|
||||||
v = strings.TrimSpace(v)
|
v = strings.TrimSpace(v)
|
||||||
v = strings.Trim(v, "'")
|
v = strings.Trim(v, "'")
|
||||||
ctx.col.EnumOptions[v] = k
|
ctx.col.EnumOptions[v] = k
|
||||||
}
|
}
|
||||||
} else if ctx.tagName == schemas.Set {
|
case schemas.Set:
|
||||||
ctx.col.SetOptions = make(map[string]int)
|
ctx.col.SetOptions = make(map[string]int)
|
||||||
for k, v := range ctx.params {
|
for k, v := range ctx.params {
|
||||||
v = strings.TrimSpace(v)
|
v = strings.TrimSpace(v)
|
||||||
v = strings.Trim(v, "'")
|
v = strings.Trim(v, "'")
|
||||||
ctx.col.SetOptions[v] = k
|
ctx.col.SetOptions[v] = k
|
||||||
}
|
}
|
||||||
} else {
|
default:
|
||||||
var err error
|
var err error
|
||||||
if len(ctx.params) == 2 {
|
if len(ctx.params) == 2 {
|
||||||
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
|
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
|
||||||
|
@ -262,7 +313,6 @@ func SQLTypeTagHandler(ctx *Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -289,11 +339,12 @@ func ExtendsTagHandler(ctx *Context) error {
|
||||||
}
|
}
|
||||||
for _, col := range parentTable.Columns() {
|
for _, col := range parentTable.Columns() {
|
||||||
col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName)
|
col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName)
|
||||||
|
col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...)
|
||||||
|
|
||||||
var tagPrefix = ctx.col.FieldName
|
var tagPrefix = ctx.col.FieldName
|
||||||
if len(ctx.params) > 0 {
|
if len(ctx.params) > 0 {
|
||||||
col.Nullable = isPtr
|
col.Nullable = isPtr
|
||||||
tagPrefix = ctx.params[0]
|
tagPrefix = strings.Trim(ctx.params[0], "'")
|
||||||
if col.IsPrimaryKey {
|
if col.IsPrimaryKey {
|
||||||
col.Name = ctx.col.FieldName
|
col.Name = ctx.col.FieldName
|
||||||
col.IsPrimaryKey = false
|
col.IsPrimaryKey = false
|
||||||
|
@ -315,7 +366,7 @@ func ExtendsTagHandler(ctx *Context) error {
|
||||||
default:
|
default:
|
||||||
//TODO: warning
|
//TODO: warning
|
||||||
}
|
}
|
||||||
return nil
|
return ErrIgnoreField
|
||||||
}
|
}
|
||||||
|
|
||||||
// CacheTagHandler describes cache tag handler
|
// CacheTagHandler describes cache tag handler
|
||||||
|
|
|
@ -7,24 +7,83 @@ package tags
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"xorm.io/xorm/internal/utils"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSplitTag(t *testing.T) {
|
func TestSplitTag(t *testing.T) {
|
||||||
var cases = []struct {
|
var cases = []struct {
|
||||||
tag string
|
tag string
|
||||||
tags []string
|
tags []tag
|
||||||
}{
|
}{
|
||||||
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}},
|
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{
|
||||||
{"TEXT", []string{"TEXT"}},
|
{
|
||||||
{"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}},
|
name: "not",
|
||||||
{"json binary", []string{"json", "binary"}},
|
},
|
||||||
|
{
|
||||||
|
name: "null",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "'2000-01-01 00:00:00'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TIMESTAMP",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"TEXT", []tag{
|
||||||
|
{
|
||||||
|
name: "TEXT",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"default('2000-01-01 00:00:00')", []tag{
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
params: []string{
|
||||||
|
"'2000-01-01 00:00:00'",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"json binary", []tag{
|
||||||
|
{
|
||||||
|
name: "json",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "binary",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"numeric(10, 2)", []tag{
|
||||||
|
{
|
||||||
|
name: "numeric",
|
||||||
|
params: []string{"10", "2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"numeric(10, 2) notnull", []tag{
|
||||||
|
{
|
||||||
|
name: "numeric",
|
||||||
|
params: []string{"10", "2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "notnull",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, kase := range cases {
|
for _, kase := range cases {
|
||||||
tags := splitTag(kase.tag)
|
t.Run(kase.tag, func(t *testing.T) {
|
||||||
if !utils.SliceEq(tags, kase.tags) {
|
tags, err := splitTag(kase.tag)
|
||||||
t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags)
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, len(tags), len(kase.tags))
|
||||||
|
for i := 0; i < len(tags); i++ {
|
||||||
|
assert.Equal(t, tags[i], kase.tags[i])
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue