diff --git a/dialects/dialect.go b/dialects/dialect.go index 555d96c6..4d4fc150 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -89,6 +89,8 @@ type Dialect interface { Filters() []Filter SetParams(params map[string]string) + IsShadow(ctx context.Context) bool + SetShadowable(s Shadowable) } // Base represents a basic dialect and all real dialects could embed this struct @@ -96,6 +98,7 @@ type Base struct { dialect Dialect uri *URI quoter schemas.Quoter + shadow Shadowable } // Alias returned col itself @@ -254,6 +257,16 @@ func (db *Base) ForUpdateSQL(query string) string { func (db *Base) SetParams(params map[string]string) { } +func (db *Base) IsShadow(ctx context.Context) bool { + if db.shadow != nil { + return db.shadow.IsShadow(ctx) + } + return false +} +func (db *Base) SetShadowable(shadow Shadowable) { + db.shadow = shadow +} + var ( dialects = map[string]func() Dialect{} ) diff --git a/dialects/shadow.go b/dialects/shadow.go new file mode 100644 index 00000000..1b504610 --- /dev/null +++ b/dialects/shadow.go @@ -0,0 +1,23 @@ +package dialects + +import "context" + +type Shadowable interface { + IsShadow(ctx context.Context) bool +} + +type TrueShadow struct{} +type FalseShadow struct{} + +func NewTrueShadow() Shadowable { + return &TrueShadow{} +} +func NewFalseShadow() Shadowable { + return &FalseShadow{} +} +func (t *TrueShadow) IsShadow(ctx context.Context) bool { + return true +} +func (f *FalseShadow) IsShadow(ctx context.Context) bool { + return false +} diff --git a/dialects/table_name.go b/dialects/table_name.go index 8a0baeac..8e2425fb 100644 --- a/dialects/table_name.go +++ b/dialects/table_name.go @@ -5,6 +5,7 @@ package dialects import ( + "context" "fmt" "reflect" "strings" @@ -14,6 +15,8 @@ import ( "xorm.io/xorm/schemas" ) +const ShadowDBNamePrefix = "shadow_" + // TableNameWithSchema will add schema prefix on table name if possible func TableNameWithSchema(dialect Dialect, tableName string) string { // Add schema name as prefix of table name. @@ -24,6 +27,18 @@ func TableNameWithSchema(dialect Dialect, tableName string) string { return tableName } +// TableNameWithDBName will add database name prefix on table name if possible +func TableNameWithDBName(dialect Dialect, tableName string) string { + // Add schema name as prefix of table name. + // Only for postgres database. + if dialect.URI().DBName != "" && + dialect.URI().DBType == schemas.MYSQL && + strings.Index(tableName, ".") == -1 { + return fmt.Sprintf("%s.%s", dialect.URI().DBName, tableName) + } + return tableName +} + // TableNameNoSchema returns table name with given tableName func TableNameNoSchema(dialect Dialect, mapper names.Mapper, tableName interface{}) string { quote := dialect.Quoter().Quote @@ -84,10 +99,19 @@ func TableNameNoSchema(dialect Dialect, mapper names.Mapper, tableName interface } // FullTableName returns table name with quote and schema according parameter -func FullTableName(dialect Dialect, mapper names.Mapper, bean interface{}, includeSchema ...bool) string { +func FullTableName(ctx context.Context, dialect Dialect, mapper names.Mapper, bean interface{}, includeSchema ...bool) string { tbName := TableNameNoSchema(dialect, mapper, bean) if len(includeSchema) > 0 && includeSchema[0] && !utils.IsSubQuery(tbName) { tbName = TableNameWithSchema(dialect, tbName) } + if dialect.URI() != nil && + (dialect.URI().DBType == schemas.MYSQL || dialect.URI().DBType == schemas.SQLITE) && + dialect.IsShadow(ctx) && !hasShadowPrefix(tbName) { + tbName = ShadowDBNamePrefix + TableNameWithDBName(dialect, tbName) + } return tbName } + +func hasShadowPrefix(tableName string) bool { + return strings.HasPrefix(tableName, ShadowDBNamePrefix) +} diff --git a/dialects/table_name_test.go b/dialects/table_name_test.go index 66edc2b4..ebccc0c3 100644 --- a/dialects/table_name_test.go +++ b/dialects/table_name_test.go @@ -5,6 +5,7 @@ package dialects import ( + "context" "testing" "xorm.io/xorm/names" @@ -23,8 +24,14 @@ func (mcc *MCC) TableName() string { } func TestFullTableName(t *testing.T) { - dialect := QueryDialect("mysql") - - assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, &MCC{})) - assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, "mcc")) + dialect, err := OpenDialect("mysql", "root:root@tcp(127.0.0.1:3306)/test?charset=utf8") + if err != nil { + panic("unknow dialect") + } + dialect.SetShadowable(NewTrueShadow()) + assert.EqualValues(t, "shadow_test.mcc", FullTableName(context.Background(), dialect, names.SnakeMapper{}, &MCC{})) + assert.EqualValues(t, "shadow_test.mcc", FullTableName(context.Background(), dialect, names.SnakeMapper{}, "mcc")) + dialect.SetShadowable(NewFalseShadow()) + assert.EqualValues(t, "mcc", FullTableName(context.Background(), dialect, names.SnakeMapper{}, &MCC{})) + assert.EqualValues(t, "mcc", FullTableName(context.Background(), dialect, names.SnakeMapper{}, "mcc")) } diff --git a/engine.go b/engine.go index 81cfc7a9..6963cda1 100644 --- a/engine.go +++ b/engine.go @@ -291,7 +291,11 @@ func (engine *Engine) NoCascade() *Session { // MapCacher Set a table use a special cacher func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error { - engine.SetCacher(dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, true), cacher) + for _, v := range []dialects.Shadowable{dialects.NewTrueShadow(), dialects.NewFalseShadow()} { + engine.dialect.SetShadowable(v) + engine.SetCacher(dialects.FullTableName(context.Background(), engine.dialect, engine.GetTableMapper(), bean, true), cacher) + engine.SetCacher(dialects.FullTableName(context.Background(), engine.dialect, engine.GetTableMapper(), bean, true), cacher) + } return nil } @@ -1067,7 +1071,12 @@ func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) { // TableName returns table name with schema prefix if has func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { - return dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, includeSchema...) + return dialects.FullTableName(context.Background(), engine.dialect, engine.GetTableMapper(), bean, includeSchema...) +} + +// ContextTableName returns table name with schema and database prefix if has +func (engine *Engine) ContextTableName(ctx context.Context, bean interface{}, includeSchema ...bool) string { + return dialects.FullTableName(ctx, engine.dialect, engine.GetTableMapper(), bean, includeSchema...) } // CreateIndexes create indexes @@ -1086,23 +1095,29 @@ func (engine *Engine) CreateUniques(bean interface{}) error { // ClearCacheBean if enabled cache, clear the cache bean func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { - tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) - cacher := engine.GetCacher(tableName) - if cacher != nil { - cacher.ClearIds(tableName) - cacher.DelBean(tableName, id) + for _, v := range []dialects.Shadowable{dialects.NewTrueShadow(), dialects.NewFalseShadow()} { + engine.dialect.SetShadowable(v) + tableName := dialects.FullTableName(context.Background(), engine.dialect, engine.GetTableMapper(), bean) + cacher := engine.GetCacher(tableName) + if cacher != nil { + cacher.ClearIds(tableName) + cacher.DelBean(tableName, id) + } } return nil } // ClearCache if enabled cache, clear some tables' cache func (engine *Engine) ClearCache(beans ...interface{}) error { - for _, bean := range beans { - tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) - cacher := engine.GetCacher(tableName) - if cacher != nil { - cacher.ClearIds(tableName) - cacher.ClearBeans(tableName) + for _, v := range []dialects.Shadowable{dialects.NewTrueShadow(), dialects.NewFalseShadow()} { + engine.dialect.SetShadowable(v) + for _, bean := range beans { + tableName := dialects.FullTableName(context.Background(), engine.dialect, engine.GetTableMapper(), bean) + cacher := engine.GetCacher(tableName) + if cacher != nil { + cacher.ClearIds(tableName) + cacher.ClearBeans(tableName) + } } } return nil @@ -1431,3 +1446,8 @@ func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interf return result, nil } + +// SetShadow Set whether to use shadow database algorithm, should be called after modify the cache setting +func (engine *Engine) SetShadow(shadow dialects.Shadowable) { + engine.dialect.SetShadowable(shadow) +} diff --git a/engine_group.go b/engine_group.go index f2fe913d..a7dbf43b 100644 --- a/engine_group.go +++ b/engine_group.go @@ -265,3 +265,16 @@ func (eg *EngineGroup) Rows(bean interface{}) (*Rows, error) { sess.isAutoClose = true return sess.Rows(bean) } + +// SetShadow Set whether to use shadow database algorithm, should be called after modify the cache setting +func (eg *EngineGroup) SetShadow(shadow dialects.Shadowable) { + eg.Engine.SetShadow(shadow) + for i := 0; i < len(eg.slaves); i++ { + eg.slaves[i].SetShadow(shadow) + } +} + +// ContextTableName returns table name with schema and database prefix if has +func (engine *EngineGroup) ContextTableName(ctx context.Context, bean interface{}, includeSchema ...bool) string { + return dialects.FullTableName(ctx, engine.dialect, engine.GetTableMapper(), bean, includeSchema...) +} diff --git a/go.mod b/go.mod index 7bde41ae..8c9ed6a4 100644 --- a/go.mod +++ b/go.mod @@ -17,5 +17,5 @@ require ( github.com/syndtr/goleveldb v1.0.0 github.com/ziutek/mymysql v1.5.4 modernc.org/sqlite v1.14.2 - xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978 + xorm.io/builder v0.3.11 ) diff --git a/go.sum b/go.sum index 8bdc9798..24d48f60 100644 --- a/go.sum +++ b/go.sum @@ -661,3 +661,5 @@ sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978 h1:bvLlAPW1ZMTWA32LuZMBEGHAUOcATZjzHcotf3SWweM= xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= +xorm.io/builder v0.3.11 h1:naLkJitGyYW7ZZdncsh/JW+HF4HshmvTHTyUyPwJS00= +xorm.io/builder v0.3.11/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= diff --git a/integrations/session_delete_test.go b/integrations/session_delete_test.go index b4e40edb..18b29102 100644 --- a/integrations/session_delete_test.go +++ b/integrations/session_delete_test.go @@ -5,8 +5,11 @@ package integrations import ( + "context" "testing" "time" + "xorm.io/xorm" + "xorm.io/xorm/dialects" "xorm.io/xorm/caches" "xorm.io/xorm/schemas" @@ -191,6 +194,42 @@ func TestCacheDelete(t *testing.T) { testEngine.SetDefaultCacher(oldCacher) } +func TestShadowCacheDelete(t *testing.T) { + testEngine, err := xorm.NewEngine(string(schemas.MYSQL), "root:root@tcp(127.0.0.1:3306)/test?charset=utf8") + assert.NoError(t, err) + testEngine.ShowSQL(true) + _, err = testEngine.NewSession().Exec("CREATE DATABASE IF NOT EXISTS shadow_test") + testEngine.SetShadow(dialects.NewTrueShadow()) + + oldCacher := testEngine.GetDefaultCacher() + cacher := caches.NewLRUCacher(caches.NewMemoryStore(), 1000) + testEngine.SetDefaultCacher(cacher) + + type CacheDeleteStruct struct { + Id int64 + } + assert.NoError(t, testEngine.Context(context.Background()).Sync(&CacheDeleteStruct{})) + err = testEngine.CreateTables(&CacheDeleteStruct{}) + assert.NoError(t, err) + + _, err = testEngine.Insert(&CacheDeleteStruct{}) + assert.NoError(t, err) + + aff, err := testEngine.Delete(&CacheDeleteStruct{ + Id: 1, + }) + assert.NoError(t, err) + assert.EqualValues(t, aff, 1) + + aff, err = testEngine.Unscoped().Delete(&CacheDeleteStruct{ + Id: 1, + }) + assert.NoError(t, err) + assert.EqualValues(t, aff, 0) + + testEngine.SetDefaultCacher(oldCacher) +} + func TestUnscopeDelete(t *testing.T) { assert.NoError(t, PrepareEngine()) diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 841ec709..306cf6be 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -5,6 +5,7 @@ package integrations import ( + "context" "database/sql" "errors" "fmt" @@ -22,6 +23,214 @@ import ( "github.com/stretchr/testify/assert" ) +func TestShadowGetVar(t *testing.T) { + testEngine, err := xorm.NewEngine(string(schemas.MYSQL), "root:root@tcp(127.0.0.1:3306)/test?charset=utf8") + assert.NoError(t, err) + testEngine.ShowSQL(true) + _, err = testEngine.NewSession().Exec("CREATE DATABASE IF NOT EXISTS shadow_test") + assert.NoError(t, err) + type GetVar struct { + Id int64 `xorm:"autoincr pk"` + Msg string `xorm:"varchar(255)"` + Age int + Money float32 + Created time.Time `xorm:"created"` + } + testEngine.SetShadow(dialects.NewTrueShadow()) + + assert.NoError(t, testEngine.Context(context.Background()).Sync(new(GetVar))) + + data := GetVar{ + Msg: "hi", + Age: 28, + Money: 1.5, + } + _, err = testEngine.InsertOne(&data) + assert.NoError(t, err) + + var msg string + has, err := testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("msg").Get(&msg) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, "hi", msg) + + var age int + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("age").Get(&age) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, 28, age) + + var ageMax int + has, err = testEngine.SQL("SELECT max(`age`) FROM "+testEngine.Quote(testEngine.ContextTableName(context.Background(), "get_var"))+" WHERE `id` = ?", data.Id).Get(&ageMax) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, 28, ageMax) + + var age2 int64 + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("age"). + Where("`age` > ?", 20). + And("`age` < ?", 30). + Get(&age2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age2) + + var age3 int8 + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("age").Get(&age3) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age3) + + var age4 int16 + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("age"). + Where("`age` > ?", 20). + And("`age` < ?", 30). + Get(&age4) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age4) + + var age5 int32 + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("age"). + Where("`age` > ?", 20). + And("`age` < ?", 30). + Get(&age5) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age5) + + var age6 int + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("age").Get(&age6) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age6) + + var age7 int64 + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("age"). + Where("`age` > ?", 20). + And("`age` < ?", 30). + Get(&age7) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age7) + + var age8 int8 + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("age").Get(&age8) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age8) + + var age9 int16 + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("age"). + Where("`age` > ?", 20). + And("`age` < ?", 30). + Get(&age9) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age9) + + var age10 int32 + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("age"). + Where("`age` > ?", 20). + And("`age` < ?", 30). + Get(&age10) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age10) + + var id sql.NullInt64 + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("id").Get(&id) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, true, id.Valid) + + var msgNull sql.NullString + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("msg").Get(&msgNull) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, true, msgNull.Valid) + assert.EqualValues(t, data.Msg, msgNull.String) + + var nullMoney sql.NullFloat64 + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("money").Get(&nullMoney) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, true, nullMoney.Valid) + assert.EqualValues(t, data.Money, nullMoney.Float64) + + var money float64 + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Cols("money").Get(&money) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) + + var money2 float64 + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + has, err = testEngine.SQL("SELECT TOP 1 `money` FROM " + testEngine.Quote(testEngine.ContextTableName(context.Background(), "get_var"))).Get(&money2) + } else { + has, err = testEngine.SQL("SELECT `money` FROM " + testEngine.Quote(testEngine.ContextTableName(context.Background(), "get_var")) + " LIMIT 1").Get(&money2) + } + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2)) + + var money3 float64 + has, err = testEngine.SQL("SELECT `money` FROM " + testEngine.Quote(testEngine.ContextTableName(context.Background(), "get_var")) + " WHERE `money` > 20").Get(&money3) + assert.NoError(t, err) + assert.Equal(t, false, has) + + valuesString := make(map[string]string) + has, err = testEngine.Table("get_var").Get(&valuesString) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, 5, len(valuesString)) + assert.Equal(t, "1", valuesString["id"]) + assert.Equal(t, "hi", valuesString["msg"]) + assert.Equal(t, "28", valuesString["age"]) + assert.Equal(t, "1.5", valuesString["money"]) + + // for mymysql driver, interface{} will be []byte, so ignore it currently + if testEngine.DriverName() != "mymysql" { + valuesInter := make(map[string]interface{}) + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Where("`id` = ?", 1).Select("*").Get(&valuesInter) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, 5, len(valuesInter)) + assert.EqualValues(t, 1, valuesInter["id"]) + assert.Equal(t, "hi", fmt.Sprintf("%s", valuesInter["msg"])) + assert.EqualValues(t, 28, valuesInter["age"]) + assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"])) + } + + valuesSliceString := make([]string, 5) + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Get(&valuesSliceString) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, "1", valuesSliceString[0]) + assert.Equal(t, "hi", valuesSliceString[1]) + assert.Equal(t, "28", valuesSliceString[2]) + assert.Equal(t, "1.5", valuesSliceString[3]) + + valuesSliceInter := make([]interface{}, 5) + has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "get_var")).Get(&valuesSliceInter) + assert.NoError(t, err) + assert.Equal(t, true, has) + + v1, err := convert.AsInt64(valuesSliceInter[0]) + assert.NoError(t, err) + assert.EqualValues(t, 1, v1) + + assert.Equal(t, "hi", fmt.Sprintf("%s", valuesSliceInter[1])) + + v3, err := convert.AsInt64(valuesSliceInter[2]) + assert.NoError(t, err) + assert.EqualValues(t, 28, v3) + + v4, err := convert.AsFloat64(valuesSliceInter[3]) + assert.NoError(t, err) + assert.Equal(t, "1.5", fmt.Sprintf("%v", v4)) +} + func TestGetVar(t *testing.T) { assert.NoError(t, PrepareEngine()) diff --git a/integrations/session_update_test.go b/integrations/session_update_test.go index 45338cad..89c39359 100644 --- a/integrations/session_update_test.go +++ b/integrations/session_update_test.go @@ -5,10 +5,12 @@ package integrations import ( + "context" "fmt" "sync" "testing" "time" + "xorm.io/xorm/dialects" "github.com/stretchr/testify/assert" "xorm.io/xorm" @@ -1470,3 +1472,138 @@ func TestNilFromDB(t *testing.T) { assert.NotNil(t, tt4.Field1) assert.NotNil(t, tt4.Field1.cb) } + +func TestShadowUpdate1(t *testing.T) { + testEngine, err := xorm.NewEngine(string(schemas.MYSQL), "root:root@tcp(127.0.0.1:3306)/test?charset=utf8") + assert.NoError(t, err) + testEngine.ShowSQL(true) + _, err = testEngine.NewSession().Exec("CREATE DATABASE IF NOT EXISTS shadow_test") + testEngine.SetShadow(dialects.NewTrueShadow()) + assert.NoError(t, testEngine.Context(context.Background()).Sync(&Userinfo{})) + + _, err = testEngine.Insert(&Userinfo{ + Username: "user1", + }) + assert.NoError(t, err) + + var ori Userinfo + has, err := testEngine.Get(&ori) + assert.NoError(t, err) + assert.True(t, has) + + // update by id + user := Userinfo{Username: "xxx", Height: 1.2} + cnt, err := testEngine.ID(ori.Uid).Update(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + condi := Condi{"username": "zzz", "departname": ""} + cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user) + assert.NoError(t, err) + + total, err := testEngine.Count(&user) + assert.NoError(t, err) + assert.EqualValues(t, cnt, total) + + // nullable update + { + user := &Userinfo{Username: "not null data", Height: 180.5} + _, err := testEngine.Insert(user) + assert.NoError(t, err) + userID := user.Uid + + has, err := testEngine.ID(userID). + And("`username` = ?", user.Username). + And("`height` = ?", user.Height). + And("`departname` = ?", ""). + And("`detail_id` = ?", 0). + And("`is_man` = ?", false). + Get(&Userinfo{}) + assert.NoError(t, err) + assert.True(t, has, "cannot insert properly") + + updatedUser := &Userinfo{Username: "null data"} + cnt, err = testEngine.ID(userID). + Nullable("height", "departname", "is_man", "created"). + Update(updatedUser) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt, "update not returned 1") + + has, err = testEngine.ID(userID). + And("`username` = ?", updatedUser.Username). + And("`height` IS NULL"). + And("`departname` IS NULL"). + And("`is_man` IS NULL"). + And("`created` IS NULL"). + And("`detail_id` = ?", 0). + Get(&Userinfo{}) + assert.NoError(t, err) + assert.True(t, has, "cannot update with null properly") + + cnt, err = testEngine.ID(userID).Delete(&Userinfo{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt, "delete not returned 1") + } + + err = testEngine.StoreEngine("Innodb").Sync(&Article{}) + assert.NoError(t, err) + + defer func() { + err = testEngine.DropTables(&Article{}) + assert.NoError(t, err) + }() + + a := &Article{0, "1", "2", "3", "4", "5", 2} + cnt, err = testEngine.Insert(a) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt, fmt.Sprintf("insert not returned 1 but %d", cnt)) + assert.Greater(t, a.Id, int32(0), "insert returned id is 0") + + cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s = "test" + + col1 := &UpdateAllCols{Ptr: &s} + err = testEngine.Sync(col1) + assert.NoError(t, err) + + _, err = testEngine.Insert(col1) + assert.NoError(t, err) + + col2 := &UpdateAllCols{col1.Id, true, "", nil} + _, err = testEngine.ID(col2.Id).AllCols().Update(col2) + assert.NoError(t, err) + + col3 := &UpdateAllCols{} + has, err = testEngine.ID(col2.Id).Get(col3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, *col2, *col3) + + { + col1 := &UpdateMustCols{} + err = testEngine.Sync(col1) + assert.NoError(t, err) + + _, err = testEngine.Insert(col1) + assert.NoError(t, err) + + col2 := &UpdateMustCols{col1.Id, true, ""} + boolStr := testEngine.GetColumnMapper().Obj2Table("Bool") + stringStr := testEngine.GetColumnMapper().Obj2Table("String") + _, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2) + assert.NoError(t, err) + + col3 := &UpdateMustCols{} + has, err := testEngine.ID(col2.Id).Get(col3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, *col2, *col3) + } +} diff --git a/interface.go b/interface.go index 55ffebe4..15516d6e 100644 --- a/interface.go +++ b/interface.go @@ -123,8 +123,10 @@ type EngineInterface interface { StoreEngine(storeEngine string) *Session TableInfo(bean interface{}) (*schemas.Table, error) TableName(interface{}, ...bool) string + ContextTableName(context.Context, interface{}, ...bool) string UnMapType(reflect.Type) EnableSessionID(bool) + SetShadow(shadow dialects.Shadowable) } var ( diff --git a/internal/statements/join.go b/internal/statements/join.go index 45fc2441..910902cb 100644 --- a/internal/statements/join.go +++ b/internal/statements/join.go @@ -51,7 +51,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition)) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) default: - tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) + tbName := dialects.FullTableName(statement.ctx, statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) if !utils.IsSubQuery(tbName) { var buf strings.Builder _ = statement.dialect.Quoter().QuoteTo(&buf, tbName) diff --git a/internal/statements/statement.go b/internal/statements/statement.go index a8fe34fa..8ecc5a25 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -5,6 +5,7 @@ package statements import ( + "context" "database/sql/driver" "errors" "fmt" @@ -36,6 +37,7 @@ var ( // Statement save all the sql info for executing SQL type Statement struct { + ctx context.Context RefTable *schemas.Table dialect dialects.Dialect defaultTimeZone *time.Location @@ -82,8 +84,9 @@ type Statement struct { } // NewStatement creates a new statement -func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZone *time.Location) *Statement { +func NewStatement(ctx context.Context, dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZone *time.Location) *Statement { statement := &Statement{ + ctx: ctx, dialect: dialect, tagParser: tagParser, defaultTimeZone: defaultTimeZone, @@ -186,7 +189,8 @@ func (statement *Statement) SetRefValue(v reflect.Value) error { if err != nil { return err } - statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), v, true) + statement.tableName = dialects.FullTableName(statement.ctx, statement.dialect, + statement.tagParser.GetTableMapper(), v, true) return nil } @@ -201,7 +205,8 @@ func (statement *Statement) SetRefBean(bean interface{}) error { if err != nil { return err } - statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), bean, true) + statement.tableName = dialects.FullTableName(statement.ctx, statement.dialect, + statement.tagParser.GetTableMapper(), bean, true) return nil } @@ -280,7 +285,8 @@ func (statement *Statement) SetTable(tableNameOrBean interface{}) error { } } - statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tableNameOrBean, true) + statement.AltTableName = dialects.FullTableName(statement.ctx, statement.dialect, statement.tagParser.GetTableMapper(), + tableNameOrBean, true) return nil } diff --git a/internal/statements/statement_test.go b/internal/statements/statement_test.go index 31428efa..4d8b7f0b 100644 --- a/internal/statements/statement_test.go +++ b/internal/statements/statement_test.go @@ -5,6 +5,7 @@ package statements import ( + "context" "os" "reflect" "strings" @@ -171,7 +172,7 @@ func (TestType) TableName() string { } func createTestStatement() (*Statement, error) { - statement := NewStatement(dialect, tagParser, time.Local) + statement := NewStatement(context.Background(), dialect, tagParser, time.Local) if err := statement.SetRefValue(reflect.ValueOf(TestType{})); err != nil { return nil, err } diff --git a/session.go b/session.go index 388678cd..309be6e5 100644 --- a/session.go +++ b/session.go @@ -113,6 +113,7 @@ func newSession(engine *Engine) *Session { engine: engine, tx: nil, statement: statements.NewStatement( + ctx, engine.dialect, engine.tagParser, engine.DatabaseTZ, diff --git a/session_find.go b/session_find.go index 2270454b..1bf5dd3c 100644 --- a/session_find.go +++ b/session_find.go @@ -392,6 +392,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in statement := session.statement session.statement = statements.NewStatement( + session.ctx, session.engine.dialect, session.engine.tagParser, session.engine.DatabaseTZ, diff --git a/session_schema.go b/session_schema.go index e66c3b42..7c85e0f8 100644 --- a/session_schema.go +++ b/session_schema.go @@ -280,7 +280,11 @@ func (session *Session) Sync(beans ...interface{}) error { if len(session.statement.AltTableName) > 0 { tbName = session.statement.AltTableName } else { - tbName = engine.TableName(bean) + if session.ctx != nil { + tbName = engine.ContextTableName(session.ctx, bean) + } else { + tbName = engine.TableName(bean) + } } tbNameWithSchema := engine.tbNameWithSchema(tbName)