set column value to belongs_to bean when cascade is disabled

This commit is contained in:
Lunny Xiao 2017-04-04 12:19:37 +08:00
parent 8a3fa4464d
commit 8490767f1e
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
7 changed files with 268 additions and 144 deletions

View File

@ -43,16 +43,21 @@ func TestBelongsTo_Get(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
assert.Equal(t, cfgNose.Id, nose.Id) assert.Equal(t, cfgNose.Id, nose.Id)
// FIXME: the id should be set back to the field assert.Equal(t, cfgNose.Face.Id, nose.Face.Id)
//assert.Equal(t, cfgNose.Face.Id, nose.Face.Id)
assert.Equal(t, "", cfgNose.Face.Name) assert.Equal(t, "", cfgNose.Face.Name)
err = testEngine.EagerLoad(&cfgNose)
assert.NoError(t, err)
assert.Equal(t, cfgNose.Id, nose.Id)
assert.Equal(t, nose.Face.Id, cfgNose.Face.Id)
assert.Equal(t, "face1", cfgNose.Face.Name)
var cfgNose2 Nose var cfgNose2 Nose
has, err = testEngine.Cascade().Get(&cfgNose2) has, err = testEngine.Cascade().Get(&cfgNose2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
assert.Equal(t, cfgNose2.Id, nose.Id) assert.Equal(t, cfgNose2.Id, nose.Id)
assert.Equal(t, cfgNose2.Face.Id, nose.Face.Id) assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id)
assert.Equal(t, "face1", cfgNose2.Face.Name) assert.Equal(t, "face1", cfgNose2.Face.Name)
} }
@ -92,16 +97,21 @@ func TestBelongsTo_GetPtr(t *testing.T) {
has, err = testEngine.Get(&cfgNose) has, err = testEngine.Get(&cfgNose)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
assert.Equal(t, cfgNose.Id, nose.Id) assert.Equal(t, nose.Id, cfgNose.Id)
// FIXME: the id should be set back to the field assert.Equal(t, nose.Face.Id, cfgNose.Face.Id)
//assert.Equal(t, cfgNose.Face.Id, nose.Face.Id)
err = testEngine.EagerLoad(&cfgNose)
assert.NoError(t, err)
assert.Equal(t, nose.Id, cfgNose.Id)
assert.Equal(t, nose.Face.Id, cfgNose.Face.Id)
assert.Equal(t, "face1", cfgNose.Face.Name)
var cfgNose2 Nose var cfgNose2 Nose
has, err = testEngine.Cascade().Get(&cfgNose2) has, err = testEngine.Cascade().Get(&cfgNose2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
assert.Equal(t, cfgNose2.Id, nose.Id) assert.Equal(t, nose.Id, cfgNose2.Id)
assert.Equal(t, cfgNose2.Face.Id, nose.Face.Id) assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id)
assert.Equal(t, "face1", cfgNose2.Face.Name) assert.Equal(t, "face1", cfgNose2.Face.Name)
} }

View File

@ -25,11 +25,10 @@ func strconvErr(err error) error {
func cloneBytes(b []byte) []byte { func cloneBytes(b []byte) []byte {
if b == nil { if b == nil {
return nil return nil
} else {
c := make([]byte, len(b))
copy(c, b)
return c
} }
c := make([]byte, len(b))
copy(c, b)
return c
} }
func asString(src interface{}) string { func asString(src interface{}) string {
@ -274,11 +273,12 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
return vv.String(), nil return vv.String(), nil
case reflect.Slice: case reflect.Slice:
if tp.Elem().Kind() == reflect.Uint8 { if tp.Elem().Kind() == reflect.Uint8 {
v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64) return string(vv.Interface().([]byte)), nil
/*v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return v, nil return v, nil*/
} }
} }

View File

@ -1498,6 +1498,13 @@ func (engine *Engine) Exist(bean ...interface{}) (bool, error) {
return session.Exist(bean...) return session.Exist(bean...)
} }
// EagerLoad loads bean's belongs to tag field immedicatlly
func (engine *Engine) EagerLoad(bean interface{}, cols ...string) error {
session := engine.NewSession()
defer session.Close()
return session.EagerLoad(bean, cols...)
}
// Find retrieve records from table, condiBeans's non-empty fields // Find retrieve records from table, condiBeans's non-empty fields
// are conditions. beans could be []Struct, []*Struct, map[int64]Struct // are conditions. beans could be []Struct, []*Struct, map[int64]Struct
// map[int64]*Struct // map[int64]*Struct

View File

@ -18,6 +18,7 @@ type Interface interface {
Alias(alias string) *Session Alias(alias string) *Session
Asc(colNames ...string) *Session Asc(colNames ...string) *Session
BufferSize(size int) *Session BufferSize(size int) *Session
Cascade(...bool) *Session
Cols(columns ...string) *Session Cols(columns ...string) *Session
Count(...interface{}) (int64, error) Count(...interface{}) (int64, error)
CreateIndexes(bean interface{}) error CreateIndexes(bean interface{}) error
@ -27,6 +28,7 @@ type Interface interface {
Delete(interface{}) (int64, error) Delete(interface{}) (int64, error)
Distinct(columns ...string) *Session Distinct(columns ...string) *Session
DropIndexes(bean interface{}) error DropIndexes(bean interface{}) error
EagerLoad(interface{}, ...string) error
Exec(string, ...interface{}) (sql.Result, error) Exec(string, ...interface{}) (sql.Result, error)
Exist(bean ...interface{}) (bool, error) Exist(bean ...interface{}) (bool, error)
Find(interface{}, ...interface{}) error Find(interface{}, ...interface{}) error

View File

@ -636,28 +636,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
session.engine.logger.Error("sql.Sanner error:", err.Error()) session.engine.logger.Error("sql.Sanner error:", err.Error())
hasAssigned = false hasAssigned = false
} }
/*} else if col.SQLType.IsJson() {
if rawValueType.Kind() == reflect.String {
hasAssigned = true
x := reflect.New(fieldType)
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
}
} else if rawValueType.Kind() == reflect.Slice {
hasAssigned = true
x := reflect.New(fieldType)
if len(vv.Bytes()) > 0 {
err := json.Unmarshal(vv.Bytes(), x.Interface())
if err != nil {
return nil, err
}
fieldValue.Set(x.Elem())
}
}*/
} else if (col.AssociateType == core.AssociateNone && } else if (col.AssociateType == core.AssociateNone &&
session.statement.cascadeMode == cascadeCompitable) || session.statement.cascadeMode == cascadeCompitable) ||
(col.AssociateType == core.AssociateBelongsTo && (col.AssociateType == core.AssociateBelongsTo &&
@ -690,122 +668,193 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
return nil, errors.New("cascade obj is not exist") return nil, errors.New("cascade obj is not exist")
} }
} }
} else if col.AssociateType == core.AssociateBelongsTo {
hasAssigned = true
err := convertAssign(fieldValue.FieldByName(table.PKColumns()[0].FieldName).Addr().Interface(),
vv.Interface())
if err != nil {
return nil, err
}
} }
case reflect.Ptr: case reflect.Ptr:
// !nashtsai! TODO merge duplicated codes above if fieldType != core.PtrTimeType && fieldType.Elem().Kind() == reflect.Struct {
switch fieldType { if (col.AssociateType == core.AssociateNone &&
// following types case matching ptr's native type, therefore assign ptr directly session.statement.cascadeMode == cascadeCompitable) ||
case core.PtrStringType: (col.AssociateType == core.AssociateBelongsTo &&
if rawValueType.Kind() == reflect.String { session.statement.cascadeMode == cascadeAutoLoad) {
x := vv.String() table := col.AssociateTable
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) if len(table.PrimaryKeys) != 1 {
} panic("unsupported non or composited primary key cascade")
case core.PtrBoolType: }
if rawValueType.Kind() == reflect.Bool { var pk = make(core.PK, len(table.PrimaryKeys))
x := vv.Bool() var err error
hasAssigned = true pk[0], err = asKind(vv, rawValueType)
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrTimeType:
if rawValueType == core.PtrTimeType {
hasAssigned = true
var x = rawValue.Interface().(time.Time)
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat64Type:
if rawValueType.Kind() == reflect.Float64 {
x := vv.Float()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint64Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint64(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt64Type:
if rawValueType.Kind() == reflect.Int64 {
x := vv.Int()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat32Type:
if rawValueType.Kind() == reflect.Float64 {
var x = float32(vv.Float())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrIntType:
if rawValueType.Kind() == reflect.Int64 {
var x = int(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUintType:
if rawValueType.Kind() == reflect.Int64 {
var x = uint(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Uint8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Uint16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Complex64Type:
var x complex64
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil { if err != nil {
return nil, err return nil, err
} }
fieldValue.Set(reflect.ValueOf(&x))
} if !isPKZero(pk) {
hasAssigned = true // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
case core.Complex128Type: // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
var x complex128 // property to be fetched lazily
if len([]byte(vv.String())) > 0 { var structInter reflect.Value
err := json.Unmarshal([]byte(vv.String()), &x) if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
structInter = reflect.New(fieldValue.Type().Elem())
} else {
structInter = *fieldValue
}
} else {
structInter = fieldValue.Addr()
}
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return nil, err
}
if has {
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(structInter)
}
} else {
return nil, errors.New("cascade obj is not exist")
}
}
} else if col.AssociateType == core.AssociateBelongsTo {
hasAssigned = true
if fieldValue.IsNil() {
structInter := reflect.New(fieldValue.Type().Elem())
fieldValue.Set(structInter)
}
//fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Set(vv)
err := convertAssign(fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Addr().Interface(),
vv.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
fieldValue.Set(reflect.ValueOf(&x))
} }
hasAssigned = true } else {
} // switch fieldType // !nashtsai! TODO merge duplicated codes above
//typeStr := fieldType.String()
switch fieldType {
// following types case matching ptr's native type, therefore assign ptr directly
case core.PtrStringType:
if rawValueType.Kind() == reflect.String {
x := vv.String()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrBoolType:
if rawValueType.Kind() == reflect.Bool {
x := vv.Bool()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrTimeType:
if rawValueType == core.PtrTimeType {
hasAssigned = true
var x = rawValue.Interface().(time.Time)
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat64Type:
if rawValueType.Kind() == reflect.Float64 {
x := vv.Float()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint64Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint64(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt64Type:
if rawValueType.Kind() == reflect.Int64 {
x := vv.Int()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat32Type:
if rawValueType.Kind() == reflect.Float64 {
var x = float32(vv.Float())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrIntType:
if rawValueType.Kind() == reflect.Int64 {
var x = int(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUintType:
if rawValueType.Kind() == reflect.Int64 {
var x = uint(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrComplex64Type:
var x complex64
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil {
session.engine.logger.Error(err)
} else {
fieldValue.Set(reflect.ValueOf(&x))
}
}
hasAssigned = true
case core.PtrComplex128Type:
var x complex128
if len([]byte(vv.String())) > 0 {
err := json.Unmarshal([]byte(vv.String()), &x)
if err != nil {
session.engine.logger.Error(err)
} else {
fieldValue.Set(reflect.ValueOf(&x))
}
}
hasAssigned = true
} // switch fieldType
}
} // switch fieldType.Kind() } // switch fieldType.Kind()
// !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value

58
session_associate.go Normal file
View File

@ -0,0 +1,58 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"errors"
"reflect"
"github.com/go-xorm/core"
)
// EagerLoad load bean's belongs to tag field immedicatlly
func (session *Session) EagerLoad(bean interface{}, cols ...string) error {
if session.isAutoClose {
defer session.Close()
}
v := rValue(bean)
tb, err := session.engine.autoMapType(v)
if err != nil {
return err
}
for _, col := range tb.Columns() {
if col.AssociateTable != nil {
if col.AssociateType == core.AssociateBelongsTo {
colV, err := col.ValueOfV(&v)
if err != nil {
return err
}
pk, err := session.engine.idOfV(*colV)
if err != nil {
return err
}
var colPtr reflect.Value
if colV.Kind() == reflect.Ptr {
colPtr = *colV
} else {
colPtr = colV.Addr()
}
if !isZero(pk[0]) {
has, err := session.ID(pk).get(colPtr.Interface())
if err != nil {
return err
}
if !has {
return errors.New("load bean does not exist")
}
}
}
}
}
return nil
}

View File

@ -33,10 +33,8 @@ var (
func createEngine(dbType, connStr string) error { func createEngine(dbType, connStr string) error {
if testEngine == nil { if testEngine == nil {
var err error var err error
if !*cluster { if !*cluster {
testEngine, err = NewEngine(dbType, connStr) testEngine, err = NewEngine(dbType, connStr)
} else { } else {
testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter)) testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter))
} }