From 3a868531e94c98bd00c077b971c675dd577881c6 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 5 Sep 2013 23:20:52 +0800 Subject: [PATCH] add support for slice, array, map, custom types fields of struct & fixed #4 --- base_test.go | 88 ++++++++++---- error.go | 5 +- session.go | 321 +++++++++++++++++++++++++++++++-------------------- statement.go | 56 +++++++-- table.go | 11 ++ 5 files changed, 321 insertions(+), 160 deletions(-) diff --git a/base_test.go b/base_test.go index 98e11e84..1eb20c83 100644 --- a/base_test.go +++ b/base_test.go @@ -657,33 +657,59 @@ type MyUInt uint type MyFloat float64 type MyString string -func (s MyString) FromDB(data []byte) error { - s = MyString(string(data)) +/*func (s *MyString) FromDB(data []byte) error { + reflect. + s MyString(data) return nil } -func (s MyString) ToDB() ([]byte, error) { - return []byte(string(s)), nil -} +func (s *MyString) ToDB() ([]byte, error) { + return []byte(string(*s)), nil +}*/ type MyStruct struct { - Type MyInt - U MyUInt - F MyFloat - //S MyString - //IA []MyInt - //UA []MyUInt - //FA []MyFloat - //SA []MyString - //NameArray []string - Name string - //UIA []uint - UI uint + Type MyInt + U MyUInt + F MyFloat + S MyString + IA []MyInt + UA []MyUInt + FA []MyFloat + SA []MyString + NameArray []string + Name string + UIA []uint + UIA8 []uint8 + UIA16 []uint16 + UIA32 []uint32 + UIA64 []uint64 + UI uint + //C64 complex64 + MSS map[string]string } func testCustomType(engine *Engine, t *testing.T) { - err := engine.CreateTables(&MyStruct{}) + err := engine.DropTables(&MyStruct{}) + if err != nil { + t.Error(err) + panic(err) + return + } + + err = engine.CreateTables(&MyStruct{}) i := MyStruct{Name: "Test", Type: MyInt(1)} + i.U = 23 + i.F = 1.34 + i.S = "fafdsafdsaf" + i.UI = 2 + i.IA = []MyInt{1, 3, 5} + i.UIA = []uint{1, 3} + i.UIA16 = []uint16{2} + i.UIA32 = []uint32{4, 5} + i.UIA64 = []uint64{6, 7, 9} + i.UIA8 = []uint8{1, 2, 3, 4} + i.NameArray = []string{"ssss fsdf", "lllll, ss"} + i.MSS = map[string]string{"s": "sfds,ss ", "x": "lfjljsl"} _, err = engine.Insert(&i) if err != nil { t.Error(err) @@ -691,6 +717,7 @@ func testCustomType(engine *Engine, t *testing.T) { return } + fmt.Println(i) has, err := engine.Get(&i) if err != nil { t.Error(err) @@ -699,9 +726,22 @@ func testCustomType(engine *Engine, t *testing.T) { t.Error(errors.New("should get one record")) panic(err) } -} -func testTrans(engine *Engine, t *testing.T) { + ss := []MyStruct{} + err = engine.Find(&ss) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(ss) + + sss := MyStruct{} + has, err = engine.Get(&sss) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(sss) } type UserCU struct { @@ -713,7 +753,13 @@ type UserCU struct { func testCreatedAndUpdated(engine *Engine, t *testing.T) { u := new(UserCU) - err := engine.CreateTables(u) + err := engine.DropTables(u) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(u) if err != nil { t.Error(err) panic(err) diff --git a/error.go b/error.go index 1b258883..772c8fc8 100644 --- a/error.go +++ b/error.go @@ -5,6 +5,7 @@ import ( ) var ( - ParamsTypeError error = errors.New("params type error") - TableNotFoundError error = errors.New("not found table") + ParamsTypeError error = errors.New("params type error") + TableNotFoundError error = errors.New("not found table") + UnSupportedTypeError error = errors.New("unsupported type error") ) diff --git a/session.go b/session.go index 96fb93fc..7145158f 100644 --- a/session.go +++ b/session.go @@ -2,6 +2,7 @@ package xorm import ( "database/sql" + "encoding/json" "errors" "fmt" "reflect" @@ -35,7 +36,7 @@ func (session *Session) Close() { session.Engine.Pool.ReleaseDB(session.Engine, session.Db) session.Db = nil session.Tx = nil - //session.Init() + session.Init() } }() } @@ -196,127 +197,29 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b if _, ok := table.Columns[key]; !ok { continue } - fieldName := table.Columns[key].FieldName + col := table.Columns[key] + fieldName := col.FieldName fieldPath := strings.Split(fieldName, ".") - var structField reflect.Value + var fieldValue reflect.Value if len(fieldPath) > 2 { session.Engine.LogError("Unsupported mutliderive", fieldName) continue } else if len(fieldPath) == 2 { parentField := dataStruct.FieldByName(fieldPath[0]) if parentField.IsValid() { - structField = parentField.FieldByName(fieldPath[1]) + fieldValue = parentField.FieldByName(fieldPath[1]) } } else { - structField = dataStruct.FieldByName(fieldName) + fieldValue = dataStruct.FieldByName(fieldName) } - if !structField.IsValid() || !structField.CanSet() { + if !fieldValue.IsValid() || !fieldValue.CanSet() { continue } - var v interface{} - - switch structField.Type().Kind() { - case reflect.Slice: - v = data - vv := reflect.ValueOf(v) - if structField.Type().String() == "[]byte" { - fmt.Println("...[]byte...") - } - if vv.Type().Kind() == reflect.Slice { - for i := 0; i < vv.Len(); i++ { - //vv.Index(i) - structField = reflect.AppendSlice(structField, vv) - //reflect.Append(structField, vv.Index(i)) - } - } else { - return errors.New(fmt.Sprintf("unsupported from other %v to %v", vv.Type().Kind(), structField.Type().Kind())) - } - case reflect.Array: - if structField.Type().Elem() == reflect.TypeOf(b) { - v = data - structField.Set(reflect.ValueOf(v)) - } - case reflect.String: - x := string(data) - structField.SetString(x) - case reflect.Bool: - v = (string(data) == "1") - structField.Set(reflect.ValueOf(v)) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - structField.SetInt(x) - case reflect.Float32, reflect.Float64: - x, err := strconv.ParseFloat(string(data), 64) - if err != nil { - return errors.New("arg " + key + " as float64: " + err.Error()) - } - structField.SetFloat(x) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - structField.SetUint(x) - //Now only support Time type - case reflect.Struct: - if structField.Type().String() == "time.Time" { - x, err := time.Parse("2006-01-02 15:04:05", string(data)) - if err != nil { - x, err = time.Parse("2006-01-02 15:04:05.000 -0700", string(data)) - - if err != nil { - return errors.New("unsupported time format: " + string(data)) - } - } - - v = x - structField.Set(reflect.ValueOf(v)) - } else if structConvert, ok := structField.Addr().Interface().(Conversion); ok { - err := structConvert.FromDB(data) - if err != nil { - return err - } - continue - } else if session.Statement.UseCascade { - table := session.Engine.AutoMapType(structField.Type()) - if table != nil { - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - if x != 0 { - structInter := reflect.New(structField.Type()) - newsession := session.Engine.NewSession() - defer newsession.Close() - has, err := newsession.Id(x).Get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Elem().Interface() - structField.Set(reflect.ValueOf(v)) - } else { - session.Engine.LogError("cascade obj is not exist!") - continue - } - } else { - continue - } - } else { - session.Engine.LogError("unsupported struct type in Scan: " + structField.Type().String()) - continue - } - } else { - continue - } - default: - return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + err := session.bytes2Value(col, &fieldValue, data) + if err != nil { + return err } - } return nil @@ -458,6 +361,7 @@ func (session *Session) Get(bean interface{}) (bool, error) { args = session.Statement.RawParams session.Engine.AutoMap(bean) } + resultsSlice, err := session.Query(sql, args...) if err != nil { return false, err @@ -467,7 +371,9 @@ func (session *Session) Get(bean interface{}) (bool, error) { } results := resultsSlice[0] + err = session.scanMapIntoStruct(bean, results) + if err != nil { return false, err } @@ -654,7 +560,6 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice } for res.Next() { result := make(map[string][]byte) - //scanResultContainers := make([]interface{}, len(fields)) var scanResultContainers []interface{} for i := 0; i < len(fields); i++ { var scanResultContainer interface{} @@ -668,6 +573,7 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice //if row is null then ignore if rawValue.Interface() == nil { + fmt.Println("ignore ...", key, rawValue) continue } aa := reflect.TypeOf(rawValue.Interface()) @@ -684,9 +590,11 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) result[key] = []byte(str) case reflect.Slice: - if aa.Elem().Kind() == reflect.Uint8 { + switch aa.Elem().Kind() { + case reflect.Uint8: result[key] = rawValue.Interface().([]byte) - break + default: + session.Engine.LogError("Unsupported type") } case reflect.String: str = vv.String() @@ -699,9 +607,9 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice } else { session.Engine.LogError("Unsupported struct type") } + default: + session.Engine.LogError("Unsupported type") } - //default: - } resultsSlice = append(resultsSlice, result) } @@ -816,7 +724,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if col.IsCreated || col.IsUpdated { args = append(args, time.Now()) } else { - arg, err := session.value2Interface(fieldValue) + arg, err := session.value2Interface(col, fieldValue) if err != nil { return 0, err } @@ -844,7 +752,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if col.IsCreated || col.IsUpdated { args = append(args, time.Now()) } else { - arg, err := session.value2Interface(fieldValue) + arg, err := session.value2Interface(col, fieldValue) if err != nil { return 0, err } @@ -893,7 +801,127 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { return session.innerInsertMulti(rowsSlicePtr) } -func (session *Session) value2Interface(fieldValue reflect.Value) (interface{}, error) { +// convert a db data([]byte) to a field value +func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data []byte) error { + if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { + return structConvert.FromDB(data) + } + + var v interface{} + key := col.Name + fieldType := fieldValue.Type() + switch fieldType.Kind() { + case reflect.Complex64, reflect.Complex128: + x := reflect.New(fieldType) + + err := json.Unmarshal(data, x.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(x.Elem()) + case reflect.Slice, reflect.Array, reflect.Map: + v = data + t := fieldType.Elem() + k := t.Kind() + if col.SQLType.IsText() { + x := reflect.New(fieldType) + err := json.Unmarshal(data, x.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(x.Elem()) + } else if col.SQLType.IsBlob() { + if k == reflect.Uint8 { + fieldValue.Set(reflect.ValueOf(v)) + } else { + x := reflect.New(fieldType) + err := json.Unmarshal(data, x.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(x.Elem()) + } + } else { + return UnSupportedTypeError + } + case reflect.String: + fieldValue.SetString(string(data)) + case reflect.Bool: + v, err := strconv.ParseBool(string(data)) + if err != nil { + return errors.New("arg " + key + " as bool: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(v)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.SetInt(x) + case reflect.Float32, reflect.Float64: + x, err := strconv.ParseFloat(string(data), 64) + if err != nil { + return errors.New("arg " + key + " as float64: " + err.Error()) + } + fieldValue.SetFloat(x) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.SetUint(x) + //Now only support Time type + case reflect.Struct: + if fieldValue.Type().String() == "time.Time" { + x, err := time.Parse("2006-01-02 15:04:05", string(data)) + if err != nil { + x, err = time.Parse("2006-01-02 15:04:05.000 -0700", string(data)) + + if err != nil { + return errors.New("unsupported time format: " + string(data)) + } + } + + v = x + fieldValue.Set(reflect.ValueOf(v)) + } else if session.Statement.UseCascade { + table := session.Engine.AutoMapType(fieldValue.Type()) + if table != nil { + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + if x != 0 { + structInter := reflect.New(fieldValue.Type()) + newsession := session.Engine.NewSession() + defer newsession.Close() + has, err := newsession.Id(x).Get(structInter.Interface()) + if err != nil { + return err + } + if has { + v = structInter.Elem().Interface() + fieldValue.Set(reflect.ValueOf(v)) + } else { + return errors.New("cascade obj is not exist!") + } + } + } else { + return errors.New("unsupported struct type in Scan: " + fieldValue.Type().String()) + } + } + default: + return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + } + + return nil +} + +// convert a field value of a struct to interface for put into db +func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) (interface{}, error) { if fieldValue.CanAddr() { if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { data, err := fieldConvert.ToDB() @@ -905,15 +933,22 @@ func (session *Session) value2Interface(fieldValue reflect.Value) (interface{}, } } - if fieldValue.Type().Kind() == reflect.Bool { + k := fieldValue.Type().Kind() + switch k { + case reflect.Bool: if fieldValue.Bool() { return 1, nil } else { return 0, nil } - } else if fieldValue.Type().String() == "time.Time" { - return fieldValue.Interface(), nil - } else if fieldValue.Type().Kind() == reflect.Struct { + case reflect.String: + return fieldValue.String(), nil + case reflect.Struct: + if fieldValue.Type().String() == "time.Time" { + //return fieldValue.Interface().(time.Time).Format(time.RFC3339Nano), nil + //return fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700"), nil + return fieldValue.Interface(), nil + } if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { if fieldTable.PrimaryKey != "" { pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumn().FieldName) @@ -924,12 +959,43 @@ func (session *Session) value2Interface(fieldValue reflect.Value) (interface{}, } else { return 0, errors.New(fmt.Sprintf("Unsupported type %v", fieldValue.Type())) } - } else if fieldValue.Type().Kind() == reflect.Array || - fieldValue.Type().Kind() == reflect.Slice { - data := fmt.Sprintf("%v", fieldValue.Interface()) - //fmt.Println(data, "--------") - return data, nil - } else { + case reflect.Complex64, reflect.Complex128: + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return 0, err + } + return string(bytes), nil + case reflect.Array, reflect.Slice, reflect.Map: + if !fieldValue.IsValid() { + return fieldValue.Interface(), nil + } + + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return 0, err + } + return string(bytes), nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (k == reflect.Array || k == reflect.Slice) && + (fieldValue.Type().Elem().Kind() == reflect.Uint8) { + bytes = fieldValue.Bytes() + } else { + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return 0, err + } + } + return bytes, nil + } else { + return nil, UnSupportedTypeError + } + default: return fieldValue.Interface(), nil } } @@ -961,7 +1027,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if col.IsCreated || col.IsUpdated { args = append(args, time.Now()) } else { - arg, err := session.value2Interface(fieldValue) + arg, err := session.value2Interface(col, fieldValue) if err != nil { return 0, err } @@ -1007,7 +1073,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { v = int(id) case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: v = uint(id) - } pkValue.Set(reflect.ValueOf(v)) diff --git a/statement.go b/statement.go index 39a786f8..2c8770e9 100644 --- a/statement.go +++ b/statement.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" //"strconv" + "encoding/json" "strings" "time" ) @@ -87,10 +88,15 @@ func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string, fieldType := reflect.TypeOf(fieldValue.Interface()) val := fieldValue.Interface() switch fieldType.Kind() { + case reflect.Bool: case reflect.String: if fieldValue.String() == "" { continue } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + val = fieldValue.String() + } case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: if fieldValue.Int() == 0 { continue @@ -109,23 +115,55 @@ func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string, if t.IsZero() { continue } + val = t } else { engine.AutoMapType(fieldValue.Type()) + if table, ok := engine.Tables[fieldValue.Type()]; ok { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) + if pkField.Int() != 0 { + val = pkField.Interface() + } else { + continue + } + } + } + case reflect.Array, reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldValue.IsNil() || !fieldValue.IsValid() { + continue } - default: - continue - } - if table, ok := engine.Tables[fieldValue.Type()]; ok { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) - if pkField.Int() != 0 { - args = append(args, pkField.Interface()) + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + engine.LogSQL(err) + continue + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + val = fieldValue.Bytes() + } else { + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + engine.LogSQL(err) + continue + } + val = bytes + } } else { continue } - } else { - args = append(args, val) + default: + val = fieldValue.Interface() } + + args = append(args, val) colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) } diff --git a/table.go b/table.go index aedbc25f..b93dd075 100644 --- a/table.go +++ b/table.go @@ -13,6 +13,17 @@ type SQLType struct { DefaultLength2 int } +func (s *SQLType) IsText() bool { + return s.Name == Char || s.Name == Varchar || s.Name == TinyText || + s.Name == Text || s.Name == MediumText || s.Name == LongText +} + +func (s *SQLType) IsBlob() bool { + return (s.Name == TinyBlob) || (s.Name == Blob) || + s.Name == MediumBlob || s.Name == LongBlob || + s.Name == Binary || s.Name == VarBinary || s.Name == Bytea +} + var ( Bit = "BIT" TinyInt = "TINYINT"