Merge remote-tracking branch 'cyjay/master'

This commit is contained in:
CyJaySong 2023-10-28 09:51:12 +08:00
commit ad73876977
6 changed files with 214 additions and 18 deletions

View File

@ -16,11 +16,21 @@ import (
"time" "time"
) )
// ConversionFrom is an inteface to allow retrieve data from database
type ConversionFrom interface {
FromDB([]byte) error
}
// ConversionTo is an interface to allow store data to database
type ConversionTo interface {
ToDB() ([]byte, error)
}
// Conversion is an interface. A type implements Conversion will according // Conversion is an interface. A type implements Conversion will according
// the custom method to fill into database and retrieve from database. // the custom method to fill into database and retrieve from database.
type Conversion interface { type Conversion interface {
FromDB([]byte) error ConversionFrom
ToDB() ([]byte, error) ConversionTo
} }
// ErrNilPtr represents an error // ErrNilPtr represents an error

View File

@ -121,6 +121,7 @@ type EngineInterface interface {
ShowSQL(show ...bool) ShowSQL(show ...bool)
Sync(...interface{}) error Sync(...interface{}) error
Sync2(...interface{}) error Sync2(...interface{}) error
SyncWithOptions(SyncOptions, ...interface{}) (*SyncResult, error)
StoreEngine(storeEngine string) *Session StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) (*schemas.Table, error) TableInfo(bean interface{}) (*schemas.Table, error)
TableName(interface{}, ...bool) string TableName(interface{}, ...bool) string

View File

@ -644,6 +644,23 @@ func (statement *Statement) convertSQLOrArgs(sqlOrArgs ...interface{}) (string,
newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05"))
} else if v, ok := arg.(*time.Time); ok && v != nil { } else if v, ok := arg.(*time.Time); ok && v != nil {
newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05"))
} else if v, ok := arg.(convert.ConversionTo); ok {
r, err := v.ToDB()
if err != nil {
return "", nil, err
}
if r != nil {
// for nvarchar column on mssql, bytes have to be converted as ucs-2 external of driver
// for binary column, a string will be converted as bytes directly. So we have to
// convert bytes as string
if statement.dialect.URI().DBType == schemas.MSSQL {
newArgs = append(newArgs, string(r))
} else {
newArgs = append(newArgs, r)
}
} else {
newArgs = append(newArgs, nil)
}
} else { } else {
newArgs = append(newArgs, arg) newArgs = append(newArgs, arg)
} }

54
sync.go
View File

@ -13,6 +13,10 @@ import (
type SyncOptions struct { type SyncOptions struct {
WarnIfDatabaseColumnMissed bool WarnIfDatabaseColumnMissed bool
// IgnoreConstrains will not add, delete or update unique constrains
IgnoreConstrains bool
// IgnoreIndices will not add or delete indices
IgnoreIndices bool
} }
type SyncResult struct{} type SyncResult struct{}
@ -49,6 +53,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
func (session *Session) Sync(beans ...interface{}) error { func (session *Session) Sync(beans ...interface{}) error {
_, err := session.SyncWithOptions(SyncOptions{ _, err := session.SyncWithOptions(SyncOptions{
WarnIfDatabaseColumnMissed: false, WarnIfDatabaseColumnMissed: false,
IgnoreConstrains: false,
IgnoreIndices: false,
}, beans...) }, beans...)
return err return err
} }
@ -103,15 +109,20 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{})
return nil, err return nil, err
} }
err = session.createUniques(bean) if !opts.IgnoreConstrains {
if err != nil { err = session.createUniques(bean)
return nil, err if err != nil {
return nil, err
}
} }
err = session.createIndexes(bean) if !opts.IgnoreIndices {
if err != nil { err = session.createIndexes(bean)
return nil, err if err != nil {
return nil, err
}
} }
continue continue
} }
@ -208,9 +219,12 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{})
} }
} }
// indices found in orig table
foundIndexNames := make(map[string]bool) foundIndexNames := make(map[string]bool)
// indices to be added
addedNames := make(map[string]*schemas.Index) addedNames := make(map[string]*schemas.Index)
// drop indices that exist in orig and new table schema but are not equal
for name, index := range table.Indexes { for name, index := range table.Indexes {
var oriIndex *schemas.Index var oriIndex *schemas.Index
for name2, index2 := range oriTable.Indexes { for name2, index2 := range oriTable.Indexes {
@ -221,15 +235,13 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{})
} }
} }
if oriIndex != nil { if oriIndex != nil && oriIndex.Type != index.Type {
if oriIndex.Type != index.Type { sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex)
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex) _, err = session.exec(sql)
_, err = session.exec(sql) if err != nil {
if err != nil { return nil, err
return nil, err
}
oriIndex = nil
} }
oriIndex = nil
} }
if oriIndex == nil { if oriIndex == nil {
@ -237,8 +249,17 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{})
} }
} }
// drop all indices that do not exist in new schema or have changed
for name2, index2 := range oriTable.Indexes { for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok { if _, ok := foundIndexNames[name2]; !ok {
// ignore based on there type
if (index2.Type == schemas.IndexType && opts.IgnoreIndices) ||
(index2.Type == schemas.UniqueType && opts.IgnoreConstrains) {
// make sure we do not add a index with same name later
delete(addedNames, name2)
continue
}
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2) sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2)
_, err = session.exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
@ -247,12 +268,13 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{})
} }
} }
// Add new indices because either they did not exist before or were dropped to update them
for name, index := range addedNames { for name, index := range addedNames {
if index.Type == schemas.UniqueType { if index.Type == schemas.UniqueType && !opts.IgnoreConstrains {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema) session.statement.SetTableName(tbNameWithSchema)
err = session.addUnique(tbNameWithSchema, name) err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == schemas.IndexType { } else if index.Type == schemas.IndexType && !opts.IgnoreIndices {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema) session.statement.SetTableName(tbNameWithSchema)
err = session.addIndex(tbNameWithSchema, name) err = session.addIndex(tbNameWithSchema, name)

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/xorm"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -645,3 +646,101 @@ func TestCollate(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
} }
type SyncWithOpts1 struct {
Id int64
Index int `xorm:"index"`
Unique int `xorm:"unique"`
Group1 int `xorm:"index(ttt)"`
Group2 int `xorm:"index(ttt)"`
UniGroup1 int `xorm:"unique(lll)"`
UniGroup2 int `xorm:"unique(lll)"`
}
func (*SyncWithOpts1) TableName() string {
return "sync_with_opts"
}
type SyncWithOpts2 struct {
Id int64
Index int `xorm:"index"`
Unique int `xorm:""`
Group1 int `xorm:"index(ttt)"`
Group2 int `xorm:"index(ttt)"`
UniGroup1 int `xorm:""`
UniGroup2 int `xorm:"unique(lll)"`
}
func (*SyncWithOpts2) TableName() string {
return "sync_with_opts"
}
type SyncWithOpts3 struct {
Id int64
Index int `xorm:""`
Unique int `xorm:"unique"`
Group1 int `xorm:""`
Group2 int `xorm:"index(ttt)"`
UniGroup1 int `xorm:"unique(lll)"`
UniGroup2 int `xorm:"unique(lll)"`
}
func (*SyncWithOpts3) TableName() string {
return "sync_with_opts"
}
func TestSyncWithOptions(t *testing.T) {
assert.NoError(t, PrepareEngine())
// ignore indices and constrains
result, err := testEngine.SyncWithOptions(xorm.SyncOptions{IgnoreIndices: true, IgnoreConstrains: true}, &SyncWithOpts1{})
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Len(t, getIndicesOfBeanFromDB(t, &SyncWithOpts1{}), 0)
// only ignore indices
result, err = testEngine.SyncWithOptions(xorm.SyncOptions{IgnoreConstrains: true}, &SyncWithOpts2{})
assert.NoError(t, err)
assert.NotNil(t, result)
indices := getIndicesOfBeanFromDB(t, &SyncWithOpts1{})
assert.Len(t, indices, 2)
assert.ElementsMatch(t, []string{"ttt", "index"}, getKeysFromMap(indices))
// only ignore constrains
result, err = testEngine.SyncWithOptions(xorm.SyncOptions{IgnoreIndices: true}, &SyncWithOpts3{})
assert.NoError(t, err)
assert.NotNil(t, result)
indices = getIndicesOfBeanFromDB(t, &SyncWithOpts1{})
assert.Len(t, indices, 4)
assert.ElementsMatch(t, []string{"ttt", "index", "unique", "lll"}, getKeysFromMap(indices))
tableInfoFromStruct, _ := testEngine.TableInfo(&SyncWithOpts1{})
assert.ElementsMatch(t, getKeysFromMap(tableInfoFromStruct.Indexes), getKeysFromMap(getIndicesOfBeanFromDB(t, &SyncWithOpts1{})))
}
func getIndicesOfBeanFromDB(t *testing.T, bean interface{}) map[string]*schemas.Index {
dbm, err := testEngine.DBMetas()
assert.NoError(t, err)
tName := testEngine.TableName(bean)
var tSchema *schemas.Table
for _, t := range dbm {
if t.Name == tName {
tSchema = t
break
}
}
if !assert.NotNil(t, tSchema) {
return nil
}
return tSchema.Indexes
}
func getKeysFromMap(m map[string]*schemas.Index) []string {
var ss []string
for k := range m {
ss = append(ss, k)
}
return ss
}

View File

@ -9,6 +9,8 @@ import (
"testing" "testing"
"time" "time"
"xorm.io/xorm/convert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -65,3 +67,48 @@ func TestExecTime(t *testing.T) {
assert.True(t, has) assert.True(t, has)
assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), uet.Created.Format("2006-01-02 15:04:05")) assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), uet.Created.Format("2006-01-02 15:04:05"))
} }
type ConversionData struct {
MyData string
}
var _ convert.Conversion = new(ConversionData)
func (c ConversionData) ToDB() ([]byte, error) {
return []byte(c.MyData), nil
}
func (c *ConversionData) FromDB(bs []byte) error {
if bs != nil {
c.MyData = string(bs)
}
return nil
}
func TestExecCustomTypes(t *testing.T) {
assert.NoError(t, PrepareEngine())
type UserinfoExec struct {
Uid int
Name string
Data string
}
assert.NoError(t, testEngine.Sync2(new(UserinfoExec)))
res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_exec`", true)+" (uid, name,data) VALUES (?, ?, ?)",
1, "user", ConversionData{"data"})
assert.NoError(t, err)
cnt, err := res.RowsAffected()
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
results, err := testEngine.QueryString("select * from " + testEngine.TableName("userinfo_exec", true))
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
id, err := strconv.Atoi(results[0]["uid"])
assert.NoError(t, err)
assert.EqualValues(t, 1, id)
assert.Equal(t, "user", results[0]["name"])
assert.EqualValues(t, "data", results[0]["data"])
}