Update() add map[string]interface{} as update columns

This commit is contained in:
Lunny Xiao 2013-08-29 13:18:02 +08:00
parent d2507a900f
commit f817b30f28
5 changed files with 59 additions and 12 deletions

View File

@ -150,6 +150,8 @@ func insertTwoTable(engine *Engine, t *testing.T) {
}
}
type Condi map[string]interface{}
func update(engine *Engine, t *testing.T) {
// update by id
user := Userinfo{Username: "xxx", Height: 1.2}
@ -159,6 +161,13 @@ func update(engine *Engine, t *testing.T) {
panic(err)
}
condi := Condi{"username": "zzz", "height": 0.0, "departname": ""}
_, err = engine.Table(&user).Id(1).Update(condi)
if err != nil {
t.Error(err)
panic(err)
}
_, err = engine.Update(&Userinfo{Username: "yyy"}, &user)
if err != nil {
t.Error(err)
@ -431,7 +440,7 @@ func createMultiTables(engine *Engine, t *testing.T) {
for i := 0; i < 10; i++ {
tableName := fmt.Sprintf("user_%v", i)
err = engine.DropTables(tableName)
err = session.DropTable(tableName)
if err != nil {
session.Rollback()
t.Error(err)

View File

@ -161,9 +161,9 @@ func (engine *Engine) In(column string, args ...interface{}) *Session {
return session.In(column, args...)
}
func (engine *Engine) Table(tableName string) *Session {
func (engine *Engine) Table(tableNameOrBean interface{}) *Session {
session := engine.NewSession()
return session.Table(tableName)
return session.Table(tableNameOrBean)
}
func (engine *Engine) Limit(limit int, start ...int) *Session {

10
error.go Normal file
View File

@ -0,0 +1,10 @@
package xorm
import (
"errors"
)
var (
ParamsTypeError error = errors.New("params type error")
TableNotFoundError error = errors.New("not found table")
)

View File

@ -60,8 +60,8 @@ func (session *Session) Id(id int64) *Session {
return session
}
func (session *Session) Table(tableName string) *Session {
session.Statement.Table(tableName)
func (session *Session) Table(tableNameOrBean interface{}) *Session {
session.Statement.Table(tableNameOrBean)
return session
}
@ -870,7 +870,7 @@ func (session *Session) value2Interface(fieldValue reflect.Value) (interface{},
} else if fieldValue.Type().Kind() == reflect.Array ||
fieldValue.Type().Kind() == reflect.Slice {
data := fmt.Sprintf("%v", fieldValue.Interface())
fmt.Println(data, "--------")
//fmt.Println(data, "--------")
return data, nil
} else {
return fieldValue.Interface(), nil
@ -974,14 +974,37 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, err
}
table := session.Engine.AutoMap(bean)
session.Statement.RefTable = table
colNames, args := BuildConditions(session.Engine, table, bean)
t := Type(bean)
var colNames []string
var args []interface{}
if t.Kind() == reflect.Struct {
table := session.Engine.AutoMap(bean)
session.Statement.RefTable = table
colNames, args = BuildConditions(session.Engine, table, bean)
} else if t.Kind() == reflect.Map {
if session.Statement.RefTable == nil {
return -1, TableNotFoundError
}
colNames = make([]string, 0)
args = make([]interface{}, 0)
bValue := reflect.ValueOf(bean)
for _, v := range bValue.MapKeys() {
colNames = append(colNames, session.Engine.Quote(v.String())+" = ?")
args = append(args, bValue.MapIndex(v).Interface())
}
} else {
return -1, ParamsTypeError
}
var condiColNames []string
var condiArgs []interface{}
if len(condiBean) > 0 {
condiColNames, condiArgs = BuildConditions(session.Engine, table, condiBean[0])
condiColNames, condiArgs = BuildConditions(session.Engine, session.Statement.RefTable, condiBean[0])
}
var condition = ""

View File

@ -77,8 +77,13 @@ func (statement *Statement) Where(querystring string, args ...interface{}) {
statement.Params = args
}
func (statement *Statement) Table(tableName string) {
statement.AltTableName = tableName
func (statement *Statement) Table(tableNameOrBean interface{}) {
t := Type(tableNameOrBean)
if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string)
} else if t.Kind() == reflect.Struct {
statement.RefTable = statement.Engine.AutoMapType(t)
}
}
func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string, []interface{}) {