merge EagerGet and EagerFind to Load

This commit is contained in:
Lunny Xiao 2017-09-11 15:43:39 +08:00
parent 3538ce1752
commit 23c6999de8
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
8 changed files with 80 additions and 91 deletions

View File

@ -1501,11 +1501,11 @@ func (engine *Engine) Exist(bean ...interface{}) (bool, error) {
return session.Exist(bean...)
}
// EagerGet loads bean's belongs to tag field immedicatlly
func (engine *Engine) EagerGet(bean interface{}, cols ...string) error {
// Load loads bean's belongs to tag field immedicatlly
func (engine *Engine) Load(bean interface{}, cols ...string) error {
session := engine.NewSession()
defer session.Close()
return session.EagerGet(bean, cols...)
return session.Load(bean, cols...)
}
// Find retrieve records from table, condiBeans's non-empty fields

View File

@ -28,7 +28,7 @@ type Interface interface {
Delete(interface{}) (int64, error)
Distinct(columns ...string) *Session
DropIndexes(bean interface{}) error
EagerGet(interface{}, ...string) error
Load(interface{}, ...string) error
Exec(string, ...interface{}) (sql.Result, error)
Exist(bean ...interface{}) (bool, error)
Find(interface{}, ...interface{}) error

View File

@ -21,6 +21,7 @@ type loadClosure struct {
Func func(core.PK, *reflect.Value) error
pk core.PK
fieldValue *reflect.Value
loaded bool
}
// Session keep a pointer to sql.DB and provides all execution of all
@ -57,6 +58,9 @@ type Session struct {
lastSQL string
lastSQLArgs []interface{}
cascadeMode cascadeMode
cascadeLevel int // load level
err error
}
@ -88,6 +92,9 @@ func (session *Session) Init() {
session.lastSQL = ""
session.lastSQLArgs = []interface{}{}
session.cascadeMode = cascadeCompitable
session.cascadeLevel = 2
}
// Close release the connection from pool
@ -155,7 +162,7 @@ func (session *Session) Alias(alias string) *Session {
// NoCascade indicate that no cascade load child object
func (session *Session) NoCascade() *Session {
session.statement.cascadeMode = cascadeManually
session.cascadeMode = cascadeLazy
return session
}
@ -210,16 +217,16 @@ func (session *Session) Charset(charset string) *Session {
// Cascade indicates if loading sub Struct
func (session *Session) Cascade(trueOrFalse ...bool) *Session {
var mode = cascadeAuto
var mode = cascadeEager
if len(trueOrFalse) >= 1 {
if trueOrFalse[0] {
mode = cascadeAuto
mode = cascadeEager
} else {
mode = cascadeManually
mode = cascadeLazy
}
}
session.statement.cascadeMode = mode
session.cascadeMode = mode
return session
}
@ -642,10 +649,10 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
session.engine.logger.Error("sql.Sanner error:", err.Error())
hasAssigned = false
}
} else if (col.AssociateType == core.AssociateNone &&
session.statement.cascadeMode == cascadeCompitable) ||
} else if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone &&
session.cascadeMode == cascadeCompitable) ||
(col.AssociateType == core.AssociateBelongsTo &&
session.statement.cascadeMode == cascadeAuto) {
session.cascadeMode == cascadeEager)) {
var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys))
var err error
rawValueType := col.AssociateTable.PKColumns()[0].FieldType
@ -660,10 +667,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
}
pk[0] = reflect.ValueOf(pk[0]).Elem().Interface()
if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type()))
}
session.afterProcessors = append(session.afterProcessors, executedProcessor{
fun: func(session *Session, bean interface{}) error {
fieldValue := bean.(*reflect.Value)
@ -672,17 +675,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
session: session,
bean: fieldValue,
})
} else {
v := fieldValue.Addr()
session.afterProcessors = append(session.afterProcessors, executedProcessor{
fun: func(session *Session, bean interface{}) error {
fieldValue := bean.(*reflect.Value)
return session.getByPK(pk, fieldValue)
},
session: session,
bean: &v,
})
}
session.cascadeLevel--
hasAssigned = true
} else if col.AssociateType == core.AssociateBelongsTo {
hasAssigned = true
@ -694,10 +687,10 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
}
case reflect.Ptr:
if fieldType != core.PtrTimeType && fieldType.Elem().Kind() == reflect.Struct {
if (col.AssociateType == core.AssociateNone &&
session.statement.cascadeMode == cascadeCompitable) ||
if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone &&
session.cascadeMode == cascadeCompitable) ||
(col.AssociateType == core.AssociateBelongsTo &&
session.statement.cascadeMode == cascadeAuto) {
session.cascadeMode == cascadeEager)) {
var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys))
var err error
rawValueType := col.AssociateTable.ColumnType(col.AssociateTable.PKColumns()[0].FieldName)
@ -712,8 +705,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
}
pk[0] = reflect.ValueOf(pk[0]).Elem().Interface()
fmt.Println("=====", fieldValue, fieldValue.IsNil())
fmt.Printf("%#v", fieldValue)
session.afterProcessors = append(session.afterProcessors, executedProcessor{
fun: func(session *Session, bean interface{}) error {
fieldValue := bean.(*reflect.Value)
@ -723,6 +714,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
bean: fieldValue,
})
session.cascadeLevel--
hasAssigned = true
} else if col.AssociateType == core.AssociateBelongsTo {
hasAssigned = true
@ -732,7 +724,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
fieldValue.Set(structInter)
}
//fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Set(vv)
err := convertAssign(fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Addr().Interface(),
vv.Interface())
if err != nil {
@ -741,7 +732,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
}
} else {
// !nashtsai! TODO merge duplicated codes above
//typeStr := fieldType.String()
switch fieldType {
// following types case matching ptr's native type, therefore assign ptr directly
case core.PtrStringType:
@ -889,14 +879,17 @@ func (session *Session) getByPK(pk core.PK, fieldValue *reflect.Value) error {
structInter = fieldValue.Addr()
}
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
has, err := session.ID(pk).NoAutoCondition().get(structInter.Interface())
if err != nil {
return err
}
if has {
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(structInter)
fmt.Println("333", fieldValue.IsNil())
} else if fieldValue.Kind() == reflect.Struct {
fieldValue.Set(structInter.Elem())
} else {
return errors.New("set value failed")
}
} else {
return errors.New("cascade obj is not exist")

View File

@ -11,8 +11,22 @@ import (
"github.com/go-xorm/core"
)
// EagerFind load 's belongs to tag field immedicatlly
func (session *Session) EagerFind(slices interface{}, cols ...string) error {
// Load loads associated fields from database
func (session *Session) Load(beanOrSlices interface{}, cols ...string) error {
v := reflect.ValueOf(beanOrSlices)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() == reflect.Slice {
return session.loadFind(beanOrSlices, cols...)
} else if v.Kind() == reflect.Struct {
return session.loadGet(beanOrSlices, cols...)
}
return errors.New("unsupported load type, must struct or slice")
}
// loadFind load 's belongs to tag field immedicatlly
func (session *Session) loadFind(slices interface{}, cols ...string) error {
/*v := reflect.ValueOf(slices)
if v.Kind() == reflect.Ptr {
v = v.Elem()
@ -80,8 +94,8 @@ func (session *Session) EagerFind(slices interface{}, cols ...string) error {
return nil
}
// EagerGet load bean's belongs to tag field immedicatlly
func (session *Session) EagerGet(bean interface{}, cols ...string) error {
// loadGet load bean's belongs to tag field immedicatlly
func (session *Session) loadGet(bean interface{}, cols ...string) error {
if session.isAutoClose {
defer session.Close()
}
@ -115,14 +129,15 @@ func (session *Session) EagerGet(bean interface{}, cols ...string) error {
colPtr = colV.Addr()
}
if !isZero(pk[0]) {
has, err := session.ID(pk).get(colPtr.Interface())
if !isZero(pk[0]) && session.cascadeLevel > 0 {
has, err := session.ID(pk).NoAutoCondition().get(colPtr.Interface())
if err != nil {
return err
}
if !has {
return errors.New("load bean does not exist")
}
session.cascadeLevel--
}
}
}

View File

@ -50,7 +50,7 @@ func TestBelongsTo_Get(t *testing.T) {
assert.Equal(t, nose.Face.Id, cfgNose.Face.Id)
assert.Equal(t, "", cfgNose.Face.Name)
err = testEngine.EagerGet(&cfgNose)
err = testEngine.Load(&cfgNose)
assert.NoError(t, err)
assert.Equal(t, nose.Id, cfgNose.Id)
assert.Equal(t, nose.Face.Id, cfgNose.Face.Id)
@ -104,7 +104,7 @@ func TestBelongsTo_GetPtr(t *testing.T) {
assert.Equal(t, nose.Id, cfgNose.Id)
assert.Equal(t, nose.Face.Id, cfgNose.Face.Id)
err = testEngine.EagerGet(&cfgNose)
err = testEngine.Load(&cfgNose)
assert.NoError(t, err)
assert.Equal(t, nose.Id, cfgNose.Id)
assert.Equal(t, nose.Face.Id, cfgNose.Face.Id)

View File

@ -202,10 +202,10 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
}
v = x
fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
} else if (col.AssociateType == core.AssociateNone &&
session.statement.cascadeMode == cascadeCompitable) ||
} else if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone &&
session.cascadeMode == cascadeCompitable) ||
(col.AssociateType == core.AssociateBelongsTo &&
session.statement.cascadeMode == cascadeAuto) {
session.cascadeMode == cascadeEager)) {
var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys))
// only 1 PK checked on tag parsing
rawValueType := col.AssociateTable.ColumnType(col.AssociateTable.PKColumns()[0].FieldName)
@ -215,6 +215,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
return err
}
session.cascadeLevel--
if err = session.getByPK(pk, fieldValue); err != nil {
return err
}
@ -466,10 +467,10 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
v = x
fieldValue.Set(reflect.ValueOf(&x))
default:
if (col.AssociateType == core.AssociateNone &&
session.statement.cascadeMode == cascadeCompitable) ||
if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone &&
session.cascadeMode == cascadeCompitable) ||
(col.AssociateType == core.AssociateBelongsTo &&
session.statement.cascadeMode == cascadeAuto) {
session.cascadeMode == cascadeEager)) {
var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys))
var err error
// only 1 PK checked on tag parsing
@ -479,6 +480,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
return err
}
session.cascadeLevel--
if err = session.getByPK(pk, fieldValue); err != nil {
return err
}

View File

@ -36,9 +36,10 @@ type exprParam struct {
type cascadeMode int
const (
cascadeCompitable cascadeMode = iota
cascadeAuto
cascadeManually
cascadeCompitable cascadeMode = iota // load field beans with another SQL with no
cascadeEager // load field beans with another SQL
cascadeJoin // load field beans with join
cascadeLazy // don't load anything
)
// Statement save all the sql info for executing SQL
@ -62,7 +63,6 @@ type Statement struct {
tableName string
RawSQL string
RawParams []interface{}
cascadeMode cascadeMode
UseAutoJoin bool
StoreEngine string
Charset string
@ -90,7 +90,6 @@ func (statement *Statement) Init() {
statement.Start = 0
statement.LimitN = 0
statement.OrderStr = ""
statement.cascadeMode = cascadeCompitable
statement.JoinStr = ""
statement.joinArgs = make([]interface{}, 0)
statement.GroupByStr = ""

View File

@ -169,22 +169,12 @@ func TestExtends(t *testing.T) {
tu6 := &tempUser3{&tempUser{0, "extends update"}, ""}
_, err = testEngine.ID(tu5.Temp.Id).Update(tu6)
if err != nil {
t.Error(err)
panic(err)
}
assert.Error(t, err)
users := make([]tempUser3, 0)
err = testEngine.Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
if len(users) != 1 {
err = errors.New("error get data not 1")
t.Error(err)
panic(err)
}
assert.NoError(t, err)
assert.EqualValues(t, 1, len(users))
assertSync(t, new(Userinfo), new(Userdetail))
@ -209,21 +199,11 @@ func TestExtends(t *testing.T) {
sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s",
qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid))
b, err := testEngine.SQL(sql).NoCascade().Get(&info)
if err != nil {
t.Error(err)
panic(err)
}
if !b {
err = errors.New("should has lest one record")
t.Error(err)
panic(err)
}
fmt.Println(info)
if info.Userinfo.Uid == 0 || info.Userdetail.Id == 0 {
err = errors.New("all of the id should has value")
t.Error(err)
panic(err)
}
assert.NoError(t, err)
assert.True(t, b)
assert.False(t, info.Userinfo.Uid == 0 || info.Userdetail.Id == 0)
panic("")
fmt.Println("----join--info2")
var info2 UserAndDetail