diff --git a/core/tx.go b/core/tx.go index a2f745f8..2cc64966 100644 --- a/core/tx.go +++ b/core/tx.go @@ -236,3 +236,7 @@ func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interf func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row { return tx.QueryRowStructContext(context.Background(), query, st) } + +func (tx *Tx) GetDB() *DB { + return tx.db +} diff --git a/dialects/ydb.go b/dialects/ydb.go index f0a33518..d5bbd1df 100644 --- a/dialects/ydb.go +++ b/dialects/ydb.go @@ -397,6 +397,11 @@ func (db *ydb) getDB(queryer interface{}) *core.DB { if internalDB, ok := queryer.(*core.DB); ok { return internalDB } + if txGetDB, ok := queryer.(interface { + GetDB() *core.DB + }); ok { + return txGetDB.GetDB() + } return nil } diff --git a/session.go b/session.go index 31af8b9a..14d0781e 100644 --- a/session.go +++ b/session.go @@ -187,9 +187,6 @@ func (session *Session) Tx() *core.Tx { } func (session *Session) getQueryer() core.Queryer { - if session.engine.dialect.URI().DBType == schemas.YDB { - return session.db() - } if session.tx != nil { return session.tx }