diff --git a/README.md b/README.md index 62b40ba3..1d90fd21 100644 --- a/README.md +++ b/README.md @@ -275,6 +275,17 @@ affected, err := engine.Where(...).Delete(&user) affected, err := engine.ID(2).Delete(&user) // DELETE FROM user Where id = ? + +// soft delete customer + +eg, err := xorm.NewEngine("mysql", dns) +if err != nil { + panic("failed to connect database " + err.Error()) +} +eg.ShowSQL(true) + +eg.SetSoftDeleteHandler(&xorm.DefaultSoftDeleteHandler{}) + ``` * `Count` count records diff --git a/engine.go b/engine.go index 96100fce..89ab2bdd 100644 --- a/engine.go +++ b/engine.go @@ -55,6 +55,7 @@ type Engine struct { cacherLock sync.RWMutex defaultContext context.Context + softDelete SoftDelete } func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { @@ -95,6 +96,9 @@ func (engine *Engine) CondDeleted(colName string) builder.Cond { if engine.dialect.DBType() == core.MSSQL { return builder.IsNull{colName} } + if engine.softDelete != nil { + return engine.softDelete.getSelectFilter(colName) + } return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1}) } @@ -315,8 +319,14 @@ func (engine *Engine) Dialect() core.Dialect { func (engine *Engine) NewSession() *Session { session := &Session{engine: engine} session.Init() + if engine.softDelete != nil { + session.setSoftDelete(engine.softDelete) + } return session } +func (engine *Engine) SetSoftDeleteHandler(handler SoftDelete) { + engine.softDelete = handler +} // Close the engine func (engine *Engine) Close() error { diff --git a/interface.go b/interface.go index a564db12..685ed287 100644 --- a/interface.go +++ b/interface.go @@ -70,7 +70,7 @@ type Interface interface { // EngineInterface defines the interface which Engine, EngineGroup will implementate. type EngineInterface interface { Interface - + SetSoftDeleteHandler(SoftDelete) Before(func(interface{})) *Session Charset(charset string) *Session ClearCache(...interface{}) error diff --git a/session.go b/session.go index b33955fd..49bb3e6d 100644 --- a/session.go +++ b/session.go @@ -60,6 +60,7 @@ type Session struct { ctx context.Context sessionType sessionType + softDelete SoftDelete } // Clone copy all the session's content and return a new session @@ -67,7 +68,10 @@ func (session *Session) Clone() *Session { var sess = *session return &sess } - +func (session *Session) setSoftDelete(softDelete SoftDelete) *Session { + session.softDelete = softDelete + return session +} // Init reset the session as the init status. func (session *Session) Init() { session.statement.Init() diff --git a/session_delete.go b/session_delete.go index 675d4d8c..1b54a7d5 100644 --- a/session_delete.go +++ b/session_delete.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "strconv" + "time" "xorm.io/core" ) @@ -192,15 +193,24 @@ func (session *Session) Delete(bean interface{}) (int64, error) { condArgs = append(condArgs, "") paramsLen := len(condArgs) copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) - - val, t := session.engine.nowTime(deletedColumn) - condArgs[0] = val - var colName = deletedColumn.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) + var t, val interface{} + if session.softDelete != nil { + val = session.softDelete.getDeleteValue() + condArgs[0] = val + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + session.softDelete.setBeanConumenAttr(bean, col, val) + }) + } else { + + val, t = session.engine.nowTime(deletedColumn) + condArgs[0] = val + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t.(time.Time)) + }) + } } if cacher := session.engine.getCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { diff --git a/session_delete_test.go b/session_delete_test.go index 5edb0718..7ac60639 100644 --- a/session_delete_test.go +++ b/session_delete_test.go @@ -5,11 +5,12 @@ package xorm import ( + "fmt" "testing" "time" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/core" ) func TestDelete(t *testing.T) { @@ -237,3 +238,94 @@ func TestUnscopeDelete(t *testing.T) { assert.NoError(t, err) assert.False(t, has) } + +func TestSoftDeleted(t *testing.T) { + type YySoftDeleted struct { + Id int64 `xorm:"pk"` + Name string + DeletedAt int64 `xorm:"not null default '0' comment('删除状态') deleted "` + } + testSoftEngine, err := createEngine(dbType, connString) + assert.NoError(t, err) + + testSoftEngine.SetSoftDeleteHandler(&DefaultSoftDeleteHandler{}) + defer testSoftEngine.SetSoftDeleteHandler(nil) + err = testSoftEngine.DropTables(&YySoftDeleted{}) + assert.NoError(t, err) + + err = testSoftEngine.CreateTables(&YySoftDeleted{}) + assert.NoError(t, err) + + _, err = testSoftEngine.InsertOne(&YySoftDeleted{Id: 1, Name: "4444"}) + assert.NoError(t, err) + + _, err = testSoftEngine.InsertOne(&YySoftDeleted{Id: 2, Name: "5555"}) + assert.NoError(t, err) + + _, err = testSoftEngine.InsertOne(&YySoftDeleted{Id: 3, Name: "6666"}) + assert.NoError(t, err) + + // Test normal Find() + var records1 []YySoftDeleted + err = testSoftEngine.Where("`"+testSoftEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").Find(&records1, &YySoftDeleted{}) + fmt.Printf("%+v", records1) + assert.EqualValues(t, 3, len(records1)) + // Test normal Get() + record1 := &YySoftDeleted{} + has, err := testSoftEngine.ID(1).Get(record1) + assert.NoError(t, err) + assert.True(t, has) + + // Test Delete() with deleted + affected, err := testSoftEngine.ID(1).Delete(&YySoftDeleted{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, affected) + + has, err = testSoftEngine.ID(1).Get(&YySoftDeleted{}) + assert.NoError(t, err) + assert.False(t, has) + + var records2 []YySoftDeleted + err = testSoftEngine.Where("`" + testSoftEngine.GetColumnMapper().Obj2Table("Id") + "` > 0").Find(&records2) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(records2)) + + // Test no rows affected after Delete() again. + affected, err = testSoftEngine.ID(1).Delete(&YySoftDeleted{}) + assert.NoError(t, err) + assert.EqualValues(t, 0, affected) + + // Deleted.DeletedAt must not be updated. + affected, err = testSoftEngine.ID(2).Update(&YySoftDeleted{Name: "23", DeletedAt: 1}) + assert.NoError(t, err) + assert.EqualValues(t, 1, affected) + + record2 := &YySoftDeleted{} + has, err = testSoftEngine.ID(2).Get(record2) + assert.NoError(t, err) + // fmt.Printf("%+v", reco) + assert.True(t, record2.DeletedAt == 0) + + // Test find all records whatever `deleted`. + var unscopedRecords1 []YySoftDeleted + err = testSoftEngine.Unscoped().Where("`"+testSoftEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").Find(&unscopedRecords1, &YySoftDeleted{}) + assert.NoError(t, err) + assert.EqualValues(t, 3, len(unscopedRecords1)) + + // Delete() must really delete a record with Unscoped() + affected, err = testSoftEngine.Unscoped().ID(1).Delete(&YySoftDeleted{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, affected) + + var unscopedRecords2 []YySoftDeleted + err = testSoftEngine.Unscoped().Where("`"+testSoftEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").Find(&unscopedRecords2, &YySoftDeleted{}) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(unscopedRecords2)) + + var records3 []YySoftDeleted + err = testSoftEngine.Where("`"+testSoftEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").And("`"+testSoftEngine.GetColumnMapper().Obj2Table("Id")+"`> 1"). + Or("`"+testSoftEngine.GetColumnMapper().Obj2Table("Id")+"` = ?", 3).Find(&records3) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(records3)) + +} diff --git a/soft_delete.go b/soft_delete.go new file mode 100644 index 00000000..0e50b50a --- /dev/null +++ b/soft_delete.go @@ -0,0 +1,39 @@ +package xorm + +import ( + "reflect" + + "xorm.io/builder" + "xorm.io/core" +) + +type SoftDelete interface { + getDeleteValue() interface{} + getSelectFilter(deleteField string) builder.Cond + setBeanConumenAttr(bean interface{}, col *core.Column, val interface{}) +} + +type DefaultSoftDeleteHandler struct { +} + +func (h *DefaultSoftDeleteHandler) setBeanConumenAttr(bean interface{}, col *core.Column, val interface{}) { + t := val.(int64) + v, err := col.ValueOf(bean) + if err != nil { + return + } + if v.CanSet() { + switch v.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32: + v.SetInt(t) + case reflect.Uint, reflect.Uint64, reflect.Uint32: + v.SetUint(uint64(t)) + } + } +} +func (h *DefaultSoftDeleteHandler) getDeleteValue() interface{} { + return int64(1) +} +func (h *DefaultSoftDeleteHandler) getSelectFilter(deleteField string) builder.Cond { + return builder.Eq{deleteField: int64(0)} +} diff --git a/xorm_test.go b/xorm_test.go index c0302df3..d2a0adf8 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -37,41 +37,40 @@ var ( ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb") ) -func createEngine(dbType, connStr string) error { - if testEngine == nil { - var err error +func createEngine(dbType, connStr string) (testEngine EngineInterface,err error) { + if testEngine == nil { if !*cluster { switch strings.ToLower(dbType) { case core.MSSQL: db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1)) if err != nil { - return err + return nil,err } if _, err = db.Exec("If(db_id(N'xorm_test') IS NULL) BEGIN CREATE DATABASE xorm_test; END;"); err != nil { - return fmt.Errorf("db.Exec: %v", err) + return nil,fmt.Errorf("db.Exec: %v", err) } db.Close() *ignoreSelectUpdate = true case core.POSTGRES: db, err := sql.Open(dbType, connStr) if err != nil { - return err + return nil,err } rows, err := db.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = 'xorm_test'")) if err != nil { - return fmt.Errorf("db.Query: %v", err) + return nil,fmt.Errorf("db.Query: %v", err) } defer rows.Close() if !rows.Next() { if _, err = db.Exec("CREATE DATABASE xorm_test"); err != nil { - return fmt.Errorf("CREATE DATABASE: %v", err) + return nil,fmt.Errorf("CREATE DATABASE: %v", err) } } if *schema != "" { if _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS " + *schema); err != nil { - return fmt.Errorf("CREATE SCHEMA: %v", err) + return nil,fmt.Errorf("CREATE SCHEMA: %v", err) } } db.Close() @@ -79,10 +78,10 @@ func createEngine(dbType, connStr string) error { case core.MYSQL: db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "mysql", -1)) if err != nil { - return err + return nil,err } if _, err = db.Exec("CREATE DATABASE IF NOT EXISTS xorm_test"); err != nil { - return fmt.Errorf("db.Exec: %v", err) + return nil,fmt.Errorf("db.Exec: %v", err) } db.Close() default: @@ -97,7 +96,7 @@ func createEngine(dbType, connStr string) error { } } if err != nil { - return err + return nil,err } if *schema != "" { @@ -124,20 +123,23 @@ func createEngine(dbType, connStr string) error { tables, err := testEngine.DBMetas() if err != nil { - return err + return nil,err } var tableNames = make([]interface{}, 0, len(tables)) for _, table := range tables { tableNames = append(tableNames, table.Name) } if err = testEngine.DropTables(tableNames...); err != nil { - return err + return nil,err } - return nil + return testEngine,nil } func prepareEngine() error { - return createEngine(dbType, connString) + var err error + testEngine ,err = createEngine(dbType, connString) + + return err } func TestMain(m *testing.M) {