From fd4a80bcad50c5f91c11ccea4119cf11d352e235 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 11 Apr 2018 11:52:16 +0800 Subject: [PATCH] fix id condition not used bug (#882) --- helpers.go | 162 -------------------------------------------- session_cols.go | 107 +++++++++++++++++++++++++++++ session_get_test.go | 17 +++-- session_insert.go | 117 +++++++++++++++++++++++++++----- session_query.go | 4 ++ session_update.go | 96 +++++++++++++++++++++++++- statement.go | 70 +++++++------------ 7 files changed, 342 insertions(+), 231 deletions(-) diff --git a/helpers.go b/helpers.go index f39ed472..f1705782 100644 --- a/helpers.go +++ b/helpers.go @@ -11,7 +11,6 @@ import ( "sort" "strconv" "strings" - "time" "github.com/go-xorm/core" ) @@ -293,19 +292,6 @@ func structName(v reflect.Type) string { return v.Name() } -func col2NewCols(columns ...string) []string { - newColumns := make([]string, 0, len(columns)) - for _, col := range columns { - col = strings.Replace(col, "`", "", -1) - col = strings.Replace(col, `"`, "", -1) - ccols := strings.Split(col, ",") - for _, c := range ccols { - newColumns = append(newColumns, strings.TrimSpace(c)) - } - } - return newColumns -} - func sliceEq(left, right []string) bool { if len(left) != len(right) { return false @@ -320,154 +306,6 @@ func sliceEq(left, right []string) bool { return true } -func setColumnInt(bean interface{}, col *core.Column, t int64) { - v, err := col.ValueOf(bean) - if err != nil { - return - } - if v.CanSet() { - switch v.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32: - v.SetInt(t) - case reflect.Uint, reflect.Uint64, reflect.Uint32: - v.SetUint(uint64(t)) - } - } -} - -func setColumnTime(bean interface{}, col *core.Column, t time.Time) { - v, err := col.ValueOf(bean) - if err != nil { - return - } - if v.CanSet() { - switch v.Type().Kind() { - case reflect.Struct: - v.Set(reflect.ValueOf(t).Convert(v.Type())) - case reflect.Int, reflect.Int64, reflect.Int32: - v.SetInt(t.Unix()) - case reflect.Uint, reflect.Uint64, reflect.Uint32: - v.SetUint(uint64(t.Unix())) - } - } -} - -func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) { - colNames := make([]string, 0, len(table.ColumnsSeq())) - args := make([]interface{}, 0, len(table.ColumnsSeq())) - - for _, col := range table.Columns() { - if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { - if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { - continue - } - } - if col.MapType == core.ONLYFROMDB { - continue - } - - fieldValuePtr, err := col.ValueOf(bean) - if err != nil { - return nil, nil, err - } - fieldValue := *fieldValuePtr - - if col.IsAutoIncrement { - switch fieldValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - case reflect.String: - if len(fieldValue.String()) == 0 { - continue - } - case reflect.Ptr: - if fieldValue.Pointer() == 0 { - continue - } - } - } - - if col.IsDeleted { - continue - } - - if session.statement.ColumnStr != "" { - if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { - continue - } else if _, ok := session.statement.incrColumns[col.Name]; ok { - continue - } else if _, ok := session.statement.decrColumns[col.Name]; ok { - continue - } - } - if session.statement.OmitStr != "" { - if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { - continue - } - } - - // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { - if col.Nullable && isZero(fieldValue.Interface()) { - var nilValue *int - fieldValue = reflect.ValueOf(nilValue) - } - } - - if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { - // if time is non-empty, then set to auto time - val, t := session.engine.nowTime(col) - args = append(args, val) - - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) - } else if col.IsVersion && session.statement.checkVersion { - args = append(args, 1) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return colNames, args, err - } - args = append(args, arg) - } - - if includeQuote { - colNames = append(colNames, session.engine.Quote(col.Name)+" = ?") - } else { - colNames = append(colNames, col.Name) - } - } - return colNames, args, nil -} - func indexName(tableName, idxName string) string { return fmt.Sprintf("IDX_%v_%v", tableName, idxName) } - -func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) { - if len(m) == 0 { - return false, false - } - - n := len(col.Name) - - for mk := range m { - if len(mk) != n { - continue - } - if strings.EqualFold(mk, col.Name) { - return m[mk], true - } - } - - return false, false -} diff --git a/session_cols.go b/session_cols.go index 9972cb0a..1c2b023d 100644 --- a/session_cols.go +++ b/session_cols.go @@ -4,6 +4,113 @@ package xorm +import ( + "reflect" + "strings" + "time" + + "github.com/go-xorm/core" +) + +type incrParam struct { + colName string + arg interface{} +} + +type decrParam struct { + colName string + arg interface{} +} + +type exprParam struct { + colName string + expr string +} + +type columnMap []string + +func (m columnMap) contain(colName string) bool { + if len(m) == 0 { + return false + } + + n := len(colName) + for _, mk := range m { + if len(mk) != n { + continue + } + if strings.EqualFold(mk, colName) { + return true + } + } + + return false +} + +func setColumnInt(bean interface{}, col *core.Column, t int64) { + v, err := col.ValueOf(bean) + if err != nil { + return + } + if v.CanSet() { + switch v.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32: + v.SetInt(t) + case reflect.Uint, reflect.Uint64, reflect.Uint32: + v.SetUint(uint64(t)) + } + } +} + +func setColumnTime(bean interface{}, col *core.Column, t time.Time) { + v, err := col.ValueOf(bean) + if err != nil { + return + } + if v.CanSet() { + switch v.Type().Kind() { + case reflect.Struct: + v.Set(reflect.ValueOf(t).Convert(v.Type())) + case reflect.Int, reflect.Int64, reflect.Int32: + v.SetInt(t.Unix()) + case reflect.Uint, reflect.Uint64, reflect.Uint32: + v.SetUint(uint64(t.Unix())) + } + } +} + +func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) { + if len(m) == 0 { + return false, false + } + + n := len(col.Name) + + for mk := range m { + if len(mk) != n { + continue + } + if strings.EqualFold(mk, col.Name) { + return m[mk], true + } + } + + return false, false +} + +func col2NewCols(columns ...string) []string { + newColumns := make([]string, 0, len(columns)) + for _, col := range columns { + col = strings.Replace(col, "`", "", -1) + col = strings.Replace(col, `"`, "", -1) + ccols := strings.Split(col, ",") + for _, c := range ccols { + newColumns = append(newColumns, strings.TrimSpace(c)) + } + } + return newColumns +} + // Incr provides a query string like "count = count + 1" func (session *Session) Incr(column string, arg ...interface{}) *Session { session.statement.Incr(column, arg...) diff --git a/session_get_test.go b/session_get_test.go index 0437b971..4ec7cf02 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -278,11 +278,20 @@ func TestGetActionMapping(t *testing.T) { assertSync(t, new(ActionMapping)) - var valuesSlice = make([]string, 2) - _, err := testEngine.Table(new(ActionMapping)). - Cols("script_id", "rollback_id"). - ID(1).Get(&valuesSlice) + _, err := testEngine.Insert(&ActionMapping{ + ActionId: "1", + ScriptId: "2", + }) assert.NoError(t, err) + + var valuesSlice = make([]string, 2) + has, err := testEngine.Table(new(ActionMapping)). + Cols("script_id", "rollback_id"). + ID("1").Get(&valuesSlice) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "2", valuesSlice[0]) + assert.EqualValues(t, "", valuesSlice[1]) } func TestGetStructId(t *testing.T) { diff --git a/session_insert.go b/session_insert.go index 8609b80c..31bee5bc 100644 --- a/session_insert.go +++ b/session_insert.go @@ -115,15 +115,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if col.IsDeleted { continue } - if session.statement.ColumnStr != "" { - if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { - continue - } + if session.statement.omitColumnMap.contain(col.Name) { + continue } - if session.statement.OmitStr != "" { - if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { - continue - } + if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + continue } if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { val, t := session.engine.nowTime(col) @@ -170,15 +166,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if col.IsDeleted { continue } - if session.statement.ColumnStr != "" { - if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { - continue - } + if session.statement.omitColumnMap.contain(col.Name) { + continue } - if session.statement.OmitStr != "" { - if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { - continue - } + if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + continue } if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { val, t := session.engine.nowTime(col) @@ -316,8 +308,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok { processor.BeforeInsert() } - // -- - colNames, args, err := genCols(session.statement.RefTable, session, bean, false, false) + + colNames, args, err := session.genInsertColumns(bean) if err != nil { return 0, err } @@ -552,3 +544,92 @@ func (session *Session) cacheInsert(table *core.Table, tables ...string) error { return nil } + +// genInsertColumns generates insert needed columns +func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) { + table := session.statement.RefTable + colNames := make([]string, 0, len(table.ColumnsSeq())) + args := make([]interface{}, 0, len(table.ColumnsSeq())) + + for _, col := range table.Columns() { + if col.MapType == core.ONLYFROMDB { + continue + } + + if col.IsDeleted { + continue + } + + if session.statement.omitColumnMap.contain(col.Name) { + continue + } + + if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + continue + } + + if _, ok := session.statement.incrColumns[col.Name]; ok { + continue + } else if _, ok := session.statement.decrColumns[col.Name]; ok { + continue + } + + fieldValuePtr, err := col.ValueOf(bean) + if err != nil { + return nil, nil, err + } + fieldValue := *fieldValuePtr + + if col.IsAutoIncrement { + switch fieldValue.Type().Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: + if fieldValue.Int() == 0 { + continue + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: + if fieldValue.Uint() == 0 { + continue + } + case reflect.String: + if len(fieldValue.String()) == 0 { + continue + } + case reflect.Ptr: + if fieldValue.Pointer() == 0 { + continue + } + } + } + + // !evalphobia! set fieldValue as nil when column is nullable and zero-value + if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { + if col.Nullable && isZero(fieldValue.Interface()) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + } + } + + if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { + // if time is non-empty, then set to auto time + val, t := session.engine.nowTime(col) + args = append(args, val) + + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) + } else if col.IsVersion && session.statement.checkVersion { + args = append(args, 1) + } else { + arg, err := session.value2Interface(col, fieldValue) + if err != nil { + return colNames, args, err + } + args = append(args, arg) + } + + colNames = append(colNames, col.Name) + } + return colNames, args, nil +} diff --git a/session_query.go b/session_query.go index acfefd3f..5c9aeb39 100644 --- a/session_query.go +++ b/session_query.go @@ -64,6 +64,10 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa } } + if err := session.statement.processIDParam(); err != nil { + return "", nil, err + } + condSQL, condArgs, err := builder.ToSQL(session.statement.cond) if err != nil { return "", nil, err diff --git a/session_update.go b/session_update.go index 11264a61..aba3906e 100644 --- a/session_update.go +++ b/session_update.go @@ -179,9 +179,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 colNames, args = buildUpdates(session.engine, session.statement.RefTable, bean, false, false, false, false, session.statement.allUseBool, session.statement.useAllCols, session.statement.mustColumnMap, session.statement.nullableMap, - session.statement.columnMap, true, session.statement.unscoped) + session.statement.columnMap, session.statement.omitColumnMap, true, session.statement.unscoped) } else { - colNames, args, err = genCols(session.statement.RefTable, session, bean, true, true) + colNames, args, err = session.genUpdateColumns(bean) if err != nil { return 0, err } @@ -202,7 +202,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 table := session.statement.RefTable if session.statement.UseAutoTime && table != nil && table.Updated != "" { - if _, ok := session.statement.columnMap[strings.ToLower(table.Updated)]; !ok { + if !session.statement.columnMap.contain(table.Updated) && + !session.statement.omitColumnMap.contain(table.Updated) { colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") col := table.UpdatedColumn() val, t := session.engine.nowTime(col) @@ -402,3 +403,92 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return res.RowsAffected() } + +func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interface{}, error) { + table := session.statement.RefTable + colNames := make([]string, 0, len(table.ColumnsSeq())) + args := make([]interface{}, 0, len(table.ColumnsSeq())) + + for _, col := range table.Columns() { + if !col.IsVersion && !col.IsCreated && !col.IsUpdated { + if session.statement.omitColumnMap.contain(col.Name) { + continue + } + } + if col.MapType == core.ONLYFROMDB { + continue + } + + fieldValuePtr, err := col.ValueOf(bean) + if err != nil { + return nil, nil, err + } + fieldValue := *fieldValuePtr + + if col.IsAutoIncrement { + switch fieldValue.Type().Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: + if fieldValue.Int() == 0 { + continue + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: + if fieldValue.Uint() == 0 { + continue + } + case reflect.String: + if len(fieldValue.String()) == 0 { + continue + } + case reflect.Ptr: + if fieldValue.Pointer() == 0 { + continue + } + } + } + + if col.IsDeleted || col.IsCreated { + continue + } + + if len(session.statement.columnMap) > 0 { + if !session.statement.columnMap.contain(col.Name) { + continue + } else if _, ok := session.statement.incrColumns[col.Name]; ok { + continue + } else if _, ok := session.statement.decrColumns[col.Name]; ok { + continue + } + } + + // !evalphobia! set fieldValue as nil when column is nullable and zero-value + if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { + if col.Nullable && isZero(fieldValue.Interface()) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + } + } + + if col.IsUpdated && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { + // if time is non-empty, then set to auto time + val, t := session.engine.nowTime(col) + args = append(args, val) + + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) + } else if col.IsVersion && session.statement.checkVersion { + args = append(args, 1) + } else { + arg, err := session.value2Interface(col, fieldValue) + if err != nil { + return colNames, args, err + } + args = append(args, arg) + } + + colNames = append(colNames, session.engine.Quote(col.Name)+" = ?") + } + return colNames, args, nil +} diff --git a/statement.go b/statement.go index ee74e35b..07f48037 100644 --- a/statement.go +++ b/statement.go @@ -18,21 +18,6 @@ import ( "github.com/go-xorm/core" ) -type incrParam struct { - colName string - arg interface{} -} - -type decrParam struct { - colName string - arg interface{} -} - -type exprParam struct { - colName string - expr string -} - // Statement save all the sql info for executing SQL type Statement struct { RefTable *core.Table @@ -47,7 +32,6 @@ type Statement struct { HavingStr string ColumnStr string selectStr string - columnMap map[string]bool useAllCols bool OmitStr string AltTableName string @@ -67,6 +51,8 @@ type Statement struct { allUseBool bool checkVersion bool unscoped bool + columnMap columnMap + omitColumnMap columnMap mustColumnMap map[string]bool nullableMap map[string]bool incrColumns map[string]incrParam @@ -89,7 +75,8 @@ func (statement *Statement) Init() { statement.HavingStr = "" statement.ColumnStr = "" statement.OmitStr = "" - statement.columnMap = make(map[string]bool) + statement.columnMap = columnMap{} + statement.omitColumnMap = columnMap{} statement.AltTableName = "" statement.tableName = "" statement.idParam = nil @@ -240,7 +227,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, mustColumnMap map[string]bool, nullableMap map[string]bool, - columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) { + columnMap columnMap, omitColumnMap columnMap, update, unscoped bool) ([]string, []interface{}) { var colNames = make([]string, 0) var args = make([]interface{}, 0) @@ -260,7 +247,10 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, if col.IsDeleted && !unscoped { continue } - if use, ok := columnMap[strings.ToLower(col.Name)]; ok && !use { + if omitColumnMap.contain(col.Name) { + continue + } + if len(columnMap) > 0 && !columnMap.contain(col.Name) { continue } @@ -596,17 +586,10 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { } func (statement *Statement) colmap2NewColsWithQuote() []string { - newColumns := make([]string, 0, len(statement.columnMap)) - for col := range statement.columnMap { - fields := strings.Split(strings.TrimSpace(col), ".") - if len(fields) == 1 { - newColumns = append(newColumns, statement.Engine.quote(fields[0])) - } else if len(fields) == 2 { - newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+ - statement.Engine.quote(fields[1])) - } else { - panic(errors.New("unwanted colnames")) - } + newColumns := make([]string, len(statement.columnMap), len(statement.columnMap)) + copy(newColumns, statement.columnMap) + for i := 0; i < len(statement.columnMap); i++ { + newColumns[i] = statement.Engine.Quote(newColumns[i]) } return newColumns } @@ -634,10 +617,11 @@ func (statement *Statement) Select(str string) *Statement { func (statement *Statement) Cols(columns ...string) *Statement { cols := col2NewCols(columns...) for _, nc := range cols { - statement.columnMap[strings.ToLower(nc)] = true + statement.columnMap = append(statement.columnMap, nc) } newColumns := statement.colmap2NewColsWithQuote() + statement.ColumnStr = strings.Join(newColumns, ", ") statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) return statement @@ -672,7 +656,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement { func (statement *Statement) Omit(columns ...string) { newColumns := col2NewCols(columns...) for _, nc := range newColumns { - statement.columnMap[strings.ToLower(nc)] = false + statement.omitColumnMap = append(statement.omitColumnMap, nc) } statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) } @@ -796,10 +780,12 @@ func (statement *Statement) genColumnStr() string { columns := statement.RefTable.Columns() for _, col := range columns { - if statement.OmitStr != "" { - if _, ok := getFlagForColumn(statement.columnMap, col); ok { - continue - } + if statement.omitColumnMap.contain(col.Name) { + continue + } + + if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) { + continue } if col.MapType == core.ONLYTODB { @@ -810,10 +796,6 @@ func (statement *Statement) genColumnStr() string { buf.WriteString(", ") } - if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { - buf.WriteString("id() AS ") - } - if statement.JoinStr != "" { if statement.TableAlias != "" { buf.WriteString(statement.TableAlias) @@ -961,6 +943,10 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, columnStr = "*" } + if err := statement.processIDParam(); err != nil { + return "", nil, err + } + if isStruct { if err := statement.mergeConds(bean); err != nil { return "", nil, err @@ -1045,10 +1031,6 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n var top string var mssqlCondi string - if err := statement.processIDParam(); err != nil { - return "", err - } - var buf bytes.Buffer if len(condSQL) > 0 { fmt.Fprintf(&buf, " WHERE %v", condSQL)