From 865979f7160bacccdf910a398b132583891c1173 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 30 Oct 2017 11:07:56 +0800 Subject: [PATCH] add support for map[string]interface{} as condition on Update and Where (#764) --- error.go | 2 ++ session_update.go | 21 +++++++++++++++++---- session_update_test.go | 31 +++++++++++++++++++++++++++++++ statement.go | 6 ++++++ 4 files changed, 56 insertions(+), 4 deletions(-) diff --git a/error.go b/error.go index 2a334f47..cfeefc31 100644 --- a/error.go +++ b/error.go @@ -23,4 +23,6 @@ var ( ErrNeedDeletedCond = errors.New("Delete need at least one condition") // ErrNotImplemented not implemented ErrNotImplemented = errors.New("Not implemented") + // ErrConditionType condition type unsupported + ErrConditionType = errors.New("Unsupported conditon type") ) diff --git a/session_update.go b/session_update.go index ca062981..f5587456 100644 --- a/session_update.go +++ b/session_update.go @@ -242,10 +242,23 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var autoCond builder.Cond if !session.statement.noAutoCondition && len(condiBean) > 0 { - var err error - autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) - if err != nil { - return 0, err + if c, ok := condiBean[0].(map[string]interface{}); ok { + autoCond = builder.Eq(c) + } else { + ct := reflect.TypeOf(condiBean[0]) + k := ct.Kind() + if k == reflect.Ptr { + k = ct.Elem().Kind() + } + if k == reflect.Struct { + var err error + autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) + if err != nil { + return 0, err + } + } else { + return 0, ErrConditionType + } } } diff --git a/session_update_test.go b/session_update_test.go index a978e566..d1bc47bc 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -1215,3 +1215,34 @@ func TestCreatedUpdated2(t *testing.T) { assert.True(t, s2.UpdateAt.Unix() > s.UpdateAt.Unix()) assert.True(t, s2.UpdateAt.Unix() > s2.CreateAt.Unix()) } + +func TestUpdateMapCondition(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UpdateMapCondition struct { + Id int64 + String string + } + + assertSync(t, new(UpdateMapCondition)) + + var c = UpdateMapCondition{ + String: "string", + } + _, err := testEngine.Insert(&c) + assert.NoError(t, err) + + cnt, err := testEngine.Update(&UpdateMapCondition{ + String: "string1", + }, map[string]interface{}{ + "id": c.Id, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var c2 UpdateMapCondition + has, err := testEngine.ID(c.Id).Get(&c2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "string1", c2.String) +} diff --git a/statement.go b/statement.go index 23346c71..69f96983 100644 --- a/statement.go +++ b/statement.go @@ -160,6 +160,9 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme case string: cond := builder.Expr(query.(string), args...) statement.cond = statement.cond.And(cond) + case map[string]interface{}: + cond := builder.Eq(query.(map[string]interface{})) + statement.cond = statement.cond.And(cond) case builder.Cond: cond := query.(builder.Cond) statement.cond = statement.cond.And(cond) @@ -181,6 +184,9 @@ func (statement *Statement) Or(query interface{}, args ...interface{}) *Statemen case string: cond := builder.Expr(query.(string), args...) statement.cond = statement.cond.Or(cond) + case map[string]interface{}: + cond := builder.Eq(query.(map[string]interface{})) + statement.cond = statement.cond.Or(cond) case builder.Cond: cond := query.(builder.Cond) statement.cond = statement.cond.Or(cond)