From 29d4a0330a00b9be468b70e3fb0f74109348c358 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 30 Sep 2017 09:26:13 +0800 Subject: [PATCH] improve processors (#743) --- processors.go | 40 ++++++++++++++++++++++----- processors_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++ rows.go | 10 +++++-- session.go | 68 +++++++++++++++++++++++++++++++++++----------- session_find.go | 7 ++++- session_get.go | 9 ++++-- 6 files changed, 170 insertions(+), 29 deletions(-) diff --git a/processors.go b/processors.go index 77dd30e5..dcd9c6ac 100644 --- a/processors.go +++ b/processors.go @@ -29,13 +29,6 @@ type AfterSetProcessor interface { AfterSet(string, Cell) } -// !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations -//// Executed before an object is validated -//type BeforeValidateProcessor interface { -// BeforeValidate() -//} -// -- - // AfterInsertProcessor executed after an object is persisted to the database type AfterInsertProcessor interface { AfterInsert() @@ -50,3 +43,36 @@ type AfterUpdateProcessor interface { type AfterDeleteProcessor interface { AfterDelete() } + +// AfterLoadProcessor executed after an ojbect has been loaded from database +type AfterLoadProcessor interface { + AfterLoad() +} + +// AfterLoadSessionProcessor executed after an ojbect has been loaded from database with session parameter +type AfterLoadSessionProcessor interface { + AfterLoad(*Session) +} + +type executedProcessorFunc func(*Session, interface{}) error + +type executedProcessor struct { + fun executedProcessorFunc + session *Session + bean interface{} +} + +func (executor *executedProcessor) execute() error { + return executor.fun(executor.session, executor.bean) +} + +func (session *Session) executeProcessors() error { + processors := session.afterProcessors + session.afterProcessors = make([]executedProcessor, 0) + for _, processor := range processors { + if err := processor.execute(); err != nil { + return err + } + } + return nil +} diff --git a/processors_test.go b/processors_test.go index 4ee59066..c5d7eb6e 100644 --- a/processors_test.go +++ b/processors_test.go @@ -964,3 +964,68 @@ func TestProcessorsTx(t *testing.T) { session.Close() // -- } + +type AfterLoadStructA struct { + Id int64 + Content string +} + +type AfterLoadStructB struct { + Id int64 + Content string + AId int64 + A AfterLoadStructA `xorm:"-"` + Err error `xorm:"-"` +} + +func (s *AfterLoadStructB) AfterLoad(session *Session) { + has, err := session.ID(s.AId).NoAutoCondition().Get(&s.A) + if err != nil { + s.Err = err + return + } + if !has { + s.Err = ErrNotExist + } +} + +func TestAfterLoadProcessor(t *testing.T) { + assert.NoError(t, prepareEngine()) + + assertSync(t, new(AfterLoadStructA), new(AfterLoadStructB)) + + var a = AfterLoadStructA{ + Content: "testa", + } + _, err := testEngine.Insert(&a) + assert.NoError(t, err) + + var b = AfterLoadStructB{ + Content: "testb", + AId: a.Id, + } + _, err = testEngine.Insert(&b) + assert.NoError(t, err) + + var b2 AfterLoadStructB + has, err := testEngine.ID(b.Id).Get(&b2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, a.Id, b2.A.Id) + assert.EqualValues(t, a.Content, b2.A.Content) + assert.NoError(t, b2.Err) + + b.Id = 0 + _, err = testEngine.Insert(&b) + assert.NoError(t, err) + + var bs []AfterLoadStructB + err = testEngine.Find(&bs) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(bs)) + for i := 0; i < len(bs); i++ { + assert.EqualValues(t, a.Id, bs[i].A.Id) + assert.EqualValues(t, a.Content, bs[i].A.Content) + assert.NoError(t, bs[i].Err) + } +} diff --git a/rows.go b/rows.go index 258d9f27..31e29ae2 100644 --- a/rows.go +++ b/rows.go @@ -99,13 +99,17 @@ func (rows *Rows) Scan(bean interface{}) error { return err } - scanResults, err := rows.session.row2Slice(rows.rows, rows.fields, len(rows.fields), bean) + scanResults, err := rows.session.row2Slice(rows.rows, rows.fields, bean) if err != nil { return err } - _, err = rows.session.slice2Bean(scanResults, rows.fields, len(rows.fields), bean, &dataStruct, rows.session.statement.RefTable) - return err + _, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable) + if err != nil { + return err + } + + return rows.session.executeProcessors() } // Close session if session.IsAutoClose is true, and claimed any opened resources diff --git a/session.go b/session.go index c69ac9e5..ed252058 100644 --- a/session.go +++ b/session.go @@ -41,6 +41,8 @@ type Session struct { beforeClosures []func(interface{}) afterClosures []func(interface{}) + afterProcessors []executedProcessor + prepareStmt bool stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) @@ -75,6 +77,8 @@ func (session *Session) Init() { session.beforeClosures = make([]func(interface{}), 0) session.afterClosures = make([]func(interface{}), 0) + session.afterProcessors = make([]executedProcessor, 0) + session.lastSQL = "" session.lastSQLArgs = []interface{}{} } @@ -296,37 +300,40 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *c // Cell cell is a result of one column field type Cell *interface{} -func (session *Session) rows2Beans(rows *core.Rows, fields []string, fieldsCount int, +func (session *Session) rows2Beans(rows *core.Rows, fields []string, table *core.Table, newElemFunc func([]string) reflect.Value, sliceValueSetFunc func(*reflect.Value, core.PK) error) error { for rows.Next() { var newValue = newElemFunc(fields) bean := newValue.Interface() - dataStruct := rValue(bean) + dataStruct := newValue.Elem() // handle beforeClosures - scanResults, err := session.row2Slice(rows, fields, fieldsCount, bean) + scanResults, err := session.row2Slice(rows, fields, bean) if err != nil { return err } - pk, err := session.slice2Bean(scanResults, fields, fieldsCount, bean, &dataStruct, table) - if err != nil { - return err - } - err = sliceValueSetFunc(&newValue, pk) + pk, err := session.slice2Bean(scanResults, fields, bean, &dataStruct, table) if err != nil { return err } + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(*Session, interface{}) error { + return sliceValueSetFunc(&newValue, pk) + }, + session: session, + bean: bean, + }) } return nil } -func (session *Session) row2Slice(rows *core.Rows, fields []string, fieldsCount int, bean interface{}) ([]interface{}, error) { +func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) { for _, closure := range session.beforeClosures { closure(bean) } - scanResults := make([]interface{}, fieldsCount) + scanResults := make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { var cell interface{} scanResults[i] = &cell @@ -343,20 +350,49 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, fieldsCount return scanResults, nil } -func (session *Session) slice2Bean(scanResults []interface{}, fields []string, fieldsCount int, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) { +func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) { defer func() { if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { for ii, key := range fields { b.AfterSet(key, Cell(scanResults[ii].(*interface{}))) } } - - // handle afterClosures - for _, closure := range session.afterClosures { - closure(bean) - } }() + // handle afterClosures + for _, closure := range session.afterClosures { + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(sess *Session, bean interface{}) error { + closure(bean) + return nil + }, + session: session, + bean: bean, + }) + } + + if a, has := bean.(AfterLoadProcessor); has { + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(sess *Session, bean interface{}) error { + a.AfterLoad() + return nil + }, + session: session, + bean: bean, + }) + } + + if a, has := bean.(AfterLoadSessionProcessor); has { + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(sess *Session, bean interface{}) error { + a.AfterLoad(sess) + return nil + }, + session: session, + bean: bean, + }) + } + var tempMap = make(map[string]int) var pk core.PK for ii, key := range fields { diff --git a/session_find.go b/session_find.go index 05ec724f..f95dcfef 100644 --- a/session_find.go +++ b/session_find.go @@ -239,7 +239,12 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va if err != nil { return err } - return session.rows2Beans(rows, fields, len(fields), tb, newElemFunc, containerValueSetFunc) + err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc) + rows.Close() + if err != nil { + return err + } + return session.executeProcessors() } for rows.Next() { diff --git a/session_get.go b/session_get.go index 1f1e61cd..8faf53c0 100644 --- a/session_get.go +++ b/session_get.go @@ -87,7 +87,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea return true, err } - scanResults, err := session.row2Slice(rows, fields, len(fields), bean) + scanResults, err := session.row2Slice(rows, fields, bean) if err != nil { return false, err } @@ -95,7 +95,12 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea rows.Close() dataStruct := rValue(bean) - _, err = session.slice2Bean(scanResults, fields, len(fields), bean, &dataStruct, table) + _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) + if err != nil { + return true, err + } + + return true, session.executeProcessors() case reflect.Slice: err = rows.ScanSlice(bean) case reflect.Map: