Fix pk bug (#1602)
Fix pk bug Reviewed-on: https://gitea.com/xorm/xorm/pulls/1602
This commit is contained in:
parent
c56c8e122a
commit
9500b23395
|
@ -0,0 +1,79 @@
|
||||||
|
// Copyright 2017 The Xorm Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package statements
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"xorm.io/builder"
|
||||||
|
"xorm.io/xorm/schemas"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ptrPkType = reflect.TypeOf(&schemas.PK{})
|
||||||
|
pkType = reflect.TypeOf(schemas.PK{})
|
||||||
|
stringType = reflect.TypeOf("")
|
||||||
|
intType = reflect.TypeOf(int64(0))
|
||||||
|
uintType = reflect.TypeOf(uint64(0))
|
||||||
|
)
|
||||||
|
|
||||||
|
// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
|
||||||
|
func (statement *Statement) ID(id interface{}) *Statement {
|
||||||
|
switch t := id.(type) {
|
||||||
|
case *schemas.PK:
|
||||||
|
statement.idParam = *t
|
||||||
|
case schemas.PK:
|
||||||
|
statement.idParam = t
|
||||||
|
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||||
|
statement.idParam = schemas.PK{id}
|
||||||
|
default:
|
||||||
|
idValue := reflect.ValueOf(id)
|
||||||
|
idType := idValue.Type()
|
||||||
|
|
||||||
|
switch idType.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
statement.idParam = schemas.PK{idValue.Convert(stringType).Interface()}
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
statement.idParam = schemas.PK{idValue.Convert(intType).Interface()}
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
statement.idParam = schemas.PK{idValue.Convert(uintType).Interface()}
|
||||||
|
case reflect.Slice:
|
||||||
|
if idType.ConvertibleTo(pkType) {
|
||||||
|
statement.idParam = idValue.Convert(pkType).Interface().(schemas.PK)
|
||||||
|
}
|
||||||
|
case reflect.Ptr:
|
||||||
|
if idType.ConvertibleTo(ptrPkType) {
|
||||||
|
statement.idParam = idValue.Convert(ptrPkType).Elem().Interface().(schemas.PK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if statement.idParam == nil {
|
||||||
|
statement.LastError = fmt.Errorf("ID param %#v is not supported", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return statement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) ProcessIDParam() error {
|
||||||
|
if statement.idParam == nil || statement.RefTable == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(statement.RefTable.PrimaryKeys) != len(statement.idParam) {
|
||||||
|
fmt.Println("=====", statement.RefTable.PrimaryKeys, statement.idParam)
|
||||||
|
return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
|
||||||
|
len(statement.RefTable.PrimaryKeys),
|
||||||
|
len(statement.idParam),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, col := range statement.RefTable.PKColumns() {
|
||||||
|
var colName = statement.colName(col, statement.TableName())
|
||||||
|
statement.cond = statement.cond.And(builder.Eq{colName: statement.idParam[i]})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -41,7 +41,7 @@ type Statement struct {
|
||||||
tagParser *tags.Parser
|
tagParser *tags.Parser
|
||||||
Start int
|
Start int
|
||||||
LimitN *int
|
LimitN *int
|
||||||
idParam *schemas.PK
|
idParam schemas.PK
|
||||||
OrderStr string
|
OrderStr string
|
||||||
JoinStr string
|
JoinStr string
|
||||||
joinArgs []interface{}
|
joinArgs []interface{}
|
||||||
|
@ -319,34 +319,6 @@ func (statement *Statement) TableName() string {
|
||||||
return statement.tableName
|
return statement.tableName
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
|
|
||||||
func (statement *Statement) ID(id interface{}) *Statement {
|
|
||||||
idValue := reflect.ValueOf(id)
|
|
||||||
idType := reflect.TypeOf(idValue.Interface())
|
|
||||||
|
|
||||||
switch idType {
|
|
||||||
case ptrPkType:
|
|
||||||
if pkPtr, ok := (id).(*schemas.PK); ok {
|
|
||||||
statement.idParam = pkPtr
|
|
||||||
return statement
|
|
||||||
}
|
|
||||||
case pkType:
|
|
||||||
if pk, ok := (id).(schemas.PK); ok {
|
|
||||||
statement.idParam = &pk
|
|
||||||
return statement
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch idType.Kind() {
|
|
||||||
case reflect.String:
|
|
||||||
statement.idParam = &schemas.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
|
|
||||||
return statement
|
|
||||||
}
|
|
||||||
|
|
||||||
statement.idParam = &schemas.PK{id}
|
|
||||||
return statement
|
|
||||||
}
|
|
||||||
|
|
||||||
// Incr Generate "Update ... Set column = column + arg" statement
|
// Incr Generate "Update ... Set column = column + arg" statement
|
||||||
func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
|
func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
|
||||||
if len(arg) > 0 {
|
if len(arg) > 0 {
|
||||||
|
@ -981,25 +953,6 @@ func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) {
|
||||||
return "", nil, ErrUnSupportedType
|
return "", nil, ErrUnSupportedType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) ProcessIDParam() error {
|
|
||||||
if statement.idParam == nil || statement.RefTable == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
|
|
||||||
return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
|
|
||||||
len(statement.RefTable.PrimaryKeys),
|
|
||||||
len(*statement.idParam),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, col := range statement.RefTable.PKColumns() {
|
|
||||||
var colName = statement.colName(col, statement.TableName())
|
|
||||||
statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string {
|
func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string {
|
||||||
var colnames = make([]string, len(cols))
|
var colnames = make([]string, len(cols))
|
||||||
for i, col := range cols {
|
for i, col := range cols {
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
// Copyright 2017 The Xorm Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package statements
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"xorm.io/xorm/schemas"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ptrPkType = reflect.TypeOf(&schemas.PK{})
|
|
||||||
pkType = reflect.TypeOf(schemas.PK{})
|
|
||||||
)
|
|
|
@ -18,58 +18,73 @@ import (
|
||||||
"xorm.io/xorm/schemas"
|
"xorm.io/xorm/schemas"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion, includeUpdated, includeNil,
|
||||||
|
includeAutoIncr, update bool) (bool, error) {
|
||||||
|
columnMap := statement.ColumnMap
|
||||||
|
omitColumnMap := statement.OmitColumnMap
|
||||||
|
unscoped := statement.unscoped
|
||||||
|
|
||||||
|
if !includeVersion && col.IsVersion {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if col.IsCreated && !columnMap.Contain(col.Name) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !includeUpdated && col.IsUpdated {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !includeAutoIncr && col.IsAutoIncrement {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if col.IsDeleted && !unscoped {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if omitColumnMap.Contain(col.Name) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if len(columnMap) > 0 && !columnMap.Contain(col.Name) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if col.MapType == schemas.ONLYFROMDB {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if statement.IncrColumns.IsColExist(col.Name) {
|
||||||
|
return false, nil
|
||||||
|
} else if statement.DecrColumns.IsColExist(col.Name) {
|
||||||
|
return false, nil
|
||||||
|
} else if statement.ExprColumns.IsColExist(col.Name) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
// BuildUpdates auto generating update columnes and values according a struct
|
// BuildUpdates auto generating update columnes and values according a struct
|
||||||
func (statement *Statement) BuildUpdates(bean interface{},
|
func (statement *Statement) BuildUpdates(tableValue reflect.Value,
|
||||||
includeVersion, includeUpdated, includeNil,
|
includeVersion, includeUpdated, includeNil,
|
||||||
includeAutoIncr, update bool) ([]string, []interface{}, error) {
|
includeAutoIncr, update bool) ([]string, []interface{}, error) {
|
||||||
//engine := statement.Engine
|
|
||||||
table := statement.RefTable
|
table := statement.RefTable
|
||||||
allUseBool := statement.allUseBool
|
allUseBool := statement.allUseBool
|
||||||
useAllCols := statement.useAllCols
|
useAllCols := statement.useAllCols
|
||||||
mustColumnMap := statement.MustColumnMap
|
mustColumnMap := statement.MustColumnMap
|
||||||
nullableMap := statement.NullableMap
|
nullableMap := statement.NullableMap
|
||||||
columnMap := statement.ColumnMap
|
|
||||||
omitColumnMap := statement.OmitColumnMap
|
|
||||||
unscoped := statement.unscoped
|
|
||||||
|
|
||||||
var colNames = make([]string, 0)
|
var colNames = make([]string, 0)
|
||||||
var args = make([]interface{}, 0)
|
var args = make([]interface{}, 0)
|
||||||
|
|
||||||
for _, col := range table.Columns() {
|
for _, col := range table.Columns() {
|
||||||
if !includeVersion && col.IsVersion {
|
ok, err := statement.ifAddColUpdate(col, includeVersion, includeUpdated, includeNil,
|
||||||
continue
|
includeAutoIncr, update)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if col.IsCreated && !columnMap.Contain(col.Name) {
|
if !ok {
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !includeUpdated && col.IsUpdated {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !includeAutoIncr && col.IsAutoIncrement {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if col.IsDeleted && !unscoped {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if omitColumnMap.Contain(col.Name) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if len(columnMap) > 0 && !columnMap.Contain(col.Name) {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if col.MapType == schemas.ONLYFROMDB {
|
fieldValuePtr, err := col.ValueOfV(&tableValue)
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if statement.IncrColumns.IsColExist(col.Name) {
|
|
||||||
continue
|
|
||||||
} else if statement.DecrColumns.IsColExist(col.Name) {
|
|
||||||
continue
|
|
||||||
} else if statement.ExprColumns.IsColExist(col.Name) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
fieldValuePtr, err := col.ValueOf(bean)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -273,9 +288,6 @@ func (statement *Statement) BuildUpdates(bean interface{},
|
||||||
|
|
||||||
APPEND:
|
APPEND:
|
||||||
args = append(args, val)
|
args = append(args, val)
|
||||||
if col.IsPrimaryKey {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name)))
|
colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ const (
|
||||||
type Column struct {
|
type Column struct {
|
||||||
Name string
|
Name string
|
||||||
TableName string
|
TableName string
|
||||||
FieldName string
|
FieldName string // Avaiable only when parsed from a struct
|
||||||
SQLType SQLType
|
SQLType SQLType
|
||||||
IsJSON bool
|
IsJSON bool
|
||||||
Length int
|
Length int
|
||||||
|
|
|
@ -53,13 +53,9 @@ func (table *Table) ColumnsSeq() []string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *Table) columnsByName(name string) []*Column {
|
func (table *Table) columnsByName(name string) []*Column {
|
||||||
n := len(name)
|
for k, cols := range table.columnsMap {
|
||||||
for k := range table.columnsMap {
|
|
||||||
if len(k) != n {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.EqualFold(k, name) {
|
if strings.EqualFold(k, name) {
|
||||||
return table.columnsMap[k]
|
return cols
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -177,7 +177,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
|
|
||||||
if session.statement.ColumnStr() == "" {
|
if session.statement.ColumnStr() == "" {
|
||||||
colNames, args, err = session.statement.BuildUpdates(bean, false, false,
|
colNames, args, err = session.statement.BuildUpdates(v, false, false,
|
||||||
false, false, true)
|
false, false, true)
|
||||||
} else {
|
} else {
|
||||||
colNames, args, err = session.genUpdateColumns(bean)
|
colNames, args, err = session.genUpdateColumns(bean)
|
||||||
|
|
|
@ -1303,3 +1303,40 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assertGetRecord()
|
assertGetRecord()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateMultiplePK(t *testing.T) {
|
||||||
|
type TestUpdateMultiplePKStruct struct {
|
||||||
|
Id string `xorm:"notnull pk" description:"唯一ID号"`
|
||||||
|
Name string `xorm:"notnull pk" description:"名称"`
|
||||||
|
Value string `xorm:"notnull varchar(4000)" description:"值"`
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, prepareEngine())
|
||||||
|
assertSync(t, new(TestUpdateMultiplePKStruct))
|
||||||
|
|
||||||
|
test := &TestUpdateMultiplePKStruct{
|
||||||
|
Id: "ID1",
|
||||||
|
Name: "Name1",
|
||||||
|
Value: "1",
|
||||||
|
}
|
||||||
|
_, err := testEngine.Insert(test)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
test.Value = "2"
|
||||||
|
_, err = testEngine.Where("`id` = ? And `name` = ?", test.Id, test.Name).Cols("Value").Update(test)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
test.Value = "3"
|
||||||
|
num, err := testEngine.Where("`id` = ? And `name` = ?", test.Id, test.Name).Update(test)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, num)
|
||||||
|
|
||||||
|
test.Value = "4"
|
||||||
|
_, err = testEngine.ID([]interface{}{test.Id, test.Name}).Update(test)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
type MySlice []interface{}
|
||||||
|
test.Value = "5"
|
||||||
|
_, err = testEngine.ID(&MySlice{test.Id, test.Name}).Update(test)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue