From 9500b233954f79ea89fe090990745019ccb6d8f1 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 13 Mar 2020 08:57:34 +0000 Subject: [PATCH] Fix pk bug (#1602) Fix pk bug Reviewed-on: https://gitea.com/xorm/xorm/pulls/1602 --- internal/statements/pk.go | 79 ++++++++++++++++++++++++++++ internal/statements/statement.go | 49 +---------------- internal/statements/types.go | 16 ------ internal/statements/update.go | 90 ++++++++++++++++++-------------- schemas/column.go | 2 +- schemas/table.go | 8 +-- session_update.go | 2 +- session_update_test.go | 37 +++++++++++++ 8 files changed, 172 insertions(+), 111 deletions(-) create mode 100644 internal/statements/pk.go delete mode 100644 internal/statements/types.go diff --git a/internal/statements/pk.go b/internal/statements/pk.go new file mode 100644 index 00000000..b6ae0f23 --- /dev/null +++ b/internal/statements/pk.go @@ -0,0 +1,79 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "reflect" + + "xorm.io/builder" + "xorm.io/xorm/schemas" +) + +var ( + ptrPkType = reflect.TypeOf(&schemas.PK{}) + pkType = reflect.TypeOf(schemas.PK{}) + stringType = reflect.TypeOf("") + intType = reflect.TypeOf(int64(0)) + uintType = reflect.TypeOf(uint64(0)) +) + +// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?" +func (statement *Statement) ID(id interface{}) *Statement { + switch t := id.(type) { + case *schemas.PK: + statement.idParam = *t + case schemas.PK: + statement.idParam = t + case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + statement.idParam = schemas.PK{id} + default: + idValue := reflect.ValueOf(id) + idType := idValue.Type() + + switch idType.Kind() { + case reflect.String: + statement.idParam = schemas.PK{idValue.Convert(stringType).Interface()} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + statement.idParam = schemas.PK{idValue.Convert(intType).Interface()} + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + statement.idParam = schemas.PK{idValue.Convert(uintType).Interface()} + case reflect.Slice: + if idType.ConvertibleTo(pkType) { + statement.idParam = idValue.Convert(pkType).Interface().(schemas.PK) + } + case reflect.Ptr: + if idType.ConvertibleTo(ptrPkType) { + statement.idParam = idValue.Convert(ptrPkType).Elem().Interface().(schemas.PK) + } + } + } + + if statement.idParam == nil { + statement.LastError = fmt.Errorf("ID param %#v is not supported", id) + } + + return statement +} + +func (statement *Statement) ProcessIDParam() error { + if statement.idParam == nil || statement.RefTable == nil { + return nil + } + + if len(statement.RefTable.PrimaryKeys) != len(statement.idParam) { + fmt.Println("=====", statement.RefTable.PrimaryKeys, statement.idParam) + return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d", + len(statement.RefTable.PrimaryKeys), + len(statement.idParam), + ) + } + + for i, col := range statement.RefTable.PKColumns() { + var colName = statement.colName(col, statement.TableName()) + statement.cond = statement.cond.And(builder.Eq{colName: statement.idParam[i]}) + } + return nil +} diff --git a/internal/statements/statement.go b/internal/statements/statement.go index d1fcaf59..af94a9d9 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -41,7 +41,7 @@ type Statement struct { tagParser *tags.Parser Start int LimitN *int - idParam *schemas.PK + idParam schemas.PK OrderStr string JoinStr string joinArgs []interface{} @@ -319,34 +319,6 @@ func (statement *Statement) TableName() string { return statement.tableName } -// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?" -func (statement *Statement) ID(id interface{}) *Statement { - idValue := reflect.ValueOf(id) - idType := reflect.TypeOf(idValue.Interface()) - - switch idType { - case ptrPkType: - if pkPtr, ok := (id).(*schemas.PK); ok { - statement.idParam = pkPtr - return statement - } - case pkType: - if pk, ok := (id).(schemas.PK); ok { - statement.idParam = &pk - return statement - } - } - - switch idType.Kind() { - case reflect.String: - statement.idParam = &schemas.PK{idValue.Convert(reflect.TypeOf("")).Interface()} - return statement - } - - statement.idParam = &schemas.PK{id} - return statement -} - // Incr Generate "Update ... Set column = column + arg" statement func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { if len(arg) > 0 { @@ -981,25 +953,6 @@ func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { return "", nil, ErrUnSupportedType } -func (statement *Statement) ProcessIDParam() error { - if statement.idParam == nil || statement.RefTable == nil { - return nil - } - - if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) { - return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d", - len(statement.RefTable.PrimaryKeys), - len(*statement.idParam), - ) - } - - for i, col := range statement.RefTable.PKColumns() { - var colName = statement.colName(col, statement.TableName()) - statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]}) - } - return nil -} - func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string { var colnames = make([]string, len(cols)) for i, col := range cols { diff --git a/internal/statements/types.go b/internal/statements/types.go deleted file mode 100644 index 0ff36f35..00000000 --- a/internal/statements/types.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package statements - -import ( - "reflect" - - "xorm.io/xorm/schemas" -) - -var ( - ptrPkType = reflect.TypeOf(&schemas.PK{}) - pkType = reflect.TypeOf(schemas.PK{}) -) diff --git a/internal/statements/update.go b/internal/statements/update.go index 2e502243..2bd7ddd3 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -18,58 +18,73 @@ import ( "xorm.io/xorm/schemas" ) +func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion, includeUpdated, includeNil, + includeAutoIncr, update bool) (bool, error) { + columnMap := statement.ColumnMap + omitColumnMap := statement.OmitColumnMap + unscoped := statement.unscoped + + if !includeVersion && col.IsVersion { + return false, nil + } + if col.IsCreated && !columnMap.Contain(col.Name) { + return false, nil + } + if !includeUpdated && col.IsUpdated { + return false, nil + } + if !includeAutoIncr && col.IsAutoIncrement { + return false, nil + } + if col.IsDeleted && !unscoped { + return false, nil + } + if omitColumnMap.Contain(col.Name) { + return false, nil + } + if len(columnMap) > 0 && !columnMap.Contain(col.Name) { + return false, nil + } + + if col.MapType == schemas.ONLYFROMDB { + return false, nil + } + + if statement.IncrColumns.IsColExist(col.Name) { + return false, nil + } else if statement.DecrColumns.IsColExist(col.Name) { + return false, nil + } else if statement.ExprColumns.IsColExist(col.Name) { + return false, nil + } + + return true, nil +} + // BuildUpdates auto generating update columnes and values according a struct -func (statement *Statement) BuildUpdates(bean interface{}, +func (statement *Statement) BuildUpdates(tableValue reflect.Value, includeVersion, includeUpdated, includeNil, includeAutoIncr, update bool) ([]string, []interface{}, error) { - //engine := statement.Engine table := statement.RefTable allUseBool := statement.allUseBool useAllCols := statement.useAllCols mustColumnMap := statement.MustColumnMap nullableMap := statement.NullableMap - columnMap := statement.ColumnMap - omitColumnMap := statement.OmitColumnMap - unscoped := statement.unscoped var colNames = make([]string, 0) var args = make([]interface{}, 0) + for _, col := range table.Columns() { - if !includeVersion && col.IsVersion { - continue + ok, err := statement.ifAddColUpdate(col, includeVersion, includeUpdated, includeNil, + includeAutoIncr, update) + if err != nil { + return nil, nil, err } - if col.IsCreated && !columnMap.Contain(col.Name) { - continue - } - if !includeUpdated && col.IsUpdated { - continue - } - if !includeAutoIncr && col.IsAutoIncrement { - continue - } - if col.IsDeleted && !unscoped { - continue - } - if omitColumnMap.Contain(col.Name) { - continue - } - if len(columnMap) > 0 && !columnMap.Contain(col.Name) { + if !ok { continue } - if col.MapType == schemas.ONLYFROMDB { - continue - } - - if statement.IncrColumns.IsColExist(col.Name) { - continue - } else if statement.DecrColumns.IsColExist(col.Name) { - continue - } else if statement.ExprColumns.IsColExist(col.Name) { - continue - } - - fieldValuePtr, err := col.ValueOf(bean) + fieldValuePtr, err := col.ValueOfV(&tableValue) if err != nil { return nil, nil, err } @@ -273,9 +288,6 @@ func (statement *Statement) BuildUpdates(bean interface{}, APPEND: args = append(args, val) - if col.IsPrimaryKey { - continue - } colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name))) } diff --git a/schemas/column.go b/schemas/column.go index 9466f6a5..418629ac 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -21,7 +21,7 @@ const ( type Column struct { Name string TableName string - FieldName string + FieldName string // Avaiable only when parsed from a struct SQLType SQLType IsJSON bool Length int diff --git a/schemas/table.go b/schemas/table.go index 2dac3ea2..38596991 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -53,13 +53,9 @@ func (table *Table) ColumnsSeq() []string { } func (table *Table) columnsByName(name string) []*Column { - n := len(name) - for k := range table.columnsMap { - if len(k) != n { - continue - } + for k, cols := range table.columnsMap { if strings.EqualFold(k, name) { - return table.columnsMap[k] + return cols } } return nil diff --git a/session_update.go b/session_update.go index dadfaaca..62116c47 100644 --- a/session_update.go +++ b/session_update.go @@ -177,7 +177,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if session.statement.ColumnStr() == "" { - colNames, args, err = session.statement.BuildUpdates(bean, false, false, + colNames, args, err = session.statement.BuildUpdates(v, false, false, false, false, true) } else { colNames, args, err = session.genUpdateColumns(bean) diff --git a/session_update_test.go b/session_update_test.go index d65e1207..5111222a 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -1303,3 +1303,40 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) { assert.NoError(t, err) assertGetRecord() } + +func TestUpdateMultiplePK(t *testing.T) { + type TestUpdateMultiplePKStruct struct { + Id string `xorm:"notnull pk" description:"唯一ID号"` + Name string `xorm:"notnull pk" description:"名称"` + Value string `xorm:"notnull varchar(4000)" description:"值"` + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(TestUpdateMultiplePKStruct)) + + test := &TestUpdateMultiplePKStruct{ + Id: "ID1", + Name: "Name1", + Value: "1", + } + _, err := testEngine.Insert(test) + assert.NoError(t, err) + + test.Value = "2" + _, err = testEngine.Where("`id` = ? And `name` = ?", test.Id, test.Name).Cols("Value").Update(test) + assert.NoError(t, err) + + test.Value = "3" + num, err := testEngine.Where("`id` = ? And `name` = ?", test.Id, test.Name).Update(test) + assert.NoError(t, err) + assert.EqualValues(t, 1, num) + + test.Value = "4" + _, err = testEngine.ID([]interface{}{test.Id, test.Name}).Update(test) + assert.NoError(t, err) + + type MySlice []interface{} + test.Value = "5" + _, err = testEngine.ID(&MySlice{test.Id, test.Name}).Update(test) + assert.NoError(t, err) +}