diff --git a/base_test.go b/base_test.go index 6a7fd85b..9f7a578a 100644 --- a/base_test.go +++ b/base_test.go @@ -1810,6 +1810,7 @@ func testProcessors(engine *Engine, t *testing.T) { t.Error(err) panic(err) } + b4InsertFunc := func(bean interface{}) { if v, ok := (bean).(*ProcessorsStruct); ok { v.B4InsertViaExt = 1 @@ -1864,7 +1865,9 @@ func testProcessors(engine *Engine, t *testing.T) { t.Error(errors.New("AfterInsertedViaExt is set")) } } + // -- + // test update processors b4UpdateFunc := func(bean interface{}) { if v, ok := (bean).(*ProcessorsStruct); ok { v.B4UpdateViaExt = 1 @@ -1921,7 +1924,9 @@ func testProcessors(engine *Engine, t *testing.T) { t.Error(errors.New("AfterUpdatedViaExt is set: " + string(p.AfterUpdatedViaExt))) } } + // -- + // test delete processors b4DeleteFunc := func(bean interface{}) { if v, ok := (bean).(*ProcessorsStruct); ok { v.B4DeleteViaExt = 1 @@ -1957,6 +1962,7 @@ func testProcessors(engine *Engine, t *testing.T) { t.Error(errors.New("AfterDeletedViaExt not set")) } } + // -- // test insert multi pslice := make([]*ProcessorsStruct, 0) @@ -2007,6 +2013,460 @@ func testProcessors(engine *Engine, t *testing.T) { } } } + // -- +} + +func testProcessorsTx(engine *Engine, t *testing.T) { + tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) + if err != nil { + t.Error(err) + panic(err) + } + + tempEngine.ShowSQL = true + err = tempEngine.DropTables(&ProcessorsStruct{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = tempEngine.CreateTables(&ProcessorsStruct{}) + if err != nil { + t.Error(err) + panic(err) + } + + // test insert processors with tx rollback + session := tempEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + p := &ProcessorsStruct{} + b4InsertFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4InsertViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + afterInsertFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterInsertedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + _, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("B4InsertFlag is set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } + + err = session.Rollback() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("B4InsertFlag is set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } + session.Close() + p2 := &ProcessorsStruct{} + _, err = tempEngine.Id(p.Id).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.Id > 0 { + err = errors.New("tx got committed upon insert!?") + t.Error(err) + panic(err) + } + } + // -- + + // test insert processors with tx commit + session = tempEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + p = &ProcessorsStruct{} + _, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } + + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag == 0 { + t.Error(errors.New("AfterInsertedFlag not set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt == 0 { + t.Error(errors.New("AfterInsertedViaExt not set")) + } + } + session.Close() + p2 = &ProcessorsStruct{} + _, err = tempEngine.Id(p.Id).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p2.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p2.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p2.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } + insertedId := p2.Id + // -- + + // test update processors with tx rollback + session = tempEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + b4UpdateFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4UpdateViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + afterUpdateFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterUpdatedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + p = p2 // reset + + _, err = session.Id(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + } + err = session.Rollback() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + } + + session.Close() + p2 = &ProcessorsStruct{} + _, err = tempEngine.Id(insertedId).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4UpdateFlag != 0 { + t.Error(errors.New("B4UpdateFlag is set")) + } + if p2.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p2.B4UpdateViaExt != 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p2.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + } + // -- + + // test update processors with tx commit + session = tempEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + p = &ProcessorsStruct{} + + _, err = session.Id(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + } + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt == 0 { + t.Error(errors.New("AfterUpdatedViaExt not set")) + } + } + session.Close() + p2 = &ProcessorsStruct{} + _, err = tempEngine.Id(insertedId).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt == 0 { + t.Error(errors.New("AfterUpdatedViaExt not set")) + } + } + // -- + + // test delete processors with tx rollback + session = tempEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + b4DeleteFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4DeleteViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + afterDeleteFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterDeletedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + p = &ProcessorsStruct{} // reset + + _, err = session.Id(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + } + err = session.Rollback() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + } + session.Close() + + p2 = &ProcessorsStruct{} + _, err = tempEngine.Id(insertedId).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4DeleteFlag != 0 { + t.Error(errors.New("B4DeleteFlag is set")) + } + if p2.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p2.B4DeleteViaExt != 0 { + t.Error(errors.New("B4DeleteViaExt is set")) + } + if p2.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + } + // -- + + // test delete processors with tx commit + session = tempEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + p = &ProcessorsStruct{} + + _, err = session.Id(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + } + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag == 0 { + t.Error(errors.New("AfterDeletedFlag not set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt == 0 { + t.Error(errors.New("AfterDeletedViaExt not set")) + } + } + session.Close() + // -- } func testAll(engine *Engine, t *testing.T) { @@ -2107,6 +2567,8 @@ func testAll2(engine *Engine, t *testing.T) { //testCreatedUpdated(engine, t) fmt.Println("-------------- processors --------------") testProcessors(engine, t) + fmt.Println("-------------- processors TX --------------") + testProcessorsTx(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) } diff --git a/session.go b/session.go index 17c744a2..64e65ca4 100644 --- a/session.go +++ b/session.go @@ -22,15 +22,15 @@ type Session struct { IsCommitedOrRollbacked bool TransType string IsAutoClose bool - - // !nashtsai! storing these beans due to yet committed tx + + // !nashtsai! storing these beans due to yet committed tx afterInsertBeans []interface{} afterUpdateBeans []interface{} afterDeleteBeans []interface{} // -- - + beforeClosures []func(interface{}) - afterClosures []func(interface{}) + afterClosures []func(interface{}) } // Method Init reset the session as the init status. @@ -40,7 +40,7 @@ func (session *Session) Init() { session.IsAutoCommit = true session.IsCommitedOrRollbacked = false session.IsAutoClose = false - + // !nashtsai! is lazy init better? session.afterInsertBeans = make([]interface{}, 0) session.afterUpdateBeans = make([]interface{}, 0) @@ -98,7 +98,7 @@ func (session *Session) Before(closures func(interface{})) *Session { session.beforeClosures = append(session.beforeClosures, closures) } return session -} +} // Apply after Processor, affected bean is passed to closure arg func (session *Session) After(closures func(interface{})) *Session { @@ -284,13 +284,13 @@ func (session *Session) Commit() error { session.Engine.LogSQL("COMMIT") session.IsCommitedOrRollbacked = true var err error - if err = session.Tx.Commit(); err == nil { + if err = session.Tx.Commit(); err == nil { // handle processors after tx committed for _, elem := range session.afterInsertBeans { for _, closure := range session.afterClosures { closure(elem) } - + if processor, ok := interface{}(elem).(AfterInsertProcessor); ok { processor.AfterInsert() } @@ -306,7 +306,7 @@ func (session *Session) Commit() error { for _, elem := range session.afterDeleteBeans { for _, closure := range session.afterClosures { closure(elem) - } + } if processor, ok := interface{}(elem).(AfterDeleteProcessor); ok { processor.AfterDelete() } @@ -327,7 +327,7 @@ func (session *Session) Commit() error { // !nash! should session based processors get cleanup? cleanUpProcessorsFunc(&session.afterClosures) - } + } return err } return nil @@ -1304,6 +1304,30 @@ func rows2maps(rows *sql.Rows) (resultsSlice []map[string][]byte, err error) { return resultsSlice, nil } +func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { + for _, filter := range session.Engine.Filters { + sql = filter.Do(sql, session) + } + + session.Engine.LogSQL(sql) + session.Engine.LogSQL(paramStr) + + if session.IsAutoCommit { + return query(session.Db, sql, paramStr...) + } + return txQuery(session.Tx, sql, paramStr...) +} + +func txQuery(tx *sql.Tx, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { + rows, err := tx.Query(sql, params...) + if err != nil { + return nil, err + } + defer rows.Close() + + return rows2maps(rows) +} + func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { s, err := db.Prepare(sql) if err != nil { @@ -1319,17 +1343,6 @@ func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[st return rows2maps(rows) } -func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { - for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) - } - - session.Engine.LogSQL(sql) - session.Engine.LogSQL(paramStr) - - return query(session.Db, sql, paramStr...) -} - // Exec a raw sql and return records as []map[string][]byte func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { err = session.newDb() @@ -1410,7 +1423,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error for i := 0; i < size; i++ { elemValue := sliceValue.Index(i).Interface() colPlaces := make([]string, 0) - + // handle BeforeInsertProcessor for _, closure := range session.beforeClosures { closure(elemValue) @@ -1420,7 +1433,6 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error processor.BeforeInsert() } // -- - if i == 0 { for _, col := range table.Columns { @@ -1497,8 +1509,8 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if table.Cacher != nil && session.Statement.UseCache { session.cacheInsert(session.Statement.TableName()) } - - hasAfterClosures := len(session.afterClosures) > 0 + + hasAfterClosures := len(session.afterClosures) > 0 for i := 0; i < size; i++ { elemValue := sliceValue.Index(i).Interface() // handle AfterInsertProcessor @@ -1511,10 +1523,10 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } } else { if hasAfterClosures { - session.afterInsertBeans = append(session.afterInsertBeans, elemValue) + session.afterInsertBeans = append(session.afterInsertBeans, elemValue) } else { if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok { - session.afterInsertBeans = append(session.afterInsertBeans, elemValue) + session.afterInsertBeans = append(session.afterInsertBeans, elemValue) } } } @@ -1802,7 +1814,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } else { session.Engine.LogDebug(session.Statement.TableName(), " has no before insert processor") } - // -- + // -- colNames, args, err := table.genCols(session, bean, false, false) if err != nil { @@ -1821,27 +1833,26 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.Engine.QuoteStr(), colPlaces) - handleAfterInsertProcessorFunc := func(bean interface{}) { if session.IsAutoCommit { for _, closure := range session.afterClosures { closure(bean) } - if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { - session.Engine.LogDebug(session.Statement.TableName(), " has after insert processor") + if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { + session.Engine.LogDebug(session.Statement.TableName(), " has after insert processor") processor.AfterInsert() } } else { if len(session.afterClosures) > 0 { - session.afterInsertBeans = append(session.afterInsertBeans, bean) + session.afterInsertBeans = append(session.afterInsertBeans, bean) } else { - if _, ok := interface{}(bean).(AfterInsertProcessor); ok { - session.afterInsertBeans = append(session.afterInsertBeans, bean) + if _, ok := interface{}(bean).(AfterInsertProcessor); ok { + session.afterInsertBeans = append(session.afterInsertBeans, bean) } } } - } + } // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. @@ -2127,7 +2138,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var args []interface{} var table *Table - // handle before update processors for _, closure := range session.beforeClosures { closure(bean) @@ -2247,13 +2257,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } else { if len(session.afterClosures) > 0 { - session.afterUpdateBeans = append(session.afterUpdateBeans, bean) + session.afterUpdateBeans = append(session.afterUpdateBeans, bean) } else { if _, ok := interface{}(bean).(AfterUpdateProcessor); ok { - session.afterUpdateBeans = append(session.afterUpdateBeans, bean) + session.afterUpdateBeans = append(session.afterUpdateBeans, bean) } } - } + } // -- return res.RowsAffected() @@ -2321,7 +2331,6 @@ func (session *Session) Delete(bean interface{}) (int64, error) { defer session.Close() } - // handle before delete processors for _, closure := range session.beforeClosures { closure(bean) @@ -2375,13 +2384,13 @@ func (session *Session) Delete(bean interface{}) (int64, error) { } } else { if len(session.afterClosures) > 0 { - session.afterDeleteBeans = append(session.afterDeleteBeans, bean) + session.afterDeleteBeans = append(session.afterDeleteBeans, bean) } else { if _, ok := interface{}(bean).(AfterDeleteProcessor); ok { - session.afterDeleteBeans = append(session.afterDeleteBeans, bean) + session.afterDeleteBeans = append(session.afterDeleteBeans, bean) } } - } + } // -- return res.RowsAffected()