Update() add map[string]interface{} as update columns

This commit is contained in:
Lunny Xiao 2013-08-29 13:18:02 +08:00
parent d2507a900f
commit f817b30f28
5 changed files with 59 additions and 12 deletions

View File

@ -150,6 +150,8 @@ func insertTwoTable(engine *Engine, t *testing.T) {
} }
} }
type Condi map[string]interface{}
func update(engine *Engine, t *testing.T) { func update(engine *Engine, t *testing.T) {
// update by id // update by id
user := Userinfo{Username: "xxx", Height: 1.2} user := Userinfo{Username: "xxx", Height: 1.2}
@ -159,6 +161,13 @@ func update(engine *Engine, t *testing.T) {
panic(err) panic(err)
} }
condi := Condi{"username": "zzz", "height": 0.0, "departname": ""}
_, err = engine.Table(&user).Id(1).Update(condi)
if err != nil {
t.Error(err)
panic(err)
}
_, err = engine.Update(&Userinfo{Username: "yyy"}, &user) _, err = engine.Update(&Userinfo{Username: "yyy"}, &user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -431,7 +440,7 @@ func createMultiTables(engine *Engine, t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
tableName := fmt.Sprintf("user_%v", i) tableName := fmt.Sprintf("user_%v", i)
err = engine.DropTables(tableName) err = session.DropTable(tableName)
if err != nil { if err != nil {
session.Rollback() session.Rollback()
t.Error(err) t.Error(err)

View File

@ -161,9 +161,9 @@ func (engine *Engine) In(column string, args ...interface{}) *Session {
return session.In(column, args...) return session.In(column, args...)
} }
func (engine *Engine) Table(tableName string) *Session { func (engine *Engine) Table(tableNameOrBean interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
return session.Table(tableName) return session.Table(tableNameOrBean)
} }
func (engine *Engine) Limit(limit int, start ...int) *Session { func (engine *Engine) Limit(limit int, start ...int) *Session {

10
error.go Normal file
View File

@ -0,0 +1,10 @@
package xorm
import (
"errors"
)
var (
ParamsTypeError error = errors.New("params type error")
TableNotFoundError error = errors.New("not found table")
)

View File

@ -60,8 +60,8 @@ func (session *Session) Id(id int64) *Session {
return session return session
} }
func (session *Session) Table(tableName string) *Session { func (session *Session) Table(tableNameOrBean interface{}) *Session {
session.Statement.Table(tableName) session.Statement.Table(tableNameOrBean)
return session return session
} }
@ -870,7 +870,7 @@ func (session *Session) value2Interface(fieldValue reflect.Value) (interface{},
} else if fieldValue.Type().Kind() == reflect.Array || } else if fieldValue.Type().Kind() == reflect.Array ||
fieldValue.Type().Kind() == reflect.Slice { fieldValue.Type().Kind() == reflect.Slice {
data := fmt.Sprintf("%v", fieldValue.Interface()) data := fmt.Sprintf("%v", fieldValue.Interface())
fmt.Println(data, "--------") //fmt.Println(data, "--------")
return data, nil return data, nil
} else { } else {
return fieldValue.Interface(), nil return fieldValue.Interface(), nil
@ -974,14 +974,37 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, err return 0, err
} }
t := Type(bean)
var colNames []string
var args []interface{}
if t.Kind() == reflect.Struct {
table := session.Engine.AutoMap(bean) table := session.Engine.AutoMap(bean)
session.Statement.RefTable = table session.Statement.RefTable = table
colNames, args := BuildConditions(session.Engine, table, bean) colNames, args = BuildConditions(session.Engine, table, bean)
} else if t.Kind() == reflect.Map {
if session.Statement.RefTable == nil {
return -1, TableNotFoundError
}
colNames = make([]string, 0)
args = make([]interface{}, 0)
bValue := reflect.ValueOf(bean)
for _, v := range bValue.MapKeys() {
colNames = append(colNames, session.Engine.Quote(v.String())+" = ?")
args = append(args, bValue.MapIndex(v).Interface())
}
} else {
return -1, ParamsTypeError
}
var condiColNames []string var condiColNames []string
var condiArgs []interface{} var condiArgs []interface{}
if len(condiBean) > 0 { if len(condiBean) > 0 {
condiColNames, condiArgs = BuildConditions(session.Engine, table, condiBean[0]) condiColNames, condiArgs = BuildConditions(session.Engine, session.Statement.RefTable, condiBean[0])
} }
var condition = "" var condition = ""

View File

@ -77,8 +77,13 @@ func (statement *Statement) Where(querystring string, args ...interface{}) {
statement.Params = args statement.Params = args
} }
func (statement *Statement) Table(tableName string) { func (statement *Statement) Table(tableNameOrBean interface{}) {
statement.AltTableName = tableName t := Type(tableNameOrBean)
if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string)
} else if t.Kind() == reflect.Struct {
statement.RefTable = statement.Engine.AutoMapType(t)
}
} }
func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string, []interface{}) { func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string, []interface{}) {