add support for slice, array, map, custom types fields of struct & fixed #4

This commit is contained in:
Lunny Xiao 2013-09-05 23:20:52 +08:00
parent 695b89c35f
commit 3a868531e9
5 changed files with 321 additions and 160 deletions

View File

@ -657,33 +657,59 @@ type MyUInt uint
type MyFloat float64
type MyString string
func (s MyString) FromDB(data []byte) error {
s = MyString(string(data))
/*func (s *MyString) FromDB(data []byte) error {
reflect.
s MyString(data)
return nil
}
func (s MyString) ToDB() ([]byte, error) {
return []byte(string(s)), nil
}
func (s *MyString) ToDB() ([]byte, error) {
return []byte(string(*s)), nil
}*/
type MyStruct struct {
Type MyInt
U MyUInt
F MyFloat
//S MyString
//IA []MyInt
//UA []MyUInt
//FA []MyFloat
//SA []MyString
//NameArray []string
S MyString
IA []MyInt
UA []MyUInt
FA []MyFloat
SA []MyString
NameArray []string
Name string
//UIA []uint
UIA []uint
UIA8 []uint8
UIA16 []uint16
UIA32 []uint32
UIA64 []uint64
UI uint
//C64 complex64
MSS map[string]string
}
func testCustomType(engine *Engine, t *testing.T) {
err := engine.CreateTables(&MyStruct{})
err := engine.DropTables(&MyStruct{})
if err != nil {
t.Error(err)
panic(err)
return
}
err = engine.CreateTables(&MyStruct{})
i := MyStruct{Name: "Test", Type: MyInt(1)}
i.U = 23
i.F = 1.34
i.S = "fafdsafdsaf"
i.UI = 2
i.IA = []MyInt{1, 3, 5}
i.UIA = []uint{1, 3}
i.UIA16 = []uint16{2}
i.UIA32 = []uint32{4, 5}
i.UIA64 = []uint64{6, 7, 9}
i.UIA8 = []uint8{1, 2, 3, 4}
i.NameArray = []string{"ssss fsdf", "lllll, ss"}
i.MSS = map[string]string{"s": "sfds,ss ", "x": "lfjljsl"}
_, err = engine.Insert(&i)
if err != nil {
t.Error(err)
@ -691,6 +717,7 @@ func testCustomType(engine *Engine, t *testing.T) {
return
}
fmt.Println(i)
has, err := engine.Get(&i)
if err != nil {
t.Error(err)
@ -699,9 +726,22 @@ func testCustomType(engine *Engine, t *testing.T) {
t.Error(errors.New("should get one record"))
panic(err)
}
}
func testTrans(engine *Engine, t *testing.T) {
ss := []MyStruct{}
err = engine.Find(&ss)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(ss)
sss := MyStruct{}
has, err = engine.Get(&sss)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(sss)
}
type UserCU struct {
@ -713,7 +753,13 @@ type UserCU struct {
func testCreatedAndUpdated(engine *Engine, t *testing.T) {
u := new(UserCU)
err := engine.CreateTables(u)
err := engine.DropTables(u)
if err != nil {
t.Error(err)
panic(err)
}
err = engine.CreateTables(u)
if err != nil {
t.Error(err)
panic(err)

View File

@ -7,4 +7,5 @@ import (
var (
ParamsTypeError error = errors.New("params type error")
TableNotFoundError error = errors.New("not found table")
UnSupportedTypeError error = errors.New("unsupported type error")
)

View File

@ -2,6 +2,7 @@ package xorm
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"reflect"
@ -35,7 +36,7 @@ func (session *Session) Close() {
session.Engine.Pool.ReleaseDB(session.Engine, session.Db)
session.Db = nil
session.Tx = nil
//session.Init()
session.Init()
}
}()
}
@ -196,127 +197,29 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
if _, ok := table.Columns[key]; !ok {
continue
}
fieldName := table.Columns[key].FieldName
col := table.Columns[key]
fieldName := col.FieldName
fieldPath := strings.Split(fieldName, ".")
var structField reflect.Value
var fieldValue reflect.Value
if len(fieldPath) > 2 {
session.Engine.LogError("Unsupported mutliderive", fieldName)
continue
} else if len(fieldPath) == 2 {
parentField := dataStruct.FieldByName(fieldPath[0])
if parentField.IsValid() {
structField = parentField.FieldByName(fieldPath[1])
fieldValue = parentField.FieldByName(fieldPath[1])
}
} else {
structField = dataStruct.FieldByName(fieldName)
fieldValue = dataStruct.FieldByName(fieldName)
}
if !structField.IsValid() || !structField.CanSet() {
if !fieldValue.IsValid() || !fieldValue.CanSet() {
continue
}
var v interface{}
switch structField.Type().Kind() {
case reflect.Slice:
v = data
vv := reflect.ValueOf(v)
if structField.Type().String() == "[]byte" {
fmt.Println("...[]byte...")
}
if vv.Type().Kind() == reflect.Slice {
for i := 0; i < vv.Len(); i++ {
//vv.Index(i)
structField = reflect.AppendSlice(structField, vv)
//reflect.Append(structField, vv.Index(i))
}
} else {
return errors.New(fmt.Sprintf("unsupported from other %v to %v", vv.Type().Kind(), structField.Type().Kind()))
}
case reflect.Array:
if structField.Type().Elem() == reflect.TypeOf(b) {
v = data
structField.Set(reflect.ValueOf(v))
}
case reflect.String:
x := string(data)
structField.SetString(x)
case reflect.Bool:
v = (string(data) == "1")
structField.Set(reflect.ValueOf(v))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
structField.SetInt(x)
case reflect.Float32, reflect.Float64:
x, err := strconv.ParseFloat(string(data), 64)
if err != nil {
return errors.New("arg " + key + " as float64: " + err.Error())
}
structField.SetFloat(x)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
structField.SetUint(x)
//Now only support Time type
case reflect.Struct:
if structField.Type().String() == "time.Time" {
x, err := time.Parse("2006-01-02 15:04:05", string(data))
if err != nil {
x, err = time.Parse("2006-01-02 15:04:05.000 -0700", string(data))
if err != nil {
return errors.New("unsupported time format: " + string(data))
}
}
v = x
structField.Set(reflect.ValueOf(v))
} else if structConvert, ok := structField.Addr().Interface().(Conversion); ok {
err := structConvert.FromDB(data)
err := session.bytes2Value(col, &fieldValue, data)
if err != nil {
return err
}
continue
} else if session.Statement.UseCascade {
table := session.Engine.AutoMapType(structField.Type())
if table != nil {
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
if x != 0 {
structInter := reflect.New(structField.Type())
newsession := session.Engine.NewSession()
defer newsession.Close()
has, err := newsession.Id(x).Get(structInter.Interface())
if err != nil {
return err
}
if has {
v = structInter.Elem().Interface()
structField.Set(reflect.ValueOf(v))
} else {
session.Engine.LogError("cascade obj is not exist!")
continue
}
} else {
continue
}
} else {
session.Engine.LogError("unsupported struct type in Scan: " + structField.Type().String())
continue
}
} else {
continue
}
default:
return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String())
}
}
return nil
@ -458,6 +361,7 @@ func (session *Session) Get(bean interface{}) (bool, error) {
args = session.Statement.RawParams
session.Engine.AutoMap(bean)
}
resultsSlice, err := session.Query(sql, args...)
if err != nil {
return false, err
@ -467,7 +371,9 @@ func (session *Session) Get(bean interface{}) (bool, error) {
}
results := resultsSlice[0]
err = session.scanMapIntoStruct(bean, results)
if err != nil {
return false, err
}
@ -654,7 +560,6 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
}
for res.Next() {
result := make(map[string][]byte)
//scanResultContainers := make([]interface{}, len(fields))
var scanResultContainers []interface{}
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
@ -668,6 +573,7 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
//if row is null then ignore
if rawValue.Interface() == nil {
fmt.Println("ignore ...", key, rawValue)
continue
}
aa := reflect.TypeOf(rawValue.Interface())
@ -684,9 +590,11 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
result[key] = []byte(str)
case reflect.Slice:
if aa.Elem().Kind() == reflect.Uint8 {
switch aa.Elem().Kind() {
case reflect.Uint8:
result[key] = rawValue.Interface().([]byte)
break
default:
session.Engine.LogError("Unsupported type")
}
case reflect.String:
str = vv.String()
@ -699,9 +607,9 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
} else {
session.Engine.LogError("Unsupported struct type")
}
default:
session.Engine.LogError("Unsupported type")
}
//default:
}
resultsSlice = append(resultsSlice, result)
}
@ -816,7 +724,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsCreated || col.IsUpdated {
args = append(args, time.Now())
} else {
arg, err := session.value2Interface(fieldValue)
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return 0, err
}
@ -844,7 +752,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsCreated || col.IsUpdated {
args = append(args, time.Now())
} else {
arg, err := session.value2Interface(fieldValue)
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return 0, err
}
@ -893,7 +801,127 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
return session.innerInsertMulti(rowsSlicePtr)
}
func (session *Session) value2Interface(fieldValue reflect.Value) (interface{}, error) {
// convert a db data([]byte) to a field value
func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data []byte) error {
if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
return structConvert.FromDB(data)
}
var v interface{}
key := col.Name
fieldType := fieldValue.Type()
switch fieldType.Kind() {
case reflect.Complex64, reflect.Complex128:
x := reflect.New(fieldType)
err := json.Unmarshal(data, x.Interface())
if err != nil {
session.Engine.LogSQL(err)
return err
}
fieldValue.Set(x.Elem())
case reflect.Slice, reflect.Array, reflect.Map:
v = data
t := fieldType.Elem()
k := t.Kind()
if col.SQLType.IsText() {
x := reflect.New(fieldType)
err := json.Unmarshal(data, x.Interface())
if err != nil {
session.Engine.LogSQL(err)
return err
}
fieldValue.Set(x.Elem())
} else if col.SQLType.IsBlob() {
if k == reflect.Uint8 {
fieldValue.Set(reflect.ValueOf(v))
} else {
x := reflect.New(fieldType)
err := json.Unmarshal(data, x.Interface())
if err != nil {
session.Engine.LogSQL(err)
return err
}
fieldValue.Set(x.Elem())
}
} else {
return UnSupportedTypeError
}
case reflect.String:
fieldValue.SetString(string(data))
case reflect.Bool:
v, err := strconv.ParseBool(string(data))
if err != nil {
return errors.New("arg " + key + " as bool: " + err.Error())
}
fieldValue.Set(reflect.ValueOf(v))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
fieldValue.SetInt(x)
case reflect.Float32, reflect.Float64:
x, err := strconv.ParseFloat(string(data), 64)
if err != nil {
return errors.New("arg " + key + " as float64: " + err.Error())
}
fieldValue.SetFloat(x)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
fieldValue.SetUint(x)
//Now only support Time type
case reflect.Struct:
if fieldValue.Type().String() == "time.Time" {
x, err := time.Parse("2006-01-02 15:04:05", string(data))
if err != nil {
x, err = time.Parse("2006-01-02 15:04:05.000 -0700", string(data))
if err != nil {
return errors.New("unsupported time format: " + string(data))
}
}
v = x
fieldValue.Set(reflect.ValueOf(v))
} else if session.Statement.UseCascade {
table := session.Engine.AutoMapType(fieldValue.Type())
if table != nil {
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
if x != 0 {
structInter := reflect.New(fieldValue.Type())
newsession := session.Engine.NewSession()
defer newsession.Close()
has, err := newsession.Id(x).Get(structInter.Interface())
if err != nil {
return err
}
if has {
v = structInter.Elem().Interface()
fieldValue.Set(reflect.ValueOf(v))
} else {
return errors.New("cascade obj is not exist!")
}
}
} else {
return errors.New("unsupported struct type in Scan: " + fieldValue.Type().String())
}
}
default:
return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String())
}
return nil
}
// convert a field value of a struct to interface for put into db
func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) (interface{}, error) {
if fieldValue.CanAddr() {
if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
data, err := fieldConvert.ToDB()
@ -905,15 +933,22 @@ func (session *Session) value2Interface(fieldValue reflect.Value) (interface{},
}
}
if fieldValue.Type().Kind() == reflect.Bool {
k := fieldValue.Type().Kind()
switch k {
case reflect.Bool:
if fieldValue.Bool() {
return 1, nil
} else {
return 0, nil
}
} else if fieldValue.Type().String() == "time.Time" {
case reflect.String:
return fieldValue.String(), nil
case reflect.Struct:
if fieldValue.Type().String() == "time.Time" {
//return fieldValue.Interface().(time.Time).Format(time.RFC3339Nano), nil
//return fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700"), nil
return fieldValue.Interface(), nil
} else if fieldValue.Type().Kind() == reflect.Struct {
}
if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok {
if fieldTable.PrimaryKey != "" {
pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumn().FieldName)
@ -924,12 +959,43 @@ func (session *Session) value2Interface(fieldValue reflect.Value) (interface{},
} else {
return 0, errors.New(fmt.Sprintf("Unsupported type %v", fieldValue.Type()))
}
} else if fieldValue.Type().Kind() == reflect.Array ||
fieldValue.Type().Kind() == reflect.Slice {
data := fmt.Sprintf("%v", fieldValue.Interface())
//fmt.Println(data, "--------")
return data, nil
case reflect.Complex64, reflect.Complex128:
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.LogSQL(err)
return 0, err
}
return string(bytes), nil
case reflect.Array, reflect.Slice, reflect.Map:
if !fieldValue.IsValid() {
return fieldValue.Interface(), nil
}
if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.LogSQL(err)
return 0, err
}
return string(bytes), nil
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (k == reflect.Array || k == reflect.Slice) &&
(fieldValue.Type().Elem().Kind() == reflect.Uint8) {
bytes = fieldValue.Bytes()
} else {
bytes, err = json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.LogSQL(err)
return 0, err
}
}
return bytes, nil
} else {
return nil, UnSupportedTypeError
}
default:
return fieldValue.Interface(), nil
}
}
@ -961,7 +1027,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if col.IsCreated || col.IsUpdated {
args = append(args, time.Now())
} else {
arg, err := session.value2Interface(fieldValue)
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return 0, err
}
@ -1007,7 +1073,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
v = int(id)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v = uint(id)
}
pkValue.Set(reflect.ValueOf(v))

View File

@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
//"strconv"
"encoding/json"
"strings"
"time"
)
@ -87,10 +88,15 @@ func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string,
fieldType := reflect.TypeOf(fieldValue.Interface())
val := fieldValue.Interface()
switch fieldType.Kind() {
case reflect.Bool:
case reflect.String:
if fieldValue.String() == "" {
continue
}
// for MyString, should convert to string or panic
if fieldType.String() != reflect.String.String() {
val = fieldValue.String()
}
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
if fieldValue.Int() == 0 {
continue
@ -109,23 +115,55 @@ func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string,
if t.IsZero() {
continue
}
val = t
} else {
engine.AutoMapType(fieldValue.Type())
}
default:
continue
}
if table, ok := engine.Tables[fieldValue.Type()]; ok {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName)
if pkField.Int() != 0 {
args = append(args, pkField.Interface())
val = pkField.Interface()
} else {
continue
}
} else {
args = append(args, val)
}
}
case reflect.Array, reflect.Slice, reflect.Map:
if fieldValue == reflect.Zero(fieldType) {
continue
}
if fieldValue.IsNil() || !fieldValue.IsValid() {
continue
}
if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
engine.LogSQL(err)
continue
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
fieldType.Elem().Kind() == reflect.Uint8 {
val = fieldValue.Bytes()
} else {
bytes, err = json.Marshal(fieldValue.Interface())
if err != nil {
engine.LogSQL(err)
continue
}
val = bytes
}
} else {
continue
}
default:
val = fieldValue.Interface()
}
args = append(args, val)
colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
}

View File

@ -13,6 +13,17 @@ type SQLType struct {
DefaultLength2 int
}
func (s *SQLType) IsText() bool {
return s.Name == Char || s.Name == Varchar || s.Name == TinyText ||
s.Name == Text || s.Name == MediumText || s.Name == LongText
}
func (s *SQLType) IsBlob() bool {
return (s.Name == TinyBlob) || (s.Name == Blob) ||
s.Name == MediumBlob || s.Name == LongBlob ||
s.Name == Binary || s.Name == VarBinary || s.Name == Bytea
}
var (
Bit = "BIT"
TinyInt = "TINYINT"