Merge branch 'master' into lunny/add_alias_table

This commit is contained in:
Lunny Xiao 2021-07-06 17:10:18 +08:00
commit 3e0887c5c2
44 changed files with 2435 additions and 834 deletions

View File

@ -249,11 +249,11 @@ volumes:
services: services:
- name: mssql - name: mssql
pull: always pull: always
image: microsoft/mssql-server-linux:latest image: mcr.microsoft.com/mssql/server:latest
environment: environment:
ACCEPT_EULA: Y ACCEPT_EULA: Y
SA_PASSWORD: yourStrong(!)Password SA_PASSWORD: yourStrong(!)Password
MSSQL_PID: Developer MSSQL_PID: Standard
--- ---
kind: pipeline kind: pipeline
@ -347,3 +347,19 @@ steps:
image: golang:1.15 image: golang:1.15
commands: commands:
- make coverage - make coverage
---
kind: pipeline
name: release-tag
trigger:
event:
- tag
steps:
- name: release-tag-gitea
pull: always
image: plugins/gitea-release:latest
settings:
base_url: https://gitea.com
title: '${DRONE_TAG} is released'
api_key:
from_secret: gitea_token

1
.gitignore vendored
View File

@ -37,3 +37,4 @@ test.db.sql
test.db test.db
integrations/*.sql integrations/*.sql
integrations/test_sqlite* integrations/test_sqlite*
cover.out

View File

@ -8,20 +8,22 @@ warningCode = 1
[rule.context-as-argument] [rule.context-as-argument]
[rule.context-keys-type] [rule.context-keys-type]
[rule.dot-imports] [rule.dot-imports]
[rule.empty-lines]
[rule.errorf]
[rule.error-return] [rule.error-return]
[rule.error-strings] [rule.error-strings]
[rule.error-naming] [rule.error-naming]
[rule.exported] [rule.exported]
[rule.if-return] [rule.if-return]
[rule.increment-decrement] [rule.increment-decrement]
[rule.var-naming] [rule.indent-error-flow]
arguments = [["ID", "UID", "UUID", "URL", "JSON"], []]
[rule.var-declaration]
[rule.package-comments] [rule.package-comments]
[rule.range] [rule.range]
[rule.receiver-naming] [rule.receiver-naming]
[rule.struct-tag]
[rule.time-naming] [rule.time-naming]
[rule.unexported-return] [rule.unexported-return]
[rule.indent-error-flow] [rule.unnecessary-stmt]
[rule.errorf] [rule.var-declaration]
[rule.struct-tag] [rule.var-naming]
arguments = [["ID", "UID", "UUID", "URL", "JSON"], []]

View File

@ -3,6 +3,41 @@
This changelog goes through all the changes that have been made in each release This changelog goes through all the changes that have been made in each release
without substantial changes to our git log. without substantial changes to our git log.
## [1.1.2](https://gitea.com/xorm/xorm/releases/tag/1.1.2) - 2021-07-04
* BUILD
* Add release tag (#1966)
## [1.1.1](https://gitea.com/xorm/xorm/releases/tag/1.1.1) - 2021-07-03
* BUGFIXES
* Ignore comments when deciding when to replace question marks. #1954 (#1955)
* Fix bug didn't reset statement on update (#1939)
* Fix create table with struct missing columns (#1938)
* Fix #929 (#1936)
* Fix exist (#1921)
* ENHANCEMENTS
* Improve get field value of bean (#1961)
* refactor splitTag function (#1960)
* Fix #1663 (#1952)
* fix pg GetColumns missing comment (#1949)
* Support build flag jsoniter to replace default json (#1916)
* refactor exprParam (#1825)
* Add DBVersion (#1723)
* TESTING
* Add test to confirm #1247 resolved (#1951)
* Add test for dump table with default value (#1950)
* Test for #1486 (#1942)
* Add sync tests to confirm #539 is gone (#1937)
* test for unsigned int32 (#1923)
* Add tests for array store (#1922)
* BUILD
* Remove mymysql from ci (#1928)
* MISC
* fix lint (#1953)
* Compitable with cockroach (#1930)
* Replace goracle with godror (#1914)
## [1.1.0](https://gitea.com/xorm/xorm/releases/tag/1.1.0) - 2021-05-14 ## [1.1.0](https://gitea.com/xorm/xorm/releases/tag/1.1.0) - 2021-05-14
* FEATURES * FEATURES

View File

@ -5,12 +5,15 @@
package xorm package xorm
import ( import (
"database/sql"
"database/sql/driver" "database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"time" "time"
"xorm.io/xorm/convert"
) )
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
@ -37,6 +40,12 @@ func asString(src interface{}) string {
return v return v
case []byte: case []byte:
return string(v) return string(v)
case *sql.NullString:
return v.String
case *sql.NullInt32:
return fmt.Sprintf("%d", v.Int32)
case *sql.NullInt64:
return fmt.Sprintf("%d", v.Int64)
} }
rv := reflect.ValueOf(src) rv := reflect.ValueOf(src)
switch rv.Kind() { switch rv.Kind() {
@ -54,6 +63,156 @@ func asString(src interface{}) string {
return fmt.Sprintf("%v", src) return fmt.Sprintf("%v", src)
} }
func asInt64(src interface{}) (int64, error) {
switch v := src.(type) {
case int:
return int64(v), nil
case int16:
return int64(v), nil
case int32:
return int64(v), nil
case int8:
return int64(v), nil
case int64:
return v, nil
case uint:
return int64(v), nil
case uint8:
return int64(v), nil
case uint16:
return int64(v), nil
case uint32:
return int64(v), nil
case uint64:
return int64(v), nil
case []byte:
return strconv.ParseInt(string(v), 10, 64)
case string:
return strconv.ParseInt(v, 10, 64)
case *sql.NullString:
return strconv.ParseInt(v.String, 10, 64)
case *sql.NullInt32:
return int64(v.Int32), nil
case *sql.NullInt64:
return int64(v.Int64), nil
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return int64(rv.Uint()), nil
case reflect.Float64:
return int64(rv.Float()), nil
case reflect.Float32:
return int64(rv.Float()), nil
case reflect.String:
return strconv.ParseInt(rv.String(), 10, 64)
}
return 0, fmt.Errorf("unsupported value %T as int64", src)
}
func asUint64(src interface{}) (uint64, error) {
switch v := src.(type) {
case int:
return uint64(v), nil
case int16:
return uint64(v), nil
case int32:
return uint64(v), nil
case int8:
return uint64(v), nil
case int64:
return uint64(v), nil
case uint:
return uint64(v), nil
case uint8:
return uint64(v), nil
case uint16:
return uint64(v), nil
case uint32:
return uint64(v), nil
case uint64:
return v, nil
case []byte:
return strconv.ParseUint(string(v), 10, 64)
case string:
return strconv.ParseUint(v, 10, 64)
case *sql.NullString:
return strconv.ParseUint(v.String, 10, 64)
case *sql.NullInt32:
return uint64(v.Int32), nil
case *sql.NullInt64:
return uint64(v.Int64), nil
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return uint64(rv.Int()), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return uint64(rv.Uint()), nil
case reflect.Float64:
return uint64(rv.Float()), nil
case reflect.Float32:
return uint64(rv.Float()), nil
case reflect.String:
return strconv.ParseUint(rv.String(), 10, 64)
}
return 0, fmt.Errorf("unsupported value %T as uint64", src)
}
func asFloat64(src interface{}) (float64, error) {
switch v := src.(type) {
case int:
return float64(v), nil
case int16:
return float64(v), nil
case int32:
return float64(v), nil
case int8:
return float64(v), nil
case int64:
return float64(v), nil
case uint:
return float64(v), nil
case uint8:
return float64(v), nil
case uint16:
return float64(v), nil
case uint32:
return float64(v), nil
case uint64:
return float64(v), nil
case []byte:
return strconv.ParseFloat(string(v), 64)
case string:
return strconv.ParseFloat(v, 64)
case *sql.NullString:
return strconv.ParseFloat(v.String, 64)
case *sql.NullInt32:
return float64(v.Int32), nil
case *sql.NullInt64:
return float64(v.Int64), nil
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return float64(rv.Int()), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return float64(rv.Uint()), nil
case reflect.Float64:
return float64(rv.Float()), nil
case reflect.Float32:
return float64(rv.Float()), nil
case reflect.String:
return strconv.ParseFloat(rv.String(), 64)
}
return 0, fmt.Errorf("unsupported value %T as int64", src)
}
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
switch rv.Kind() { switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@ -76,7 +235,7 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
// convertAssign copies to dest the value in src, converting it if possible. // convertAssign copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information. // An error is returned if the copy would result in loss of information.
// dest should be a pointer type. // dest should be a pointer type.
func convertAssign(dest, src interface{}) error { func convertAssign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error {
// Common cases, without reflect. // Common cases, without reflect.
switch s := src.(type) { switch s := src.(type) {
case string: case string:
@ -143,6 +302,163 @@ func convertAssign(dest, src interface{}) error {
*d = nil *d = nil
return nil return nil
} }
case *sql.NullString:
switch d := dest.(type) {
case *int:
if s.Valid {
*d, _ = strconv.Atoi(s.String)
}
case *int64:
if s.Valid {
*d, _ = strconv.ParseInt(s.String, 10, 64)
}
case *string:
if s.Valid {
*d = s.String
}
return nil
case *time.Time:
if s.Valid {
var err error
dt, err := convert.String2Time(s.String, originalLocation, convertedLocation)
if err != nil {
return err
}
*d = *dt
}
return nil
case *sql.NullTime:
if s.Valid {
var err error
dt, err := convert.String2Time(s.String, originalLocation, convertedLocation)
if err != nil {
return err
}
d.Valid = true
d.Time = *dt
}
}
case *sql.NullInt32:
switch d := dest.(type) {
case *int:
if s.Valid {
*d = int(s.Int32)
}
return nil
case *int8:
if s.Valid {
*d = int8(s.Int32)
}
return nil
case *int16:
if s.Valid {
*d = int16(s.Int32)
}
return nil
case *int32:
if s.Valid {
*d = s.Int32
}
return nil
case *int64:
if s.Valid {
*d = int64(s.Int32)
}
return nil
}
case *sql.NullInt64:
switch d := dest.(type) {
case *int:
if s.Valid {
*d = int(s.Int64)
}
return nil
case *int8:
if s.Valid {
*d = int8(s.Int64)
}
return nil
case *int16:
if s.Valid {
*d = int16(s.Int64)
}
return nil
case *int32:
if s.Valid {
*d = int32(s.Int64)
}
return nil
case *int64:
if s.Valid {
*d = s.Int64
}
return nil
}
case *sql.NullFloat64:
switch d := dest.(type) {
case *int:
if s.Valid {
*d = int(s.Float64)
}
return nil
case *float64:
if s.Valid {
*d = s.Float64
}
return nil
}
case *sql.NullBool:
switch d := dest.(type) {
case *bool:
if s.Valid {
*d = s.Bool
}
return nil
}
case *sql.NullTime:
switch d := dest.(type) {
case *time.Time:
if s.Valid {
*d = s.Time
}
return nil
case *string:
if s.Valid {
*d = s.Time.In(convertedLocation).Format("2006-01-02 15:04:05")
}
return nil
}
case *NullUint32:
switch d := dest.(type) {
case *uint8:
if s.Valid {
*d = uint8(s.Uint32)
}
return nil
case *uint16:
if s.Valid {
*d = uint16(s.Uint32)
}
return nil
case *uint:
if s.Valid {
*d = uint(s.Uint32)
}
return nil
}
case *NullUint64:
switch d := dest.(type) {
case *uint64:
if s.Valid {
*d = s.Uint64
}
return nil
}
case *sql.RawBytes:
switch d := dest.(type) {
case convert.Conversion:
return d.FromDB(*s)
}
} }
var sv reflect.Value var sv reflect.Value
@ -175,7 +491,10 @@ func convertAssign(dest, src interface{}) error {
return nil return nil
} }
dpv := reflect.ValueOf(dest) return convertAssignV(reflect.ValueOf(dest), src, originalLocation, convertedLocation)
}
func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, convertedLocation *time.Location) error {
if dpv.Kind() != reflect.Ptr { if dpv.Kind() != reflect.Ptr {
return errors.New("destination not a pointer") return errors.New("destination not a pointer")
} }
@ -183,9 +502,7 @@ func convertAssign(dest, src interface{}) error {
return errNilPtr return errNilPtr
} }
if !sv.IsValid() { var sv = reflect.ValueOf(src)
sv = reflect.ValueOf(src)
}
dv := reflect.Indirect(dpv) dv := reflect.Indirect(dpv)
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
@ -211,31 +528,28 @@ func convertAssign(dest, src interface{}) error {
} }
dv.Set(reflect.New(dv.Type().Elem())) dv.Set(reflect.New(dv.Type().Elem()))
return convertAssign(dv.Interface(), src) return convertAssign(dv.Interface(), src, originalLocation, convertedLocation)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s := asString(src) i64, err := asInt64(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
if err != nil { if err != nil {
err = strconvErr(err) err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err)
} }
dv.SetInt(i64) dv.SetInt(i64)
return nil return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
s := asString(src) u64, err := asUint64(src)
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
if err != nil { if err != nil {
err = strconvErr(err) err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err)
} }
dv.SetUint(u64) dv.SetUint(u64)
return nil return nil
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
s := asString(src) f64, err := asFloat64(src)
f64, err := strconv.ParseFloat(s, dv.Type().Bits())
if err != nil { if err != nil {
err = strconvErr(err) err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err)
} }
dv.SetFloat(f64) dv.SetFloat(f64)
return nil return nil
@ -244,7 +558,7 @@ func convertAssign(dest, src interface{}) error {
return nil return nil
} }
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dpv.Interface())
} }
func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
@ -376,47 +690,79 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) {
return v.Interface(), nil return v.Interface(), nil
} }
func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { var (
var v interface{} _ sql.Scanner = &NullUint64{}
kind := tp.Kind() )
if kind == reflect.Ptr { // NullUint64 represents an uint64 that may be null.
kind = tp.Elem().Kind() // NullUint64 implements the Scanner interface so
} // it can be used as a scan destination, similar to NullString.
type NullUint64 struct {
switch kind { Uint64 uint64
case reflect.Int16: Valid bool
temp := int16(id)
v = &temp
case reflect.Int32:
temp := int32(id)
v = &temp
case reflect.Int:
temp := int(id)
v = &temp
case reflect.Int64:
temp := id
v = &temp
case reflect.Uint16:
temp := uint16(id)
v = &temp
case reflect.Uint32:
temp := uint32(id)
v = &temp
case reflect.Uint64:
temp := uint64(id)
v = &temp
case reflect.Uint:
temp := uint(id)
v = &temp
}
if tp.Kind() == reflect.Ptr {
return reflect.ValueOf(v).Convert(tp)
}
return reflect.ValueOf(v).Elem().Convert(tp)
} }
func int64ToInt(id int64, tp reflect.Type) interface{} { // Scan implements the Scanner interface.
return int64ToIntValue(id, tp).Interface() func (n *NullUint64) Scan(value interface{}) error {
if value == nil {
n.Uint64, n.Valid = 0, false
return nil
}
n.Valid = true
var err error
n.Uint64, err = asUint64(value)
return err
}
// Value implements the driver Valuer interface.
func (n NullUint64) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Uint64, nil
}
var (
_ sql.Scanner = &NullUint32{}
)
// NullUint32 represents an uint32 that may be null.
// NullUint32 implements the Scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullUint32 struct {
Uint32 uint32
Valid bool // Valid is true if Uint32 is not NULL
}
// Scan implements the Scanner interface.
func (n *NullUint32) Scan(value interface{}) error {
if value == nil {
n.Uint32, n.Valid = 0, false
return nil
}
n.Valid = true
i64, err := asUint64(value)
if err != nil {
return err
}
n.Uint32 = uint32(i64)
return nil
}
// Value implements the driver Valuer interface.
func (n NullUint32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return int64(n.Uint32), nil
}
var (
_ sql.Scanner = &EmptyScanner{}
)
type EmptyScanner struct{}
func (EmptyScanner) Scan(value interface{}) error {
return nil
} }

48
convert/interface.go Normal file
View File

@ -0,0 +1,48 @@
// Copyright 2021 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package convert
import (
"database/sql"
"fmt"
"time"
)
func Interface2Interface(userLocation *time.Location, v interface{}) (interface{}, error) {
if v == nil {
return nil, nil
}
switch vv := v.(type) {
case *int64:
return *vv, nil
case *int8:
return *vv, nil
case *sql.NullString:
return vv.String, nil
case *sql.RawBytes:
if len([]byte(*vv)) > 0 {
return []byte(*vv), nil
}
return nil, nil
case *sql.NullInt32:
return vv.Int32, nil
case *sql.NullInt64:
return vv.Int64, nil
case *sql.NullFloat64:
return vv.Float64, nil
case *sql.NullBool:
if vv.Valid {
return vv.Bool, nil
}
return nil, nil
case *sql.NullTime:
if vv.Valid {
return vv.Time.In(userLocation).Format("2006-01-02 15:04:05"), nil
}
return "", nil
default:
return "", fmt.Errorf("convert assign string unsupported type: %#v", vv)
}
}

30
convert/time.go Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2021 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package convert
import (
"fmt"
"time"
)
// String2Time converts a string to time with original location
func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) {
if len(s) == 19 {
dt, err := time.ParseInLocation("2006-01-02 15:04:05", s, originalLocation)
if err != nil {
return nil, err
}
dt = dt.In(convertedLocation)
return &dt, nil
} else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' {
dt, err := time.ParseInLocation("2006-01-02T15:04:05Z", s, originalLocation)
if err != nil {
return nil, err
}
dt = dt.In(convertedLocation)
return &dt, nil
}
return nil, fmt.Errorf("unsupported convertion from %s to time", s)
}

View File

@ -5,12 +5,29 @@
package dialects package dialects
import ( import (
"database/sql"
"fmt" "fmt"
"time"
"xorm.io/xorm/core"
) )
// ScanContext represents a context when Scan
type ScanContext struct {
DBLocation *time.Location
UserLocation *time.Location
}
type DriverFeatures struct {
SupportNullable bool
}
// Driver represents a database driver // Driver represents a database driver
type Driver interface { type Driver interface {
Parse(string, string) (*URI, error) Parse(string, string) (*URI, error)
Features() DriverFeatures
GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface
Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error
} }
var ( var (
@ -59,3 +76,15 @@ func OpenDialect(driverName, connstr string) (Dialect, error) {
return dialect, nil return dialect, nil
} }
type baseDriver struct{}
func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error {
return rows.Scan(v...)
}
func (b *baseDriver) Features() DriverFeatures {
return DriverFeatures{
SupportNullable: true,
}
}

View File

@ -23,13 +23,45 @@ type SeqFilter struct {
func convertQuestionMark(sql, prefix string, start int) string { func convertQuestionMark(sql, prefix string, start int) string {
var buf strings.Builder var buf strings.Builder
var beginSingleQuote bool var beginSingleQuote bool
var isLineComment bool
var isComment bool
var isMaybeLineComment bool
var isMaybeComment bool
var isMaybeCommentEnd bool
var index = start var index = start
for _, c := range sql { for _, c := range sql {
if !beginSingleQuote && c == '?' { if !beginSingleQuote && !isLineComment && !isComment && c == '?' {
buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
index++ index++
} else { } else {
if c == '\'' { if isMaybeLineComment {
if c == '-' {
isLineComment = true
}
isMaybeLineComment = false
} else if isMaybeComment {
if c == '*' {
isComment = true
}
isMaybeComment = false
} else if isMaybeCommentEnd {
if c == '/' {
isComment = false
}
isMaybeCommentEnd = false
} else if isLineComment {
if c == '\n' {
isLineComment = false
}
} else if isComment {
if c == '*' {
isMaybeCommentEnd = true
}
} else if !beginSingleQuote && c == '-' {
isMaybeLineComment = true
} else if !beginSingleQuote && c == '/' {
isMaybeComment = true
} else if c == '\'' {
beginSingleQuote = !beginSingleQuote beginSingleQuote = !beginSingleQuote
} }
buf.WriteRune(c) buf.WriteRune(c)

View File

@ -19,3 +19,60 @@ func TestSeqFilter(t *testing.T) {
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
} }
} }
func TestSeqFilterLineComment(t *testing.T) {
var kases = map[string]string{
`SELECT *
FROM TABLE1
WHERE foo='bar'
AND a=? -- it's a comment
AND b=?`: `SELECT *
FROM TABLE1
WHERE foo='bar'
AND a=$1 -- it's a comment
AND b=$2`,
`SELECT *
FROM TABLE1
WHERE foo='bar'
AND a=? -- it's a comment?
AND b=?`: `SELECT *
FROM TABLE1
WHERE foo='bar'
AND a=$1 -- it's a comment?
AND b=$2`,
`SELECT *
FROM TABLE1
WHERE a=? -- it's a comment? and that's okay?
AND b=?`: `SELECT *
FROM TABLE1
WHERE a=$1 -- it's a comment? and that's okay?
AND b=$2`,
}
for sql, result := range kases {
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
}
}
func TestSeqFilterComment(t *testing.T) {
var kases = map[string]string{
`SELECT *
FROM TABLE1
WHERE a=? /* it's a comment */
AND b=?`: `SELECT *
FROM TABLE1
WHERE a=$1 /* it's a comment */
AND b=$2`,
`SELECT /* it's a comment * ?
More comment on the next line! */ *
FROM TABLE1
WHERE a=? /**/
AND b=?`: `SELECT /* it's a comment * ?
More comment on the next line! */ *
FROM TABLE1
WHERE a=$1 /**/
AND b=$2`,
}
for sql, result := range kases {
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
}
}

View File

@ -6,6 +6,7 @@ package dialects
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
@ -624,6 +625,7 @@ func (db *mssql) Filters() []Filter {
} }
type odbcDriver struct { type odbcDriver struct {
baseDriver
} }
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
@ -652,3 +654,26 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
} }
return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil
} }
func (p *odbcDriver) GenScanResult(colType string) (interface{}, error) {
switch colType {
case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT":
fallthrough
case "DATE", "DATETIME", "DATETIME2", "TIME":
var s sql.NullString
return &s, nil
case "FLOAT", "REAL":
var s sql.NullFloat64
return &s, nil
case "BIGINT", "DATETIMEOFFSET":
var s sql.NullInt64
return &s, nil
case "TINYINT", "SMALLINT", "INT":
var s sql.NullInt32
return &s, nil
default:
var r sql.RawBytes
return &r, nil
}
}

View File

@ -7,6 +7,7 @@ package dialects
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"database/sql"
"errors" "errors"
"fmt" "fmt"
"regexp" "regexp"
@ -14,6 +15,7 @@ import (
"strings" "strings"
"time" "time"
"xorm.io/xorm/convert"
"xorm.io/xorm/core" "xorm.io/xorm/core"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -645,7 +647,125 @@ func (db *mysql) Filters() []Filter {
return []Filter{} return []Filter{}
} }
type mysqlDriver struct {
baseDriver
}
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dataSourceName)
// tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
uri := &URI{DBType: schemas.MYSQL}
for i, match := range matches {
switch names[i] {
case "dbname":
uri.DBName = match
case "params":
if len(match) > 0 {
kvs := strings.Split(match, "&")
for _, kv := range kvs {
splits := strings.Split(kv, "=")
if len(splits) == 2 {
switch splits[0] {
case "charset":
uri.Charset = splits[1]
}
}
}
}
}
}
return uri, nil
}
func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) {
switch colType {
case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET":
var s sql.NullString
return &s, nil
case "BIGINT":
var s sql.NullInt64
return &s, nil
case "TINYINT", "SMALLINT", "MEDIUMINT", "INT":
var s sql.NullInt32
return &s, nil
case "FLOAT", "REAL", "DOUBLE PRECISION":
var s sql.NullFloat64
return &s, nil
case "DECIMAL", "NUMERIC":
var s sql.NullString
return &s, nil
case "DATETIME":
var s sql.NullTime
return &s, nil
case "BIT":
var s sql.RawBytes
return &s, nil
case "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB":
var r sql.RawBytes
return &r, nil
default:
var r sql.RawBytes
return &r, nil
}
}
func (p *mysqlDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, scanResults ...interface{}) error {
var v2 = make([]interface{}, 0, len(scanResults))
var turnBackIdxes = make([]int, 0, 5)
for i, vv := range scanResults {
switch vv.(type) {
case *time.Time:
v2 = append(v2, &sql.NullString{})
turnBackIdxes = append(turnBackIdxes, i)
case *sql.NullTime:
v2 = append(v2, &sql.NullString{})
turnBackIdxes = append(turnBackIdxes, i)
default:
v2 = append(v2, scanResults[i])
}
}
if err := rows.Scan(v2...); err != nil {
return err
}
for _, i := range turnBackIdxes {
switch t := scanResults[i].(type) {
case *time.Time:
var s = *(v2[i].(*sql.NullString))
if !s.Valid {
break
}
dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation)
if err != nil {
return err
}
*t = *dt
case *sql.NullTime:
var s = *(v2[i].(*sql.NullString))
if !s.Valid {
break
}
dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation)
if err != nil {
return err
}
t.Time = *dt
t.Valid = true
}
}
return nil
}
type mymysqlDriver struct { type mymysqlDriver struct {
mysqlDriver
} }
func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
@ -696,41 +816,3 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
return uri, nil return uri, nil
} }
type mysqlDriver struct {
}
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dataSourceName)
// tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
uri := &URI{DBType: schemas.MYSQL}
for i, match := range matches {
switch names[i] {
case "dbname":
uri.DBName = match
case "params":
if len(match) > 0 {
kvs := strings.Split(match, "&")
for _, kv := range kvs {
splits := strings.Split(kv, "=")
if len(splits) == 2 {
switch splits[0] {
case "charset":
uri.Charset = splits[1]
}
}
}
}
}
}
return uri, nil
}

View File

@ -6,6 +6,7 @@ package dialects
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"fmt" "fmt"
"regexp" "regexp"
@ -823,6 +824,7 @@ func (db *oracle) Filters() []Filter {
} }
type godrorDriver struct { type godrorDriver struct {
baseDriver
} }
func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) { func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) {
@ -848,7 +850,28 @@ func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error)
return db, nil return db, nil
} }
func (p *godrorDriver) GenScanResult(colType string) (interface{}, error) {
switch colType {
case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB":
var s sql.NullString
return &s, nil
case "NUMBER":
var s sql.NullString
return &s, nil
case "DATE":
var s sql.NullTime
return &s, nil
case "BLOB":
var r sql.RawBytes
return &r, nil
default:
var r sql.RawBytes
return &r, nil
}
}
type oci8Driver struct { type oci8Driver struct {
godrorDriver
} }
// dataSourceName=user/password@ipv4:port/dbname // dataSourceName=user/password@ipv4:port/dbname

View File

@ -6,6 +6,7 @@ package dialects
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
@ -1056,12 +1057,13 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, description,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
FROM pg_attribute f FROM pg_attribute f
JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid
LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum
LEFT JOIN pg_description de ON f.attrelid=de.objoid AND f.attnum=de.objsubid
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN pg_class AS g ON p.confrelid = g.oid
@ -1090,9 +1092,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
var colName, isNullable, dataType string var colName, isNullable, dataType string
var maxLenStr, colDefault *string var maxLenStr, colDefault, description *string
var isPK, isUnique bool var isPK, isUnique bool
err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &isPK, &isUnique) err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &description, &isPK, &isUnique)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -1138,6 +1140,10 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
col.DefaultIsEmpty = true col.DefaultIsEmpty = true
} }
if description != nil {
col.Comment = *description
}
if isPK { if isPK {
col.IsPrimaryKey = true col.IsPrimaryKey = true
} }
@ -1305,6 +1311,13 @@ func (db *postgres) Filters() []Filter {
} }
type pqDriver struct { type pqDriver struct {
baseDriver
}
func (b *pqDriver) Features() DriverFeatures {
return DriverFeatures{
SupportNullable: false,
}
} }
type values map[string]string type values map[string]string
@ -1381,6 +1394,36 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
return db, nil return db, nil
} }
func (p *pqDriver) GenScanResult(colType string) (interface{}, error) {
switch colType {
case "VARCHAR", "TEXT":
var s sql.NullString
return &s, nil
case "BIGINT":
var s sql.NullInt64
return &s, nil
case "TINYINT", "INT", "INT8", "INT4":
var s sql.NullInt32
return &s, nil
case "FLOAT", "FLOAT4":
var s sql.NullFloat64
return &s, nil
case "DATETIME", "TIMESTAMP":
var s sql.NullTime
return &s, nil
case "BIT":
var s sql.RawBytes
return &s, nil
case "BOOL":
var s sql.NullBool
return &s, nil
default:
fmt.Printf("unknow postgres database type: %v\n", colType)
var r sql.RawBytes
return &r, nil
}
}
type pqDriverPgx struct { type pqDriverPgx struct {
pqDriver pqDriver
} }

View File

@ -540,6 +540,7 @@ func (db *sqlite3) Filters() []Filter {
} }
type sqlite3Driver struct { type sqlite3Driver struct {
baseDriver
} }
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
@ -549,3 +550,35 @@ func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil
} }
func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) {
switch colType {
case "TEXT":
var s sql.NullString
return &s, nil
case "INTEGER":
var s sql.NullInt64
return &s, nil
case "DATETIME":
var s sql.NullTime
return &s, nil
case "REAL":
var s sql.NullFloat64
return &s, nil
case "NUMERIC":
var s sql.NullString
return &s, nil
case "BLOB":
var s sql.RawBytes
return &s, nil
default:
var r sql.NullString
return &r, nil
}
}
func (b *sqlite3Driver) Features() DriverFeatures {
return DriverFeatures{
SupportNullable: false,
}
}

View File

@ -35,6 +35,7 @@ type Engine struct {
cacherMgr *caches.Manager cacherMgr *caches.Manager
defaultContext context.Context defaultContext context.Context
dialect dialects.Dialect dialect dialects.Dialect
driver dialects.Driver
engineGroup *EngineGroup engineGroup *EngineGroup
logger log.ContextLogger logger log.ContextLogger
tagParser *tags.Parser tagParser *tags.Parser
@ -72,6 +73,7 @@ func newEngine(driverName, dataSourceName string, dialect dialects.Dialect, db *
engine := &Engine{ engine := &Engine{
dialect: dialect, dialect: dialect,
driver: dialects.QueryDriver(driverName),
TZLocation: time.Local, TZLocation: time.Local,
defaultContext: context.Background(), defaultContext: context.Background(),
cacherMgr: cacherMgr, cacherMgr: cacherMgr,
@ -444,7 +446,7 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return engine.dumpTables(tables, w, tp...) return engine.dumpTables(tables, w, tp...)
} }
func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string { func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string {
if d == nil { if d == nil {
return "NULL" return "NULL"
} }
@ -473,10 +475,8 @@ func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas.
return "'" + strings.Replace(v, "'", "''", -1) + "'" return "'" + strings.Replace(v, "'", "''", -1) + "'"
} else if col.SQLType.IsTime() { } else if col.SQLType.IsTime() {
if dstDialect.URI().DBType == schemas.MSSQL && col.SQLType.Name == schemas.DateTime { if t, ok := d.(time.Time); ok {
if t, ok := d.(time.Time); ok { return "'" + t.In(dbLocation).Format("2006-01-02 15:04:05") + "'"
return "'" + t.UTC().Format("2006-01-02 15:04:05") + "'"
}
} }
var v = fmt.Sprintf("%s", d) var v = fmt.Sprintf("%s", d)
if strings.HasSuffix(v, " +0000 UTC") { if strings.HasSuffix(v, " +0000 UTC") {
@ -652,12 +652,8 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return errors.New("unknown column error") return errors.New("unknown column error")
} }
fields := strings.Split(col.FieldName, ".") field := dataStruct.FieldByIndex(col.FieldIndex)
field := dataStruct temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, field.Interface(), col)
for _, fieldName := range fields {
field = field.FieldByName(fieldName)
}
temp += "," + formatColumnValue(dstDialect, field.Interface(), col)
} }
_, err = io.WriteString(w, temp[1:]+");\n") _, err = io.WriteString(w, temp[1:]+");\n")
if err != nil { if err != nil {
@ -684,7 +680,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return errors.New("unknow column error") return errors.New("unknow column error")
} }
temp += "," + formatColumnValue(dstDialect, d, col) temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col)
} }
_, err = io.WriteString(w, temp[1:]+");\n") _, err = io.WriteString(w, temp[1:]+");\n")
if err != nil { if err != nil {

View File

@ -176,6 +176,23 @@ func TestDumpTables(t *testing.T) {
} }
} }
func TestDumpTables2(t *testing.T) {
assert.NoError(t, PrepareEngine())
type TestDumpTableStruct2 struct {
Id int64
Created time.Time `xorm:"Default CURRENT_TIMESTAMP"`
}
assertSync(t, new(TestDumpTableStruct2))
fp := fmt.Sprintf("./dump2-%v-table.sql", testEngine.Dialect().URI().DBType)
os.Remove(fp)
tb, err := testEngine.TableInfo(new(TestDumpTableStruct2))
assert.NoError(t, err)
assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, fp))
}
func TestSetSchema(t *testing.T) { func TestSetSchema(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
@ -209,3 +226,39 @@ func TestDBVersion(t *testing.T) {
fmt.Println(testEngine.Dialect().URI().DBType, "version is", version) fmt.Println(testEngine.Dialect().URI().DBType, "version is", version)
} }
func TestGetColumns(t *testing.T) {
if testEngine.Dialect().URI().DBType != schemas.POSTGRES {
t.Skip()
return
}
type TestCommentStruct struct {
HasComment int
NoComment int
}
assertSync(t, new(TestCommentStruct))
comment := "this is a comment"
sql := fmt.Sprintf("comment on column %s.%s is '%s'", testEngine.TableName(new(TestCommentStruct), true), "has_comment", comment)
_, err := testEngine.Exec(sql)
assert.NoError(t, err)
tables, err := testEngine.DBMetas()
assert.NoError(t, err)
tableName := testEngine.GetColumnMapper().Obj2Table("TestCommentStruct")
var hasComment, noComment string
for _, table := range tables {
if table.Name == tableName {
col := table.GetColumn("has_comment")
assert.NotNil(t, col)
hasComment = col.Comment
col2 := table.GetColumn("no_comment")
assert.NotNil(t, col2)
noComment = col2.Comment
break
}
}
assert.Equal(t, comment, hasComment)
assert.Zero(t, noComment)
}

View File

@ -406,16 +406,16 @@ func TestFindMapPtrString(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestFindBit(t *testing.T) { func TestFindBool(t *testing.T) {
type FindBitStruct struct { type FindBoolStruct struct {
Id int64 Id int64
Msg bool `xorm:"bit"` Msg bool
} }
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
assertSync(t, new(FindBitStruct)) assertSync(t, new(FindBoolStruct))
cnt, err := testEngine.Insert([]FindBitStruct{ cnt, err := testEngine.Insert([]FindBoolStruct{
{ {
Msg: false, Msg: false,
}, },
@ -426,14 +426,13 @@ func TestFindBit(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, cnt) assert.EqualValues(t, 2, cnt)
var results = make([]FindBitStruct, 0, 2) var results = make([]FindBoolStruct, 0, 2)
err = testEngine.Find(&results) err = testEngine.Find(&results)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, len(results)) assert.EqualValues(t, 2, len(results))
} }
func TestFindMark(t *testing.T) { func TestFindMark(t *testing.T) {
type Mark struct { type Mark struct {
Mark1 string `xorm:"VARCHAR(1)"` Mark1 string `xorm:"VARCHAR(1)"`
Mark2 string `xorm:"VARCHAR(1)"` Mark2 string `xorm:"VARCHAR(1)"`
@ -468,7 +467,7 @@ func TestFindAndCountOneFunc(t *testing.T) {
type FindAndCountStruct struct { type FindAndCountStruct struct {
Id int64 Id int64
Content string Content string
Msg bool `xorm:"bit"` Msg bool
} }
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())

View File

@ -32,7 +32,6 @@ func TestInsertOne(t *testing.T) {
} }
func TestInsertMulti(t *testing.T) { func TestInsertMulti(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
type TestMulti struct { type TestMulti struct {
Id int64 `xorm:"int(11) pk"` Id int64 `xorm:"int(11) pk"`
@ -78,7 +77,6 @@ func insertMultiDatas(step int, datas interface{}) (num int64, err error) {
} }
func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) error) (err error) { func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) error) (err error) {
sliceValue := reflect.Indirect(reflect.ValueOf(datas)) sliceValue := reflect.Indirect(reflect.ValueOf(datas))
if sliceValue.Kind() != reflect.Slice { if sliceValue.Kind() != reflect.Slice {
return fmt.Errorf("not slice") return fmt.Errorf("not slice")
@ -170,17 +168,17 @@ func TestInsertAutoIncr(t *testing.T) {
assert.Greater(t, user.Uid, int64(0)) assert.Greater(t, user.Uid, int64(0))
} }
type DefaultInsert struct {
Id int64
Status int `xorm:"default -1"`
Name string
Created time.Time `xorm:"created"`
Updated time.Time `xorm:"updated"`
}
func TestInsertDefault(t *testing.T) { func TestInsertDefault(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
type DefaultInsert struct {
Id int64
Status int `xorm:"default -1"`
Name string
Created time.Time `xorm:"created"`
Updated time.Time `xorm:"updated"`
}
di := new(DefaultInsert) di := new(DefaultInsert)
err := testEngine.Sync2(di) err := testEngine.Sync2(di)
assert.NoError(t, err) assert.NoError(t, err)
@ -197,16 +195,16 @@ func TestInsertDefault(t *testing.T) {
assert.EqualValues(t, di2.Created.Unix(), di.Created.Unix()) assert.EqualValues(t, di2.Created.Unix(), di.Created.Unix())
} }
type DefaultInsert2 struct {
Id int64
Name string
Url string `xorm:"text"`
CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00' TIMESTAMP"`
}
func TestInsertDefault2(t *testing.T) { func TestInsertDefault2(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
type DefaultInsert2 struct {
Id int64
Name string
Url string `xorm:"text"`
CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00' TIMESTAMP"`
}
di := new(DefaultInsert2) di := new(DefaultInsert2)
err := testEngine.Sync2(di) err := testEngine.Sync2(di)
assert.NoError(t, err) assert.NoError(t, err)
@ -1026,3 +1024,44 @@ func TestInsertIntSlice(t *testing.T) {
assert.True(t, has) assert.True(t, has)
assert.EqualValues(t, v3, v4) assert.EqualValues(t, v3, v4)
} }
func TestInsertDeleted(t *testing.T) {
assert.NoError(t, PrepareEngine())
type InsertDeletedStructNotRight struct {
ID uint64 `xorm:"'ID' pk autoincr"`
DeletedAt time.Time `xorm:"'DELETED_AT' deleted notnull"`
}
// notnull tag will be ignored
err := testEngine.Sync2(new(InsertDeletedStructNotRight))
assert.NoError(t, err)
type InsertDeletedStruct struct {
ID uint64 `xorm:"'ID' pk autoincr"`
DeletedAt time.Time `xorm:"'DELETED_AT' deleted"`
}
assert.NoError(t, testEngine.Sync2(new(InsertDeletedStruct)))
var v InsertDeletedStruct
_, err = testEngine.Insert(&v)
assert.NoError(t, err)
var v2 InsertDeletedStruct
has, err := testEngine.Get(&v2)
assert.NoError(t, err)
assert.True(t, has)
_, err = testEngine.ID(v.ID).Delete(new(InsertDeletedStruct))
assert.NoError(t, err)
var v3 InsertDeletedStruct
has, err = testEngine.Get(&v3)
assert.NoError(t, err)
assert.False(t, has)
var v4 InsertDeletedStruct
has, err = testEngine.Unscoped().Get(&v4)
assert.NoError(t, err)
assert.True(t, has)
}

View File

@ -52,7 +52,7 @@ func TestQueryString2(t *testing.T) {
type GetVar3 struct { type GetVar3 struct {
Id int64 `xorm:"autoincr pk"` Id int64 `xorm:"autoincr pk"`
Msg bool `xorm:"bit"` Msg bool
} }
assert.NoError(t, testEngine.Sync2(new(GetVar3))) assert.NoError(t, testEngine.Sync2(new(GetVar3)))
@ -107,6 +107,16 @@ func toFloat64(i interface{}) float64 {
return 0 return 0
} }
func toBool(i interface{}) bool {
switch t := i.(type) {
case int32:
return t > 0
case bool:
return t
}
return false
}
func TestQueryInterface(t *testing.T) { func TestQueryInterface(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
@ -132,10 +142,10 @@ func TestQueryInterface(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(records)) assert.Equal(t, 1, len(records))
assert.Equal(t, 5, len(records[0])) assert.Equal(t, 5, len(records[0]))
assert.EqualValues(t, 1, toInt64(records[0]["id"])) assert.EqualValues(t, int64(1), records[0]["id"])
assert.Equal(t, "hi", toString(records[0]["msg"])) assert.Equal(t, "hi", records[0]["msg"])
assert.EqualValues(t, 28, toInt64(records[0]["age"])) assert.EqualValues(t, 28, records[0]["age"])
assert.EqualValues(t, 1.5, toFloat64(records[0]["money"])) assert.EqualValues(t, 1.5, records[0]["money"])
} }
func TestQueryNoParams(t *testing.T) { func TestQueryNoParams(t *testing.T) {
@ -192,7 +202,7 @@ func TestQueryStringNoParam(t *testing.T) {
type GetVar4 struct { type GetVar4 struct {
Id int64 `xorm:"autoincr pk"` Id int64 `xorm:"autoincr pk"`
Msg bool `xorm:"bit"` Msg bool
} }
assert.NoError(t, testEngine.Sync2(new(GetVar4))) assert.NoError(t, testEngine.Sync2(new(GetVar4)))
@ -229,7 +239,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
type GetVar6 struct { type GetVar6 struct {
Id int64 `xorm:"autoincr pk"` Id int64 `xorm:"autoincr pk"`
Msg bool `xorm:"bit"` Msg bool
} }
assert.NoError(t, testEngine.Sync2(new(GetVar6))) assert.NoError(t, testEngine.Sync2(new(GetVar6)))
@ -266,7 +276,7 @@ func TestQueryInterfaceNoParam(t *testing.T) {
type GetVar5 struct { type GetVar5 struct {
Id int64 `xorm:"autoincr pk"` Id int64 `xorm:"autoincr pk"`
Msg bool `xorm:"bit"` Msg bool
} }
assert.NoError(t, testEngine.Sync2(new(GetVar5))) assert.NoError(t, testEngine.Sync2(new(GetVar5)))
@ -280,14 +290,14 @@ func TestQueryInterfaceNoParam(t *testing.T) {
records, err := testEngine.Table("get_var5").Limit(1).QueryInterface() records, err := testEngine.Table("get_var5").Limit(1).QueryInterface()
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, 1, toInt64(records[0]["id"])) assert.EqualValues(t, 1, records[0]["id"])
assert.EqualValues(t, 0, toInt64(records[0]["msg"])) assert.False(t, toBool(records[0]["msg"]))
records, err = testEngine.Table("get_var5").Where(builder.Eq{"id": 1}).QueryInterface() records, err = testEngine.Table("get_var5").Where(builder.Eq{"id": 1}).QueryInterface()
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, 1, toInt64(records[0]["id"])) assert.EqualValues(t, 1, records[0]["id"])
assert.EqualValues(t, 0, toInt64(records[0]["msg"])) assert.False(t, toBool(records[0]["msg"]))
} }
func TestQueryWithBuilder(t *testing.T) { func TestQueryWithBuilder(t *testing.T) {

View File

@ -472,6 +472,11 @@ func TestUpdateIncrDecr(t *testing.T) {
cnt, err = testEngine.ID(col1.Id).Cols(colName).Incr(colName).Update(col1) cnt, err = testEngine.ID(col1.Id).Cols(colName).Incr(colName).Update(col1)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
testEngine.SetColumnMapper(testEngine.GetColumnMapper())
cnt, err = testEngine.Cols(colName).Decr(colName, 2).ID(col1.Id).Update(new(UpdateIncr))
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
} }
type UpdatedUpdate struct { type UpdatedUpdate struct {

View File

@ -27,6 +27,7 @@ type Expr struct {
Arg interface{} Arg interface{}
} }
// WriteArgs writes args to the writer
func (expr *Expr) WriteArgs(w *builder.BytesWriter) error { func (expr *Expr) WriteArgs(w *builder.BytesWriter) error {
switch arg := expr.Arg.(type) { switch arg := expr.Arg.(type) {
case *builder.Builder: case *builder.Builder:

View File

@ -17,7 +17,7 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem
if _, err := buf.WriteString(" OUTPUT Inserted."); err != nil { if _, err := buf.WriteString(" OUTPUT Inserted."); err != nil {
return err return err
} }
if _, err := buf.WriteString(table.AutoIncrement); err != nil { if err := statement.dialect.Quoter().QuoteTo(buf, table.AutoIncrement); err != nil {
return err return err
} }
} }

View File

@ -343,7 +343,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
var args []interface{} var args []interface{}
var joinStr string var joinStr string
var err error var err error
var b interface{} = nil var b interface{}
if len(bean) > 0 { if len(bean) > 0 {
b = bean[0] b = bean[0]
beanValue := reflect.ValueOf(bean[0]) beanValue := reflect.ValueOf(bean[0])

View File

@ -208,20 +208,18 @@ func (statement *Statement) quote(s string) string {
// And add Where & and statement // And add Where & and statement
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
switch query.(type) { switch qr := query.(type) {
case string: case string:
cond := builder.Expr(query.(string), args...) cond := builder.Expr(qr, args...)
statement.cond = statement.cond.And(cond) statement.cond = statement.cond.And(cond)
case map[string]interface{}: case map[string]interface{}:
queryMap := query.(map[string]interface{}) cond := make(builder.Eq)
newMap := make(map[string]interface{}) for k, v := range qr {
for k, v := range queryMap { cond[statement.quote(k)] = v
newMap[statement.quote(k)] = v
} }
statement.cond = statement.cond.And(builder.Eq(newMap))
case builder.Cond:
cond := query.(builder.Cond)
statement.cond = statement.cond.And(cond) statement.cond = statement.cond.And(cond)
case builder.Cond:
statement.cond = statement.cond.And(qr)
for _, v := range args { for _, v := range args {
if vv, ok := v.(builder.Cond); ok { if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.And(vv) statement.cond = statement.cond.And(vv)
@ -236,23 +234,25 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
// Or add Where & Or statement // Or add Where & Or statement
func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
switch query.(type) { switch qr := query.(type) {
case string: case string:
cond := builder.Expr(query.(string), args...) cond := builder.Expr(qr, args...)
statement.cond = statement.cond.Or(cond) statement.cond = statement.cond.Or(cond)
case map[string]interface{}: case map[string]interface{}:
cond := builder.Eq(query.(map[string]interface{})) cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.Or(cond) statement.cond = statement.cond.Or(cond)
case builder.Cond: case builder.Cond:
cond := query.(builder.Cond) statement.cond = statement.cond.Or(qr)
statement.cond = statement.cond.Or(cond)
for _, v := range args { for _, v := range args {
if vv, ok := v.(builder.Cond); ok { if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.Or(vv) statement.cond = statement.cond.Or(vv)
} }
} }
default: default:
// TODO: not support condition type statement.LastError = ErrConditionType
} }
return statement return statement
} }
@ -734,6 +734,8 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
//engine.logger.Warn(err) //engine.logger.Warn(err)
} }
continue continue
} else if fieldValuePtr == nil {
continue
} }
if col.IsDeleted && !unscoped { // tag "deleted" is enabled if col.IsDeleted && !unscoped { // tag "deleted" is enabled
@ -976,7 +978,7 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName
// CondDeleted returns the conditions whether a record is soft deleted. // CondDeleted returns the conditions whether a record is soft deleted.
func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond {
var colName = col.Name var colName = statement.quote(col.Name)
if statement.JoinStr != "" { if statement.JoinStr != "" {
var prefix string var prefix string
if statement.TableAlias != "" { if statement.TableAlias != "" {

View File

@ -78,7 +78,6 @@ func TestColumnsStringGeneration(t *testing.T) {
} }
func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
b.StopTimer() b.StopTimer()
mapCols := make(map[string]bool) mapCols := make(map[string]bool)
@ -101,9 +100,7 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, col := range cols { for _, col := range cols {
if _, ok := getFlagForColumn(mapCols, col); !ok { if _, ok := getFlagForColumn(mapCols, col); !ok {
b.Fatal("Unexpected result") b.Fatal("Unexpected result")
} }
@ -112,7 +109,6 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
} }
func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) {
b.StopTimer() b.StopTimer()
mapCols := make(map[string]bool) mapCols := make(map[string]bool)
@ -131,9 +127,7 @@ func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) {
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, col := range cols { for _, col := range cols {
if _, ok := getFlagForColumn(mapCols, col); ok { if _, ok := getFlagForColumn(mapCols, col); ok {
b.Fatal("Unexpected result") b.Fatal("Unexpected result")
} }

View File

@ -88,6 +88,9 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if fieldValuePtr == nil {
continue
}
fieldValue := *fieldValuePtr fieldValue := *fieldValuePtr
fieldType := reflect.TypeOf(fieldValue.Interface()) fieldType := reflect.TypeOf(fieldValue.Interface())

View File

@ -132,7 +132,6 @@ func (s *SimpleLogger) Error(v ...interface{}) {
if s.level <= LOG_ERR { if s.level <= LOG_ERR {
s.ERR.Output(2, fmt.Sprintln(v...)) s.ERR.Output(2, fmt.Sprintln(v...))
} }
return
} }
// Errorf implement ILogger // Errorf implement ILogger
@ -140,7 +139,6 @@ func (s *SimpleLogger) Errorf(format string, v ...interface{}) {
if s.level <= LOG_ERR { if s.level <= LOG_ERR {
s.ERR.Output(2, fmt.Sprintf(format, v...)) s.ERR.Output(2, fmt.Sprintf(format, v...))
} }
return
} }
// Debug implement ILogger // Debug implement ILogger
@ -148,7 +146,6 @@ func (s *SimpleLogger) Debug(v ...interface{}) {
if s.level <= LOG_DEBUG { if s.level <= LOG_DEBUG {
s.DEBUG.Output(2, fmt.Sprintln(v...)) s.DEBUG.Output(2, fmt.Sprintln(v...))
} }
return
} }
// Debugf implement ILogger // Debugf implement ILogger
@ -156,7 +153,6 @@ func (s *SimpleLogger) Debugf(format string, v ...interface{}) {
if s.level <= LOG_DEBUG { if s.level <= LOG_DEBUG {
s.DEBUG.Output(2, fmt.Sprintf(format, v...)) s.DEBUG.Output(2, fmt.Sprintf(format, v...))
} }
return
} }
// Info implement ILogger // Info implement ILogger
@ -164,7 +160,6 @@ func (s *SimpleLogger) Info(v ...interface{}) {
if s.level <= LOG_INFO { if s.level <= LOG_INFO {
s.INFO.Output(2, fmt.Sprintln(v...)) s.INFO.Output(2, fmt.Sprintln(v...))
} }
return
} }
// Infof implement ILogger // Infof implement ILogger
@ -172,7 +167,6 @@ func (s *SimpleLogger) Infof(format string, v ...interface{}) {
if s.level <= LOG_INFO { if s.level <= LOG_INFO {
s.INFO.Output(2, fmt.Sprintf(format, v...)) s.INFO.Output(2, fmt.Sprintf(format, v...))
} }
return
} }
// Warn implement ILogger // Warn implement ILogger
@ -180,7 +174,6 @@ func (s *SimpleLogger) Warn(v ...interface{}) {
if s.level <= LOG_WARNING { if s.level <= LOG_WARNING {
s.WARN.Output(2, fmt.Sprintln(v...)) s.WARN.Output(2, fmt.Sprintln(v...))
} }
return
} }
// Warnf implement ILogger // Warnf implement ILogger
@ -188,7 +181,6 @@ func (s *SimpleLogger) Warnf(format string, v ...interface{}) {
if s.level <= LOG_WARNING { if s.level <= LOG_WARNING {
s.WARN.Output(2, fmt.Sprintf(format, v...)) s.WARN.Output(2, fmt.Sprintf(format, v...))
} }
return
} }
// Level implement ILogger // Level implement ILogger
@ -199,7 +191,6 @@ func (s *SimpleLogger) Level() LogLevel {
// SetLevel implement ILogger // SetLevel implement ILogger
func (s *SimpleLogger) SetLevel(l LogLevel) { func (s *SimpleLogger) SetLevel(l LogLevel) {
s.level = l s.level = l
return
} }
// ShowSQL implement ILogger // ShowSQL implement ILogger

303
scan.go Normal file
View File

@ -0,0 +1,303 @@
// Copyright 2021 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"database/sql"
"fmt"
"reflect"
"time"
"xorm.io/xorm/convert"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
)
// genScanResultsByBeanNullabale generates scan result
func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) {
switch t := bean.(type) {
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes:
return t, false, nil
case *time.Time:
return &sql.NullTime{}, true, nil
case *string:
return &sql.NullString{}, true, nil
case *int, *int8, *int16, *int32:
return &sql.NullInt32{}, true, nil
case *int64:
return &sql.NullInt64{}, true, nil
case *uint, *uint8, *uint16, *uint32:
return &NullUint32{}, true, nil
case *uint64:
return &NullUint64{}, true, nil
case *float32, *float64:
return &sql.NullFloat64{}, true, nil
case *bool:
return &sql.NullBool{}, true, nil
case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString,
time.Time,
string,
int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64,
float32, float64,
bool:
return nil, false, fmt.Errorf("unsupported scan type: %t", t)
case convert.Conversion:
return &sql.RawBytes{}, true, nil
}
tp := reflect.TypeOf(bean).Elem()
switch tp.Kind() {
case reflect.String:
return &sql.NullString{}, true, nil
case reflect.Int64:
return &sql.NullInt64{}, true, nil
case reflect.Int32, reflect.Int, reflect.Int16, reflect.Int8:
return &sql.NullInt32{}, true, nil
case reflect.Uint64:
return &NullUint64{}, true, nil
case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8:
return &NullUint32{}, true, nil
default:
return nil, false, fmt.Errorf("unsupported type: %#v", bean)
}
}
func genScanResultsByBean(bean interface{}) (interface{}, bool, error) {
switch t := bean.(type) {
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString,
*string,
*int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64,
*float32, *float64,
*bool:
return t, false, nil
case *time.Time:
return &sql.NullTime{}, true, nil
case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString,
time.Time,
string,
int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64,
bool:
return nil, false, fmt.Errorf("unsupported scan type: %t", t)
case convert.Conversion:
return &sql.RawBytes{}, true, nil
}
tp := reflect.TypeOf(bean).Elem()
switch tp.Kind() {
case reflect.String:
return new(string), true, nil
case reflect.Int64:
return new(int64), true, nil
case reflect.Int32:
return new(int32), true, nil
case reflect.Int:
return new(int32), true, nil
case reflect.Int16:
return new(int32), true, nil
case reflect.Int8:
return new(int32), true, nil
case reflect.Uint64:
return new(uint64), true, nil
case reflect.Uint32:
return new(uint32), true, nil
case reflect.Uint:
return new(uint), true, nil
case reflect.Uint16:
return new(uint16), true, nil
case reflect.Uint8:
return new(uint8), true, nil
case reflect.Float32:
return new(float32), true, nil
case reflect.Float64:
return new(float64), true, nil
default:
return nil, false, fmt.Errorf("unsupported type: %#v", bean)
}
}
func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) {
var scanResults = make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var s sql.NullString
scanResults[i] = &s
}
if err := rows.Scan(scanResults...); err != nil {
return nil, err
}
result := make(map[string]string, len(fields))
for ii, key := range fields {
s := scanResults[ii].(*sql.NullString)
result[key] = s.String
}
return result, nil
}
func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string][]byte, error) {
var scanResults = make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var s sql.NullString
scanResults[i] = &s
}
if err := rows.Scan(scanResults...); err != nil {
return nil, err
}
result := make(map[string][]byte, len(fields))
for ii, key := range fields {
s := scanResults[ii].(*sql.NullString)
result[key] = []byte(s.String)
}
return result, nil
}
func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) {
var scanResults = make([]interface{}, len(types))
for i := 0; i < len(types); i++ {
var s sql.NullString
scanResults[i] = &s
}
if err := engine.driver.Scan(&dialects.ScanContext{
DBLocation: engine.DatabaseTZ,
UserLocation: engine.TZLocation,
}, rows, types, scanResults...); err != nil {
return nil, err
}
return scanResults, nil
}
// scan is a wrap of driver.Scan but will automatically change the input values according requirements
func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error {
var scanResults = make([]interface{}, 0, len(types))
var replaces = make([]bool, 0, len(types))
var err error
for _, v := range vv {
var replaced bool
var scanResult interface{}
if _, ok := v.(sql.Scanner); !ok {
var useNullable = true
if engine.driver.Features().SupportNullable {
nullable, ok := types[0].Nullable()
useNullable = ok && nullable
}
if useNullable {
scanResult, replaced, err = genScanResultsByBeanNullable(v)
} else {
scanResult, replaced, err = genScanResultsByBean(v)
}
if err != nil {
return err
}
} else {
scanResult = v
}
scanResults = append(scanResults, scanResult)
replaces = append(replaces, replaced)
}
var scanCtx = dialects.ScanContext{
DBLocation: engine.DatabaseTZ,
UserLocation: engine.TZLocation,
}
if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil {
return err
}
for i, replaced := range replaces {
if replaced {
if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil {
return err
}
}
}
return nil
}
func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) {
var scanResultContainers = make([]interface{}, len(types))
for i := 0; i < len(types); i++ {
scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName())
if err != nil {
return nil, err
}
scanResultContainers[i] = scanResult
}
if err := engine.driver.Scan(&dialects.ScanContext{
DBLocation: engine.DatabaseTZ,
UserLocation: engine.TZLocation,
}, rows, types, scanResultContainers...); err != nil {
return nil, err
}
return scanResultContainers, nil
}
func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) {
scanResults, err := engine.scanStringInterface(rows, types)
if err != nil {
return nil, err
}
var results = make([]string, 0, len(fields))
for i := 0; i < len(fields); i++ {
results = append(results, scanResults[i].(*sql.NullString).String)
}
return results, nil
}
func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
fields, err := rows.Columns()
if err != nil {
return nil, err
}
types, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
for rows.Next() {
result, err := row2mapBytes(rows, types, fields)
if err != nil {
return nil, err
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil
}
func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]interface{}, error) {
var resultsMap = make(map[string]interface{}, len(fields))
var scanResultContainers = make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName())
if err != nil {
return nil, err
}
scanResultContainers[i] = scanResult
}
if err := engine.driver.Scan(&dialects.ScanContext{
DBLocation: engine.DatabaseTZ,
UserLocation: engine.TZLocation,
}, rows, types, scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
res, err := convert.Interface2Interface(engine.TZLocation, scanResultContainers[ii])
if err != nil {
return nil, err
}
resultsMap[key] = res
}
return resultsMap, nil
}

View File

@ -6,10 +6,8 @@ package schemas
import ( import (
"errors" "errors"
"fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"time" "time"
) )
@ -25,6 +23,7 @@ type Column struct {
Name string Name string
TableName string TableName string
FieldName string // Available only when parsed from a struct FieldName string // Available only when parsed from a struct
FieldIndex []int // Available only when parsed from a struct
SQLType SQLType SQLType SQLType
IsJSON bool IsJSON bool
Length int Length int
@ -83,41 +82,17 @@ func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) {
// ValueOfV returns column's filed of struct's value accept reflevt value // ValueOfV returns column's filed of struct's value accept reflevt value
func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
var fieldValue reflect.Value var v = *dataStruct
fieldPath := strings.Split(col.FieldName, ".") for _, i := range col.FieldIndex {
if v.Kind() == reflect.Ptr {
if dataStruct.Type().Kind() == reflect.Map { if v.IsNil() {
keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1]) v.Set(reflect.New(v.Type().Elem()))
fieldValue = dataStruct.MapIndex(keyValue)
return &fieldValue, nil
} else if dataStruct.Type().Kind() == reflect.Interface {
structValue := reflect.ValueOf(dataStruct.Interface())
dataStruct = &structValue
}
level := len(fieldPath)
fieldValue = dataStruct.FieldByName(fieldPath[0])
for i := 0; i < level-1; i++ {
if !fieldValue.IsValid() {
break
}
if fieldValue.Kind() == reflect.Struct {
fieldValue = fieldValue.FieldByName(fieldPath[i+1])
} else if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
} }
fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) v = v.Elem()
} else {
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
} }
v = v.FieldByIndex([]int{i})
} }
return &v, nil
if !fieldValue.IsValid() {
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
}
return &fieldValue, nil
} }
// ConvertID converts id content to suitable type according column type // ConvertID converts id content to suitable type according column type

View File

@ -5,7 +5,6 @@
package schemas package schemas
import ( import (
"fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
@ -159,24 +158,8 @@ func (table *Table) IDOfV(rv reflect.Value) (PK, error) {
for i, col := range table.PKColumns() { for i, col := range table.PKColumns() {
var err error var err error
fieldName := col.FieldName pkField := v.FieldByIndex(col.FieldIndex)
for {
parts := strings.SplitN(fieldName, ".", 2)
if len(parts) == 1 {
break
}
v = v.FieldByName(parts[0])
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName)
}
fieldName = parts[1]
}
pkField := v.FieldByName(fieldName)
switch pkField.Kind() { switch pkField.Kind() {
case reflect.String: case reflect.String:
pk[i], err = col.ConvertID(pkField.String()) pk[i], err = col.ConvertID(pkField.String())

View File

@ -27,7 +27,6 @@ var testsGetColumn = []struct {
var table *Table var table *Table
func init() { func init() {
table = NewEmptyTable() table = NewEmptyTable()
var name string var name string
@ -41,7 +40,6 @@ func init() {
} }
func TestGetColumn(t *testing.T) { func TestGetColumn(t *testing.T) {
for _, test := range testsGetColumn { for _, test := range testsGetColumn {
if table.GetColumn(test.name) == nil { if table.GetColumn(test.name) == nil {
t.Error("Column not found!") t.Error("Column not found!")
@ -50,7 +48,6 @@ func TestGetColumn(t *testing.T) {
} }
func TestGetColumnIdx(t *testing.T) { func TestGetColumnIdx(t *testing.T) {
for _, test := range testsGetColumn { for _, test := range testsGetColumn {
if table.GetColumnIdx(test.name, test.idx) == nil { if table.GetColumnIdx(test.name, test.idx) == nil {
t.Errorf("Column %s with idx %d not found!", test.name, test.idx) t.Errorf("Column %s with idx %d not found!", test.name, test.idx)
@ -59,7 +56,6 @@ func TestGetColumnIdx(t *testing.T) {
} }
func BenchmarkGetColumnWithToLower(b *testing.B) { func BenchmarkGetColumnWithToLower(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, test := range testsGetColumn { for _, test := range testsGetColumn {
@ -71,7 +67,6 @@ func BenchmarkGetColumnWithToLower(b *testing.B) {
} }
func BenchmarkGetColumnIdxWithToLower(b *testing.B) { func BenchmarkGetColumnIdxWithToLower(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, test := range testsGetColumn { for _, test := range testsGetColumn {
@ -89,7 +84,6 @@ func BenchmarkGetColumnIdxWithToLower(b *testing.B) {
} }
func BenchmarkGetColumn(b *testing.B) { func BenchmarkGetColumn(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, test := range testsGetColumn { for _, test := range testsGetColumn {
if table.GetColumn(test.name) == nil { if table.GetColumn(test.name) == nil {
@ -100,7 +94,6 @@ func BenchmarkGetColumn(b *testing.B) {
} }
func BenchmarkGetColumnIdx(b *testing.B) { func BenchmarkGetColumnIdx(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, test := range testsGetColumn { for _, test := range testsGetColumn {
if table.GetColumnIdx(test.name, test.idx) == nil { if table.GetColumnIdx(test.name, test.idx) == nil {

View File

@ -375,6 +375,9 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s
if err != nil { if err != nil {
return nil, err return nil, err
} }
if fieldValue == nil {
return nil, ErrFieldIsNotValid{key, table.Name}
}
if !fieldValue.IsValid() || !fieldValue.CanSet() { if !fieldValue.IsValid() || !fieldValue.CanSet() {
return nil, ErrFieldIsNotValid{key, table.Name} return nil, ErrFieldIsNotValid{key, table.Name}

View File

@ -35,27 +35,20 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time
sd, err := strconv.ParseInt(sdata, 10, 64) sd, err := strconv.ParseInt(sdata, 10, 64)
if err == nil { if err == nil {
x = time.Unix(sd, 0) x = time.Unix(sd, 0)
//session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else {
//session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} }
} else if len(sdata) > 19 && strings.Contains(sdata, "-") { } else if len(sdata) > 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc) x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc)
session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.Name, x, sdata)
if err != nil { if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc)
//session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} }
if err != nil { if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc)
//session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} }
} else if len(sdata) == 19 && strings.Contains(sdata, "-") { } else if len(sdata) == 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc)
//session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
//session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if col.SQLType.Name == schemas.Time { } else if col.SQLType.Name == schemas.Time {
if strings.Contains(sdata, " ") { if strings.Contains(sdata, " ") {
ssd := strings.Split(sdata, " ") ssd := strings.Split(sdata, " ")
@ -69,7 +62,6 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time
st := fmt.Sprintf("2006-01-02 %v", sdata) st := fmt.Sprintf("2006-01-02 %v", sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc) x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc)
//session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else { } else {
outErr = fmt.Errorf("unsupported time format %v", sdata) outErr = fmt.Errorf("unsupported time format %v", sdata)
return return

View File

@ -276,7 +276,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error {
cols := table.PKColumns() cols := table.PKColumns()
if len(cols) == 1 { if len(cols) == 1 {
return convertAssign(dst, pk[0]) return convertAssign(dst, pk[0], nil, nil)
} }
dst = pk dst = pk

View File

@ -6,12 +6,16 @@ package xorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"time"
"xorm.io/xorm/caches" "xorm.io/xorm/caches"
"xorm.io/xorm/convert"
"xorm.io/xorm/core"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -108,6 +112,17 @@ func (session *Session) get(bean interface{}) (bool, error) {
return true, nil return true, nil
} }
var (
valuerTypePlaceHolder driver.Valuer
valuerType = reflect.TypeOf(&valuerTypePlaceHolder).Elem()
scannerTypePlaceHolder sql.Scanner
scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem()
conversionTypePlaceHolder convert.Conversion
conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem()
)
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
rows, err := session.queryRows(sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)
if err != nil { if err != nil {
@ -122,155 +137,161 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table,
return false, nil return false, nil
} }
switch bean.(type) { // WARN: Alougth rows return true, but we may also return error.
case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString: types, err := rows.ColumnTypes()
return true, rows.Scan(&bean) if err != nil {
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString: return true, err
return true, rows.Scan(bean) }
case *string: fields, err := rows.Columns()
var res sql.NullString if err != nil {
if err := rows.Scan(&res); err != nil { return true, err
return true, err
}
if res.Valid {
*(bean.(*string)) = res.String
}
return true, nil
case *int:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int)) = int(res.Int64)
}
return true, nil
case *int8:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int8)) = int8(res.Int64)
}
return true, nil
case *int16:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int16)) = int16(res.Int64)
}
return true, nil
case *int32:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int32)) = int32(res.Int64)
}
return true, nil
case *int64:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int64)) = int64(res.Int64)
}
return true, nil
case *uint:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint)) = uint(res.Int64)
}
return true, nil
case *uint8:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint8)) = uint8(res.Int64)
}
return true, nil
case *uint16:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint16)) = uint16(res.Int64)
}
return true, nil
case *uint32:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint32)) = uint32(res.Int64)
}
return true, nil
case *uint64:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint64)) = uint64(res.Int64)
}
return true, nil
case *bool:
var res sql.NullBool
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*bool)) = res.Bool
}
return true, nil
} }
switch beanKind { switch beanKind {
case reflect.Struct: case reflect.Struct:
fields, err := rows.Columns() if _, ok := bean.(*time.Time); ok {
if err != nil { break
// WARN: Alougth rows return true, but get fields failed
return true, err
} }
if _, ok := bean.(sql.Scanner); ok {
scanResults, err := session.row2Slice(rows, fields, bean) break
if err != nil {
return false, err
} }
// close it before convert data if _, ok := bean.(convert.Conversion); len(types) == 1 && ok {
rows.Close() break
dataStruct := utils.ReflectValue(bean)
_, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table)
if err != nil {
return true, err
} }
return session.getStruct(rows, types, fields, table, bean)
return true, session.executeProcessors()
case reflect.Slice: case reflect.Slice:
err = rows.ScanSlice(bean) return session.getSlice(rows, types, fields, bean)
case reflect.Map: case reflect.Map:
err = rows.ScanMap(bean) return session.getMap(rows, types, fields, bean)
case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
err = rows.Scan(bean)
default:
err = rows.Scan(bean)
} }
return true, err return session.getVars(rows, types, fields, bean)
}
func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) {
switch t := bean.(type) {
case *[]string:
res, err := session.engine.scanStringInterface(rows, types)
if err != nil {
return true, err
}
var needAppend = len(*t) == 0 // both support slice is empty or has been initlized
for i, r := range res {
if needAppend {
*t = append(*t, r.(*sql.NullString).String)
} else {
(*t)[i] = r.(*sql.NullString).String
}
}
return true, nil
case *[]interface{}:
scanResults, err := session.engine.scanInterfaces(rows, types)
if err != nil {
return true, err
}
var needAppend = len(*t) == 0
for ii := range fields {
s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
if err != nil {
return true, err
}
if needAppend {
*t = append(*t, s)
} else {
(*t)[ii] = s
}
}
return true, nil
default:
return true, fmt.Errorf("unspoorted slice type: %t", t)
}
}
func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) {
switch t := bean.(type) {
case *map[string]string:
scanResults, err := session.engine.scanStringInterface(rows, types)
if err != nil {
return true, err
}
for ii, key := range fields {
(*t)[key] = scanResults[ii].(*sql.NullString).String
}
return true, nil
case *map[string]interface{}:
scanResults, err := session.engine.scanInterfaces(rows, types)
if err != nil {
return true, err
}
for ii, key := range fields {
s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
if err != nil {
return true, err
}
(*t)[key] = s
}
return true, nil
default:
return true, fmt.Errorf("unspoorted map type: %t", t)
}
}
func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields []string, beans ...interface{}) (bool, error) {
if len(beans) != len(types) {
return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans))
}
var scanResults = make([]interface{}, 0, len(types))
var replaceds = make([]bool, 0, len(types))
for _, bean := range beans {
switch t := bean.(type) {
case sql.Scanner:
scanResults = append(scanResults, t)
replaceds = append(replaceds, false)
case convert.Conversion:
scanResults = append(scanResults, &sql.RawBytes{})
replaceds = append(replaceds, true)
default:
scanResults = append(scanResults, bean)
replaceds = append(replaceds, false)
}
}
err := session.engine.scan(rows, fields, types, scanResults...)
if err != nil {
return true, err
}
for i, replaced := range replaceds {
if replaced {
err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation)
if err != nil {
return true, err
}
}
}
return true, nil
}
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {
fields, err := rows.Columns()
if err != nil {
// WARN: Alougth rows return true, but get fields failed
return true, err
}
scanResults, err := session.row2Slice(rows, fields, bean)
if err != nil {
return false, err
}
// close it before convert data
rows.Close()
dataStruct := utils.ReflectValue(bean)
_, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table)
if err != nil {
return true, err
}
return true, session.executeProcessors()
} }
func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) {

View File

@ -11,6 +11,7 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"time"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
@ -374,9 +375,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 1, nil return 1, nil
} }
aiValue.Set(int64ToIntValue(id, aiValue.Type())) return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation)
return 1, nil
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
session.engine.dialect.URI().DBType == schemas.MSSQL) { session.engine.dialect.URI().DBType == schemas.MSSQL) {
res, err := session.queryBytes(sqlStr, args...) res, err := session.queryBytes(sqlStr, args...)
@ -416,9 +415,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 1, nil return 1, nil
} }
aiValue.Set(int64ToIntValue(id, aiValue.Type())) return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation)
return 1, nil
} }
res, err := session.exec(sqlStr, args...) res, err := session.exec(sqlStr, args...)
@ -458,7 +455,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return res.RowsAffected() return res.RowsAffected()
} }
aiValue.Set(int64ToIntValue(id, aiValue.Type())) if err := convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation); err != nil {
return 0, err
}
return res.RowsAffected() return res.RowsAffected()
} }
@ -499,6 +498,16 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
} }
if col.IsDeleted { if col.IsDeleted {
colNames = append(colNames, col.Name)
if !col.Nullable {
if col.SQLType.IsNumeric() {
args = append(args, 0)
} else {
args = append(args, time.Time{}.Format("2006-01-02 15:04:05"))
}
} else {
args = append(args, nil)
}
continue continue
} }

View File

@ -5,13 +5,7 @@
package xorm package xorm
import ( import (
"fmt"
"reflect"
"strconv"
"time"
"xorm.io/xorm/core" "xorm.io/xorm/core"
"xorm.io/xorm/schemas"
) )
// Query runs a raw sql and return records as []map[string][]byte // Query runs a raw sql and return records as []map[string][]byte
@ -28,116 +22,18 @@ func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, er
return session.queryBytes(sqlStr, args...) return session.queryBytes(sqlStr, args...)
} }
func value2String(rawValue *reflect.Value) (str string, err error) { func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
aa := reflect.TypeOf((*rawValue).Interface())
vv := reflect.ValueOf((*rawValue).Interface())
switch aa.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
case reflect.Float32, reflect.Float64:
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
case reflect.String:
str = vv.String()
case reflect.Array, reflect.Slice:
switch aa.Elem().Kind() {
case reflect.Uint8:
data := rawValue.Interface().([]byte)
str = string(data)
if str == "\x00" {
str = "0"
}
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
// time type
case reflect.Struct:
if aa.ConvertibleTo(schemas.TimeType) {
str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
} else {
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
case reflect.Bool:
str = strconv.FormatBool(vv.Bool())
case reflect.Complex128, reflect.Complex64:
str = fmt.Sprintf("%v", vv.Complex())
/* TODO: unsupported types below
case reflect.Map:
case reflect.Ptr:
case reflect.Uintptr:
case reflect.UnsafePointer:
case reflect.Chan, reflect.Func, reflect.Interface:
*/
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
return
}
func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) {
result := make(map[string]string)
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
// if row is null then as empty string
if rawValue.Interface() == nil {
result[key] = ""
continue
}
if data, err := value2String(&rawValue); err == nil {
result[key] = data
} else {
return nil, err
}
}
return result, nil
}
func row2sliceStr(rows *core.Rows, fields []string) (results []string, err error) {
result := make([]string, 0, len(fields))
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for i := 0; i < len(fields); i++ {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[i]))
// if row is null then as empty string
if rawValue.Interface() == nil {
result = append(result, "")
continue
}
if data, err := value2String(&rawValue); err == nil {
result = append(result, data)
} else {
return nil, err
}
}
return result, nil
}
func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
fields, err := rows.Columns() fields, err := rows.Columns()
if err != nil { if err != nil {
return nil, err return nil, err
} }
types, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
for rows.Next() { for rows.Next() {
result, err := row2mapStr(rows, fields) result, err := row2mapStr(rows, types, fields)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -147,13 +43,18 @@ func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error)
return resultsSlice, nil return resultsSlice, nil
} }
func rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) { func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) {
fields, err := rows.Columns() fields, err := rows.Columns()
if err != nil { if err != nil {
return nil, err return nil, err
} }
types, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
for rows.Next() { for rows.Next() {
record, err := row2sliceStr(rows, fields) record, err := session.engine.row2sliceStr(rows, types, fields)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -180,7 +81,7 @@ func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]stri
} }
defer rows.Close() defer rows.Close()
return rows2Strings(rows) return session.rows2Strings(rows)
} }
// QuerySliceString runs a raw sql and return records as [][]string // QuerySliceString runs a raw sql and return records as [][]string
@ -200,33 +101,20 @@ func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string,
} }
defer rows.Close() defer rows.Close()
return rows2SliceString(rows) return session.rows2SliceString(rows)
} }
func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) { func (session *Session) rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) {
resultsMap = make(map[string]interface{}, len(fields))
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
resultsMap[key] = reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])).Interface()
}
return
}
func rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) {
fields, err := rows.Columns() fields, err := rows.Columns()
if err != nil { if err != nil {
return nil, err return nil, err
} }
types, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
for rows.Next() { for rows.Next() {
result, err := row2mapInterface(rows, fields) result, err := session.engine.row2mapInterface(rows, types, fields)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -253,5 +141,5 @@ func (session *Session) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]i
} }
defer rows.Close() defer rows.Close()
return rows2Interfaces(rows) return session.rows2Interfaces(rows)
} }

View File

@ -6,9 +6,13 @@ package xorm
import ( import (
"database/sql" "database/sql"
"fmt"
"reflect" "reflect"
"strconv"
"time"
"xorm.io/xorm/core" "xorm.io/xorm/core"
"xorm.io/xorm/schemas"
) )
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
@ -71,6 +75,53 @@ func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row {
return core.NewRow(session.queryRows(sqlStr, args...)) return core.NewRow(session.queryRows(sqlStr, args...))
} }
func value2String(rawValue *reflect.Value) (str string, err error) {
aa := reflect.TypeOf((*rawValue).Interface())
vv := reflect.ValueOf((*rawValue).Interface())
switch aa.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
case reflect.Float32, reflect.Float64:
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
case reflect.String:
str = vv.String()
case reflect.Array, reflect.Slice:
switch aa.Elem().Kind() {
case reflect.Uint8:
data := rawValue.Interface().([]byte)
str = string(data)
if str == "\x00" {
str = "0"
}
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
// time type
case reflect.Struct:
if aa.ConvertibleTo(schemas.TimeType) {
str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
} else {
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
case reflect.Bool:
str = strconv.FormatBool(vv.Bool())
case reflect.Complex128, reflect.Complex64:
str = fmt.Sprintf("%v", vv.Complex())
/* TODO: unsupported types below
case reflect.Map:
case reflect.Ptr:
case reflect.Uintptr:
case reflect.UnsafePointer:
case reflect.Chan, reflect.Func, reflect.Interface:
*/
default:
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}
return
}
func value2Bytes(rawValue *reflect.Value) ([]byte, error) { func value2Bytes(rawValue *reflect.Value) ([]byte, error) {
str, err := value2String(rawValue) str, err := value2String(rawValue)
if err != nil { if err != nil {
@ -79,50 +130,6 @@ func value2Bytes(rawValue *reflect.Value) ([]byte, error) {
return []byte(str), nil return []byte(str), nil
} }
func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) {
result := make(map[string][]byte)
scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers[i] = &scanResultContainer
}
if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
result[key] = []byte{}
continue
}
if data, err := value2Bytes(&rawValue); err == nil {
result[key] = data
} else {
return nil, err // !nashtsai! REVIEW, should return err or just error log?
}
}
return result, nil
}
func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
fields, err := rows.Columns()
if err != nil {
return nil, err
}
for rows.Next() {
result, err := row2map(rows, fields)
if err != nil {
return nil, err
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil
}
func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) { func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) {
rows, err := session.queryRows(sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)
if err != nil { if err != nil {

View File

@ -280,15 +280,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
k = ct.Elem().Kind() k = ct.Elem().Kind()
} }
if k == reflect.Struct { if k == reflect.Struct {
var refTable = session.statement.RefTable condTable, err := session.engine.TableInfo(condiBean[0])
if refTable == nil { if err != nil {
refTable, err = session.engine.TableInfo(condiBean[0]) return 0, err
if err != nil {
return 0, err
}
} }
var err error
autoCond, err = session.statement.BuildConds(refTable, condiBean[0], true, true, false, true, false) autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -457,7 +454,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
// FIXME: if bean is a map type, it will panic because map cannot be as map key // FIXME: if bean is a map type, it will panic because map cannot be as map key
session.afterUpdateBeans[bean] = &afterClosures session.afterUpdateBeans[bean] = &afterClosures
} }
} else { } else {
if _, ok := interface{}(bean).(AfterUpdateProcessor); ok { if _, ok := interface{}(bean).(AfterUpdateProcessor); ok {
session.afterUpdateBeans[bean] = nil session.afterUpdateBeans[bean] = nil

View File

@ -7,7 +7,6 @@ package tags
import ( import (
"encoding/gob" "encoding/gob"
"errors" "errors"
"fmt"
"reflect" "reflect"
"strings" "strings"
"sync" "sync"
@ -23,7 +22,7 @@ import (
var ( var (
// ErrUnsupportedType represents an unsupported type error // ErrUnsupportedType represents an unsupported type error
ErrUnsupportedType = errors.New("Unsupported type") ErrUnsupportedType = errors.New("unsupported type")
) )
// Parser represents a parser for xorm tag // Parser represents a parser for xorm tag
@ -125,6 +124,147 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index
} }
} }
var ErrIgnoreField = errors.New("field will be ignored")
func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
var sqlType schemas.SQLType
if fieldValue.CanAddr() {
if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
sqlType = schemas.SQLType{Name: schemas.Text}
}
}
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
sqlType = schemas.SQLType{Name: schemas.Text}
} else {
sqlType = schemas.Type2SQLType(field.Type)
}
col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name),
field.Name, sqlType, sqlType.DefaultLength,
sqlType.DefaultLength2, true)
col.FieldIndex = []int{fieldIndex}
if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
col.IsAutoIncrement = true
col.IsPrimaryKey = true
col.Nullable = false
}
return col, nil
}
func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) {
var col = &schemas.Column{
FieldName: field.Name,
FieldIndex: []int{fieldIndex},
Nullable: true,
IsPrimaryKey: false,
IsAutoIncrement: false,
MapType: schemas.TWOSIDES,
Indexes: make(map[string]int),
DefaultIsEmpty: true,
}
var ctx = Context{
table: table,
col: col,
fieldValue: fieldValue,
indexNames: make(map[string]int),
parser: parser,
}
for j, tag := range tags {
if ctx.ignoreNext {
ctx.ignoreNext = false
continue
}
ctx.tag = tag
ctx.tagUname = strings.ToUpper(tag.name)
if j > 0 {
ctx.preTag = strings.ToUpper(tags[j-1].name)
}
if j < len(tags)-1 {
ctx.nextTag = tags[j+1].name
} else {
ctx.nextTag = ""
}
if h, ok := parser.handlers[ctx.tagUname]; ok {
if err := h(&ctx); err != nil {
return nil, err
}
} else {
if strings.HasPrefix(ctx.tag.name, "'") && strings.HasSuffix(ctx.tag.name, "'") {
col.Name = ctx.tag.name[1 : len(ctx.tag.name)-1]
} else {
col.Name = ctx.tag.name
}
}
if ctx.hasCacheTag {
if parser.cacherMgr.GetDefaultCacher() != nil {
parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher())
} else {
parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000))
}
}
if ctx.hasNoCacheTag {
parser.cacherMgr.SetCacher(table.Name, nil)
}
}
if col.SQLType.Name == "" {
col.SQLType = schemas.Type2SQLType(field.Type)
}
parser.dialect.SQLType(col)
if col.Length == 0 {
col.Length = col.SQLType.DefaultLength
}
if col.Length2 == 0 {
col.Length2 = col.SQLType.DefaultLength2
}
if col.Name == "" {
col.Name = parser.columnMapper.Obj2Table(field.Name)
}
if ctx.isUnique {
ctx.indexNames[col.Name] = schemas.UniqueType
} else if ctx.isIndex {
ctx.indexNames[col.Name] = schemas.IndexType
}
for indexName, indexType := range ctx.indexNames {
addIndex(indexName, table, col, indexType)
}
return col, nil
}
func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
var (
tag = field.Tag
ormTagStr = strings.TrimSpace(tag.Get(parser.identifier))
)
if ormTagStr == "-" {
return nil, ErrIgnoreField
}
if ormTagStr == "" {
return parser.parseFieldWithNoTag(fieldIndex, field, fieldValue)
}
tags, err := splitTag(ormTagStr)
if err != nil {
return nil, err
}
return parser.parseFieldWithTags(table, fieldIndex, field, fieldValue, tags)
}
func isNotTitle(n string) bool {
for _, c := range n {
return unicode.IsLower(c)
}
return true
}
// Parse parses a struct as a table information // Parse parses a struct as a table information
func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
t := v.Type() t := v.Type()
@ -140,192 +280,26 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
table.Type = t table.Type = t
table.Name = names.GetTableName(parser.tableMapper, v) table.Name = names.GetTableName(parser.tableMapper, v)
var idFieldColName string
var hasCacheTag, hasNoCacheTag bool
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
var isUnexportField bool var field = t.Field(i)
for _, c := range t.Field(i).Name { if isNotTitle(field.Name) {
if unicode.IsLower(c) {
isUnexportField = true
}
break
}
if isUnexportField {
continue continue
} }
tag := t.Field(i).Tag col, err := parser.parseField(table, i, field, v.Field(i))
ormTagStr := tag.Get(parser.identifier) if err == ErrIgnoreField {
var col *schemas.Column continue
fieldValue := v.Field(i) } else if err != nil {
fieldType := fieldValue.Type() return nil, err
if ormTagStr != "" {
col = &schemas.Column{
FieldName: t.Field(i).Name,
Nullable: true,
IsPrimaryKey: false,
IsAutoIncrement: false,
MapType: schemas.TWOSIDES,
Indexes: make(map[string]int),
DefaultIsEmpty: true,
}
tags := splitTag(ormTagStr)
if len(tags) > 0 {
if tags[0] == "-" {
continue
}
var ctx = Context{
table: table,
col: col,
fieldValue: fieldValue,
indexNames: make(map[string]int),
parser: parser,
}
if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") {
pStart := strings.Index(tags[0], "(")
if pStart > -1 && strings.HasSuffix(tags[0], ")") {
var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool {
return r == '\'' || r == '"'
})
ctx.params = []string{tagPrefix}
}
if err := ExtendsTagHandler(&ctx); err != nil {
return nil, err
}
continue
}
for j, key := range tags {
if ctx.ignoreNext {
ctx.ignoreNext = false
continue
}
k := strings.ToUpper(key)
ctx.tagName = k
ctx.params = []string{}
pStart := strings.Index(k, "(")
if pStart == 0 {
return nil, errors.New("( could not be the first character")
}
if pStart > -1 {
if !strings.HasSuffix(k, ")") {
return nil, fmt.Errorf("field %s tag %s cannot match ) character", col.FieldName, key)
}
ctx.tagName = k[:pStart]
ctx.params = strings.Split(key[pStart+1:len(k)-1], ",")
}
if j > 0 {
ctx.preTag = strings.ToUpper(tags[j-1])
}
if j < len(tags)-1 {
ctx.nextTag = tags[j+1]
} else {
ctx.nextTag = ""
}
if h, ok := parser.handlers[ctx.tagName]; ok {
if err := h(&ctx); err != nil {
return nil, err
}
} else {
if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") {
col.Name = key[1 : len(key)-1]
} else {
col.Name = key
}
}
if ctx.hasCacheTag {
hasCacheTag = true
}
if ctx.hasNoCacheTag {
hasNoCacheTag = true
}
}
if col.SQLType.Name == "" {
col.SQLType = schemas.Type2SQLType(fieldType)
}
parser.dialect.SQLType(col)
if col.Length == 0 {
col.Length = col.SQLType.DefaultLength
}
if col.Length2 == 0 {
col.Length2 = col.SQLType.DefaultLength2
}
if col.Name == "" {
col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name)
}
if ctx.isUnique {
ctx.indexNames[col.Name] = schemas.UniqueType
} else if ctx.isIndex {
ctx.indexNames[col.Name] = schemas.IndexType
}
for indexName, indexType := range ctx.indexNames {
addIndex(indexName, table, col, indexType)
}
}
} else {
var sqlType schemas.SQLType
if fieldValue.CanAddr() {
if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
sqlType = schemas.SQLType{Name: schemas.Text}
}
}
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
sqlType = schemas.SQLType{Name: schemas.Text}
} else {
sqlType = schemas.Type2SQLType(fieldType)
}
col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name),
t.Field(i).Name, sqlType, sqlType.DefaultLength,
sqlType.DefaultLength2, true)
if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
idFieldColName = col.Name
}
}
if col.IsAutoIncrement {
col.Nullable = false
} }
table.AddColumn(col) table.AddColumn(col)
} // end for } // end for
if idFieldColName != "" && len(table.PrimaryKeys) == 0 { deletedColumn := table.DeletedColumn()
col := table.GetColumn(idFieldColName) // check columns
col.IsPrimaryKey = true if deletedColumn != nil {
col.IsAutoIncrement = true deletedColumn.Nullable = true
col.Nullable = false
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
table.AutoIncrement = col.Name
}
if hasCacheTag {
if parser.cacherMgr.GetDefaultCacher() != nil { // !nash! use engine's cacher if provided
//engine.logger.Info("enable cache on table:", table.Name)
parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher())
} else {
//engine.logger.Info("enable LRU cache on table:", table.Name)
parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000))
}
}
if hasNoCacheTag {
//engine.logger.Info("disable cache on table:", table.Name)
parser.cacherMgr.SetCacher(table.Name, nil)
} }
return table, nil return table, nil

View File

@ -6,12 +6,16 @@ package tags
import ( import (
"reflect" "reflect"
"strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert"
"xorm.io/xorm/caches" "xorm.io/xorm/caches"
"xorm.io/xorm/dialects" "xorm.io/xorm/dialects"
"xorm.io/xorm/names" "xorm.io/xorm/names"
"xorm.io/xorm/schemas"
"github.com/stretchr/testify/assert"
) )
type ParseTableName1 struct{} type ParseTableName1 struct{}
@ -80,7 +84,7 @@ func TestParseWithOtherIdentifier(t *testing.T) {
parser := NewParser( parser := NewParser(
"xorm", "xorm",
dialects.QueryDialect("mysql"), dialects.QueryDialect("mysql"),
names.GonicMapper{}, names.SameMapper{},
names.SnakeMapper{}, names.SnakeMapper{},
caches.NewManager(), caches.NewManager(),
) )
@ -88,13 +92,461 @@ func TestParseWithOtherIdentifier(t *testing.T) {
type StructWithDBTag struct { type StructWithDBTag struct {
FieldFoo string `db:"foo"` FieldFoo string `db:"foo"`
} }
parser.SetIdentifier("db") parser.SetIdentifier("db")
table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag))) table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag)))
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, "struct_with_db_tag", table.Name) assert.EqualValues(t, "StructWithDBTag", table.Name)
assert.EqualValues(t, 1, len(table.Columns())) assert.EqualValues(t, 1, len(table.Columns()))
for _, col := range table.Columns() { for _, col := range table.Columns() {
assert.EqualValues(t, "foo", col.Name) assert.EqualValues(t, "foo", col.Name)
} }
} }
func TestParseWithIgnore(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SameMapper{},
names.SnakeMapper{},
caches.NewManager(),
)
type StructWithIgnoreTag struct {
FieldFoo string `db:"-"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithIgnoreTag)))
assert.NoError(t, err)
assert.EqualValues(t, "StructWithIgnoreTag", table.Name)
assert.EqualValues(t, 0, len(table.Columns()))
}
func TestParseWithAutoincrement(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithAutoIncrement struct {
ID int64
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_auto_increment", table.Name)
assert.EqualValues(t, 1, len(table.Columns()))
assert.EqualValues(t, "id", table.Columns()[0].Name)
assert.True(t, table.Columns()[0].IsAutoIncrement)
assert.True(t, table.Columns()[0].IsPrimaryKey)
}
func TestParseWithAutoincrement2(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithAutoIncrement2 struct {
ID int64 `db:"pk autoincr"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement2)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_auto_increment2", table.Name)
assert.EqualValues(t, 1, len(table.Columns()))
assert.EqualValues(t, "id", table.Columns()[0].Name)
assert.True(t, table.Columns()[0].IsAutoIncrement)
assert.True(t, table.Columns()[0].IsPrimaryKey)
assert.False(t, table.Columns()[0].Nullable)
}
func TestParseWithNullable(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithNullable struct {
Name string `db:"notnull"`
FullName string `db:"null comment('column comment,字段注释')"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithNullable)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_nullable", table.Name)
assert.EqualValues(t, 2, len(table.Columns()))
assert.EqualValues(t, "name", table.Columns()[0].Name)
assert.EqualValues(t, "full_name", table.Columns()[1].Name)
assert.False(t, table.Columns()[0].Nullable)
assert.True(t, table.Columns()[1].Nullable)
assert.EqualValues(t, "column comment,字段注释", table.Columns()[1].Comment)
}
func TestParseWithTimes(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithTimes struct {
Name string `db:"notnull"`
CreatedAt time.Time `db:"created"`
UpdatedAt time.Time `db:"updated"`
DeletedAt time.Time `db:"deleted"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithTimes)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_times", table.Name)
assert.EqualValues(t, 4, len(table.Columns()))
assert.EqualValues(t, "name", table.Columns()[0].Name)
assert.EqualValues(t, "created_at", table.Columns()[1].Name)
assert.EqualValues(t, "updated_at", table.Columns()[2].Name)
assert.EqualValues(t, "deleted_at", table.Columns()[3].Name)
assert.False(t, table.Columns()[0].Nullable)
assert.True(t, table.Columns()[1].Nullable)
assert.True(t, table.Columns()[1].IsCreated)
assert.True(t, table.Columns()[2].Nullable)
assert.True(t, table.Columns()[2].IsUpdated)
assert.True(t, table.Columns()[3].Nullable)
assert.True(t, table.Columns()[3].IsDeleted)
}
func TestParseWithExtends(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithEmbed struct {
Name string
CreatedAt time.Time `db:"created"`
UpdatedAt time.Time `db:"updated"`
DeletedAt time.Time `db:"deleted"`
}
type StructWithExtends struct {
SW StructWithEmbed `db:"extends"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithExtends)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_extends", table.Name)
assert.EqualValues(t, 4, len(table.Columns()))
assert.EqualValues(t, "name", table.Columns()[0].Name)
assert.EqualValues(t, "created_at", table.Columns()[1].Name)
assert.EqualValues(t, "updated_at", table.Columns()[2].Name)
assert.EqualValues(t, "deleted_at", table.Columns()[3].Name)
assert.True(t, table.Columns()[0].Nullable)
assert.True(t, table.Columns()[1].Nullable)
assert.True(t, table.Columns()[1].IsCreated)
assert.True(t, table.Columns()[2].Nullable)
assert.True(t, table.Columns()[2].IsUpdated)
assert.True(t, table.Columns()[3].Nullable)
assert.True(t, table.Columns()[3].IsDeleted)
}
func TestParseWithCache(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithCache struct {
Name string `db:"cache"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithCache)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_cache", table.Name)
assert.EqualValues(t, 1, len(table.Columns()))
assert.EqualValues(t, "name", table.Columns()[0].Name)
assert.True(t, table.Columns()[0].Nullable)
cacher := parser.cacherMgr.GetCacher(table.Name)
assert.NotNil(t, cacher)
}
func TestParseWithNoCache(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithNoCache struct {
Name string `db:"nocache"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithNoCache)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_no_cache", table.Name)
assert.EqualValues(t, 1, len(table.Columns()))
assert.EqualValues(t, "name", table.Columns()[0].Name)
assert.True(t, table.Columns()[0].Nullable)
cacher := parser.cacherMgr.GetCacher(table.Name)
assert.Nil(t, cacher)
}
func TestParseWithEnum(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithEnum struct {
Name string `db:"enum('alice', 'bob')"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithEnum)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_enum", table.Name)
assert.EqualValues(t, 1, len(table.Columns()))
assert.EqualValues(t, "name", table.Columns()[0].Name)
assert.True(t, table.Columns()[0].Nullable)
assert.EqualValues(t, schemas.Enum, strings.ToUpper(table.Columns()[0].SQLType.Name))
assert.EqualValues(t, map[string]int{
"alice": 0,
"bob": 1,
}, table.Columns()[0].EnumOptions)
}
func TestParseWithSet(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithSet struct {
Name string `db:"set('alice', 'bob')"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithSet)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_set", table.Name)
assert.EqualValues(t, 1, len(table.Columns()))
assert.EqualValues(t, "name", table.Columns()[0].Name)
assert.True(t, table.Columns()[0].Nullable)
assert.EqualValues(t, schemas.Set, strings.ToUpper(table.Columns()[0].SQLType.Name))
assert.EqualValues(t, map[string]int{
"alice": 0,
"bob": 1,
}, table.Columns()[0].SetOptions)
}
func TestParseWithIndex(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithIndex struct {
Name string `db:"index"`
Name2 string `db:"index(s)"`
Name3 string `db:"unique"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithIndex)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_index", table.Name)
assert.EqualValues(t, 3, len(table.Columns()))
assert.EqualValues(t, "name", table.Columns()[0].Name)
assert.EqualValues(t, "name2", table.Columns()[1].Name)
assert.EqualValues(t, "name3", table.Columns()[2].Name)
assert.True(t, table.Columns()[0].Nullable)
assert.True(t, table.Columns()[1].Nullable)
assert.True(t, table.Columns()[2].Nullable)
assert.EqualValues(t, 1, len(table.Columns()[0].Indexes))
assert.EqualValues(t, 1, len(table.Columns()[1].Indexes))
assert.EqualValues(t, 1, len(table.Columns()[2].Indexes))
}
func TestParseWithVersion(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithVersion struct {
Name string
Version int `db:"version"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithVersion)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_version", table.Name)
assert.EqualValues(t, 2, len(table.Columns()))
assert.EqualValues(t, "name", table.Columns()[0].Name)
assert.EqualValues(t, "version", table.Columns()[1].Name)
assert.True(t, table.Columns()[0].Nullable)
assert.True(t, table.Columns()[1].Nullable)
assert.True(t, table.Columns()[1].IsVersion)
}
func TestParseWithLocale(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithLocale struct {
UTCLocale time.Time `db:"utc"`
LocalLocale time.Time `db:"local"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithLocale)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_locale", table.Name)
assert.EqualValues(t, 2, len(table.Columns()))
assert.EqualValues(t, "utc_locale", table.Columns()[0].Name)
assert.EqualValues(t, "local_locale", table.Columns()[1].Name)
assert.EqualValues(t, time.UTC, table.Columns()[0].TimeZone)
assert.EqualValues(t, time.Local, table.Columns()[1].TimeZone)
}
func TestParseWithDefault(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithDefault struct {
Default1 time.Time `db:"default '1970-01-01 00:00:00'"`
Default2 time.Time `db:"default(CURRENT_TIMESTAMP)"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithDefault)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_default", table.Name)
assert.EqualValues(t, 2, len(table.Columns()))
assert.EqualValues(t, "default1", table.Columns()[0].Name)
assert.EqualValues(t, "default2", table.Columns()[1].Name)
assert.EqualValues(t, "'1970-01-01 00:00:00'", table.Columns()[0].Default)
assert.EqualValues(t, "CURRENT_TIMESTAMP", table.Columns()[1].Default)
}
func TestParseWithOnlyToDB(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.GonicMapper{
"DB": true,
},
names.SnakeMapper{},
caches.NewManager(),
)
type StructWithOnlyToDB struct {
Default1 time.Time `db:"->"`
Default2 time.Time `db:"<-"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithOnlyToDB)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_only_to_db", table.Name)
assert.EqualValues(t, 2, len(table.Columns()))
assert.EqualValues(t, "default1", table.Columns()[0].Name)
assert.EqualValues(t, "default2", table.Columns()[1].Name)
assert.EqualValues(t, schemas.ONLYTODB, table.Columns()[0].MapType)
assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType)
}
func TestParseWithJSON(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.GonicMapper{
"JSON": true,
},
names.SnakeMapper{},
caches.NewManager(),
)
type StructWithJSON struct {
Default1 []string `db:"json"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithJSON)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_json", table.Name)
assert.EqualValues(t, 1, len(table.Columns()))
assert.EqualValues(t, "default1", table.Columns()[0].Name)
assert.True(t, table.Columns()[0].IsJSON)
}
func TestParseWithSQLType(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.GonicMapper{
"SQL": true,
},
names.GonicMapper{
"UUID": true,
},
caches.NewManager(),
)
type StructWithSQLType struct {
Col1 string `db:"varchar(32)"`
Col2 string `db:"char(32)"`
Int int64 `db:"bigint"`
DateTime time.Time `db:"datetime"`
UUID string `db:"uuid"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithSQLType)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_sql_type", table.Name)
assert.EqualValues(t, 5, len(table.Columns()))
assert.EqualValues(t, "col1", table.Columns()[0].Name)
assert.EqualValues(t, "col2", table.Columns()[1].Name)
assert.EqualValues(t, "int", table.Columns()[2].Name)
assert.EqualValues(t, "date_time", table.Columns()[3].Name)
assert.EqualValues(t, "uuid", table.Columns()[4].Name)
assert.EqualValues(t, "VARCHAR", table.Columns()[0].SQLType.Name)
assert.EqualValues(t, "CHAR", table.Columns()[1].SQLType.Name)
assert.EqualValues(t, "BIGINT", table.Columns()[2].SQLType.Name)
assert.EqualValues(t, "DATETIME", table.Columns()[3].SQLType.Name)
assert.EqualValues(t, "UUID", table.Columns()[4].SQLType.Name)
}

View File

@ -14,30 +14,74 @@ import (
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
func splitTag(tag string) (tags []string) { type tag struct {
tag = strings.TrimSpace(tag) name string
var hasQuote = false params []string
var lastIdx = 0 }
for i, t := range tag {
if t == '\'' { func splitTag(tagStr string) ([]tag, error) {
hasQuote = !hasQuote tagStr = strings.TrimSpace(tagStr)
} else if t == ' ' { var (
if lastIdx < i && !hasQuote { inQuote bool
tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) inBigQuote bool
lastIdx = i + 1 lastIdx int
curTag tag
paramStart int
tags []tag
)
for i, t := range tagStr {
switch t {
case '\'':
inQuote = !inQuote
case ' ':
if !inQuote && !inBigQuote {
if lastIdx < i {
if curTag.name == "" {
curTag.name = tagStr[lastIdx:i]
}
tags = append(tags, curTag)
lastIdx = i + 1
curTag = tag{}
} else if lastIdx == i {
lastIdx = i + 1
}
} else if inBigQuote && !inQuote {
paramStart = i + 1
}
case ',':
if !inQuote && !inBigQuote {
return nil, fmt.Errorf("comma[%d] of %s should be in quote or big quote", i, tagStr)
}
if !inQuote && inBigQuote {
curTag.params = append(curTag.params, strings.TrimSpace(tagStr[paramStart:i]))
paramStart = i + 1
}
case '(':
inBigQuote = true
if !inQuote {
curTag.name = tagStr[lastIdx:i]
paramStart = i + 1
}
case ')':
inBigQuote = false
if !inQuote {
curTag.params = append(curTag.params, tagStr[paramStart:i])
} }
} }
} }
if lastIdx < len(tag) { if lastIdx < len(tagStr) {
tags = append(tags, strings.TrimSpace(tag[lastIdx:])) if curTag.name == "" {
curTag.name = tagStr[lastIdx:]
}
tags = append(tags, curTag)
} }
return return tags, nil
} }
// Context represents a context for xorm tag parse. // Context represents a context for xorm tag parse.
type Context struct { type Context struct {
tagName string tag
params []string tagUname string
preTag, nextTag string preTag, nextTag string
table *schemas.Table table *schemas.Table
col *schemas.Column col *schemas.Column
@ -76,6 +120,7 @@ var (
"CACHE": CacheTagHandler, "CACHE": CacheTagHandler,
"NOCACHE": NoCacheTagHandler, "NOCACHE": NoCacheTagHandler,
"COMMENT": CommentTagHandler, "COMMENT": CommentTagHandler,
"EXTENDS": ExtendsTagHandler,
} }
) )
@ -124,6 +169,7 @@ func NotNullTagHandler(ctx *Context) error {
// AutoIncrTagHandler describes autoincr tag handler // AutoIncrTagHandler describes autoincr tag handler
func AutoIncrTagHandler(ctx *Context) error { func AutoIncrTagHandler(ctx *Context) error {
ctx.col.IsAutoIncrement = true ctx.col.IsAutoIncrement = true
ctx.col.Nullable = false
/* /*
if len(ctx.params) > 0 { if len(ctx.params) > 0 {
autoStartInt, err := strconv.Atoi(ctx.params[0]) autoStartInt, err := strconv.Atoi(ctx.params[0])
@ -192,6 +238,7 @@ func UpdatedTagHandler(ctx *Context) error {
// DeletedTagHandler describes deleted tag handler // DeletedTagHandler describes deleted tag handler
func DeletedTagHandler(ctx *Context) error { func DeletedTagHandler(ctx *Context) error {
ctx.col.IsDeleted = true ctx.col.IsDeleted = true
ctx.col.Nullable = true
return nil return nil
} }
@ -225,41 +272,44 @@ func CommentTagHandler(ctx *Context) error {
// SQLTypeTagHandler describes SQL Type tag handler // SQLTypeTagHandler describes SQL Type tag handler
func SQLTypeTagHandler(ctx *Context) error { func SQLTypeTagHandler(ctx *Context) error {
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName} ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname}
if strings.EqualFold(ctx.tagName, "JSON") { if ctx.tagUname == "JSON" {
ctx.col.IsJSON = true ctx.col.IsJSON = true
} }
if len(ctx.params) > 0 { if len(ctx.params) == 0 {
if ctx.tagName == schemas.Enum { return nil
ctx.col.EnumOptions = make(map[string]int) }
for k, v := range ctx.params {
v = strings.TrimSpace(v) switch ctx.tagUname {
v = strings.Trim(v, "'") case schemas.Enum:
ctx.col.EnumOptions[v] = k ctx.col.EnumOptions = make(map[string]int)
for k, v := range ctx.params {
v = strings.TrimSpace(v)
v = strings.Trim(v, "'")
ctx.col.EnumOptions[v] = k
}
case schemas.Set:
ctx.col.SetOptions = make(map[string]int)
for k, v := range ctx.params {
v = strings.TrimSpace(v)
v = strings.Trim(v, "'")
ctx.col.SetOptions[v] = k
}
default:
var err error
if len(ctx.params) == 2 {
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
if err != nil {
return err
} }
} else if ctx.tagName == schemas.Set { ctx.col.Length2, err = strconv.Atoi(ctx.params[1])
ctx.col.SetOptions = make(map[string]int) if err != nil {
for k, v := range ctx.params { return err
v = strings.TrimSpace(v)
v = strings.Trim(v, "'")
ctx.col.SetOptions[v] = k
} }
} else { } else if len(ctx.params) == 1 {
var err error ctx.col.Length, err = strconv.Atoi(ctx.params[0])
if len(ctx.params) == 2 { if err != nil {
ctx.col.Length, err = strconv.Atoi(ctx.params[0]) return err
if err != nil {
return err
}
ctx.col.Length2, err = strconv.Atoi(ctx.params[1])
if err != nil {
return err
}
} else if len(ctx.params) == 1 {
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
if err != nil {
return err
}
} }
} }
} }
@ -289,11 +339,12 @@ func ExtendsTagHandler(ctx *Context) error {
} }
for _, col := range parentTable.Columns() { for _, col := range parentTable.Columns() {
col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName)
col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...)
var tagPrefix = ctx.col.FieldName var tagPrefix = ctx.col.FieldName
if len(ctx.params) > 0 { if len(ctx.params) > 0 {
col.Nullable = isPtr col.Nullable = isPtr
tagPrefix = ctx.params[0] tagPrefix = strings.Trim(ctx.params[0], "'")
if col.IsPrimaryKey { if col.IsPrimaryKey {
col.Name = ctx.col.FieldName col.Name = ctx.col.FieldName
col.IsPrimaryKey = false col.IsPrimaryKey = false
@ -315,7 +366,7 @@ func ExtendsTagHandler(ctx *Context) error {
default: default:
//TODO: warning //TODO: warning
} }
return nil return ErrIgnoreField
} }
// CacheTagHandler describes cache tag handler // CacheTagHandler describes cache tag handler

View File

@ -7,24 +7,83 @@ package tags
import ( import (
"testing" "testing"
"xorm.io/xorm/internal/utils" "github.com/stretchr/testify/assert"
) )
func TestSplitTag(t *testing.T) { func TestSplitTag(t *testing.T) {
var cases = []struct { var cases = []struct {
tag string tag string
tags []string tags []tag
}{ }{
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, {"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{
{"TEXT", []string{"TEXT"}}, {
{"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, name: "not",
{"json binary", []string{"json", "binary"}}, },
{
name: "null",
},
{
name: "default",
},
{
name: "'2000-01-01 00:00:00'",
},
{
name: "TIMESTAMP",
},
},
},
{"TEXT", []tag{
{
name: "TEXT",
},
},
},
{"default('2000-01-01 00:00:00')", []tag{
{
name: "default",
params: []string{
"'2000-01-01 00:00:00'",
},
},
},
},
{"json binary", []tag{
{
name: "json",
},
{
name: "binary",
},
},
},
{"numeric(10, 2)", []tag{
{
name: "numeric",
params: []string{"10", "2"},
},
},
},
{"numeric(10, 2) notnull", []tag{
{
name: "numeric",
params: []string{"10", "2"},
},
{
name: "notnull",
},
},
},
} }
for _, kase := range cases { for _, kase := range cases {
tags := splitTag(kase.tag) t.Run(kase.tag, func(t *testing.T) {
if !utils.SliceEq(tags, kase.tags) { tags, err := splitTag(kase.tag)
t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) assert.NoError(t, err)
} assert.EqualValues(t, len(tags), len(kase.tags))
for i := 0; i < len(tags); i++ {
assert.Equal(t, tags[i], kase.tags[i])
}
})
} }
} }