diff --git a/blongs_to_test.go b/blongs_to_test.go index 4de7a3d0..eb0d93e1 100644 --- a/blongs_to_test.go +++ b/blongs_to_test.go @@ -43,16 +43,21 @@ func TestBelongsTo_Get(t *testing.T) { assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, cfgNose.Id, nose.Id) - // FIXME: the id should be set back to the field - //assert.Equal(t, cfgNose.Face.Id, nose.Face.Id) + assert.Equal(t, cfgNose.Face.Id, nose.Face.Id) assert.Equal(t, "", cfgNose.Face.Name) + err = testEngine.EagerLoad(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, cfgNose.Id, nose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "face1", cfgNose.Face.Name) + var cfgNose2 Nose has, err = testEngine.Cascade().Get(&cfgNose2) assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, cfgNose2.Id, nose.Id) - assert.Equal(t, cfgNose2.Face.Id, nose.Face.Id) + assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) assert.Equal(t, "face1", cfgNose2.Face.Name) } @@ -92,16 +97,21 @@ func TestBelongsTo_GetPtr(t *testing.T) { has, err = testEngine.Get(&cfgNose) assert.NoError(t, err) assert.Equal(t, true, has) - assert.Equal(t, cfgNose.Id, nose.Id) - // FIXME: the id should be set back to the field - //assert.Equal(t, cfgNose.Face.Id, nose.Face.Id) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + + err = testEngine.EagerLoad(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "face1", cfgNose.Face.Name) var cfgNose2 Nose has, err = testEngine.Cascade().Get(&cfgNose2) assert.NoError(t, err) assert.Equal(t, true, has) - assert.Equal(t, cfgNose2.Id, nose.Id) - assert.Equal(t, cfgNose2.Face.Id, nose.Face.Id) + assert.Equal(t, nose.Id, cfgNose2.Id) + assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) assert.Equal(t, "face1", cfgNose2.Face.Name) } diff --git a/convert.go b/convert.go index 0504bef1..db6c7fab 100644 --- a/convert.go +++ b/convert.go @@ -25,11 +25,10 @@ func strconvErr(err error) error { func cloneBytes(b []byte) []byte { if b == nil { return nil - } else { - c := make([]byte, len(b)) - copy(c, b) - return c } + c := make([]byte, len(b)) + copy(c, b) + return c } func asString(src interface{}) string { @@ -274,11 +273,12 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { return vv.String(), nil case reflect.Slice: if tp.Elem().Kind() == reflect.Uint8 { - v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64) + return string(vv.Interface().([]byte)), nil + /*v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64) if err != nil { return nil, err } - return v, nil + return v, nil*/ } } diff --git a/engine.go b/engine.go index 5ca568e1..c3e0d8d0 100644 --- a/engine.go +++ b/engine.go @@ -1498,6 +1498,13 @@ func (engine *Engine) Exist(bean ...interface{}) (bool, error) { return session.Exist(bean...) } +// EagerLoad loads bean's belongs to tag field immedicatlly +func (engine *Engine) EagerLoad(bean interface{}, cols ...string) error { + session := engine.NewSession() + defer session.Close() + return session.EagerLoad(bean, cols...) +} + // Find retrieve records from table, condiBeans's non-empty fields // are conditions. beans could be []Struct, []*Struct, map[int64]Struct // map[int64]*Struct diff --git a/interface.go b/interface.go index 4f94750b..4adee92e 100644 --- a/interface.go +++ b/interface.go @@ -18,6 +18,7 @@ type Interface interface { Alias(alias string) *Session Asc(colNames ...string) *Session BufferSize(size int) *Session + Cascade(...bool) *Session Cols(columns ...string) *Session Count(...interface{}) (int64, error) CreateIndexes(bean interface{}) error @@ -27,6 +28,7 @@ type Interface interface { Delete(interface{}) (int64, error) Distinct(columns ...string) *Session DropIndexes(bean interface{}) error + EagerLoad(interface{}, ...string) error Exec(string, ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error diff --git a/session.go b/session.go index 789ec35c..d7677163 100644 --- a/session.go +++ b/session.go @@ -636,28 +636,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b session.engine.logger.Error("sql.Sanner error:", err.Error()) hasAssigned = false } - /*} else if col.SQLType.IsJson() { - if rawValueType.Kind() == reflect.String { - hasAssigned = true - x := reflect.New(fieldType) - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } else if rawValueType.Kind() == reflect.Slice { - hasAssigned = true - x := reflect.New(fieldType) - if len(vv.Bytes()) > 0 { - err := json.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - }*/ } else if (col.AssociateType == core.AssociateNone && session.statement.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && @@ -690,122 +668,193 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b return nil, errors.New("cascade obj is not exist") } } + } else if col.AssociateType == core.AssociateBelongsTo { + hasAssigned = true + err := convertAssign(fieldValue.FieldByName(table.PKColumns()[0].FieldName).Addr().Interface(), + vv.Interface()) + if err != nil { + return nil, err + } } case reflect.Ptr: - // !nashtsai! TODO merge duplicated codes above - switch fieldType { - // following types case matching ptr's native type, therefore assign ptr directly - case core.PtrStringType: - if rawValueType.Kind() == reflect.String { - x := vv.String() + if fieldType != core.PtrTimeType && fieldType.Elem().Kind() == reflect.Struct { + if (col.AssociateType == core.AssociateNone && + session.statement.cascadeMode == cascadeCompitable) || + (col.AssociateType == core.AssociateBelongsTo && + session.statement.cascadeMode == cascadeAutoLoad) { + table := col.AssociateTable + hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrBoolType: - if rawValueType.Kind() == reflect.Bool { - x := vv.Bool() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrTimeType: - if rawValueType == core.PtrTimeType { - hasAssigned = true - var x = rawValue.Interface().(time.Time) - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrFloat64Type: - if rawValueType.Kind() == reflect.Float64 { - x := vv.Float() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUint64Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint64(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt64Type: - if rawValueType.Kind() == reflect.Int64 { - x := vv.Int() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrFloat32Type: - if rawValueType.Kind() == reflect.Float64 { - var x = float32(vv.Float()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrIntType: - if rawValueType.Kind() == reflect.Int64 { - var x = int(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUintType: - if rawValueType.Kind() == reflect.Int64 { - var x = uint(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUint32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Uint8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Uint16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Complex64Type: - var x complex64 - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), &x) + if len(table.PrimaryKeys) != 1 { + panic("unsupported non or composited primary key cascade") + } + var pk = make(core.PK, len(table.PrimaryKeys)) + var err error + pk[0], err = asKind(vv, rawValueType) if err != nil { return nil, err } - fieldValue.Set(reflect.ValueOf(&x)) - } - hasAssigned = true - case core.Complex128Type: - var x complex128 - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), &x) + + if !isPKZero(pk) { + // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // property to be fetched lazily + var structInter reflect.Value + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + structInter = reflect.New(fieldValue.Type().Elem()) + } else { + structInter = *fieldValue + } + } else { + structInter = fieldValue.Addr() + } + + has, err := session.ID(pk).NoCascade().get(structInter.Interface()) + if err != nil { + return nil, err + } + if has { + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(structInter) + } + } else { + return nil, errors.New("cascade obj is not exist") + } + } + } else if col.AssociateType == core.AssociateBelongsTo { + hasAssigned = true + if fieldValue.IsNil() { + structInter := reflect.New(fieldValue.Type().Elem()) + fieldValue.Set(structInter) + } + + //fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Set(vv) + err := convertAssign(fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Addr().Interface(), + vv.Interface()) if err != nil { return nil, err } - fieldValue.Set(reflect.ValueOf(&x)) } - hasAssigned = true - } // switch fieldType + } else { + // !nashtsai! TODO merge duplicated codes above + //typeStr := fieldType.String() + switch fieldType { + // following types case matching ptr's native type, therefore assign ptr directly + case core.PtrStringType: + if rawValueType.Kind() == reflect.String { + x := vv.String() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrBoolType: + if rawValueType.Kind() == reflect.Bool { + x := vv.Bool() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrTimeType: + if rawValueType == core.PtrTimeType { + hasAssigned = true + var x = rawValue.Interface().(time.Time) + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrFloat64Type: + if rawValueType.Kind() == reflect.Float64 { + x := vv.Float() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint64Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint64(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt64Type: + if rawValueType.Kind() == reflect.Int64 { + x := vv.Int() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrFloat32Type: + if rawValueType.Kind() == reflect.Float64 { + var x = float32(vv.Float()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrIntType: + if rawValueType.Kind() == reflect.Int64 { + var x = int(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt32Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt8Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt16Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUintType: + if rawValueType.Kind() == reflect.Int64 { + var x = uint(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint32Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint8Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint16Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrComplex64Type: + var x complex64 + if len([]byte(vv.String())) > 0 { + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + session.engine.logger.Error(err) + } else { + fieldValue.Set(reflect.ValueOf(&x)) + } + } + hasAssigned = true + case core.PtrComplex128Type: + var x complex128 + if len([]byte(vv.String())) > 0 { + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + session.engine.logger.Error(err) + } else { + fieldValue.Set(reflect.ValueOf(&x)) + } + } + hasAssigned = true + } // switch fieldType + } } // switch fieldType.Kind() // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value diff --git a/session_associate.go b/session_associate.go new file mode 100644 index 00000000..8e2c6231 --- /dev/null +++ b/session_associate.go @@ -0,0 +1,58 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "errors" + "reflect" + + "github.com/go-xorm/core" +) + +// EagerLoad load bean's belongs to tag field immedicatlly +func (session *Session) EagerLoad(bean interface{}, cols ...string) error { + if session.isAutoClose { + defer session.Close() + } + + v := rValue(bean) + tb, err := session.engine.autoMapType(v) + if err != nil { + return err + } + + for _, col := range tb.Columns() { + if col.AssociateTable != nil { + if col.AssociateType == core.AssociateBelongsTo { + colV, err := col.ValueOfV(&v) + if err != nil { + return err + } + + pk, err := session.engine.idOfV(*colV) + if err != nil { + return err + } + var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + } + + if !isZero(pk[0]) { + has, err := session.ID(pk).get(colPtr.Interface()) + if err != nil { + return err + } + if !has { + return errors.New("load bean does not exist") + } + } + } + } + } + return nil +} diff --git a/xorm_test.go b/xorm_test.go index f7581e06..d86fedf0 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -33,10 +33,8 @@ var ( func createEngine(dbType, connStr string) error { if testEngine == nil { var err error - if !*cluster { testEngine, err = NewEngine(dbType, connStr) - } else { testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter)) }