From f817b30f288d3226438a500daa68869cd8d86c5f Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 29 Aug 2013 13:18:02 +0800 Subject: [PATCH] Update() add map[string]interface{} as update columns --- base_test.go | 11 ++++++++++- engine.go | 4 ++-- error.go | 10 ++++++++++ session.go | 37 ++++++++++++++++++++++++++++++------- statement.go | 9 +++++++-- 5 files changed, 59 insertions(+), 12 deletions(-) create mode 100644 error.go diff --git a/base_test.go b/base_test.go index 418bc39d..260f2dcc 100644 --- a/base_test.go +++ b/base_test.go @@ -150,6 +150,8 @@ func insertTwoTable(engine *Engine, t *testing.T) { } } +type Condi map[string]interface{} + func update(engine *Engine, t *testing.T) { // update by id user := Userinfo{Username: "xxx", Height: 1.2} @@ -159,6 +161,13 @@ func update(engine *Engine, t *testing.T) { panic(err) } + condi := Condi{"username": "zzz", "height": 0.0, "departname": ""} + _, err = engine.Table(&user).Id(1).Update(condi) + if err != nil { + t.Error(err) + panic(err) + } + _, err = engine.Update(&Userinfo{Username: "yyy"}, &user) if err != nil { t.Error(err) @@ -431,7 +440,7 @@ func createMultiTables(engine *Engine, t *testing.T) { for i := 0; i < 10; i++ { tableName := fmt.Sprintf("user_%v", i) - err = engine.DropTables(tableName) + err = session.DropTable(tableName) if err != nil { session.Rollback() t.Error(err) diff --git a/engine.go b/engine.go index 970dcd0b..8d75dfe3 100644 --- a/engine.go +++ b/engine.go @@ -161,9 +161,9 @@ func (engine *Engine) In(column string, args ...interface{}) *Session { return session.In(column, args...) } -func (engine *Engine) Table(tableName string) *Session { +func (engine *Engine) Table(tableNameOrBean interface{}) *Session { session := engine.NewSession() - return session.Table(tableName) + return session.Table(tableNameOrBean) } func (engine *Engine) Limit(limit int, start ...int) *Session { diff --git a/error.go b/error.go new file mode 100644 index 00000000..1b258883 --- /dev/null +++ b/error.go @@ -0,0 +1,10 @@ +package xorm + +import ( + "errors" +) + +var ( + ParamsTypeError error = errors.New("params type error") + TableNotFoundError error = errors.New("not found table") +) diff --git a/session.go b/session.go index 19acfe5e..1c6a96b7 100644 --- a/session.go +++ b/session.go @@ -60,8 +60,8 @@ func (session *Session) Id(id int64) *Session { return session } -func (session *Session) Table(tableName string) *Session { - session.Statement.Table(tableName) +func (session *Session) Table(tableNameOrBean interface{}) *Session { + session.Statement.Table(tableNameOrBean) return session } @@ -870,7 +870,7 @@ func (session *Session) value2Interface(fieldValue reflect.Value) (interface{}, } else if fieldValue.Type().Kind() == reflect.Array || fieldValue.Type().Kind() == reflect.Slice { data := fmt.Sprintf("%v", fieldValue.Interface()) - fmt.Println(data, "--------") + //fmt.Println(data, "--------") return data, nil } else { return fieldValue.Interface(), nil @@ -974,14 +974,37 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, err } - table := session.Engine.AutoMap(bean) - session.Statement.RefTable = table - colNames, args := BuildConditions(session.Engine, table, bean) + t := Type(bean) + + var colNames []string + var args []interface{} + + if t.Kind() == reflect.Struct { + table := session.Engine.AutoMap(bean) + session.Statement.RefTable = table + colNames, args = BuildConditions(session.Engine, table, bean) + } else if t.Kind() == reflect.Map { + if session.Statement.RefTable == nil { + return -1, TableNotFoundError + } + colNames = make([]string, 0) + args = make([]interface{}, 0) + bValue := reflect.ValueOf(bean) + + for _, v := range bValue.MapKeys() { + colNames = append(colNames, session.Engine.Quote(v.String())+" = ?") + args = append(args, bValue.MapIndex(v).Interface()) + } + + } else { + return -1, ParamsTypeError + } + var condiColNames []string var condiArgs []interface{} if len(condiBean) > 0 { - condiColNames, condiArgs = BuildConditions(session.Engine, table, condiBean[0]) + condiColNames, condiArgs = BuildConditions(session.Engine, session.Statement.RefTable, condiBean[0]) } var condition = "" diff --git a/statement.go b/statement.go index 35a13bc7..4f044283 100644 --- a/statement.go +++ b/statement.go @@ -77,8 +77,13 @@ func (statement *Statement) Where(querystring string, args ...interface{}) { statement.Params = args } -func (statement *Statement) Table(tableName string) { - statement.AltTableName = tableName +func (statement *Statement) Table(tableNameOrBean interface{}) { + t := Type(tableNameOrBean) + if t.Kind() == reflect.String { + statement.AltTableName = tableNameOrBean.(string) + } else if t.Kind() == reflect.Struct { + statement.RefTable = statement.Engine.AutoMapType(t) + } } func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string, []interface{}) {