diff --git a/.drone.yml b/.drone.yml index 9b4ffe9a..8a9f8877 100644 --- a/.drone.yml +++ b/.drone.yml @@ -249,11 +249,11 @@ volumes: services: - name: mssql pull: always - image: microsoft/mssql-server-linux:latest + image: mcr.microsoft.com/mssql/server:latest environment: ACCEPT_EULA: Y SA_PASSWORD: yourStrong(!)Password - MSSQL_PID: Developer + MSSQL_PID: Standard --- kind: pipeline @@ -347,3 +347,19 @@ steps: image: golang:1.15 commands: - make coverage + +--- +kind: pipeline +name: release-tag +trigger: + event: + - tag +steps: +- name: release-tag-gitea + pull: always + image: plugins/gitea-release:latest + settings: + base_url: https://gitea.com + title: '${DRONE_TAG} is released' + api_key: + from_secret: gitea_token \ No newline at end of file diff --git a/.gitignore b/.gitignore index a3fbadd4..a183a295 100644 --- a/.gitignore +++ b/.gitignore @@ -36,4 +36,5 @@ test.db.sql *coverage.out test.db integrations/*.sql -integrations/test_sqlite* \ No newline at end of file +integrations/test_sqlite* +cover.out \ No newline at end of file diff --git a/.revive.toml b/.revive.toml index 6dec7465..9e3b629d 100644 --- a/.revive.toml +++ b/.revive.toml @@ -8,20 +8,22 @@ warningCode = 1 [rule.context-as-argument] [rule.context-keys-type] [rule.dot-imports] +[rule.empty-lines] +[rule.errorf] [rule.error-return] [rule.error-strings] [rule.error-naming] [rule.exported] [rule.if-return] [rule.increment-decrement] -[rule.var-naming] - arguments = [["ID", "UID", "UUID", "URL", "JSON"], []] -[rule.var-declaration] +[rule.indent-error-flow] [rule.package-comments] [rule.range] [rule.receiver-naming] +[rule.struct-tag] [rule.time-naming] [rule.unexported-return] -[rule.indent-error-flow] -[rule.errorf] -[rule.struct-tag] \ No newline at end of file +[rule.unnecessary-stmt] +[rule.var-declaration] +[rule.var-naming] + arguments = [["ID", "UID", "UUID", "URL", "JSON"], []] \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 13e721ec..cd567b27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,41 @@ This changelog goes through all the changes that have been made in each release without substantial changes to our git log. +## [1.1.2](https://gitea.com/xorm/xorm/releases/tag/1.1.2) - 2021-07-04 + +* BUILD + * Add release tag (#1966) + +## [1.1.1](https://gitea.com/xorm/xorm/releases/tag/1.1.1) - 2021-07-03 + +* BUGFIXES + * Ignore comments when deciding when to replace question marks. #1954 (#1955) + * Fix bug didn't reset statement on update (#1939) + * Fix create table with struct missing columns (#1938) + * Fix #929 (#1936) + * Fix exist (#1921) +* ENHANCEMENTS + * Improve get field value of bean (#1961) + * refactor splitTag function (#1960) + * Fix #1663 (#1952) + * fix pg GetColumns missing comment (#1949) + * Support build flag jsoniter to replace default json (#1916) + * refactor exprParam (#1825) + * Add DBVersion (#1723) +* TESTING + * Add test to confirm #1247 resolved (#1951) + * Add test for dump table with default value (#1950) + * Test for #1486 (#1942) + * Add sync tests to confirm #539 is gone (#1937) + * test for unsigned int32 (#1923) + * Add tests for array store (#1922) +* BUILD + * Remove mymysql from ci (#1928) +* MISC + * fix lint (#1953) + * Compitable with cockroach (#1930) + * Replace goracle with godror (#1914) + ## [1.1.0](https://gitea.com/xorm/xorm/releases/tag/1.1.0) - 2021-05-14 * FEATURES diff --git a/README.md b/README.md index 67380839..40826f13 100644 --- a/README.md +++ b/README.md @@ -245,35 +245,38 @@ for rows.Next() { ```Go affected, err := engine.ID(1).Update(&user) -// UPDATE user SET ... Where id = ? +// UPDATE user SET ... WHERE id = ? affected, err := engine.Update(&user, &User{Name:name}) -// UPDATE user SET ... Where name = ? +// UPDATE user SET ... WHERE name = ? var ids = []int64{1, 2, 3} affected, err := engine.In("id", ids).Update(&user) -// UPDATE user SET ... Where id IN (?, ?, ?) +// UPDATE user SET ... WHERE id IN (?, ?, ?) // force update indicated columns by Cols affected, err := engine.ID(1).Cols("age").Update(&User{Name:name, Age: 12}) -// UPDATE user SET age = ?, updated=? Where id = ? +// UPDATE user SET age = ?, updated=? WHERE id = ? // force NOT update indicated columns by Omit affected, err := engine.ID(1).Omit("name").Update(&User{Name:name, Age: 12}) -// UPDATE user SET age = ?, updated=? Where id = ? +// UPDATE user SET age = ?, updated=? WHERE id = ? affected, err := engine.ID(1).AllCols().Update(&user) -// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ? +// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? WHERE id = ? ``` * `Delete` delete one or more records, Delete MUST have condition ```Go affected, err := engine.Where(...).Delete(&user) -// DELETE FROM user Where ... +// DELETE FROM user WHERE ... affected, err := engine.ID(2).Delete(&user) -// DELETE FROM user Where id = ? +// DELETE FROM user WHERE id = ? + +affected, err := engine.Table("user").Where(...).Delete() +// DELETE FROM user WHERE ... ``` * `Count` count records diff --git a/README_CN.md b/README_CN.md index 80245dd3..06706417 100644 --- a/README_CN.md +++ b/README_CN.md @@ -271,6 +271,9 @@ affected, err := engine.Where(...).Delete(&user) affected, err := engine.ID(2).Delete(&user) // DELETE FROM user Where id = ? + +affected, err := engine.Table("user").Where(...).Delete() +// DELETE FROM user WHERE ... ``` * `Count` 获取记录条数 diff --git a/convert.go b/convert.go index c19d30e0..f7d733ad 100644 --- a/convert.go +++ b/convert.go @@ -5,12 +5,15 @@ package xorm import ( + "database/sql" "database/sql/driver" "errors" "fmt" "reflect" "strconv" "time" + + "xorm.io/xorm/convert" ) var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error @@ -37,6 +40,12 @@ func asString(src interface{}) string { return v case []byte: return string(v) + case *sql.NullString: + return v.String + case *sql.NullInt32: + return fmt.Sprintf("%d", v.Int32) + case *sql.NullInt64: + return fmt.Sprintf("%d", v.Int64) } rv := reflect.ValueOf(src) switch rv.Kind() { @@ -54,6 +63,156 @@ func asString(src interface{}) string { return fmt.Sprintf("%v", src) } +func asInt64(src interface{}) (int64, error) { + switch v := src.(type) { + case int: + return int64(v), nil + case int16: + return int64(v), nil + case int32: + return int64(v), nil + case int8: + return int64(v), nil + case int64: + return v, nil + case uint: + return int64(v), nil + case uint8: + return int64(v), nil + case uint16: + return int64(v), nil + case uint32: + return int64(v), nil + case uint64: + return int64(v), nil + case []byte: + return strconv.ParseInt(string(v), 10, 64) + case string: + return strconv.ParseInt(v, 10, 64) + case *sql.NullString: + return strconv.ParseInt(v.String, 10, 64) + case *sql.NullInt32: + return int64(v.Int32), nil + case *sql.NullInt64: + return int64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int64(rv.Uint()), nil + case reflect.Float64: + return int64(rv.Float()), nil + case reflect.Float32: + return int64(rv.Float()), nil + case reflect.String: + return strconv.ParseInt(rv.String(), 10, 64) + } + return 0, fmt.Errorf("unsupported value %T as int64", src) +} + +func asUint64(src interface{}) (uint64, error) { + switch v := src.(type) { + case int: + return uint64(v), nil + case int16: + return uint64(v), nil + case int32: + return uint64(v), nil + case int8: + return uint64(v), nil + case int64: + return uint64(v), nil + case uint: + return uint64(v), nil + case uint8: + return uint64(v), nil + case uint16: + return uint64(v), nil + case uint32: + return uint64(v), nil + case uint64: + return v, nil + case []byte: + return strconv.ParseUint(string(v), 10, 64) + case string: + return strconv.ParseUint(v, 10, 64) + case *sql.NullString: + return strconv.ParseUint(v.String, 10, 64) + case *sql.NullInt32: + return uint64(v.Int32), nil + case *sql.NullInt64: + return uint64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return uint64(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uint64(rv.Uint()), nil + case reflect.Float64: + return uint64(rv.Float()), nil + case reflect.Float32: + return uint64(rv.Float()), nil + case reflect.String: + return strconv.ParseUint(rv.String(), 10, 64) + } + return 0, fmt.Errorf("unsupported value %T as uint64", src) +} + +func asFloat64(src interface{}) (float64, error) { + switch v := src.(type) { + case int: + return float64(v), nil + case int16: + return float64(v), nil + case int32: + return float64(v), nil + case int8: + return float64(v), nil + case int64: + return float64(v), nil + case uint: + return float64(v), nil + case uint8: + return float64(v), nil + case uint16: + return float64(v), nil + case uint32: + return float64(v), nil + case uint64: + return float64(v), nil + case []byte: + return strconv.ParseFloat(string(v), 64) + case string: + return strconv.ParseFloat(v, 64) + case *sql.NullString: + return strconv.ParseFloat(v.String, 64) + case *sql.NullInt32: + return float64(v.Int32), nil + case *sql.NullInt64: + return float64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return float64(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return float64(rv.Uint()), nil + case reflect.Float64: + return float64(rv.Float()), nil + case reflect.Float32: + return float64(rv.Float()), nil + case reflect.String: + return strconv.ParseFloat(rv.String(), 64) + } + return 0, fmt.Errorf("unsupported value %T as int64", src) +} + func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -76,7 +235,7 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { // convertAssign copies to dest the value in src, converting it if possible. // An error is returned if the copy would result in loss of information. // dest should be a pointer type. -func convertAssign(dest, src interface{}) error { +func convertAssign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error { // Common cases, without reflect. switch s := src.(type) { case string: @@ -143,6 +302,163 @@ func convertAssign(dest, src interface{}) error { *d = nil return nil } + case *sql.NullString: + switch d := dest.(type) { + case *int: + if s.Valid { + *d, _ = strconv.Atoi(s.String) + } + case *int64: + if s.Valid { + *d, _ = strconv.ParseInt(s.String, 10, 64) + } + case *string: + if s.Valid { + *d = s.String + } + return nil + case *time.Time: + if s.Valid { + var err error + dt, err := convert.String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + *d = *dt + } + return nil + case *sql.NullTime: + if s.Valid { + var err error + dt, err := convert.String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + d.Valid = true + d.Time = *dt + } + } + case *sql.NullInt32: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int32) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int32) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int32) + } + return nil + case *int32: + if s.Valid { + *d = s.Int32 + } + return nil + case *int64: + if s.Valid { + *d = int64(s.Int32) + } + return nil + } + case *sql.NullInt64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int64) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int64) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int64) + } + return nil + case *int32: + if s.Valid { + *d = int32(s.Int64) + } + return nil + case *int64: + if s.Valid { + *d = s.Int64 + } + return nil + } + case *sql.NullFloat64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Float64) + } + return nil + case *float64: + if s.Valid { + *d = s.Float64 + } + return nil + } + case *sql.NullBool: + switch d := dest.(type) { + case *bool: + if s.Valid { + *d = s.Bool + } + return nil + } + case *sql.NullTime: + switch d := dest.(type) { + case *time.Time: + if s.Valid { + *d = s.Time + } + return nil + case *string: + if s.Valid { + *d = s.Time.In(convertedLocation).Format("2006-01-02 15:04:05") + } + return nil + } + case *NullUint32: + switch d := dest.(type) { + case *uint8: + if s.Valid { + *d = uint8(s.Uint32) + } + return nil + case *uint16: + if s.Valid { + *d = uint16(s.Uint32) + } + return nil + case *uint: + if s.Valid { + *d = uint(s.Uint32) + } + return nil + } + case *NullUint64: + switch d := dest.(type) { + case *uint64: + if s.Valid { + *d = s.Uint64 + } + return nil + } + case *sql.RawBytes: + switch d := dest.(type) { + case convert.Conversion: + return d.FromDB(*s) + } } var sv reflect.Value @@ -175,7 +491,10 @@ func convertAssign(dest, src interface{}) error { return nil } - dpv := reflect.ValueOf(dest) + return convertAssignV(reflect.ValueOf(dest), src, originalLocation, convertedLocation) +} + +func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, convertedLocation *time.Location) error { if dpv.Kind() != reflect.Ptr { return errors.New("destination not a pointer") } @@ -183,9 +502,7 @@ func convertAssign(dest, src interface{}) error { return errNilPtr } - if !sv.IsValid() { - sv = reflect.ValueOf(src) - } + var sv = reflect.ValueOf(src) dv := reflect.Indirect(dpv) if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { @@ -211,31 +528,28 @@ func convertAssign(dest, src interface{}) error { } dv.Set(reflect.New(dv.Type().Elem())) - return convertAssign(dv.Interface(), src) + return convertAssign(dv.Interface(), src, originalLocation, convertedLocation) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - s := asString(src) - i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + i64, err := asInt64(src) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) } dv.SetInt(i64) return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - s := asString(src) - u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + u64, err := asUint64(src) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) } dv.SetUint(u64) return nil case reflect.Float32, reflect.Float64: - s := asString(src) - f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + f64, err := asFloat64(src) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) } dv.SetFloat(f64) return nil @@ -244,7 +558,7 @@ func convertAssign(dest, src interface{}) error { return nil } - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dpv.Interface()) } func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { @@ -376,47 +690,79 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) { return v.Interface(), nil } -func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { - var v interface{} - kind := tp.Kind() +var ( + _ sql.Scanner = &NullUint64{} +) - if kind == reflect.Ptr { - kind = tp.Elem().Kind() - } - - switch kind { - case reflect.Int16: - temp := int16(id) - v = &temp - case reflect.Int32: - temp := int32(id) - v = &temp - case reflect.Int: - temp := int(id) - v = &temp - case reflect.Int64: - temp := id - v = &temp - case reflect.Uint16: - temp := uint16(id) - v = &temp - case reflect.Uint32: - temp := uint32(id) - v = &temp - case reflect.Uint64: - temp := uint64(id) - v = &temp - case reflect.Uint: - temp := uint(id) - v = &temp - } - - if tp.Kind() == reflect.Ptr { - return reflect.ValueOf(v).Convert(tp) - } - return reflect.ValueOf(v).Elem().Convert(tp) +// NullUint64 represents an uint64 that may be null. +// NullUint64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint64 struct { + Uint64 uint64 + Valid bool } -func int64ToInt(id int64, tp reflect.Type) interface{} { - return int64ToIntValue(id, tp).Interface() +// Scan implements the Scanner interface. +func (n *NullUint64) Scan(value interface{}) error { + if value == nil { + n.Uint64, n.Valid = 0, false + return nil + } + n.Valid = true + var err error + n.Uint64, err = asUint64(value) + return err +} + +// Value implements the driver Valuer interface. +func (n NullUint64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Uint64, nil +} + +var ( + _ sql.Scanner = &NullUint32{} +) + +// NullUint32 represents an uint32 that may be null. +// NullUint32 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint32 struct { + Uint32 uint32 + Valid bool // Valid is true if Uint32 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullUint32) Scan(value interface{}) error { + if value == nil { + n.Uint32, n.Valid = 0, false + return nil + } + n.Valid = true + i64, err := asUint64(value) + if err != nil { + return err + } + n.Uint32 = uint32(i64) + return nil +} + +// Value implements the driver Valuer interface. +func (n NullUint32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Uint32), nil +} + +var ( + _ sql.Scanner = &EmptyScanner{} +) + +type EmptyScanner struct{} + +func (EmptyScanner) Scan(value interface{}) error { + return nil } diff --git a/convert/interface.go b/convert/interface.go new file mode 100644 index 00000000..2b055253 --- /dev/null +++ b/convert/interface.go @@ -0,0 +1,48 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "fmt" + "time" +) + +func Interface2Interface(userLocation *time.Location, v interface{}) (interface{}, error) { + if v == nil { + return nil, nil + } + switch vv := v.(type) { + case *int64: + return *vv, nil + case *int8: + return *vv, nil + case *sql.NullString: + return vv.String, nil + case *sql.RawBytes: + if len([]byte(*vv)) > 0 { + return []byte(*vv), nil + } + return nil, nil + case *sql.NullInt32: + return vv.Int32, nil + case *sql.NullInt64: + return vv.Int64, nil + case *sql.NullFloat64: + return vv.Float64, nil + case *sql.NullBool: + if vv.Valid { + return vv.Bool, nil + } + return nil, nil + case *sql.NullTime: + if vv.Valid { + return vv.Time.In(userLocation).Format("2006-01-02 15:04:05"), nil + } + return "", nil + default: + return "", fmt.Errorf("convert assign string unsupported type: %#v", vv) + } +} diff --git a/convert/time.go b/convert/time.go new file mode 100644 index 00000000..8901279b --- /dev/null +++ b/convert/time.go @@ -0,0 +1,30 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "fmt" + "time" +) + +// String2Time converts a string to time with original location +func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) { + if len(s) == 19 { + dt, err := time.ParseInLocation("2006-01-02 15:04:05", s, originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { + dt, err := time.ParseInLocation("2006-01-02T15:04:05Z", s, originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } + return nil, fmt.Errorf("unsupported convertion from %s to time", s) +} diff --git a/dialects/driver.go b/dialects/driver.go index bb46a936..0b6187d3 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -5,12 +5,29 @@ package dialects import ( + "database/sql" "fmt" + "time" + + "xorm.io/xorm/core" ) +// ScanContext represents a context when Scan +type ScanContext struct { + DBLocation *time.Location + UserLocation *time.Location +} + +type DriverFeatures struct { + SupportNullable bool +} + // Driver represents a database driver type Driver interface { Parse(string, string) (*URI, error) + Features() DriverFeatures + GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface + Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error } var ( @@ -59,3 +76,15 @@ func OpenDialect(driverName, connstr string) (Dialect, error) { return dialect, nil } + +type baseDriver struct{} + +func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error { + return rows.Scan(v...) +} + +func (b *baseDriver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: true, + } +} diff --git a/dialects/filter.go b/dialects/filter.go index 2a36a731..bfe2e93e 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -23,13 +23,45 @@ type SeqFilter struct { func convertQuestionMark(sql, prefix string, start int) string { var buf strings.Builder var beginSingleQuote bool + var isLineComment bool + var isComment bool + var isMaybeLineComment bool + var isMaybeComment bool + var isMaybeCommentEnd bool var index = start for _, c := range sql { - if !beginSingleQuote && c == '?' { + if !beginSingleQuote && !isLineComment && !isComment && c == '?' { buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) index++ } else { - if c == '\'' { + if isMaybeLineComment { + if c == '-' { + isLineComment = true + } + isMaybeLineComment = false + } else if isMaybeComment { + if c == '*' { + isComment = true + } + isMaybeComment = false + } else if isMaybeCommentEnd { + if c == '/' { + isComment = false + } + isMaybeCommentEnd = false + } else if isLineComment { + if c == '\n' { + isLineComment = false + } + } else if isComment { + if c == '*' { + isMaybeCommentEnd = true + } + } else if !beginSingleQuote && c == '-' { + isMaybeLineComment = true + } else if !beginSingleQuote && c == '/' { + isMaybeComment = true + } else if c == '\'' { beginSingleQuote = !beginSingleQuote } buf.WriteRune(c) diff --git a/dialects/filter_test.go b/dialects/filter_test.go index 7e2ef0a2..15050656 100644 --- a/dialects/filter_test.go +++ b/dialects/filter_test.go @@ -19,3 +19,60 @@ func TestSeqFilter(t *testing.T) { assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) } } + +func TestSeqFilterLineComment(t *testing.T) { + var kases = map[string]string{ + `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=? -- it's a comment + AND b=?`: `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=$1 -- it's a comment + AND b=$2`, + `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=? -- it's a comment? + AND b=?`: `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=$1 -- it's a comment? + AND b=$2`, + `SELECT * + FROM TABLE1 + WHERE a=? -- it's a comment? and that's okay? + AND b=?`: `SELECT * + FROM TABLE1 + WHERE a=$1 -- it's a comment? and that's okay? + AND b=$2`, + } + for sql, result := range kases { + assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + } +} + +func TestSeqFilterComment(t *testing.T) { + var kases = map[string]string{ + `SELECT * + FROM TABLE1 + WHERE a=? /* it's a comment */ + AND b=?`: `SELECT * + FROM TABLE1 + WHERE a=$1 /* it's a comment */ + AND b=$2`, + `SELECT /* it's a comment * ? + More comment on the next line! */ * + FROM TABLE1 + WHERE a=? /**/ + AND b=?`: `SELECT /* it's a comment * ? + More comment on the next line! */ * + FROM TABLE1 + WHERE a=$1 /**/ + AND b=$2`, + } + for sql, result := range kases { + assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + } +} diff --git a/dialects/mssql.go b/dialects/mssql.go index 7e922e62..c3c15077 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "net/url" @@ -624,6 +625,7 @@ func (db *mssql) Filters() []Filter { } type odbcDriver struct { + baseDriver } func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -652,3 +654,26 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { } return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil } + +func (p *odbcDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT": + fallthrough + case "DATE", "DATETIME", "DATETIME2", "TIME": + var s sql.NullString + return &s, nil + case "FLOAT", "REAL": + var s sql.NullFloat64 + return &s, nil + case "BIGINT", "DATETIMEOFFSET": + var s sql.NullInt64 + return &s, nil + case "TINYINT", "SMALLINT", "INT": + var s sql.NullInt32 + return &s, nil + + default: + var r sql.RawBytes + return &r, nil + } +} diff --git a/dialects/mysql.go b/dialects/mysql.go index a169b901..a341ce05 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -7,6 +7,7 @@ package dialects import ( "context" "crypto/tls" + "database/sql" "errors" "fmt" "regexp" @@ -14,6 +15,7 @@ import ( "strings" "time" + "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/schemas" ) @@ -630,7 +632,125 @@ func (db *mysql) Filters() []Filter { return []Filter{} } +type mysqlDriver struct { + baseDriver +} + +func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { + dsnPattern := regexp.MustCompile( + `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] + `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] + `\/(?P.*?)` + // /dbname + `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] + matches := dsnPattern.FindStringSubmatch(dataSourceName) + // tlsConfigRegister := make(map[string]*tls.Config) + names := dsnPattern.SubexpNames() + + uri := &URI{DBType: schemas.MYSQL} + + for i, match := range matches { + switch names[i] { + case "dbname": + uri.DBName = match + case "params": + if len(match) > 0 { + kvs := strings.Split(match, "&") + for _, kv := range kvs { + splits := strings.Split(kv, "=") + if len(splits) == 2 { + switch splits[0] { + case "charset": + uri.Charset = splits[1] + } + } + } + } + + } + } + return uri, nil +} + +func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET": + var s sql.NullString + return &s, nil + case "BIGINT": + var s sql.NullInt64 + return &s, nil + case "TINYINT", "SMALLINT", "MEDIUMINT", "INT": + var s sql.NullInt32 + return &s, nil + case "FLOAT", "REAL", "DOUBLE PRECISION": + var s sql.NullFloat64 + return &s, nil + case "DECIMAL", "NUMERIC": + var s sql.NullString + return &s, nil + case "DATETIME": + var s sql.NullTime + return &s, nil + case "BIT": + var s sql.RawBytes + return &s, nil + case "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB": + var r sql.RawBytes + return &r, nil + default: + var r sql.RawBytes + return &r, nil + } +} + +func (p *mysqlDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, scanResults ...interface{}) error { + var v2 = make([]interface{}, 0, len(scanResults)) + var turnBackIdxes = make([]int, 0, 5) + for i, vv := range scanResults { + switch vv.(type) { + case *time.Time: + v2 = append(v2, &sql.NullString{}) + turnBackIdxes = append(turnBackIdxes, i) + case *sql.NullTime: + v2 = append(v2, &sql.NullString{}) + turnBackIdxes = append(turnBackIdxes, i) + default: + v2 = append(v2, scanResults[i]) + } + } + if err := rows.Scan(v2...); err != nil { + return err + } + for _, i := range turnBackIdxes { + switch t := scanResults[i].(type) { + case *time.Time: + var s = *(v2[i].(*sql.NullString)) + if !s.Valid { + break + } + dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation) + if err != nil { + return err + } + *t = *dt + case *sql.NullTime: + var s = *(v2[i].(*sql.NullString)) + if !s.Valid { + break + } + dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation) + if err != nil { + return err + } + t.Time = *dt + t.Valid = true + } + } + return nil +} + type mymysqlDriver struct { + mysqlDriver } func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -681,41 +801,3 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { return uri, nil } - -type mysqlDriver struct { -} - -func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { - dsnPattern := regexp.MustCompile( - `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] - `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] - `\/(?P.*?)` + // /dbname - `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] - matches := dsnPattern.FindStringSubmatch(dataSourceName) - // tlsConfigRegister := make(map[string]*tls.Config) - names := dsnPattern.SubexpNames() - - uri := &URI{DBType: schemas.MYSQL} - - for i, match := range matches { - switch names[i] { - case "dbname": - uri.DBName = match - case "params": - if len(match) > 0 { - kvs := strings.Split(match, "&") - for _, kv := range kvs { - splits := strings.Split(kv, "=") - if len(splits) == 2 { - switch splits[0] { - case "charset": - uri.Charset = splits[1] - } - } - } - } - - } - } - return uri, nil -} diff --git a/dialects/oracle.go b/dialects/oracle.go index 0b06c4c6..7043972b 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "regexp" @@ -823,6 +824,7 @@ func (db *oracle) Filters() []Filter { } type godrorDriver struct { + baseDriver } func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -848,7 +850,28 @@ func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) return db, nil } +func (p *godrorDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": + var s sql.NullString + return &s, nil + case "NUMBER": + var s sql.NullString + return &s, nil + case "DATE": + var s sql.NullTime + return &s, nil + case "BLOB": + var r sql.RawBytes + return &r, nil + default: + var r sql.RawBytes + return &r, nil + } +} + type oci8Driver struct { + godrorDriver } // dataSourceName=user/password@ipv4:port/dbname diff --git a/dialects/postgres.go b/dialects/postgres.go index 52c88567..a2611c60 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "net/url" @@ -1044,12 +1045,13 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} - s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, + s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, description, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey FROM pg_attribute f JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum + LEFT JOIN pg_description de ON f.attrelid=de.objoid AND f.attnum=de.objsubid LEFT JOIN pg_namespace n ON n.oid = c.relnamespace LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid @@ -1078,9 +1080,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A col.Indexes = make(map[string]int) var colName, isNullable, dataType string - var maxLenStr, colDefault *string + var maxLenStr, colDefault, description *string var isPK, isUnique bool - err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &isPK, &isUnique) + err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &description, &isPK, &isUnique) if err != nil { return nil, nil, err } @@ -1126,6 +1128,10 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A col.DefaultIsEmpty = true } + if description != nil { + col.Comment = *description + } + if isPK { col.IsPrimaryKey = true } @@ -1293,6 +1299,13 @@ func (db *postgres) Filters() []Filter { } type pqDriver struct { + baseDriver +} + +func (b *pqDriver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: false, + } } type values map[string]string @@ -1369,6 +1382,36 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { return db, nil } +func (p *pqDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "VARCHAR", "TEXT": + var s sql.NullString + return &s, nil + case "BIGINT": + var s sql.NullInt64 + return &s, nil + case "TINYINT", "INT", "INT8", "INT4": + var s sql.NullInt32 + return &s, nil + case "FLOAT", "FLOAT4": + var s sql.NullFloat64 + return &s, nil + case "DATETIME", "TIMESTAMP": + var s sql.NullTime + return &s, nil + case "BIT": + var s sql.RawBytes + return &s, nil + case "BOOL": + var s sql.NullBool + return &s, nil + default: + fmt.Printf("unknow postgres database type: %v\n", colType) + var r sql.RawBytes + return &r, nil + } +} + type pqDriverPgx struct { pqDriver } diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index a42aad48..1bc0b218 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -540,6 +540,7 @@ func (db *sqlite3) Filters() []Filter { } type sqlite3Driver struct { + baseDriver } func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -549,3 +550,35 @@ func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil } + +func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "TEXT": + var s sql.NullString + return &s, nil + case "INTEGER": + var s sql.NullInt64 + return &s, nil + case "DATETIME": + var s sql.NullTime + return &s, nil + case "REAL": + var s sql.NullFloat64 + return &s, nil + case "NUMERIC": + var s sql.NullString + return &s, nil + case "BLOB": + var s sql.RawBytes + return &s, nil + default: + var r sql.NullString + return &r, nil + } +} + +func (b *sqlite3Driver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: false, + } +} diff --git a/engine.go b/engine.go index 649ec1a2..a45771a2 100644 --- a/engine.go +++ b/engine.go @@ -35,6 +35,7 @@ type Engine struct { cacherMgr *caches.Manager defaultContext context.Context dialect dialects.Dialect + driver dialects.Driver engineGroup *EngineGroup logger log.ContextLogger tagParser *tags.Parser @@ -72,6 +73,7 @@ func newEngine(driverName, dataSourceName string, dialect dialects.Dialect, db * engine := &Engine{ dialect: dialect, + driver: dialects.QueryDriver(driverName), TZLocation: time.Local, defaultContext: context.Background(), cacherMgr: cacherMgr, @@ -444,7 +446,7 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return engine.dumpTables(tables, w, tp...) } -func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string { +func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string { if d == nil { return "NULL" } @@ -473,10 +475,8 @@ func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas. return "'" + strings.Replace(v, "'", "''", -1) + "'" } else if col.SQLType.IsTime() { - if dstDialect.URI().DBType == schemas.MSSQL && col.SQLType.Name == schemas.DateTime { - if t, ok := d.(time.Time); ok { - return "'" + t.UTC().Format("2006-01-02 15:04:05") + "'" - } + if t, ok := d.(time.Time); ok { + return "'" + t.In(dbLocation).Format("2006-01-02 15:04:05") + "'" } var v = fmt.Sprintf("%s", d) if strings.HasSuffix(v, " +0000 UTC") { @@ -652,12 +652,8 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return errors.New("unknown column error") } - fields := strings.Split(col.FieldName, ".") - field := dataStruct - for _, fieldName := range fields { - field = field.FieldByName(fieldName) - } - temp += "," + formatColumnValue(dstDialect, field.Interface(), col) + field := dataStruct.FieldByIndex(col.FieldIndex) + temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, field.Interface(), col) } _, err = io.WriteString(w, temp[1:]+");\n") if err != nil { @@ -684,7 +680,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return errors.New("unknow column error") } - temp += "," + formatColumnValue(dstDialect, d, col) + temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col) } _, err = io.WriteString(w, temp[1:]+");\n") if err != nil { @@ -1206,10 +1202,10 @@ func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64 } // Delete records, bean's non-empty fields are conditions -func (engine *Engine) Delete(bean interface{}) (int64, error) { +func (engine *Engine) Delete(beans ...interface{}) (int64, error) { session := engine.NewSession() defer session.Close() - return session.Delete(bean) + return session.Delete(beans...) } // Get retrieve one record from table, bean's non-empty fields diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 9b70f9b5..a06d91aa 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -176,6 +176,23 @@ func TestDumpTables(t *testing.T) { } } +func TestDumpTables2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestDumpTableStruct2 struct { + Id int64 + Created time.Time `xorm:"Default CURRENT_TIMESTAMP"` + } + + assertSync(t, new(TestDumpTableStruct2)) + + fp := fmt.Sprintf("./dump2-%v-table.sql", testEngine.Dialect().URI().DBType) + os.Remove(fp) + tb, err := testEngine.TableInfo(new(TestDumpTableStruct2)) + assert.NoError(t, err) + assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, fp)) +} + func TestSetSchema(t *testing.T) { assert.NoError(t, PrepareEngine()) @@ -209,3 +226,39 @@ func TestDBVersion(t *testing.T) { fmt.Println(testEngine.Dialect().URI().DBType, "version is", version) } + +func TestGetColumns(t *testing.T) { + if testEngine.Dialect().URI().DBType != schemas.POSTGRES { + t.Skip() + return + } + type TestCommentStruct struct { + HasComment int + NoComment int + } + + assertSync(t, new(TestCommentStruct)) + + comment := "this is a comment" + sql := fmt.Sprintf("comment on column %s.%s is '%s'", testEngine.TableName(new(TestCommentStruct), true), "has_comment", comment) + _, err := testEngine.Exec(sql) + assert.NoError(t, err) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + tableName := testEngine.GetColumnMapper().Obj2Table("TestCommentStruct") + var hasComment, noComment string + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("has_comment") + assert.NotNil(t, col) + hasComment = col.Comment + col2 := table.GetColumn("no_comment") + assert.NotNil(t, col2) + noComment = col2.Comment + break + } + } + assert.Equal(t, comment, hasComment) + assert.Zero(t, noComment) +} diff --git a/integrations/session_delete_test.go b/integrations/session_delete_test.go index cc7e861d..56f6f5b8 100644 --- a/integrations/session_delete_test.go +++ b/integrations/session_delete_test.go @@ -241,3 +241,28 @@ func TestUnscopeDelete(t *testing.T) { assert.NoError(t, err) assert.False(t, has) } + +func TestDelete2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoDelete2 struct { + Uid int64 `xorm:"id pk not null autoincr"` + IsMan bool + } + + assert.NoError(t, testEngine.Sync2(new(UserinfoDelete2))) + + user := UserinfoDelete2{} + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Table("userinfo_delete2").In("id", []int{1}).Delete() + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + user2 := UserinfoDelete2{} + has, err := testEngine.ID(1).Get(&user2) + assert.NoError(t, err) + assert.False(t, has) +} diff --git a/integrations/session_find_test.go b/integrations/session_find_test.go index 0ea12e26..80f3b72c 100644 --- a/integrations/session_find_test.go +++ b/integrations/session_find_test.go @@ -406,16 +406,16 @@ func TestFindMapPtrString(t *testing.T) { assert.NoError(t, err) } -func TestFindBit(t *testing.T) { - type FindBitStruct struct { +func TestFindBool(t *testing.T) { + type FindBoolStruct struct { Id int64 - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, PrepareEngine()) - assertSync(t, new(FindBitStruct)) + assertSync(t, new(FindBoolStruct)) - cnt, err := testEngine.Insert([]FindBitStruct{ + cnt, err := testEngine.Insert([]FindBoolStruct{ { Msg: false, }, @@ -426,14 +426,13 @@ func TestFindBit(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, cnt) - var results = make([]FindBitStruct, 0, 2) + var results = make([]FindBoolStruct, 0, 2) err = testEngine.Find(&results) assert.NoError(t, err) assert.EqualValues(t, 2, len(results)) } func TestFindMark(t *testing.T) { - type Mark struct { Mark1 string `xorm:"VARCHAR(1)"` Mark2 string `xorm:"VARCHAR(1)"` @@ -468,7 +467,7 @@ func TestFindAndCountOneFunc(t *testing.T) { type FindAndCountStruct struct { Id int64 Content string - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, PrepareEngine()) diff --git a/integrations/session_insert_test.go b/integrations/session_insert_test.go index eaa1b2c7..a023ab72 100644 --- a/integrations/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -32,7 +32,6 @@ func TestInsertOne(t *testing.T) { } func TestInsertMulti(t *testing.T) { - assert.NoError(t, PrepareEngine()) type TestMulti struct { Id int64 `xorm:"int(11) pk"` @@ -78,7 +77,6 @@ func insertMultiDatas(step int, datas interface{}) (num int64, err error) { } func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) error) (err error) { - sliceValue := reflect.Indirect(reflect.ValueOf(datas)) if sliceValue.Kind() != reflect.Slice { return fmt.Errorf("not slice") @@ -170,17 +168,17 @@ func TestInsertAutoIncr(t *testing.T) { assert.Greater(t, user.Uid, int64(0)) } -type DefaultInsert struct { - Id int64 - Status int `xorm:"default -1"` - Name string - Created time.Time `xorm:"created"` - Updated time.Time `xorm:"updated"` -} - func TestInsertDefault(t *testing.T) { assert.NoError(t, PrepareEngine()) + type DefaultInsert struct { + Id int64 + Status int `xorm:"default -1"` + Name string + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` + } + di := new(DefaultInsert) err := testEngine.Sync2(di) assert.NoError(t, err) @@ -197,16 +195,16 @@ func TestInsertDefault(t *testing.T) { assert.EqualValues(t, di2.Created.Unix(), di.Created.Unix()) } -type DefaultInsert2 struct { - Id int64 - Name string - Url string `xorm:"text"` - CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00' TIMESTAMP"` -} - func TestInsertDefault2(t *testing.T) { assert.NoError(t, PrepareEngine()) + type DefaultInsert2 struct { + Id int64 + Name string + Url string `xorm:"text"` + CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00' TIMESTAMP"` + } + di := new(DefaultInsert2) err := testEngine.Sync2(di) assert.NoError(t, err) @@ -1026,3 +1024,44 @@ func TestInsertIntSlice(t *testing.T) { assert.True(t, has) assert.EqualValues(t, v3, v4) } + +func TestInsertDeleted(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type InsertDeletedStructNotRight struct { + ID uint64 `xorm:"'ID' pk autoincr"` + DeletedAt time.Time `xorm:"'DELETED_AT' deleted notnull"` + } + // notnull tag will be ignored + err := testEngine.Sync2(new(InsertDeletedStructNotRight)) + assert.NoError(t, err) + + type InsertDeletedStruct struct { + ID uint64 `xorm:"'ID' pk autoincr"` + DeletedAt time.Time `xorm:"'DELETED_AT' deleted"` + } + + assert.NoError(t, testEngine.Sync2(new(InsertDeletedStruct))) + + var v InsertDeletedStruct + _, err = testEngine.Insert(&v) + assert.NoError(t, err) + + var v2 InsertDeletedStruct + has, err := testEngine.Get(&v2) + assert.NoError(t, err) + assert.True(t, has) + + _, err = testEngine.ID(v.ID).Delete(new(InsertDeletedStruct)) + assert.NoError(t, err) + + var v3 InsertDeletedStruct + has, err = testEngine.Get(&v3) + assert.NoError(t, err) + assert.False(t, has) + + var v4 InsertDeletedStruct + has, err = testEngine.Unscoped().Get(&v4) + assert.NoError(t, err) + assert.True(t, has) +} diff --git a/integrations/session_query_test.go b/integrations/session_query_test.go index 30f2e6ab..ed03ff3e 100644 --- a/integrations/session_query_test.go +++ b/integrations/session_query_test.go @@ -52,7 +52,7 @@ func TestQueryString2(t *testing.T) { type GetVar3 struct { Id int64 `xorm:"autoincr pk"` - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, testEngine.Sync2(new(GetVar3))) @@ -107,6 +107,16 @@ func toFloat64(i interface{}) float64 { return 0 } +func toBool(i interface{}) bool { + switch t := i.(type) { + case int32: + return t > 0 + case bool: + return t + } + return false +} + func TestQueryInterface(t *testing.T) { assert.NoError(t, PrepareEngine()) @@ -132,10 +142,10 @@ func TestQueryInterface(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) - assert.EqualValues(t, 1, toInt64(records[0]["id"])) - assert.Equal(t, "hi", toString(records[0]["msg"])) - assert.EqualValues(t, 28, toInt64(records[0]["age"])) - assert.EqualValues(t, 1.5, toFloat64(records[0]["money"])) + assert.EqualValues(t, int64(1), records[0]["id"]) + assert.Equal(t, "hi", records[0]["msg"]) + assert.EqualValues(t, 28, records[0]["age"]) + assert.EqualValues(t, 1.5, records[0]["money"]) } func TestQueryNoParams(t *testing.T) { @@ -192,7 +202,7 @@ func TestQueryStringNoParam(t *testing.T) { type GetVar4 struct { Id int64 `xorm:"autoincr pk"` - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, testEngine.Sync2(new(GetVar4))) @@ -229,7 +239,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { type GetVar6 struct { Id int64 `xorm:"autoincr pk"` - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, testEngine.Sync2(new(GetVar6))) @@ -266,7 +276,7 @@ func TestQueryInterfaceNoParam(t *testing.T) { type GetVar5 struct { Id int64 `xorm:"autoincr pk"` - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, testEngine.Sync2(new(GetVar5))) @@ -280,14 +290,14 @@ func TestQueryInterfaceNoParam(t *testing.T) { records, err := testEngine.Table("get_var5").Limit(1).QueryInterface() assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) - assert.EqualValues(t, 1, toInt64(records[0]["id"])) - assert.EqualValues(t, 0, toInt64(records[0]["msg"])) + assert.EqualValues(t, 1, records[0]["id"]) + assert.False(t, toBool(records[0]["msg"])) records, err = testEngine.Table("get_var5").Where(builder.Eq{"id": 1}).QueryInterface() assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) - assert.EqualValues(t, 1, toInt64(records[0]["id"])) - assert.EqualValues(t, 0, toInt64(records[0]["msg"])) + assert.EqualValues(t, 1, records[0]["id"]) + assert.False(t, toBool(records[0]["msg"])) } func TestQueryWithBuilder(t *testing.T) { diff --git a/integrations/session_update_test.go b/integrations/session_update_test.go index 15d2f694..796bfa0a 100644 --- a/integrations/session_update_test.go +++ b/integrations/session_update_test.go @@ -472,6 +472,11 @@ func TestUpdateIncrDecr(t *testing.T) { cnt, err = testEngine.ID(col1.Id).Cols(colName).Incr(colName).Update(col1) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + + testEngine.SetColumnMapper(testEngine.GetColumnMapper()) + cnt, err = testEngine.Cols(colName).Decr(colName, 2).ID(col1.Id).Update(new(UpdateIncr)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) } type UpdatedUpdate struct { diff --git a/interface.go b/interface.go index 24f4ccf3..5d68f536 100644 --- a/interface.go +++ b/interface.go @@ -30,7 +30,7 @@ type Interface interface { CreateUniques(bean interface{}) error Decr(column string, arg ...interface{}) *Session Desc(...string) *Session - Delete(interface{}) (int64, error) + Delete(...interface{}) (int64, error) Distinct(columns ...string) *Session DropIndexes(bean interface{}) error Exec(sqlOrArgs ...interface{}) (sql.Result, error) diff --git a/internal/statements/expr.go b/internal/statements/expr.go index b44c96ca..c2a2e1cc 100644 --- a/internal/statements/expr.go +++ b/internal/statements/expr.go @@ -27,6 +27,7 @@ type Expr struct { Arg interface{} } +// WriteArgs writes args to the writer func (expr *Expr) WriteArgs(w *builder.BytesWriter) error { switch arg := expr.Arg.(type) { case *builder.Builder: diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 367dbdc9..4e43c5bd 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -17,7 +17,7 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem if _, err := buf.WriteString(" OUTPUT Inserted."); err != nil { return err } - if _, err := buf.WriteString(table.AutoIncrement); err != nil { + if err := statement.dialect.Quoter().QuoteTo(buf, table.AutoIncrement); err != nil { return err } } diff --git a/internal/statements/query.go b/internal/statements/query.go index e1091e9f..a972a8e0 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -343,7 +343,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac var args []interface{} var joinStr string var err error - var b interface{} = nil + var b interface{} if len(bean) > 0 { b = bean[0] beanValue := reflect.ValueOf(bean[0]) diff --git a/internal/statements/statement.go b/internal/statements/statement.go index a52c6ca2..2d173b87 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -208,20 +208,18 @@ func (statement *Statement) quote(s string) string { // And add Where & and statement func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { - switch query.(type) { + switch qr := query.(type) { case string: - cond := builder.Expr(query.(string), args...) + cond := builder.Expr(qr, args...) statement.cond = statement.cond.And(cond) case map[string]interface{}: - queryMap := query.(map[string]interface{}) - newMap := make(map[string]interface{}) - for k, v := range queryMap { - newMap[statement.quote(k)] = v + cond := make(builder.Eq) + for k, v := range qr { + cond[statement.quote(k)] = v } - statement.cond = statement.cond.And(builder.Eq(newMap)) - case builder.Cond: - cond := query.(builder.Cond) statement.cond = statement.cond.And(cond) + case builder.Cond: + statement.cond = statement.cond.And(qr) for _, v := range args { if vv, ok := v.(builder.Cond); ok { statement.cond = statement.cond.And(vv) @@ -236,23 +234,25 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme // Or add Where & Or statement func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { - switch query.(type) { + switch qr := query.(type) { case string: - cond := builder.Expr(query.(string), args...) + cond := builder.Expr(qr, args...) statement.cond = statement.cond.Or(cond) case map[string]interface{}: - cond := builder.Eq(query.(map[string]interface{})) + cond := make(builder.Eq) + for k, v := range qr { + cond[statement.quote(k)] = v + } statement.cond = statement.cond.Or(cond) case builder.Cond: - cond := query.(builder.Cond) - statement.cond = statement.cond.Or(cond) + statement.cond = statement.cond.Or(qr) for _, v := range args { if vv, ok := v.(builder.Cond); ok { statement.cond = statement.cond.Or(vv) } } default: - // TODO: not support condition type + statement.LastError = ErrConditionType } return statement } @@ -734,6 +734,8 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, //engine.logger.Warn(err) } continue + } else if fieldValuePtr == nil { + continue } if col.IsDeleted && !unscoped { // tag "deleted" is enabled @@ -976,7 +978,7 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName // CondDeleted returns the conditions whether a record is soft deleted. func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { - var colName = col.Name + var colName = statement.quote(col.Name) if statement.JoinStr != "" { var prefix string if statement.TableAlias != "" { diff --git a/internal/statements/statement_test.go b/internal/statements/statement_test.go index 15f446f4..ba92330e 100644 --- a/internal/statements/statement_test.go +++ b/internal/statements/statement_test.go @@ -78,7 +78,6 @@ func TestColumnsStringGeneration(t *testing.T) { } func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { - b.StopTimer() mapCols := make(map[string]bool) @@ -101,9 +100,7 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - for _, col := range cols { - if _, ok := getFlagForColumn(mapCols, col); !ok { b.Fatal("Unexpected result") } @@ -112,7 +109,6 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { } func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { - b.StopTimer() mapCols := make(map[string]bool) @@ -131,9 +127,7 @@ func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - for _, col := range cols { - if _, ok := getFlagForColumn(mapCols, col); ok { b.Fatal("Unexpected result") } diff --git a/internal/statements/update.go b/internal/statements/update.go index 251880b2..06cf0689 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -88,6 +88,9 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, if err != nil { return nil, nil, err } + if fieldValuePtr == nil { + continue + } fieldValue := *fieldValuePtr fieldType := reflect.TypeOf(fieldValue.Interface()) diff --git a/log/logger.go b/log/logger.go index eeb63693..3b6db34e 100644 --- a/log/logger.go +++ b/log/logger.go @@ -132,7 +132,6 @@ func (s *SimpleLogger) Error(v ...interface{}) { if s.level <= LOG_ERR { s.ERR.Output(2, fmt.Sprintln(v...)) } - return } // Errorf implement ILogger @@ -140,7 +139,6 @@ func (s *SimpleLogger) Errorf(format string, v ...interface{}) { if s.level <= LOG_ERR { s.ERR.Output(2, fmt.Sprintf(format, v...)) } - return } // Debug implement ILogger @@ -148,7 +146,6 @@ func (s *SimpleLogger) Debug(v ...interface{}) { if s.level <= LOG_DEBUG { s.DEBUG.Output(2, fmt.Sprintln(v...)) } - return } // Debugf implement ILogger @@ -156,7 +153,6 @@ func (s *SimpleLogger) Debugf(format string, v ...interface{}) { if s.level <= LOG_DEBUG { s.DEBUG.Output(2, fmt.Sprintf(format, v...)) } - return } // Info implement ILogger @@ -164,7 +160,6 @@ func (s *SimpleLogger) Info(v ...interface{}) { if s.level <= LOG_INFO { s.INFO.Output(2, fmt.Sprintln(v...)) } - return } // Infof implement ILogger @@ -172,7 +167,6 @@ func (s *SimpleLogger) Infof(format string, v ...interface{}) { if s.level <= LOG_INFO { s.INFO.Output(2, fmt.Sprintf(format, v...)) } - return } // Warn implement ILogger @@ -180,7 +174,6 @@ func (s *SimpleLogger) Warn(v ...interface{}) { if s.level <= LOG_WARNING { s.WARN.Output(2, fmt.Sprintln(v...)) } - return } // Warnf implement ILogger @@ -188,7 +181,6 @@ func (s *SimpleLogger) Warnf(format string, v ...interface{}) { if s.level <= LOG_WARNING { s.WARN.Output(2, fmt.Sprintf(format, v...)) } - return } // Level implement ILogger @@ -199,7 +191,6 @@ func (s *SimpleLogger) Level() LogLevel { // SetLevel implement ILogger func (s *SimpleLogger) SetLevel(l LogLevel) { s.level = l - return } // ShowSQL implement ILogger diff --git a/scan.go b/scan.go new file mode 100644 index 00000000..c5cb77ff --- /dev/null +++ b/scan.go @@ -0,0 +1,303 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "database/sql" + "fmt" + "reflect" + "time" + + "xorm.io/xorm/convert" + "xorm.io/xorm/core" + "xorm.io/xorm/dialects" +) + +// genScanResultsByBeanNullabale generates scan result +func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { + switch t := bean.(type) { + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: + return t, false, nil + case *time.Time: + return &sql.NullTime{}, true, nil + case *string: + return &sql.NullString{}, true, nil + case *int, *int8, *int16, *int32: + return &sql.NullInt32{}, true, nil + case *int64: + return &sql.NullInt64{}, true, nil + case *uint, *uint8, *uint16, *uint32: + return &NullUint32{}, true, nil + case *uint64: + return &NullUint64{}, true, nil + case *float32, *float64: + return &sql.NullFloat64{}, true, nil + case *bool: + return &sql.NullBool{}, true, nil + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, + time.Time, + string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + float32, float64, + bool: + return nil, false, fmt.Errorf("unsupported scan type: %t", t) + case convert.Conversion: + return &sql.RawBytes{}, true, nil + } + + tp := reflect.TypeOf(bean).Elem() + switch tp.Kind() { + case reflect.String: + return &sql.NullString{}, true, nil + case reflect.Int64: + return &sql.NullInt64{}, true, nil + case reflect.Int32, reflect.Int, reflect.Int16, reflect.Int8: + return &sql.NullInt32{}, true, nil + case reflect.Uint64: + return &NullUint64{}, true, nil + case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8: + return &NullUint32{}, true, nil + default: + return nil, false, fmt.Errorf("unsupported type: %#v", bean) + } +} + +func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { + switch t := bean.(type) { + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, + *string, + *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, + *float32, *float64, + *bool: + return t, false, nil + case *time.Time: + return &sql.NullTime{}, true, nil + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, + time.Time, + string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + bool: + return nil, false, fmt.Errorf("unsupported scan type: %t", t) + case convert.Conversion: + return &sql.RawBytes{}, true, nil + } + + tp := reflect.TypeOf(bean).Elem() + switch tp.Kind() { + case reflect.String: + return new(string), true, nil + case reflect.Int64: + return new(int64), true, nil + case reflect.Int32: + return new(int32), true, nil + case reflect.Int: + return new(int32), true, nil + case reflect.Int16: + return new(int32), true, nil + case reflect.Int8: + return new(int32), true, nil + case reflect.Uint64: + return new(uint64), true, nil + case reflect.Uint32: + return new(uint32), true, nil + case reflect.Uint: + return new(uint), true, nil + case reflect.Uint16: + return new(uint16), true, nil + case reflect.Uint8: + return new(uint8), true, nil + case reflect.Float32: + return new(float32), true, nil + case reflect.Float64: + return new(float64), true, nil + default: + return nil, false, fmt.Errorf("unsupported type: %#v", bean) + } +} + +func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { + var scanResults = make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + var s sql.NullString + scanResults[i] = &s + } + + if err := rows.Scan(scanResults...); err != nil { + return nil, err + } + + result := make(map[string]string, len(fields)) + for ii, key := range fields { + s := scanResults[ii].(*sql.NullString) + result[key] = s.String + } + return result, nil +} + +func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string][]byte, error) { + var scanResults = make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + var s sql.NullString + scanResults[i] = &s + } + + if err := rows.Scan(scanResults...); err != nil { + return nil, err + } + + result := make(map[string][]byte, len(fields)) + for ii, key := range fields { + s := scanResults[ii].(*sql.NullString) + result[key] = []byte(s.String) + } + return result, nil +} + +func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { + var scanResults = make([]interface{}, len(types)) + for i := 0; i < len(types); i++ { + var s sql.NullString + scanResults[i] = &s + } + + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResults...); err != nil { + return nil, err + } + return scanResults, nil +} + +// scan is a wrap of driver.Scan but will automatically change the input values according requirements +func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error { + var scanResults = make([]interface{}, 0, len(types)) + var replaces = make([]bool, 0, len(types)) + var err error + for _, v := range vv { + var replaced bool + var scanResult interface{} + if _, ok := v.(sql.Scanner); !ok { + var useNullable = true + if engine.driver.Features().SupportNullable { + nullable, ok := types[0].Nullable() + useNullable = ok && nullable + } + + if useNullable { + scanResult, replaced, err = genScanResultsByBeanNullable(v) + } else { + scanResult, replaced, err = genScanResultsByBean(v) + } + if err != nil { + return err + } + } else { + scanResult = v + } + scanResults = append(scanResults, scanResult) + replaces = append(replaces, replaced) + } + + var scanCtx = dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + } + + if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil { + return err + } + + for i, replaced := range replaces { + if replaced { + if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil { + return err + } + } + } + + return nil +} + +func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { + var scanResultContainers = make([]interface{}, len(types)) + for i := 0; i < len(types); i++ { + scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) + if err != nil { + return nil, err + } + scanResultContainers[i] = scanResult + } + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResultContainers...); err != nil { + return nil, err + } + return scanResultContainers, nil +} + +func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { + scanResults, err := engine.scanStringInterface(rows, types) + if err != nil { + return nil, err + } + + var results = make([]string, 0, len(fields)) + for i := 0; i < len(fields); i++ { + results = append(results, scanResults[i].(*sql.NullString).String) + } + return results, nil +} + +func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + for rows.Next() { + result, err := row2mapBytes(rows, types, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, result) + } + + return resultsSlice, nil +} + +func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]interface{}, error) { + var resultsMap = make(map[string]interface{}, len(fields)) + var scanResultContainers = make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) + if err != nil { + return nil, err + } + scanResultContainers[i] = scanResult + } + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResultContainers...); err != nil { + return nil, err + } + + for ii, key := range fields { + res, err := convert.Interface2Interface(engine.TZLocation, scanResultContainers[ii]) + if err != nil { + return nil, err + } + resultsMap[key] = res + } + return resultsMap, nil +} diff --git a/schemas/column.go b/schemas/column.go index 24b53802..4bbb6c2d 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -6,10 +6,8 @@ package schemas import ( "errors" - "fmt" "reflect" "strconv" - "strings" "time" ) @@ -25,6 +23,7 @@ type Column struct { Name string TableName string FieldName string // Available only when parsed from a struct + FieldIndex []int // Available only when parsed from a struct SQLType SQLType IsJSON bool Length int @@ -83,41 +82,17 @@ func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) { // ValueOfV returns column's filed of struct's value accept reflevt value func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { - var fieldValue reflect.Value - fieldPath := strings.Split(col.FieldName, ".") - - if dataStruct.Type().Kind() == reflect.Map { - keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1]) - fieldValue = dataStruct.MapIndex(keyValue) - return &fieldValue, nil - } else if dataStruct.Type().Kind() == reflect.Interface { - structValue := reflect.ValueOf(dataStruct.Interface()) - dataStruct = &structValue - } - - level := len(fieldPath) - fieldValue = dataStruct.FieldByName(fieldPath[0]) - for i := 0; i < level-1; i++ { - if !fieldValue.IsValid() { - break - } - if fieldValue.Kind() == reflect.Struct { - fieldValue = fieldValue.FieldByName(fieldPath[i+1]) - } else if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + var v = *dataStruct + for _, i := range col.FieldIndex { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) } - fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) - } else { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) + v = v.Elem() } + v = v.FieldByIndex([]int{i}) } - - if !fieldValue.IsValid() { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) - } - - return &fieldValue, nil + return &v, nil } // ConvertID converts id content to suitable type according column type diff --git a/schemas/table.go b/schemas/table.go index bfa517aa..91b33e06 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -5,7 +5,6 @@ package schemas import ( - "fmt" "reflect" "strconv" "strings" @@ -159,24 +158,8 @@ func (table *Table) IDOfV(rv reflect.Value) (PK, error) { for i, col := range table.PKColumns() { var err error - fieldName := col.FieldName - for { - parts := strings.SplitN(fieldName, ".", 2) - if len(parts) == 1 { - break - } + pkField := v.FieldByIndex(col.FieldIndex) - v = v.FieldByName(parts[0]) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - if v.Kind() != reflect.Struct { - return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName) - } - fieldName = parts[1] - } - - pkField := v.FieldByName(fieldName) switch pkField.Kind() { case reflect.String: pk[i], err = col.ConvertID(pkField.String()) diff --git a/schemas/table_test.go b/schemas/table_test.go index 9bf10e33..0e35193f 100644 --- a/schemas/table_test.go +++ b/schemas/table_test.go @@ -27,7 +27,6 @@ var testsGetColumn = []struct { var table *Table func init() { - table = NewEmptyTable() var name string @@ -41,7 +40,6 @@ func init() { } func TestGetColumn(t *testing.T) { - for _, test := range testsGetColumn { if table.GetColumn(test.name) == nil { t.Error("Column not found!") @@ -50,7 +48,6 @@ func TestGetColumn(t *testing.T) { } func TestGetColumnIdx(t *testing.T) { - for _, test := range testsGetColumn { if table.GetColumnIdx(test.name, test.idx) == nil { t.Errorf("Column %s with idx %d not found!", test.name, test.idx) @@ -59,7 +56,6 @@ func TestGetColumnIdx(t *testing.T) { } func BenchmarkGetColumnWithToLower(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { @@ -71,7 +67,6 @@ func BenchmarkGetColumnWithToLower(b *testing.B) { } func BenchmarkGetColumnIdxWithToLower(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { @@ -89,7 +84,6 @@ func BenchmarkGetColumnIdxWithToLower(b *testing.B) { } func BenchmarkGetColumn(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { if table.GetColumn(test.name) == nil { @@ -100,7 +94,6 @@ func BenchmarkGetColumn(b *testing.B) { } func BenchmarkGetColumnIdx(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { if table.GetColumnIdx(test.name, test.idx) == nil { diff --git a/session.go b/session.go index d5ccb6dc..6df9e20d 100644 --- a/session.go +++ b/session.go @@ -375,6 +375,9 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s if err != nil { return nil, err } + if fieldValue == nil { + return nil, ErrFieldIsNotValid{key, table.Name} + } if !fieldValue.IsValid() || !fieldValue.CanSet() { return nil, ErrFieldIsNotValid{key, table.Name} diff --git a/session_convert.go b/session_convert.go index a6839947..b8218a77 100644 --- a/session_convert.go +++ b/session_convert.go @@ -35,27 +35,20 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time sd, err := strconv.ParseInt(sdata, 10, 64) if err == nil { x = time.Unix(sd, 0) - //session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } else { - //session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } } else if len(sdata) > 19 && strings.Contains(sdata, "-") { x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc) - session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) + session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.Name, x, sdata) if err != nil { x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc) - //session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } if err != nil { x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc) - //session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } } else if len(sdata) == 19 && strings.Contains(sdata, "-") { x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc) - //session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) - //session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else if col.SQLType.Name == schemas.Time { if strings.Contains(sdata, " ") { ssd := strings.Split(sdata, " ") @@ -69,7 +62,6 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time st := fmt.Sprintf("2006-01-02 %v", sdata) x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc) - //session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else { outErr = fmt.Errorf("unsupported time format %v", sdata) return diff --git a/session_delete.go b/session_delete.go index 13bf791f..baabb558 100644 --- a/session_delete.go +++ b/session_delete.go @@ -83,7 +83,7 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri } // Delete records, bean's non-empty fields are conditions -func (session *Session) Delete(bean interface{}) (int64, error) { +func (session *Session) Delete(beans ...interface{}) (int64, error) { if session.isAutoClose { defer session.Close() } @@ -92,20 +92,32 @@ func (session *Session) Delete(bean interface{}) (int64, error) { return 0, session.statement.LastError } - if err := session.statement.SetRefBean(bean); err != nil { - return 0, err + var ( + condSQL string + condArgs []interface{} + err error + bean interface{} + ) + if len(beans) > 0 { + bean = beans[0] + if err = session.statement.SetRefBean(bean); err != nil { + return 0, err + } + + executeBeforeClosures(session, bean) + + if processor, ok := interface{}(bean).(BeforeDeleteProcessor); ok { + processor.BeforeDelete() + } + + condSQL, condArgs, err = session.statement.GenConds(bean) + } else { + condSQL, condArgs, err = session.statement.GenCondSQL(session.statement.Conds()) } - - executeBeforeClosures(session, bean) - - if processor, ok := interface{}(bean).(BeforeDeleteProcessor); ok { - processor.BeforeDelete() - } - - condSQL, condArgs, err := session.statement.GenConds(bean) if err != nil { return 0, err } + pLimitN := session.statement.LimitN if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { return 0, ErrNeedDeletedCond @@ -156,7 +168,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { var realSQL string argsForCache := make([]interface{}, 0, len(condArgs)*2) - if session.statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled + if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled realSQL = deleteSQL copy(argsForCache, condArgs) argsForCache = append(condArgs, argsForCache...) @@ -220,27 +232,29 @@ func (session *Session) Delete(bean interface{}) (int64, error) { return 0, err } - // handle after delete processors - if session.isAutoCommit { - for _, closure := range session.afterClosures { - closure(bean) - } - if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { - processor.AfterDelete() - } - } else { - lenAfterClosures := len(session.afterClosures) - if lenAfterClosures > 0 { - if value, has := session.afterDeleteBeans[bean]; has && value != nil { - *value = append(*value, session.afterClosures...) - } else { - afterClosures := make([]func(interface{}), lenAfterClosures) - copy(afterClosures, session.afterClosures) - session.afterDeleteBeans[bean] = &afterClosures + if bean != nil { + // handle after delete processors + if session.isAutoCommit { + for _, closure := range session.afterClosures { + closure(bean) + } + if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { + processor.AfterDelete() } } else { - if _, ok := interface{}(bean).(AfterDeleteProcessor); ok { - session.afterDeleteBeans[bean] = nil + lenAfterClosures := len(session.afterClosures) + if lenAfterClosures > 0 && len(beans) > 0 { + if value, has := session.afterDeleteBeans[beans[0]]; has && value != nil { + *value = append(*value, session.afterClosures...) + } else { + afterClosures := make([]func(interface{}), lenAfterClosures) + copy(afterClosures, session.afterClosures) + session.afterDeleteBeans[bean] = &afterClosures + } + } else { + if _, ok := interface{}(bean).(AfterDeleteProcessor); ok { + session.afterDeleteBeans[bean] = nil + } } } } diff --git a/session_find.go b/session_find.go index 0daea005..261e6b7f 100644 --- a/session_find.go +++ b/session_find.go @@ -276,7 +276,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { cols := table.PKColumns() if len(cols) == 1 { - return convertAssign(dst, pk[0]) + return convertAssign(dst, pk[0], nil, nil) } dst = pk diff --git a/session_get.go b/session_get.go index e303176d..cb2bda75 100644 --- a/session_get.go +++ b/session_get.go @@ -6,12 +6,16 @@ package xorm import ( "database/sql" + "database/sql/driver" "errors" "fmt" "reflect" "strconv" + "time" "xorm.io/xorm/caches" + "xorm.io/xorm/convert" + "xorm.io/xorm/core" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -108,6 +112,17 @@ func (session *Session) get(bean interface{}) (bool, error) { return true, nil } +var ( + valuerTypePlaceHolder driver.Valuer + valuerType = reflect.TypeOf(&valuerTypePlaceHolder).Elem() + + scannerTypePlaceHolder sql.Scanner + scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem() + + conversionTypePlaceHolder convert.Conversion + conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem() +) + func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { @@ -122,155 +137,161 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, return false, nil } - switch bean.(type) { - case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString: - return true, rows.Scan(&bean) - case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString: - return true, rows.Scan(bean) - case *string: - var res sql.NullString - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*string)) = res.String - } - return true, nil - case *int: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int)) = int(res.Int64) - } - return true, nil - case *int8: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int8)) = int8(res.Int64) - } - return true, nil - case *int16: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int16)) = int16(res.Int64) - } - return true, nil - case *int32: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int32)) = int32(res.Int64) - } - return true, nil - case *int64: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int64)) = int64(res.Int64) - } - return true, nil - case *uint: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint)) = uint(res.Int64) - } - return true, nil - case *uint8: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint8)) = uint8(res.Int64) - } - return true, nil - case *uint16: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint16)) = uint16(res.Int64) - } - return true, nil - case *uint32: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint32)) = uint32(res.Int64) - } - return true, nil - case *uint64: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint64)) = uint64(res.Int64) - } - return true, nil - case *bool: - var res sql.NullBool - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*bool)) = res.Bool - } - return true, nil + // WARN: Alougth rows return true, but we may also return error. + types, err := rows.ColumnTypes() + if err != nil { + return true, err + } + fields, err := rows.Columns() + if err != nil { + return true, err } - switch beanKind { case reflect.Struct: - fields, err := rows.Columns() - if err != nil { - // WARN: Alougth rows return true, but get fields failed - return true, err + if _, ok := bean.(*time.Time); ok { + break } - - scanResults, err := session.row2Slice(rows, fields, bean) - if err != nil { - return false, err + if _, ok := bean.(sql.Scanner); ok { + break } - // close it before convert data - rows.Close() - - dataStruct := utils.ReflectValue(bean) - _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) - if err != nil { - return true, err + if _, ok := bean.(convert.Conversion); len(types) == 1 && ok { + break } - - return true, session.executeProcessors() + return session.getStruct(rows, types, fields, table, bean) case reflect.Slice: - err = rows.ScanSlice(bean) + return session.getSlice(rows, types, fields, bean) case reflect.Map: - err = rows.ScanMap(bean) - case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - err = rows.Scan(bean) - default: - err = rows.Scan(bean) + return session.getMap(rows, types, fields, bean) } - return true, err + return session.getVars(rows, types, fields, bean) +} + +func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { + switch t := bean.(type) { + case *[]string: + res, err := session.engine.scanStringInterface(rows, types) + if err != nil { + return true, err + } + + var needAppend = len(*t) == 0 // both support slice is empty or has been initlized + for i, r := range res { + if needAppend { + *t = append(*t, r.(*sql.NullString).String) + } else { + (*t)[i] = r.(*sql.NullString).String + } + } + return true, nil + case *[]interface{}: + scanResults, err := session.engine.scanInterfaces(rows, types) + if err != nil { + return true, err + } + var needAppend = len(*t) == 0 + for ii := range fields { + s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) + if err != nil { + return true, err + } + if needAppend { + *t = append(*t, s) + } else { + (*t)[ii] = s + } + } + return true, nil + default: + return true, fmt.Errorf("unspoorted slice type: %t", t) + } +} + +func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { + switch t := bean.(type) { + case *map[string]string: + scanResults, err := session.engine.scanStringInterface(rows, types) + if err != nil { + return true, err + } + for ii, key := range fields { + (*t)[key] = scanResults[ii].(*sql.NullString).String + } + return true, nil + case *map[string]interface{}: + scanResults, err := session.engine.scanInterfaces(rows, types) + if err != nil { + return true, err + } + for ii, key := range fields { + s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) + if err != nil { + return true, err + } + (*t)[key] = s + } + return true, nil + default: + return true, fmt.Errorf("unspoorted map type: %t", t) + } +} + +func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields []string, beans ...interface{}) (bool, error) { + if len(beans) != len(types) { + return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans)) + } + var scanResults = make([]interface{}, 0, len(types)) + var replaceds = make([]bool, 0, len(types)) + for _, bean := range beans { + switch t := bean.(type) { + case sql.Scanner: + scanResults = append(scanResults, t) + replaceds = append(replaceds, false) + case convert.Conversion: + scanResults = append(scanResults, &sql.RawBytes{}) + replaceds = append(replaceds, true) + default: + scanResults = append(scanResults, bean) + replaceds = append(replaceds, false) + } + } + + err := session.engine.scan(rows, fields, types, scanResults...) + if err != nil { + return true, err + } + for i, replaced := range replaceds { + if replaced { + err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation) + if err != nil { + return true, err + } + } + } + return true, nil +} + +func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { + fields, err := rows.Columns() + if err != nil { + // WARN: Alougth rows return true, but get fields failed + return true, err + } + + scanResults, err := session.row2Slice(rows, fields, bean) + if err != nil { + return false, err + } + // close it before convert data + rows.Close() + + dataStruct := utils.ReflectValue(bean) + _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) + if err != nil { + return true, err + } + + return true, session.executeProcessors() } func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { diff --git a/session_insert.go b/session_insert.go index 5f968151..7f8f3008 100644 --- a/session_insert.go +++ b/session_insert.go @@ -11,6 +11,7 @@ import ( "sort" "strconv" "strings" + "time" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" @@ -374,9 +375,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) - - return 1, nil + return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) @@ -416,9 +415,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) - - return 1, nil + return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) } res, err := session.exec(sqlStr, args...) @@ -458,7 +455,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) + if err := convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation); err != nil { + return 0, err + } return res.RowsAffected() } @@ -499,6 +498,16 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } if col.IsDeleted { + colNames = append(colNames, col.Name) + if !col.Nullable { + if col.SQLType.IsNumeric() { + args = append(args, 0) + } else { + args = append(args, time.Time{}.Format("2006-01-02 15:04:05")) + } + } else { + args = append(args, nil) + } continue } diff --git a/session_query.go b/session_query.go index 12136466..fa33496d 100644 --- a/session_query.go +++ b/session_query.go @@ -5,13 +5,7 @@ package xorm import ( - "fmt" - "reflect" - "strconv" - "time" - "xorm.io/xorm/core" - "xorm.io/xorm/schemas" ) // Query runs a raw sql and return records as []map[string][]byte @@ -28,116 +22,18 @@ func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, er return session.queryBytes(sqlStr, args...) } -func value2String(rawValue *reflect.Value) (str string, err error) { - aa := reflect.TypeOf((*rawValue).Interface()) - vv := reflect.ValueOf((*rawValue).Interface()) - switch aa.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str = strconv.FormatInt(vv.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str = strconv.FormatUint(vv.Uint(), 10) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) - case reflect.String: - str = vv.String() - case reflect.Array, reflect.Slice: - switch aa.Elem().Kind() { - case reflect.Uint8: - data := rawValue.Interface().([]byte) - str = string(data) - if str == "\x00" { - str = "0" - } - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - // time type - case reflect.Struct: - if aa.ConvertibleTo(schemas.TimeType) { - str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) - } else { - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - case reflect.Bool: - str = strconv.FormatBool(vv.Bool()) - case reflect.Complex128, reflect.Complex64: - str = fmt.Sprintf("%v", vv.Complex()) - /* TODO: unsupported types below - case reflect.Map: - case reflect.Ptr: - case reflect.Uintptr: - case reflect.UnsafePointer: - case reflect.Chan, reflect.Func, reflect.Interface: - */ - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - return -} - -func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) { - result := make(map[string]string) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) - // if row is null then as empty string - if rawValue.Interface() == nil { - result[key] = "" - continue - } - - if data, err := value2String(&rawValue); err == nil { - result[key] = data - } else { - return nil, err - } - } - return result, nil -} - -func row2sliceStr(rows *core.Rows, fields []string) (results []string, err error) { - result := make([]string, 0, len(fields)) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for i := 0; i < len(fields); i++ { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[i])) - // if row is null then as empty string - if rawValue.Interface() == nil { - result = append(result, "") - continue - } - - if data, err := value2String(&rawValue); err == nil { - result = append(result, data) - } else { - return nil, err - } - } - return result, nil -} - -func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { +func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { fields, err := rows.Columns() if err != nil { return nil, err } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + for rows.Next() { - result, err := row2mapStr(rows, fields) + result, err := row2mapStr(rows, types, fields) if err != nil { return nil, err } @@ -147,13 +43,18 @@ func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) return resultsSlice, nil } -func rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) { +func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) { fields, err := rows.Columns() if err != nil { return nil, err } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + for rows.Next() { - record, err := row2sliceStr(rows, fields) + record, err := session.engine.row2sliceStr(rows, types, fields) if err != nil { return nil, err } @@ -180,7 +81,7 @@ func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]stri } defer rows.Close() - return rows2Strings(rows) + return session.rows2Strings(rows) } // QuerySliceString runs a raw sql and return records as [][]string @@ -200,33 +101,20 @@ func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string, } defer rows.Close() - return rows2SliceString(rows) + return session.rows2SliceString(rows) } -func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) { - resultsMap = make(map[string]interface{}, len(fields)) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - resultsMap[key] = reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])).Interface() - } - return -} - -func rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) { +func (session *Session) rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) { fields, err := rows.Columns() if err != nil { return nil, err } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } for rows.Next() { - result, err := row2mapInterface(rows, fields) + result, err := session.engine.row2mapInterface(rows, types, fields) if err != nil { return nil, err } @@ -253,5 +141,5 @@ func (session *Session) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]i } defer rows.Close() - return rows2Interfaces(rows) + return session.rows2Interfaces(rows) } diff --git a/session_raw.go b/session_raw.go index 4cfe297a..bf32c6ed 100644 --- a/session_raw.go +++ b/session_raw.go @@ -6,9 +6,13 @@ package xorm import ( "database/sql" + "fmt" "reflect" + "strconv" + "time" "xorm.io/xorm/core" + "xorm.io/xorm/schemas" ) func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { @@ -71,6 +75,53 @@ func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row { return core.NewRow(session.queryRows(sqlStr, args...)) } +func value2String(rawValue *reflect.Value) (str string, err error) { + aa := reflect.TypeOf((*rawValue).Interface()) + vv := reflect.ValueOf((*rawValue).Interface()) + switch aa.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(vv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str = strconv.FormatUint(vv.Uint(), 10) + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) + case reflect.String: + str = vv.String() + case reflect.Array, reflect.Slice: + switch aa.Elem().Kind() { + case reflect.Uint8: + data := rawValue.Interface().([]byte) + str = string(data) + if str == "\x00" { + str = "0" + } + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + // time type + case reflect.Struct: + if aa.ConvertibleTo(schemas.TimeType) { + str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) + } else { + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + case reflect.Bool: + str = strconv.FormatBool(vv.Bool()) + case reflect.Complex128, reflect.Complex64: + str = fmt.Sprintf("%v", vv.Complex()) + /* TODO: unsupported types below + case reflect.Map: + case reflect.Ptr: + case reflect.Uintptr: + case reflect.UnsafePointer: + case reflect.Chan, reflect.Func, reflect.Interface: + */ + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + return +} + func value2Bytes(rawValue *reflect.Value) ([]byte, error) { str, err := value2String(rawValue) if err != nil { @@ -79,50 +130,6 @@ func value2Bytes(rawValue *reflect.Value) ([]byte, error) { return []byte(str), nil } -func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) { - result := make(map[string][]byte) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) - //if row is null then ignore - if rawValue.Interface() == nil { - result[key] = []byte{} - continue - } - - if data, err := value2Bytes(&rawValue); err == nil { - result[key] = data - } else { - return nil, err // !nashtsai! REVIEW, should return err or just error log? - } - } - return result, nil -} - -func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { - fields, err := rows.Columns() - if err != nil { - return nil, err - } - for rows.Next() { - result, err := row2map(rows, fields) - if err != nil { - return nil, err - } - resultsSlice = append(resultsSlice, result) - } - - return resultsSlice, nil -} - func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { diff --git a/session_update.go b/session_update.go index f791bb2d..78907e43 100644 --- a/session_update.go +++ b/session_update.go @@ -280,15 +280,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 k = ct.Elem().Kind() } if k == reflect.Struct { - var refTable = session.statement.RefTable - if refTable == nil { - refTable, err = session.engine.TableInfo(condiBean[0]) - if err != nil { - return 0, err - } + condTable, err := session.engine.TableInfo(condiBean[0]) + if err != nil { + return 0, err } - var err error - autoCond, err = session.statement.BuildConds(refTable, condiBean[0], true, true, false, true, false) + + autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false) if err != nil { return 0, err } @@ -457,7 +454,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 // FIXME: if bean is a map type, it will panic because map cannot be as map key session.afterUpdateBeans[bean] = &afterClosures } - } else { if _, ok := interface{}(bean).(AfterUpdateProcessor); ok { session.afterUpdateBeans[bean] = nil diff --git a/tags/parser.go b/tags/parser.go index ff329daa..b793a8f1 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -7,7 +7,6 @@ package tags import ( "encoding/gob" "errors" - "fmt" "reflect" "strings" "sync" @@ -23,7 +22,7 @@ import ( var ( // ErrUnsupportedType represents an unsupported type error - ErrUnsupportedType = errors.New("Unsupported type") + ErrUnsupportedType = errors.New("unsupported type") ) // Parser represents a parser for xorm tag @@ -125,6 +124,147 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index } } +var ErrIgnoreField = errors.New("field will be ignored") + +func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { + var sqlType schemas.SQLType + if fieldValue.CanAddr() { + if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } + } + if _, ok := fieldValue.Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } else { + sqlType = schemas.Type2SQLType(field.Type) + } + col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name), + field.Name, sqlType, sqlType.DefaultLength, + sqlType.DefaultLength2, true) + col.FieldIndex = []int{fieldIndex} + + if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { + col.IsAutoIncrement = true + col.IsPrimaryKey = true + col.Nullable = false + } + return col, nil +} + +func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) { + var col = &schemas.Column{ + FieldName: field.Name, + FieldIndex: []int{fieldIndex}, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: schemas.TWOSIDES, + Indexes: make(map[string]int), + DefaultIsEmpty: true, + } + + var ctx = Context{ + table: table, + col: col, + fieldValue: fieldValue, + indexNames: make(map[string]int), + parser: parser, + } + + for j, tag := range tags { + if ctx.ignoreNext { + ctx.ignoreNext = false + continue + } + + ctx.tag = tag + ctx.tagUname = strings.ToUpper(tag.name) + + if j > 0 { + ctx.preTag = strings.ToUpper(tags[j-1].name) + } + if j < len(tags)-1 { + ctx.nextTag = tags[j+1].name + } else { + ctx.nextTag = "" + } + + if h, ok := parser.handlers[ctx.tagUname]; ok { + if err := h(&ctx); err != nil { + return nil, err + } + } else { + if strings.HasPrefix(ctx.tag.name, "'") && strings.HasSuffix(ctx.tag.name, "'") { + col.Name = ctx.tag.name[1 : len(ctx.tag.name)-1] + } else { + col.Name = ctx.tag.name + } + } + + if ctx.hasCacheTag { + if parser.cacherMgr.GetDefaultCacher() != nil { + parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher()) + } else { + parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)) + } + } + if ctx.hasNoCacheTag { + parser.cacherMgr.SetCacher(table.Name, nil) + } + } + + if col.SQLType.Name == "" { + col.SQLType = schemas.Type2SQLType(field.Type) + } + parser.dialect.SQLType(col) + if col.Length == 0 { + col.Length = col.SQLType.DefaultLength + } + if col.Length2 == 0 { + col.Length2 = col.SQLType.DefaultLength2 + } + if col.Name == "" { + col.Name = parser.columnMapper.Obj2Table(field.Name) + } + + if ctx.isUnique { + ctx.indexNames[col.Name] = schemas.UniqueType + } else if ctx.isIndex { + ctx.indexNames[col.Name] = schemas.IndexType + } + + for indexName, indexType := range ctx.indexNames { + addIndex(indexName, table, col, indexType) + } + + return col, nil +} + +func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { + var ( + tag = field.Tag + ormTagStr = strings.TrimSpace(tag.Get(parser.identifier)) + ) + if ormTagStr == "-" { + return nil, ErrIgnoreField + } + if ormTagStr == "" { + return parser.parseFieldWithNoTag(fieldIndex, field, fieldValue) + } + tags, err := splitTag(ormTagStr) + if err != nil { + return nil, err + } + return parser.parseFieldWithTags(table, fieldIndex, field, fieldValue, tags) +} + +func isNotTitle(n string) bool { + for _, c := range n { + return unicode.IsLower(c) + } + return true +} + // Parse parses a struct as a table information func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { t := v.Type() @@ -140,192 +280,26 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { table.Type = t table.Name = names.GetTableName(parser.tableMapper, v) - var idFieldColName string - var hasCacheTag, hasNoCacheTag bool - for i := 0; i < t.NumField(); i++ { - var isUnexportField bool - for _, c := range t.Field(i).Name { - if unicode.IsLower(c) { - isUnexportField = true - } - break - } - if isUnexportField { + var field = t.Field(i) + if isNotTitle(field.Name) { continue } - tag := t.Field(i).Tag - ormTagStr := tag.Get(parser.identifier) - var col *schemas.Column - fieldValue := v.Field(i) - fieldType := fieldValue.Type() - - if ormTagStr != "" { - col = &schemas.Column{ - FieldName: t.Field(i).Name, - Nullable: true, - IsPrimaryKey: false, - IsAutoIncrement: false, - MapType: schemas.TWOSIDES, - Indexes: make(map[string]int), - DefaultIsEmpty: true, - } - tags := splitTag(ormTagStr) - - if len(tags) > 0 { - if tags[0] == "-" { - continue - } - - var ctx = Context{ - table: table, - col: col, - fieldValue: fieldValue, - indexNames: make(map[string]int), - parser: parser, - } - - if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") { - pStart := strings.Index(tags[0], "(") - if pStart > -1 && strings.HasSuffix(tags[0], ")") { - var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool { - return r == '\'' || r == '"' - }) - - ctx.params = []string{tagPrefix} - } - - if err := ExtendsTagHandler(&ctx); err != nil { - return nil, err - } - continue - } - - for j, key := range tags { - if ctx.ignoreNext { - ctx.ignoreNext = false - continue - } - - k := strings.ToUpper(key) - ctx.tagName = k - ctx.params = []string{} - - pStart := strings.Index(k, "(") - if pStart == 0 { - return nil, errors.New("( could not be the first character") - } - if pStart > -1 { - if !strings.HasSuffix(k, ")") { - return nil, fmt.Errorf("field %s tag %s cannot match ) character", col.FieldName, key) - } - - ctx.tagName = k[:pStart] - ctx.params = strings.Split(key[pStart+1:len(k)-1], ",") - } - - if j > 0 { - ctx.preTag = strings.ToUpper(tags[j-1]) - } - if j < len(tags)-1 { - ctx.nextTag = tags[j+1] - } else { - ctx.nextTag = "" - } - - if h, ok := parser.handlers[ctx.tagName]; ok { - if err := h(&ctx); err != nil { - return nil, err - } - } else { - if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") { - col.Name = key[1 : len(key)-1] - } else { - col.Name = key - } - } - - if ctx.hasCacheTag { - hasCacheTag = true - } - if ctx.hasNoCacheTag { - hasNoCacheTag = true - } - } - - if col.SQLType.Name == "" { - col.SQLType = schemas.Type2SQLType(fieldType) - } - parser.dialect.SQLType(col) - if col.Length == 0 { - col.Length = col.SQLType.DefaultLength - } - if col.Length2 == 0 { - col.Length2 = col.SQLType.DefaultLength2 - } - if col.Name == "" { - col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name) - } - - if ctx.isUnique { - ctx.indexNames[col.Name] = schemas.UniqueType - } else if ctx.isIndex { - ctx.indexNames[col.Name] = schemas.IndexType - } - - for indexName, indexType := range ctx.indexNames { - addIndex(indexName, table, col, indexType) - } - } - } else { - var sqlType schemas.SQLType - if fieldValue.CanAddr() { - if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - sqlType = schemas.SQLType{Name: schemas.Text} - } - } - if _, ok := fieldValue.Interface().(convert.Conversion); ok { - sqlType = schemas.SQLType{Name: schemas.Text} - } else { - sqlType = schemas.Type2SQLType(fieldType) - } - col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name), - t.Field(i).Name, sqlType, sqlType.DefaultLength, - sqlType.DefaultLength2, true) - - if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { - idFieldColName = col.Name - } - } - if col.IsAutoIncrement { - col.Nullable = false + col, err := parser.parseField(table, i, field, v.Field(i)) + if err == ErrIgnoreField { + continue + } else if err != nil { + return nil, err } table.AddColumn(col) } // end for - if idFieldColName != "" && len(table.PrimaryKeys) == 0 { - col := table.GetColumn(idFieldColName) - col.IsPrimaryKey = true - col.IsAutoIncrement = true - col.Nullable = false - table.PrimaryKeys = append(table.PrimaryKeys, col.Name) - table.AutoIncrement = col.Name - } - - if hasCacheTag { - if parser.cacherMgr.GetDefaultCacher() != nil { // !nash! use engine's cacher if provided - //engine.logger.Info("enable cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher()) - } else { - //engine.logger.Info("enable LRU cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)) - } - } - if hasNoCacheTag { - //engine.logger.Info("disable cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, nil) + deletedColumn := table.DeletedColumn() + // check columns + if deletedColumn != nil { + deletedColumn.Nullable = true } return table, nil diff --git a/tags/parser_test.go b/tags/parser_test.go index 5add1e13..70c57692 100644 --- a/tags/parser_test.go +++ b/tags/parser_test.go @@ -6,12 +6,16 @@ package tags import ( "reflect" + "strings" "testing" + "time" - "github.com/stretchr/testify/assert" "xorm.io/xorm/caches" "xorm.io/xorm/dialects" "xorm.io/xorm/names" + "xorm.io/xorm/schemas" + + "github.com/stretchr/testify/assert" ) type ParseTableName1 struct{} @@ -80,7 +84,7 @@ func TestParseWithOtherIdentifier(t *testing.T) { parser := NewParser( "xorm", dialects.QueryDialect("mysql"), - names.GonicMapper{}, + names.SameMapper{}, names.SnakeMapper{}, caches.NewManager(), ) @@ -88,13 +92,461 @@ func TestParseWithOtherIdentifier(t *testing.T) { type StructWithDBTag struct { FieldFoo string `db:"foo"` } + parser.SetIdentifier("db") table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag))) assert.NoError(t, err) - assert.EqualValues(t, "struct_with_db_tag", table.Name) + assert.EqualValues(t, "StructWithDBTag", table.Name) assert.EqualValues(t, 1, len(table.Columns())) for _, col := range table.Columns() { assert.EqualValues(t, "foo", col.Name) } } + +func TestParseWithIgnore(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SameMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithIgnoreTag struct { + FieldFoo string `db:"-"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithIgnoreTag))) + assert.NoError(t, err) + assert.EqualValues(t, "StructWithIgnoreTag", table.Name) + assert.EqualValues(t, 0, len(table.Columns())) +} + +func TestParseWithAutoincrement(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithAutoIncrement struct { + ID int64 + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_auto_increment", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "id", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsAutoIncrement) + assert.True(t, table.Columns()[0].IsPrimaryKey) +} + +func TestParseWithAutoincrement2(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithAutoIncrement2 struct { + ID int64 `db:"pk autoincr"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement2))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_auto_increment2", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "id", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsAutoIncrement) + assert.True(t, table.Columns()[0].IsPrimaryKey) + assert.False(t, table.Columns()[0].Nullable) +} + +func TestParseWithNullable(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithNullable struct { + Name string `db:"notnull"` + FullName string `db:"null comment('column comment,字段注释')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithNullable))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_nullable", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "full_name", table.Columns()[1].Name) + assert.False(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.EqualValues(t, "column comment,字段注释", table.Columns()[1].Comment) +} + +func TestParseWithTimes(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithTimes struct { + Name string `db:"notnull"` + CreatedAt time.Time `db:"created"` + UpdatedAt time.Time `db:"updated"` + DeletedAt time.Time `db:"deleted"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithTimes))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_times", table.Name) + assert.EqualValues(t, 4, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "created_at", table.Columns()[1].Name) + assert.EqualValues(t, "updated_at", table.Columns()[2].Name) + assert.EqualValues(t, "deleted_at", table.Columns()[3].Name) + assert.False(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsCreated) + assert.True(t, table.Columns()[2].Nullable) + assert.True(t, table.Columns()[2].IsUpdated) + assert.True(t, table.Columns()[3].Nullable) + assert.True(t, table.Columns()[3].IsDeleted) +} + +func TestParseWithExtends(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithEmbed struct { + Name string + CreatedAt time.Time `db:"created"` + UpdatedAt time.Time `db:"updated"` + DeletedAt time.Time `db:"deleted"` + } + + type StructWithExtends struct { + SW StructWithEmbed `db:"extends"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithExtends))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_extends", table.Name) + assert.EqualValues(t, 4, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "created_at", table.Columns()[1].Name) + assert.EqualValues(t, "updated_at", table.Columns()[2].Name) + assert.EqualValues(t, "deleted_at", table.Columns()[3].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsCreated) + assert.True(t, table.Columns()[2].Nullable) + assert.True(t, table.Columns()[2].IsUpdated) + assert.True(t, table.Columns()[3].Nullable) + assert.True(t, table.Columns()[3].IsDeleted) +} + +func TestParseWithCache(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithCache struct { + Name string `db:"cache"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithCache))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_cache", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + cacher := parser.cacherMgr.GetCacher(table.Name) + assert.NotNil(t, cacher) +} + +func TestParseWithNoCache(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithNoCache struct { + Name string `db:"nocache"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithNoCache))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_no_cache", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + cacher := parser.cacherMgr.GetCacher(table.Name) + assert.Nil(t, cacher) +} + +func TestParseWithEnum(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithEnum struct { + Name string `db:"enum('alice', 'bob')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithEnum))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_enum", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.EqualValues(t, schemas.Enum, strings.ToUpper(table.Columns()[0].SQLType.Name)) + assert.EqualValues(t, map[string]int{ + "alice": 0, + "bob": 1, + }, table.Columns()[0].EnumOptions) +} + +func TestParseWithSet(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithSet struct { + Name string `db:"set('alice', 'bob')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithSet))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_set", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.EqualValues(t, schemas.Set, strings.ToUpper(table.Columns()[0].SQLType.Name)) + assert.EqualValues(t, map[string]int{ + "alice": 0, + "bob": 1, + }, table.Columns()[0].SetOptions) +} + +func TestParseWithIndex(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithIndex struct { + Name string `db:"index"` + Name2 string `db:"index(s)"` + Name3 string `db:"unique"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithIndex))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_index", table.Name) + assert.EqualValues(t, 3, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "name2", table.Columns()[1].Name) + assert.EqualValues(t, "name3", table.Columns()[2].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[2].Nullable) + assert.EqualValues(t, 1, len(table.Columns()[0].Indexes)) + assert.EqualValues(t, 1, len(table.Columns()[1].Indexes)) + assert.EqualValues(t, 1, len(table.Columns()[2].Indexes)) +} + +func TestParseWithVersion(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithVersion struct { + Name string + Version int `db:"version"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithVersion))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_version", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "version", table.Columns()[1].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsVersion) +} + +func TestParseWithLocale(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithLocale struct { + UTCLocale time.Time `db:"utc"` + LocalLocale time.Time `db:"local"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithLocale))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_locale", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "utc_locale", table.Columns()[0].Name) + assert.EqualValues(t, "local_locale", table.Columns()[1].Name) + assert.EqualValues(t, time.UTC, table.Columns()[0].TimeZone) + assert.EqualValues(t, time.Local, table.Columns()[1].TimeZone) +} + +func TestParseWithDefault(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithDefault struct { + Default1 time.Time `db:"default '1970-01-01 00:00:00'"` + Default2 time.Time `db:"default(CURRENT_TIMESTAMP)"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithDefault))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_default", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.EqualValues(t, "default2", table.Columns()[1].Name) + assert.EqualValues(t, "'1970-01-01 00:00:00'", table.Columns()[0].Default) + assert.EqualValues(t, "CURRENT_TIMESTAMP", table.Columns()[1].Default) +} + +func TestParseWithOnlyToDB(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "DB": true, + }, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithOnlyToDB struct { + Default1 time.Time `db:"->"` + Default2 time.Time `db:"<-"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithOnlyToDB))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_only_to_db", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.EqualValues(t, "default2", table.Columns()[1].Name) + assert.EqualValues(t, schemas.ONLYTODB, table.Columns()[0].MapType) + assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType) +} + +func TestParseWithJSON(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "JSON": true, + }, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithJSON struct { + Default1 []string `db:"json"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithJSON))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_json", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsJSON) +} + +func TestParseWithSQLType(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "SQL": true, + }, + names.GonicMapper{ + "UUID": true, + }, + caches.NewManager(), + ) + + type StructWithSQLType struct { + Col1 string `db:"varchar(32)"` + Col2 string `db:"char(32)"` + Int int64 `db:"bigint"` + DateTime time.Time `db:"datetime"` + UUID string `db:"uuid"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithSQLType))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_sql_type", table.Name) + assert.EqualValues(t, 5, len(table.Columns())) + assert.EqualValues(t, "col1", table.Columns()[0].Name) + assert.EqualValues(t, "col2", table.Columns()[1].Name) + assert.EqualValues(t, "int", table.Columns()[2].Name) + assert.EqualValues(t, "date_time", table.Columns()[3].Name) + assert.EqualValues(t, "uuid", table.Columns()[4].Name) + + assert.EqualValues(t, "VARCHAR", table.Columns()[0].SQLType.Name) + assert.EqualValues(t, "CHAR", table.Columns()[1].SQLType.Name) + assert.EqualValues(t, "BIGINT", table.Columns()[2].SQLType.Name) + assert.EqualValues(t, "DATETIME", table.Columns()[3].SQLType.Name) + assert.EqualValues(t, "UUID", table.Columns()[4].SQLType.Name) +} diff --git a/tags/tag.go b/tags/tag.go index bb5b5838..641b8c52 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -14,30 +14,74 @@ import ( "xorm.io/xorm/schemas" ) -func splitTag(tag string) (tags []string) { - tag = strings.TrimSpace(tag) - var hasQuote = false - var lastIdx = 0 - for i, t := range tag { - if t == '\'' { - hasQuote = !hasQuote - } else if t == ' ' { - if lastIdx < i && !hasQuote { - tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) - lastIdx = i + 1 +type tag struct { + name string + params []string +} + +func splitTag(tagStr string) ([]tag, error) { + tagStr = strings.TrimSpace(tagStr) + var ( + inQuote bool + inBigQuote bool + lastIdx int + curTag tag + paramStart int + tags []tag + ) + for i, t := range tagStr { + switch t { + case '\'': + inQuote = !inQuote + case ' ': + if !inQuote && !inBigQuote { + if lastIdx < i { + if curTag.name == "" { + curTag.name = tagStr[lastIdx:i] + } + tags = append(tags, curTag) + lastIdx = i + 1 + curTag = tag{} + } else if lastIdx == i { + lastIdx = i + 1 + } + } else if inBigQuote && !inQuote { + paramStart = i + 1 + } + case ',': + if !inQuote && !inBigQuote { + return nil, fmt.Errorf("comma[%d] of %s should be in quote or big quote", i, tagStr) + } + if !inQuote && inBigQuote { + curTag.params = append(curTag.params, strings.TrimSpace(tagStr[paramStart:i])) + paramStart = i + 1 + } + case '(': + inBigQuote = true + if !inQuote { + curTag.name = tagStr[lastIdx:i] + paramStart = i + 1 + } + case ')': + inBigQuote = false + if !inQuote { + curTag.params = append(curTag.params, tagStr[paramStart:i]) } } } - if lastIdx < len(tag) { - tags = append(tags, strings.TrimSpace(tag[lastIdx:])) + if lastIdx < len(tagStr) { + if curTag.name == "" { + curTag.name = tagStr[lastIdx:] + } + tags = append(tags, curTag) } - return + return tags, nil } // Context represents a context for xorm tag parse. type Context struct { - tagName string - params []string + tag + tagUname string preTag, nextTag string table *schemas.Table col *schemas.Column @@ -76,6 +120,7 @@ var ( "CACHE": CacheTagHandler, "NOCACHE": NoCacheTagHandler, "COMMENT": CommentTagHandler, + "EXTENDS": ExtendsTagHandler, } ) @@ -124,6 +169,7 @@ func NotNullTagHandler(ctx *Context) error { // AutoIncrTagHandler describes autoincr tag handler func AutoIncrTagHandler(ctx *Context) error { ctx.col.IsAutoIncrement = true + ctx.col.Nullable = false /* if len(ctx.params) > 0 { autoStartInt, err := strconv.Atoi(ctx.params[0]) @@ -192,6 +238,7 @@ func UpdatedTagHandler(ctx *Context) error { // DeletedTagHandler describes deleted tag handler func DeletedTagHandler(ctx *Context) error { ctx.col.IsDeleted = true + ctx.col.Nullable = true return nil } @@ -225,41 +272,44 @@ func CommentTagHandler(ctx *Context) error { // SQLTypeTagHandler describes SQL Type tag handler func SQLTypeTagHandler(ctx *Context) error { - ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName} - if strings.EqualFold(ctx.tagName, "JSON") { + ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} + if ctx.tagUname == "JSON" { ctx.col.IsJSON = true } - if len(ctx.params) > 0 { - if ctx.tagName == schemas.Enum { - ctx.col.EnumOptions = make(map[string]int) - for k, v := range ctx.params { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - ctx.col.EnumOptions[v] = k + if len(ctx.params) == 0 { + return nil + } + + switch ctx.tagUname { + case schemas.Enum: + ctx.col.EnumOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.EnumOptions[v] = k + } + case schemas.Set: + ctx.col.SetOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.SetOptions[v] = k + } + default: + var err error + if len(ctx.params) == 2 { + ctx.col.Length, err = strconv.Atoi(ctx.params[0]) + if err != nil { + return err } - } else if ctx.tagName == schemas.Set { - ctx.col.SetOptions = make(map[string]int) - for k, v := range ctx.params { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - ctx.col.SetOptions[v] = k + ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) + if err != nil { + return err } - } else { - var err error - if len(ctx.params) == 2 { - ctx.col.Length, err = strconv.Atoi(ctx.params[0]) - if err != nil { - return err - } - ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) - if err != nil { - return err - } - } else if len(ctx.params) == 1 { - ctx.col.Length, err = strconv.Atoi(ctx.params[0]) - if err != nil { - return err - } + } else if len(ctx.params) == 1 { + ctx.col.Length, err = strconv.Atoi(ctx.params[0]) + if err != nil { + return err } } } @@ -289,11 +339,12 @@ func ExtendsTagHandler(ctx *Context) error { } for _, col := range parentTable.Columns() { col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) + col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...) var tagPrefix = ctx.col.FieldName if len(ctx.params) > 0 { col.Nullable = isPtr - tagPrefix = ctx.params[0] + tagPrefix = strings.Trim(ctx.params[0], "'") if col.IsPrimaryKey { col.Name = ctx.col.FieldName col.IsPrimaryKey = false @@ -315,7 +366,7 @@ func ExtendsTagHandler(ctx *Context) error { default: //TODO: warning } - return nil + return ErrIgnoreField } // CacheTagHandler describes cache tag handler diff --git a/tags/tag_test.go b/tags/tag_test.go index 5775b40a..3ceeefd1 100644 --- a/tags/tag_test.go +++ b/tags/tag_test.go @@ -7,24 +7,83 @@ package tags import ( "testing" - "xorm.io/xorm/internal/utils" + "github.com/stretchr/testify/assert" ) func TestSplitTag(t *testing.T) { var cases = []struct { tag string - tags []string + tags []tag }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ + { + name: "not", + }, + { + name: "null", + }, + { + name: "default", + }, + { + name: "'2000-01-01 00:00:00'", + }, + { + name: "TIMESTAMP", + }, + }, + }, + {"TEXT", []tag{ + { + name: "TEXT", + }, + }, + }, + {"default('2000-01-01 00:00:00')", []tag{ + { + name: "default", + params: []string{ + "'2000-01-01 00:00:00'", + }, + }, + }, + }, + {"json binary", []tag{ + { + name: "json", + }, + { + name: "binary", + }, + }, + }, + {"numeric(10, 2)", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + }, + }, + {"numeric(10, 2) notnull", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + { + name: "notnull", + }, + }, + }, } for _, kase := range cases { - tags := splitTag(kase.tag) - if !utils.SliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } + t.Run(kase.tag, func(t *testing.T) { + tags, err := splitTag(kase.tag) + assert.NoError(t, err) + assert.EqualValues(t, len(tags), len(kase.tags)) + for i := 0; i < len(tags); i++ { + assert.Equal(t, tags[i], kase.tags[i]) + } + }) } }