add support for slice, array, map, custom types fields of struct & fixed #4

This commit is contained in:
Lunny Xiao 2013-09-05 23:20:52 +08:00
parent 695b89c35f
commit 3a868531e9
5 changed files with 321 additions and 160 deletions

View File

@ -657,33 +657,59 @@ type MyUInt uint
type MyFloat float64 type MyFloat float64
type MyString string type MyString string
func (s MyString) FromDB(data []byte) error { /*func (s *MyString) FromDB(data []byte) error {
s = MyString(string(data)) reflect.
s MyString(data)
return nil return nil
} }
func (s MyString) ToDB() ([]byte, error) { func (s *MyString) ToDB() ([]byte, error) {
return []byte(string(s)), nil return []byte(string(*s)), nil
} }*/
type MyStruct struct { type MyStruct struct {
Type MyInt Type MyInt
U MyUInt U MyUInt
F MyFloat F MyFloat
//S MyString S MyString
//IA []MyInt IA []MyInt
//UA []MyUInt UA []MyUInt
//FA []MyFloat FA []MyFloat
//SA []MyString SA []MyString
//NameArray []string NameArray []string
Name string Name string
//UIA []uint UIA []uint
UI uint UIA8 []uint8
UIA16 []uint16
UIA32 []uint32
UIA64 []uint64
UI uint
//C64 complex64
MSS map[string]string
} }
func testCustomType(engine *Engine, t *testing.T) { func testCustomType(engine *Engine, t *testing.T) {
err := engine.CreateTables(&MyStruct{}) err := engine.DropTables(&MyStruct{})
if err != nil {
t.Error(err)
panic(err)
return
}
err = engine.CreateTables(&MyStruct{})
i := MyStruct{Name: "Test", Type: MyInt(1)} i := MyStruct{Name: "Test", Type: MyInt(1)}
i.U = 23
i.F = 1.34
i.S = "fafdsafdsaf"
i.UI = 2
i.IA = []MyInt{1, 3, 5}
i.UIA = []uint{1, 3}
i.UIA16 = []uint16{2}
i.UIA32 = []uint32{4, 5}
i.UIA64 = []uint64{6, 7, 9}
i.UIA8 = []uint8{1, 2, 3, 4}
i.NameArray = []string{"ssss fsdf", "lllll, ss"}
i.MSS = map[string]string{"s": "sfds,ss ", "x": "lfjljsl"}
_, err = engine.Insert(&i) _, err = engine.Insert(&i)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -691,6 +717,7 @@ func testCustomType(engine *Engine, t *testing.T) {
return return
} }
fmt.Println(i)
has, err := engine.Get(&i) has, err := engine.Get(&i)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -699,9 +726,22 @@ func testCustomType(engine *Engine, t *testing.T) {
t.Error(errors.New("should get one record")) t.Error(errors.New("should get one record"))
panic(err) panic(err)
} }
}
func testTrans(engine *Engine, t *testing.T) { ss := []MyStruct{}
err = engine.Find(&ss)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(ss)
sss := MyStruct{}
has, err = engine.Get(&sss)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(sss)
} }
type UserCU struct { type UserCU struct {
@ -713,7 +753,13 @@ type UserCU struct {
func testCreatedAndUpdated(engine *Engine, t *testing.T) { func testCreatedAndUpdated(engine *Engine, t *testing.T) {
u := new(UserCU) u := new(UserCU)
err := engine.CreateTables(u) err := engine.DropTables(u)
if err != nil {
t.Error(err)
panic(err)
}
err = engine.CreateTables(u)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)

View File

@ -5,6 +5,7 @@ import (
) )
var ( var (
ParamsTypeError error = errors.New("params type error") ParamsTypeError error = errors.New("params type error")
TableNotFoundError error = errors.New("not found table") TableNotFoundError error = errors.New("not found table")
UnSupportedTypeError error = errors.New("unsupported type error")
) )

View File

@ -2,6 +2,7 @@ package xorm
import ( import (
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@ -35,7 +36,7 @@ func (session *Session) Close() {
session.Engine.Pool.ReleaseDB(session.Engine, session.Db) session.Engine.Pool.ReleaseDB(session.Engine, session.Db)
session.Db = nil session.Db = nil
session.Tx = nil session.Tx = nil
//session.Init() session.Init()
} }
}() }()
} }
@ -196,127 +197,29 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
if _, ok := table.Columns[key]; !ok { if _, ok := table.Columns[key]; !ok {
continue continue
} }
fieldName := table.Columns[key].FieldName col := table.Columns[key]
fieldName := col.FieldName
fieldPath := strings.Split(fieldName, ".") fieldPath := strings.Split(fieldName, ".")
var structField reflect.Value var fieldValue reflect.Value
if len(fieldPath) > 2 { if len(fieldPath) > 2 {
session.Engine.LogError("Unsupported mutliderive", fieldName) session.Engine.LogError("Unsupported mutliderive", fieldName)
continue continue
} else if len(fieldPath) == 2 { } else if len(fieldPath) == 2 {
parentField := dataStruct.FieldByName(fieldPath[0]) parentField := dataStruct.FieldByName(fieldPath[0])
if parentField.IsValid() { if parentField.IsValid() {
structField = parentField.FieldByName(fieldPath[1]) fieldValue = parentField.FieldByName(fieldPath[1])
} }
} else { } else {
structField = dataStruct.FieldByName(fieldName) fieldValue = dataStruct.FieldByName(fieldName)
} }
if !structField.IsValid() || !structField.CanSet() { if !fieldValue.IsValid() || !fieldValue.CanSet() {
continue continue
} }
var v interface{} err := session.bytes2Value(col, &fieldValue, data)
if err != nil {
switch structField.Type().Kind() { return err
case reflect.Slice:
v = data
vv := reflect.ValueOf(v)
if structField.Type().String() == "[]byte" {
fmt.Println("...[]byte...")
}
if vv.Type().Kind() == reflect.Slice {
for i := 0; i < vv.Len(); i++ {
//vv.Index(i)
structField = reflect.AppendSlice(structField, vv)
//reflect.Append(structField, vv.Index(i))
}
} else {
return errors.New(fmt.Sprintf("unsupported from other %v to %v", vv.Type().Kind(), structField.Type().Kind()))
}
case reflect.Array:
if structField.Type().Elem() == reflect.TypeOf(b) {
v = data
structField.Set(reflect.ValueOf(v))
}
case reflect.String:
x := string(data)
structField.SetString(x)
case reflect.Bool:
v = (string(data) == "1")
structField.Set(reflect.ValueOf(v))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
structField.SetInt(x)
case reflect.Float32, reflect.Float64:
x, err := strconv.ParseFloat(string(data), 64)
if err != nil {
return errors.New("arg " + key + " as float64: " + err.Error())
}
structField.SetFloat(x)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
structField.SetUint(x)
//Now only support Time type
case reflect.Struct:
if structField.Type().String() == "time.Time" {
x, err := time.Parse("2006-01-02 15:04:05", string(data))
if err != nil {
x, err = time.Parse("2006-01-02 15:04:05.000 -0700", string(data))
if err != nil {
return errors.New("unsupported time format: " + string(data))
}
}
v = x
structField.Set(reflect.ValueOf(v))
} else if structConvert, ok := structField.Addr().Interface().(Conversion); ok {
err := structConvert.FromDB(data)
if err != nil {
return err
}
continue
} else if session.Statement.UseCascade {
table := session.Engine.AutoMapType(structField.Type())
if table != nil {
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
if x != 0 {
structInter := reflect.New(structField.Type())
newsession := session.Engine.NewSession()
defer newsession.Close()
has, err := newsession.Id(x).Get(structInter.Interface())
if err != nil {
return err
}
if has {
v = structInter.Elem().Interface()
structField.Set(reflect.ValueOf(v))
} else {
session.Engine.LogError("cascade obj is not exist!")
continue
}
} else {
continue
}
} else {
session.Engine.LogError("unsupported struct type in Scan: " + structField.Type().String())
continue
}
} else {
continue
}
default:
return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String())
} }
} }
return nil return nil
@ -458,6 +361,7 @@ func (session *Session) Get(bean interface{}) (bool, error) {
args = session.Statement.RawParams args = session.Statement.RawParams
session.Engine.AutoMap(bean) session.Engine.AutoMap(bean)
} }
resultsSlice, err := session.Query(sql, args...) resultsSlice, err := session.Query(sql, args...)
if err != nil { if err != nil {
return false, err return false, err
@ -467,7 +371,9 @@ func (session *Session) Get(bean interface{}) (bool, error) {
} }
results := resultsSlice[0] results := resultsSlice[0]
err = session.scanMapIntoStruct(bean, results) err = session.scanMapIntoStruct(bean, results)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -654,7 +560,6 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
} }
for res.Next() { for res.Next() {
result := make(map[string][]byte) result := make(map[string][]byte)
//scanResultContainers := make([]interface{}, len(fields))
var scanResultContainers []interface{} var scanResultContainers []interface{}
for i := 0; i < len(fields); i++ { for i := 0; i < len(fields); i++ {
var scanResultContainer interface{} var scanResultContainer interface{}
@ -668,6 +573,7 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
//if row is null then ignore //if row is null then ignore
if rawValue.Interface() == nil { if rawValue.Interface() == nil {
fmt.Println("ignore ...", key, rawValue)
continue continue
} }
aa := reflect.TypeOf(rawValue.Interface()) aa := reflect.TypeOf(rawValue.Interface())
@ -684,9 +590,11 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
result[key] = []byte(str) result[key] = []byte(str)
case reflect.Slice: case reflect.Slice:
if aa.Elem().Kind() == reflect.Uint8 { switch aa.Elem().Kind() {
case reflect.Uint8:
result[key] = rawValue.Interface().([]byte) result[key] = rawValue.Interface().([]byte)
break default:
session.Engine.LogError("Unsupported type")
} }
case reflect.String: case reflect.String:
str = vv.String() str = vv.String()
@ -699,9 +607,9 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
} else { } else {
session.Engine.LogError("Unsupported struct type") session.Engine.LogError("Unsupported struct type")
} }
default:
session.Engine.LogError("Unsupported type")
} }
//default:
} }
resultsSlice = append(resultsSlice, result) resultsSlice = append(resultsSlice, result)
} }
@ -816,7 +724,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsCreated || col.IsUpdated { if col.IsCreated || col.IsUpdated {
args = append(args, time.Now()) args = append(args, time.Now())
} else { } else {
arg, err := session.value2Interface(fieldValue) arg, err := session.value2Interface(col, fieldValue)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -844,7 +752,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsCreated || col.IsUpdated { if col.IsCreated || col.IsUpdated {
args = append(args, time.Now()) args = append(args, time.Now())
} else { } else {
arg, err := session.value2Interface(fieldValue) arg, err := session.value2Interface(col, fieldValue)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -893,7 +801,127 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
return session.innerInsertMulti(rowsSlicePtr) return session.innerInsertMulti(rowsSlicePtr)
} }
func (session *Session) value2Interface(fieldValue reflect.Value) (interface{}, error) { // convert a db data([]byte) to a field value
func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data []byte) error {
if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
return structConvert.FromDB(data)
}
var v interface{}
key := col.Name
fieldType := fieldValue.Type()
switch fieldType.Kind() {
case reflect.Complex64, reflect.Complex128:
x := reflect.New(fieldType)
err := json.Unmarshal(data, x.Interface())
if err != nil {
session.Engine.LogSQL(err)
return err
}
fieldValue.Set(x.Elem())
case reflect.Slice, reflect.Array, reflect.Map:
v = data
t := fieldType.Elem()
k := t.Kind()
if col.SQLType.IsText() {
x := reflect.New(fieldType)
err := json.Unmarshal(data, x.Interface())
if err != nil {
session.Engine.LogSQL(err)
return err
}
fieldValue.Set(x.Elem())
} else if col.SQLType.IsBlob() {
if k == reflect.Uint8 {
fieldValue.Set(reflect.ValueOf(v))
} else {
x := reflect.New(fieldType)
err := json.Unmarshal(data, x.Interface())
if err != nil {
session.Engine.LogSQL(err)
return err
}
fieldValue.Set(x.Elem())
}
} else {
return UnSupportedTypeError
}
case reflect.String:
fieldValue.SetString(string(data))
case reflect.Bool:
v, err := strconv.ParseBool(string(data))
if err != nil {
return errors.New("arg " + key + " as bool: " + err.Error())
}
fieldValue.Set(reflect.ValueOf(v))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
fieldValue.SetInt(x)
case reflect.Float32, reflect.Float64:
x, err := strconv.ParseFloat(string(data), 64)
if err != nil {
return errors.New("arg " + key + " as float64: " + err.Error())
}
fieldValue.SetFloat(x)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
fieldValue.SetUint(x)
//Now only support Time type
case reflect.Struct:
if fieldValue.Type().String() == "time.Time" {
x, err := time.Parse("2006-01-02 15:04:05", string(data))
if err != nil {
x, err = time.Parse("2006-01-02 15:04:05.000 -0700", string(data))
if err != nil {
return errors.New("unsupported time format: " + string(data))
}
}
v = x
fieldValue.Set(reflect.ValueOf(v))
} else if session.Statement.UseCascade {
table := session.Engine.AutoMapType(fieldValue.Type())
if table != nil {
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
if x != 0 {
structInter := reflect.New(fieldValue.Type())
newsession := session.Engine.NewSession()
defer newsession.Close()
has, err := newsession.Id(x).Get(structInter.Interface())
if err != nil {
return err
}
if has {
v = structInter.Elem().Interface()
fieldValue.Set(reflect.ValueOf(v))
} else {
return errors.New("cascade obj is not exist!")
}
}
} else {
return errors.New("unsupported struct type in Scan: " + fieldValue.Type().String())
}
}
default:
return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String())
}
return nil
}
// convert a field value of a struct to interface for put into db
func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) (interface{}, error) {
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
data, err := fieldConvert.ToDB() data, err := fieldConvert.ToDB()
@ -905,15 +933,22 @@ func (session *Session) value2Interface(fieldValue reflect.Value) (interface{},
} }
} }
if fieldValue.Type().Kind() == reflect.Bool { k := fieldValue.Type().Kind()
switch k {
case reflect.Bool:
if fieldValue.Bool() { if fieldValue.Bool() {
return 1, nil return 1, nil
} else { } else {
return 0, nil return 0, nil
} }
} else if fieldValue.Type().String() == "time.Time" { case reflect.String:
return fieldValue.Interface(), nil return fieldValue.String(), nil
} else if fieldValue.Type().Kind() == reflect.Struct { case reflect.Struct:
if fieldValue.Type().String() == "time.Time" {
//return fieldValue.Interface().(time.Time).Format(time.RFC3339Nano), nil
//return fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700"), nil
return fieldValue.Interface(), nil
}
if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok {
if fieldTable.PrimaryKey != "" { if fieldTable.PrimaryKey != "" {
pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumn().FieldName) pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumn().FieldName)
@ -924,12 +959,43 @@ func (session *Session) value2Interface(fieldValue reflect.Value) (interface{},
} else { } else {
return 0, errors.New(fmt.Sprintf("Unsupported type %v", fieldValue.Type())) return 0, errors.New(fmt.Sprintf("Unsupported type %v", fieldValue.Type()))
} }
} else if fieldValue.Type().Kind() == reflect.Array || case reflect.Complex64, reflect.Complex128:
fieldValue.Type().Kind() == reflect.Slice { bytes, err := json.Marshal(fieldValue.Interface())
data := fmt.Sprintf("%v", fieldValue.Interface()) if err != nil {
//fmt.Println(data, "--------") session.Engine.LogSQL(err)
return data, nil return 0, err
} else { }
return string(bytes), nil
case reflect.Array, reflect.Slice, reflect.Map:
if !fieldValue.IsValid() {
return fieldValue.Interface(), nil
}
if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.LogSQL(err)
return 0, err
}
return string(bytes), nil
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (k == reflect.Array || k == reflect.Slice) &&
(fieldValue.Type().Elem().Kind() == reflect.Uint8) {
bytes = fieldValue.Bytes()
} else {
bytes, err = json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.LogSQL(err)
return 0, err
}
}
return bytes, nil
} else {
return nil, UnSupportedTypeError
}
default:
return fieldValue.Interface(), nil return fieldValue.Interface(), nil
} }
} }
@ -961,7 +1027,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if col.IsCreated || col.IsUpdated { if col.IsCreated || col.IsUpdated {
args = append(args, time.Now()) args = append(args, time.Now())
} else { } else {
arg, err := session.value2Interface(fieldValue) arg, err := session.value2Interface(col, fieldValue)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -1007,7 +1073,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
v = int(id) v = int(id)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v = uint(id) v = uint(id)
} }
pkValue.Set(reflect.ValueOf(v)) pkValue.Set(reflect.ValueOf(v))

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
//"strconv" //"strconv"
"encoding/json"
"strings" "strings"
"time" "time"
) )
@ -87,10 +88,15 @@ func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string,
fieldType := reflect.TypeOf(fieldValue.Interface()) fieldType := reflect.TypeOf(fieldValue.Interface())
val := fieldValue.Interface() val := fieldValue.Interface()
switch fieldType.Kind() { switch fieldType.Kind() {
case reflect.Bool:
case reflect.String: case reflect.String:
if fieldValue.String() == "" { if fieldValue.String() == "" {
continue continue
} }
// for MyString, should convert to string or panic
if fieldType.String() != reflect.String.String() {
val = fieldValue.String()
}
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
if fieldValue.Int() == 0 { if fieldValue.Int() == 0 {
continue continue
@ -109,23 +115,55 @@ func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string,
if t.IsZero() { if t.IsZero() {
continue continue
} }
val = t
} else { } else {
engine.AutoMapType(fieldValue.Type()) engine.AutoMapType(fieldValue.Type())
if table, ok := engine.Tables[fieldValue.Type()]; ok {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName)
if pkField.Int() != 0 {
val = pkField.Interface()
} else {
continue
}
}
}
case reflect.Array, reflect.Slice, reflect.Map:
if fieldValue == reflect.Zero(fieldType) {
continue
}
if fieldValue.IsNil() || !fieldValue.IsValid() {
continue
} }
default:
continue
}
if table, ok := engine.Tables[fieldValue.Type()]; ok { if col.SQLType.IsText() {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) bytes, err := json.Marshal(fieldValue.Interface())
if pkField.Int() != 0 { if err != nil {
args = append(args, pkField.Interface()) engine.LogSQL(err)
continue
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
fieldType.Elem().Kind() == reflect.Uint8 {
val = fieldValue.Bytes()
} else {
bytes, err = json.Marshal(fieldValue.Interface())
if err != nil {
engine.LogSQL(err)
continue
}
val = bytes
}
} else { } else {
continue continue
} }
} else { default:
args = append(args, val) val = fieldValue.Interface()
} }
args = append(args, val)
colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
} }

View File

@ -13,6 +13,17 @@ type SQLType struct {
DefaultLength2 int DefaultLength2 int
} }
func (s *SQLType) IsText() bool {
return s.Name == Char || s.Name == Varchar || s.Name == TinyText ||
s.Name == Text || s.Name == MediumText || s.Name == LongText
}
func (s *SQLType) IsBlob() bool {
return (s.Name == TinyBlob) || (s.Name == Blob) ||
s.Name == MediumBlob || s.Name == LongBlob ||
s.Name == Binary || s.Name == VarBinary || s.Name == Bytea
}
var ( var (
Bit = "BIT" Bit = "BIT"
TinyInt = "TINYINT" TinyInt = "TINYINT"