This commit is contained in:
liuchenrang 2019-09-29 04:31:09 +00:00 committed by GitHub
commit 7ea90a1aea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 195 additions and 27 deletions

View File

@ -275,6 +275,17 @@ affected, err := engine.Where(...).Delete(&user)
affected, err := engine.ID(2).Delete(&user) affected, err := engine.ID(2).Delete(&user)
// DELETE FROM user Where id = ? // DELETE FROM user Where id = ?
// soft delete customer
eg, err := xorm.NewEngine("mysql", dns)
if err != nil {
panic("failed to connect database " + err.Error())
}
eg.ShowSQL(true)
eg.SetSoftDeleteHandler(&xorm.DefaultSoftDeleteHandler{})
``` ```
* `Count` count records * `Count` count records

View File

@ -55,6 +55,7 @@ type Engine struct {
cacherLock sync.RWMutex cacherLock sync.RWMutex
defaultContext context.Context defaultContext context.Context
softDelete SoftDelete
} }
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
@ -95,6 +96,9 @@ func (engine *Engine) CondDeleted(colName string) builder.Cond {
if engine.dialect.DBType() == core.MSSQL { if engine.dialect.DBType() == core.MSSQL {
return builder.IsNull{colName} return builder.IsNull{colName}
} }
if engine.softDelete != nil {
return engine.softDelete.getSelectFilter(colName)
}
return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1}) return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1})
} }
@ -315,8 +319,14 @@ func (engine *Engine) Dialect() core.Dialect {
func (engine *Engine) NewSession() *Session { func (engine *Engine) NewSession() *Session {
session := &Session{engine: engine} session := &Session{engine: engine}
session.Init() session.Init()
if engine.softDelete != nil {
session.setSoftDelete(engine.softDelete)
}
return session return session
} }
func (engine *Engine) SetSoftDeleteHandler(handler SoftDelete) {
engine.softDelete = handler
}
// Close the engine // Close the engine
func (engine *Engine) Close() error { func (engine *Engine) Close() error {

View File

@ -70,7 +70,7 @@ type Interface interface {
// EngineInterface defines the interface which Engine, EngineGroup will implementate. // EngineInterface defines the interface which Engine, EngineGroup will implementate.
type EngineInterface interface { type EngineInterface interface {
Interface Interface
SetSoftDeleteHandler(SoftDelete)
Before(func(interface{})) *Session Before(func(interface{})) *Session
Charset(charset string) *Session Charset(charset string) *Session
ClearCache(...interface{}) error ClearCache(...interface{}) error

View File

@ -60,6 +60,7 @@ type Session struct {
ctx context.Context ctx context.Context
sessionType sessionType sessionType sessionType
softDelete SoftDelete
} }
// Clone copy all the session's content and return a new session // Clone copy all the session's content and return a new session
@ -67,7 +68,10 @@ func (session *Session) Clone() *Session {
var sess = *session var sess = *session
return &sess return &sess
} }
func (session *Session) setSoftDelete(softDelete SoftDelete) *Session {
session.softDelete = softDelete
return session
}
// Init reset the session as the init status. // Init reset the session as the init status.
func (session *Session) Init() { func (session *Session) Init() {
session.statement.Init() session.statement.Init()

View File

@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"time"
"xorm.io/core" "xorm.io/core"
) )
@ -192,15 +193,24 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
condArgs = append(condArgs, "") condArgs = append(condArgs, "")
paramsLen := len(condArgs) paramsLen := len(condArgs)
copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1])
val, t := session.engine.nowTime(deletedColumn)
condArgs[0] = val
var colName = deletedColumn.Name var colName = deletedColumn.Name
var t, val interface{}
if session.softDelete != nil {
val = session.softDelete.getDeleteValue()
condArgs[0] = val
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) session.softDelete.setBeanConumenAttr(bean, col, val)
}) })
} else {
val, t = session.engine.nowTime(deletedColumn)
condArgs[0] = val
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t.(time.Time))
})
}
} }
if cacher := session.engine.getCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { if cacher := session.engine.getCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache {

View File

@ -5,11 +5,12 @@
package xorm package xorm
import ( import (
"fmt"
"testing" "testing"
"time" "time"
"xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core"
) )
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
@ -237,3 +238,94 @@ func TestUnscopeDelete(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, has) assert.False(t, has)
} }
func TestSoftDeleted(t *testing.T) {
type YySoftDeleted struct {
Id int64 `xorm:"pk"`
Name string
DeletedAt int64 `xorm:"not null default '0' comment('删除状态') deleted "`
}
testSoftEngine, err := createEngine(dbType, connString)
assert.NoError(t, err)
testSoftEngine.SetSoftDeleteHandler(&DefaultSoftDeleteHandler{})
defer testSoftEngine.SetSoftDeleteHandler(nil)
err = testSoftEngine.DropTables(&YySoftDeleted{})
assert.NoError(t, err)
err = testSoftEngine.CreateTables(&YySoftDeleted{})
assert.NoError(t, err)
_, err = testSoftEngine.InsertOne(&YySoftDeleted{Id: 1, Name: "4444"})
assert.NoError(t, err)
_, err = testSoftEngine.InsertOne(&YySoftDeleted{Id: 2, Name: "5555"})
assert.NoError(t, err)
_, err = testSoftEngine.InsertOne(&YySoftDeleted{Id: 3, Name: "6666"})
assert.NoError(t, err)
// Test normal Find()
var records1 []YySoftDeleted
err = testSoftEngine.Where("`"+testSoftEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").Find(&records1, &YySoftDeleted{})
fmt.Printf("%+v", records1)
assert.EqualValues(t, 3, len(records1))
// Test normal Get()
record1 := &YySoftDeleted{}
has, err := testSoftEngine.ID(1).Get(record1)
assert.NoError(t, err)
assert.True(t, has)
// Test Delete() with deleted
affected, err := testSoftEngine.ID(1).Delete(&YySoftDeleted{})
assert.NoError(t, err)
assert.EqualValues(t, 1, affected)
has, err = testSoftEngine.ID(1).Get(&YySoftDeleted{})
assert.NoError(t, err)
assert.False(t, has)
var records2 []YySoftDeleted
err = testSoftEngine.Where("`" + testSoftEngine.GetColumnMapper().Obj2Table("Id") + "` > 0").Find(&records2)
assert.NoError(t, err)
assert.EqualValues(t, 2, len(records2))
// Test no rows affected after Delete() again.
affected, err = testSoftEngine.ID(1).Delete(&YySoftDeleted{})
assert.NoError(t, err)
assert.EqualValues(t, 0, affected)
// Deleted.DeletedAt must not be updated.
affected, err = testSoftEngine.ID(2).Update(&YySoftDeleted{Name: "23", DeletedAt: 1})
assert.NoError(t, err)
assert.EqualValues(t, 1, affected)
record2 := &YySoftDeleted{}
has, err = testSoftEngine.ID(2).Get(record2)
assert.NoError(t, err)
// fmt.Printf("%+v", reco)
assert.True(t, record2.DeletedAt == 0)
// Test find all records whatever `deleted`.
var unscopedRecords1 []YySoftDeleted
err = testSoftEngine.Unscoped().Where("`"+testSoftEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").Find(&unscopedRecords1, &YySoftDeleted{})
assert.NoError(t, err)
assert.EqualValues(t, 3, len(unscopedRecords1))
// Delete() must really delete a record with Unscoped()
affected, err = testSoftEngine.Unscoped().ID(1).Delete(&YySoftDeleted{})
assert.NoError(t, err)
assert.EqualValues(t, 1, affected)
var unscopedRecords2 []YySoftDeleted
err = testSoftEngine.Unscoped().Where("`"+testSoftEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").Find(&unscopedRecords2, &YySoftDeleted{})
assert.NoError(t, err)
assert.EqualValues(t, 2, len(unscopedRecords2))
var records3 []YySoftDeleted
err = testSoftEngine.Where("`"+testSoftEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").And("`"+testSoftEngine.GetColumnMapper().Obj2Table("Id")+"`> 1").
Or("`"+testSoftEngine.GetColumnMapper().Obj2Table("Id")+"` = ?", 3).Find(&records3)
assert.NoError(t, err)
assert.EqualValues(t, 2, len(records3))
}

39
soft_delete.go Normal file
View File

@ -0,0 +1,39 @@
package xorm
import (
"reflect"
"xorm.io/builder"
"xorm.io/core"
)
type SoftDelete interface {
getDeleteValue() interface{}
getSelectFilter(deleteField string) builder.Cond
setBeanConumenAttr(bean interface{}, col *core.Column, val interface{})
}
type DefaultSoftDeleteHandler struct {
}
func (h *DefaultSoftDeleteHandler) setBeanConumenAttr(bean interface{}, col *core.Column, val interface{}) {
t := val.(int64)
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t)
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t))
}
}
}
func (h *DefaultSoftDeleteHandler) getDeleteValue() interface{} {
return int64(1)
}
func (h *DefaultSoftDeleteHandler) getSelectFilter(deleteField string) builder.Cond {
return builder.Eq{deleteField: int64(0)}
}

View File

@ -37,41 +37,40 @@ var (
ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb") ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb")
) )
func createEngine(dbType, connStr string) error { func createEngine(dbType, connStr string) (testEngine EngineInterface,err error) {
if testEngine == nil {
var err error
if testEngine == nil {
if !*cluster { if !*cluster {
switch strings.ToLower(dbType) { switch strings.ToLower(dbType) {
case core.MSSQL: case core.MSSQL:
db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1)) db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1))
if err != nil { if err != nil {
return err return nil,err
} }
if _, err = db.Exec("If(db_id(N'xorm_test') IS NULL) BEGIN CREATE DATABASE xorm_test; END;"); err != nil { if _, err = db.Exec("If(db_id(N'xorm_test') IS NULL) BEGIN CREATE DATABASE xorm_test; END;"); err != nil {
return fmt.Errorf("db.Exec: %v", err) return nil,fmt.Errorf("db.Exec: %v", err)
} }
db.Close() db.Close()
*ignoreSelectUpdate = true *ignoreSelectUpdate = true
case core.POSTGRES: case core.POSTGRES:
db, err := sql.Open(dbType, connStr) db, err := sql.Open(dbType, connStr)
if err != nil { if err != nil {
return err return nil,err
} }
rows, err := db.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = 'xorm_test'")) rows, err := db.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = 'xorm_test'"))
if err != nil { if err != nil {
return fmt.Errorf("db.Query: %v", err) return nil,fmt.Errorf("db.Query: %v", err)
} }
defer rows.Close() defer rows.Close()
if !rows.Next() { if !rows.Next() {
if _, err = db.Exec("CREATE DATABASE xorm_test"); err != nil { if _, err = db.Exec("CREATE DATABASE xorm_test"); err != nil {
return fmt.Errorf("CREATE DATABASE: %v", err) return nil,fmt.Errorf("CREATE DATABASE: %v", err)
} }
} }
if *schema != "" { if *schema != "" {
if _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS " + *schema); err != nil { if _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS " + *schema); err != nil {
return fmt.Errorf("CREATE SCHEMA: %v", err) return nil,fmt.Errorf("CREATE SCHEMA: %v", err)
} }
} }
db.Close() db.Close()
@ -79,10 +78,10 @@ func createEngine(dbType, connStr string) error {
case core.MYSQL: case core.MYSQL:
db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "mysql", -1)) db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "mysql", -1))
if err != nil { if err != nil {
return err return nil,err
} }
if _, err = db.Exec("CREATE DATABASE IF NOT EXISTS xorm_test"); err != nil { if _, err = db.Exec("CREATE DATABASE IF NOT EXISTS xorm_test"); err != nil {
return fmt.Errorf("db.Exec: %v", err) return nil,fmt.Errorf("db.Exec: %v", err)
} }
db.Close() db.Close()
default: default:
@ -97,7 +96,7 @@ func createEngine(dbType, connStr string) error {
} }
} }
if err != nil { if err != nil {
return err return nil,err
} }
if *schema != "" { if *schema != "" {
@ -124,20 +123,23 @@ func createEngine(dbType, connStr string) error {
tables, err := testEngine.DBMetas() tables, err := testEngine.DBMetas()
if err != nil { if err != nil {
return err return nil,err
} }
var tableNames = make([]interface{}, 0, len(tables)) var tableNames = make([]interface{}, 0, len(tables))
for _, table := range tables { for _, table := range tables {
tableNames = append(tableNames, table.Name) tableNames = append(tableNames, table.Name)
} }
if err = testEngine.DropTables(tableNames...); err != nil { if err = testEngine.DropTables(tableNames...); err != nil {
return err return nil,err
} }
return nil return testEngine,nil
} }
func prepareEngine() error { func prepareEngine() error {
return createEngine(dbType, connString) var err error
testEngine ,err = createEngine(dbType, connString)
return err
} }
func TestMain(m *testing.M) { func TestMain(m *testing.M) {