Don't keep db on dialects (#1623)
don't keep db on dialects Reviewed-on: https://gitea.com/xorm/xorm/pulls/1623
This commit is contained in:
parent
79cdec7d88
commit
5053c35701
|
@ -77,6 +77,10 @@ type cacheStruct struct {
|
||||||
idx int
|
idx int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ QueryExecuter = &DB{}
|
||||||
|
)
|
||||||
|
|
||||||
// DB is a wrap of sql.DB with extra contents
|
// DB is a wrap of sql.DB with extra contents
|
||||||
type DB struct {
|
type DB struct {
|
||||||
*sql.DB
|
*sql.DB
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Queryer represents an interface to query a SQL to get data from database
|
||||||
|
type Queryer interface {
|
||||||
|
QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Executer represents an interface to execute a SQL
|
||||||
|
type Executer interface {
|
||||||
|
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryExecuter combines the Queryer and Executer
|
||||||
|
type QueryExecuter interface {
|
||||||
|
Queryer
|
||||||
|
Executer
|
||||||
|
}
|
|
@ -27,7 +27,7 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
|
||||||
var i int
|
var i int
|
||||||
query = re.ReplaceAllStringFunc(query, func(src string) string {
|
query = re.ReplaceAllStringFunc(query, func(src string) string {
|
||||||
names[src[1:]] = i
|
names[src[1:]] = i
|
||||||
i += 1
|
i++
|
||||||
return "?"
|
return "?"
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,11 @@ import (
|
||||||
"xorm.io/xorm/log"
|
"xorm.io/xorm/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ QueryExecuter = &Tx{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tx represents a transaction
|
||||||
type Tx struct {
|
type Tx struct {
|
||||||
*sql.Tx
|
*sql.Tx
|
||||||
db *DB
|
db *DB
|
||||||
|
@ -50,7 +55,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
|
||||||
var i int
|
var i int
|
||||||
query = re.ReplaceAllStringFunc(query, func(src string) string {
|
query = re.ReplaceAllStringFunc(query, func(src string) string {
|
||||||
names[src[1:]] = i
|
names[src[1:]] = i
|
||||||
i += 1
|
i++
|
||||||
return "?"
|
return "?"
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -39,9 +39,8 @@ func (uri *URI) SetSchema(schema string) {
|
||||||
|
|
||||||
// Dialect represents a kind of database
|
// Dialect represents a kind of database
|
||||||
type Dialect interface {
|
type Dialect interface {
|
||||||
Init(*core.DB, *URI) error
|
Init(*URI) error
|
||||||
URI() *URI
|
URI() *URI
|
||||||
DB() *core.DB
|
|
||||||
SQLType(*schemas.Column) string
|
SQLType(*schemas.Column) string
|
||||||
FormatBytes(b []byte) string
|
FormatBytes(b []byte) string
|
||||||
DefaultSchema() string
|
DefaultSchema() string
|
||||||
|
@ -52,18 +51,18 @@ type Dialect interface {
|
||||||
|
|
||||||
AutoIncrStr() string
|
AutoIncrStr() string
|
||||||
|
|
||||||
GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error)
|
GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error)
|
||||||
IndexCheckSQL(tableName, idxName string) (string, []interface{})
|
IndexCheckSQL(tableName, idxName string) (string, []interface{})
|
||||||
CreateIndexSQL(tableName string, index *schemas.Index) string
|
CreateIndexSQL(tableName string, index *schemas.Index) string
|
||||||
DropIndexSQL(tableName string, index *schemas.Index) string
|
DropIndexSQL(tableName string, index *schemas.Index) string
|
||||||
|
|
||||||
GetTables(ctx context.Context) ([]*schemas.Table, error)
|
GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error)
|
||||||
IsTableExist(ctx context.Context, tableName string) (bool, error)
|
IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error)
|
||||||
CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool)
|
CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool)
|
||||||
DropTableSQL(tableName string) (string, bool)
|
DropTableSQL(tableName string) (string, bool)
|
||||||
|
|
||||||
GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
|
GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
|
||||||
IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error)
|
IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error)
|
||||||
AddColumnSQL(tableName string, col *schemas.Column) string
|
AddColumnSQL(tableName string, col *schemas.Column) string
|
||||||
ModifyColumnSQL(tableName string, col *schemas.Column) string
|
ModifyColumnSQL(tableName string, col *schemas.Column) string
|
||||||
|
|
||||||
|
@ -75,7 +74,6 @@ type Dialect interface {
|
||||||
|
|
||||||
// Base represents a basic dialect and all real dialects could embed this struct
|
// Base represents a basic dialect and all real dialects could embed this struct
|
||||||
type Base struct {
|
type Base struct {
|
||||||
db *core.DB
|
|
||||||
dialect Dialect
|
dialect Dialect
|
||||||
uri *URI
|
uri *URI
|
||||||
quoter schemas.Quoter
|
quoter schemas.Quoter
|
||||||
|
@ -85,16 +83,12 @@ func (b *Base) Quoter() schemas.Quoter {
|
||||||
return b.quoter
|
return b.quoter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Base) DB() *core.DB {
|
|
||||||
return b.db
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Base) DefaultSchema() string {
|
func (b *Base) DefaultSchema() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI) error {
|
func (b *Base) Init(dialect Dialect, uri *URI) error {
|
||||||
b.db, b.dialect, b.uri = db, dialect, uri
|
b.dialect, b.uri = dialect, uri
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,8 +154,8 @@ func (db *Base) DropTableSQL(tableName string) (string, bool) {
|
||||||
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true
|
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *Base) HasRecords(ctx context.Context, query string, args ...interface{}) (bool, error) {
|
func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...interface{}) (bool, error) {
|
||||||
rows, err := db.DB().QueryContext(ctx, query, args...)
|
rows, err := queryer.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -173,7 +167,7 @@ func (db *Base) HasRecords(ctx context.Context, query string, args ...interface{
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *Base) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) {
|
func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
|
||||||
quote := db.dialect.Quoter().Quote
|
quote := db.dialect.Quoter().Quote
|
||||||
query := fmt.Sprintf(
|
query := fmt.Sprintf(
|
||||||
"SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?",
|
"SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?",
|
||||||
|
@ -184,7 +178,7 @@ func (db *Base) IsColumnExist(ctx context.Context, tableName, colName string) (b
|
||||||
quote("TABLE_NAME"),
|
quote("TABLE_NAME"),
|
||||||
quote("COLUMN_NAME"),
|
quote("COLUMN_NAME"),
|
||||||
)
|
)
|
||||||
return db.HasRecords(ctx, query, db.uri.DBName, tableName, colName)
|
return db.HasRecords(queryer, ctx, query, db.uri.DBName, tableName, colName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
|
func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
|
||||||
|
|
|
@ -6,8 +6,6 @@ package dialects
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"xorm.io/xorm/core"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Driver interface {
|
type Driver interface {
|
||||||
|
@ -53,11 +51,7 @@ func OpenDialect(driverName, connstr string) (Dialect, error) {
|
||||||
return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType)
|
return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := core.Open(driverName, connstr)
|
dialect.Init(uri)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
dialect.Init(db, uri)
|
|
||||||
|
|
||||||
return dialect, nil
|
return dialect, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -212,9 +212,9 @@ type mssql struct {
|
||||||
Base
|
Base
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mssql) Init(d *core.DB, uri *URI) error {
|
func (db *mssql) Init(uri *URI) error {
|
||||||
db.quoter = mssqlQuoter
|
db.quoter = mssqlQuoter
|
||||||
return db.Base.Init(d, db, uri)
|
return db.Base.Init(db, uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mssql) SQLType(c *schemas.Column) string {
|
func (db *mssql) SQLType(c *schemas.Column) string {
|
||||||
|
@ -319,18 +319,18 @@ func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}
|
||||||
return sql, args
|
return sql, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mssql) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) {
|
func (db *mssql) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
|
||||||
query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?`
|
query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?`
|
||||||
|
|
||||||
return db.HasRecords(ctx, query, tableName, colName)
|
return db.HasRecords(queryer, ctx, query, tableName, colName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mssql) IsTableExist(ctx context.Context, tableName string) (bool, error) {
|
func (db *mssql) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
|
||||||
sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1"
|
sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1"
|
||||||
return db.HasRecords(ctx, sql)
|
return db.HasRecords(queryer, ctx, sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable,
|
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable,
|
||||||
"default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END),
|
"default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END),
|
||||||
|
@ -346,7 +346,7 @@ func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, ma
|
||||||
) as p on p.object_id = a.object_id AND p.column_id = a.column_id
|
) as p on p.object_id = a.object_id AND p.column_id = a.column_id
|
||||||
where a.object_id=object_id('` + tableName + `')`
|
where a.object_id=object_id('` + tableName + `')`
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -401,11 +401,11 @@ func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, ma
|
||||||
return colSeq, cols, nil
|
return colSeq, cols, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mssql) GetTables(ctx context.Context) ([]*schemas.Table, error) {
|
func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
s := `select name from sysobjects where xtype ='U'`
|
s := `select name from sysobjects where xtype ='U'`
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -425,7 +425,7 @@ func (db *mssql) GetTables(ctx context.Context) ([]*schemas.Table, error) {
|
||||||
return tables, nil
|
return tables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mssql) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
func (db *mssql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := `SELECT
|
s := `SELECT
|
||||||
IXS.NAME AS [INDEX_NAME],
|
IXS.NAME AS [INDEX_NAME],
|
||||||
|
@ -439,7 +439,7 @@ AND IXCS.COLUMN_ID=C.COLUMN_ID
|
||||||
WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
|
WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -179,9 +179,9 @@ type mysql struct {
|
||||||
rowFormat string
|
rowFormat string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) Init(d *core.DB, uri *URI) error {
|
func (db *mysql) Init(uri *URI) error {
|
||||||
db.quoter = mysqlQuoter
|
db.quoter = mysqlQuoter
|
||||||
return db.Base.Init(d, db, uri)
|
return db.Base.Init(db, uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) SetParams(params map[string]string) {
|
func (db *mysql) SetParams(params map[string]string) {
|
||||||
|
@ -286,9 +286,9 @@ func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{}
|
||||||
return sql, args
|
return sql, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) IsTableExist(ctx context.Context, tableName string) (bool, error) {
|
func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
|
||||||
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
|
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
|
||||||
return db.HasRecords(ctx, sql, db.uri.DBName, tableName)
|
return db.HasRecords(queryer, ctx, sql, db.uri.DBName, tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
|
func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
|
||||||
|
@ -301,12 +301,12 @@ func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
|
||||||
return sql
|
return sql
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||||
args := []interface{}{db.uri.DBName, tableName}
|
args := []interface{}{db.uri.DBName, tableName}
|
||||||
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
|
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
|
||||||
" `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
|
" `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -411,12 +411,12 @@ func (db *mysql) GetColumns(ctx context.Context, tableName string) ([]string, ma
|
||||||
return colSeq, cols, nil
|
return colSeq, cols, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) GetTables(ctx context.Context) ([]*schemas.Table, error) {
|
func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
||||||
args := []interface{}{db.uri.DBName}
|
args := []interface{}{db.uri.DBName}
|
||||||
s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " +
|
s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " +
|
||||||
"`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')"
|
"`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')"
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -459,11 +459,11 @@ func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
||||||
args := []interface{}{db.uri.DBName, tableName}
|
args := []interface{}{db.uri.DBName, tableName}
|
||||||
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
|
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -506,9 +506,9 @@ type oracle struct {
|
||||||
Base
|
Base
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *oracle) Init(d *core.DB, uri *URI) error {
|
func (db *oracle) Init(uri *URI) error {
|
||||||
db.quoter = oracleQuoter
|
db.quoter = oracleQuoter
|
||||||
return db.Base.Init(d, db, uri)
|
return db.Base.Init(db, uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *oracle) SQLType(c *schemas.Column) string {
|
func (db *oracle) SQLType(c *schemas.Column) string {
|
||||||
|
@ -611,23 +611,23 @@ func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []interface{
|
||||||
`WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args
|
`WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *oracle) IsTableExist(ctx context.Context, tableName string) (bool, error) {
|
func (db *oracle) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
|
||||||
return db.HasRecords(ctx, `SELECT table_name FROM user_tables WHERE table_name = :1`, tableName)
|
return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = :1`, tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *oracle) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) {
|
func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
|
||||||
args := []interface{}{tableName, colName}
|
args := []interface{}{tableName, colName}
|
||||||
query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" +
|
query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" +
|
||||||
" AND column_name = :2"
|
" AND column_name = :2"
|
||||||
return db.HasRecords(ctx, query, args...)
|
return db.HasRecords(queryer, ctx, query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *oracle) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
|
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
|
||||||
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
|
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -719,11 +719,11 @@ func (db *oracle) GetColumns(ctx context.Context, tableName string) ([]string, m
|
||||||
return colSeq, cols, nil
|
return colSeq, cols, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *oracle) GetTables(ctx context.Context) ([]*schemas.Table, error) {
|
func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
s := "SELECT table_name FROM user_tables"
|
s := "SELECT table_name FROM user_tables"
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -742,12 +742,12 @@ func (db *oracle) GetTables(ctx context.Context) ([]*schemas.Table, error) {
|
||||||
return tables, nil
|
return tables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *oracle) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " +
|
s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " +
|
||||||
"WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1"
|
"WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1"
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -776,9 +776,9 @@ type postgres struct {
|
||||||
Base
|
Base
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) Init(d *core.DB, uri *URI) error {
|
func (db *postgres) Init(uri *URI) error {
|
||||||
db.quoter = postgresQuoter
|
db.quoter = postgresQuoter
|
||||||
err := db.Base.Init(d, db, uri)
|
err := db.Base.Init(db, uri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -942,12 +942,12 @@ func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interfac
|
||||||
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
|
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) IsTableExist(ctx context.Context, tableName string) (bool, error) {
|
func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
|
||||||
if len(db.uri.Schema) == 0 {
|
if len(db.uri.Schema) == 0 {
|
||||||
return db.HasRecords(ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName)
|
return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.HasRecords(ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`,
|
return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`,
|
||||||
db.uri.Schema, tableName)
|
db.uri.Schema, tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -980,7 +980,7 @@ func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string
|
||||||
return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName))
|
return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) {
|
func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
|
||||||
args := []interface{}{db.uri.Schema, tableName, colName}
|
args := []interface{}{db.uri.Schema, tableName, colName}
|
||||||
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
|
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
|
||||||
" AND column_name = $3"
|
" AND column_name = $3"
|
||||||
|
@ -990,7 +990,7 @@ func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string
|
||||||
" AND column_name = $2"
|
" AND column_name = $2"
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, query, args...)
|
rows, err := queryer.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -999,7 +999,7 @@ func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string
|
||||||
return rows.Next(), nil
|
return rows.Next(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||||
args := []interface{}{db.uri.Schema, tableName, db.uri.Schema}
|
args := []interface{}{db.uri.Schema, tableName, db.uri.Schema}
|
||||||
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
|
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
|
||||||
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
|
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
|
||||||
|
@ -1013,7 +1013,7 @@ FROM pg_attribute f
|
||||||
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
|
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
|
||||||
WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_schema = $3 AND f.attnum > 0 ORDER BY f.attnum;`
|
WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_schema = $3 AND f.attnum > 0 ORDER BY f.attnum;`
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -1132,7 +1132,7 @@ WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_sch
|
||||||
return colSeq, cols, nil
|
return colSeq, cols, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) GetTables(ctx context.Context) ([]*schemas.Table, error) {
|
func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
s := "SELECT tablename FROM pg_tables"
|
s := "SELECT tablename FROM pg_tables"
|
||||||
if len(db.uri.Schema) != 0 {
|
if len(db.uri.Schema) != 0 {
|
||||||
|
@ -1140,7 +1140,7 @@ func (db *postgres) GetTables(ctx context.Context) ([]*schemas.Table, error) {
|
||||||
s = s + " WHERE schemaname = $1"
|
s = s + " WHERE schemaname = $1"
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1171,7 +1171,7 @@ func getIndexColName(indexdef string) []string {
|
||||||
return colNames
|
return colNames
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
|
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
|
||||||
if len(db.uri.Schema) != 0 {
|
if len(db.uri.Schema) != 0 {
|
||||||
|
@ -1179,7 +1179,7 @@ func (db *postgres) GetIndexes(ctx context.Context, tableName string) (map[strin
|
||||||
s = s + " AND schemaname=$2"
|
s = s + " AND schemaname=$2"
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -151,9 +151,9 @@ type sqlite3 struct {
|
||||||
Base
|
Base
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) Init(d *core.DB, uri *URI) error {
|
func (db *sqlite3) Init(uri *URI) error {
|
||||||
db.quoter = sqlite3Quoter
|
db.quoter = sqlite3Quoter
|
||||||
return db.Base.Init(d, db, uri)
|
return db.Base.Init(db, uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) {
|
func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) {
|
||||||
|
@ -225,8 +225,8 @@ func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []interface
|
||||||
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
|
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) IsTableExist(ctx context.Context, tableName string) (bool, error) {
|
func (db *sqlite3) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
|
||||||
return db.HasRecords(ctx, "SELECT name FROM sqlite_master WHERE type='table' and name = ?", tableName)
|
return db.HasRecords(queryer, ctx, "SELECT name FROM sqlite_master WHERE type='table' and name = ?", tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
|
func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
|
||||||
|
@ -286,9 +286,9 @@ func (db *sqlite3) ForUpdateSQL(query string) string {
|
||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) {
|
func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
|
||||||
query := "SELECT * FROM " + tableName + " LIMIT 0"
|
query := "SELECT * FROM " + tableName + " LIMIT 0"
|
||||||
rows, err := db.DB().QueryContext(ctx, query)
|
rows, err := queryer.QueryContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -370,11 +370,11 @@ func parseString(colStr string) (*schemas.Column, error) {
|
||||||
return col, nil
|
return col, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
|
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -427,11 +427,11 @@ func (db *sqlite3) GetColumns(ctx context.Context, tableName string) ([]string,
|
||||||
return colSeq, cols, nil
|
return colSeq, cols, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) GetTables(ctx context.Context) ([]*schemas.Table, error) {
|
func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
s := "SELECT name FROM sqlite_master WHERE type='table'"
|
s := "SELECT name FROM sqlite_master WHERE type='table'"
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -452,11 +452,11 @@ func (db *sqlite3) GetTables(ctx context.Context) ([]*schemas.Table, error) {
|
||||||
return tables, nil
|
return tables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
|
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
|
||||||
|
|
||||||
rows, err := db.DB().QueryContext(ctx, s, args...)
|
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
13
engine.go
13
engine.go
|
@ -35,6 +35,7 @@ type Engine struct {
|
||||||
engineGroup *EngineGroup
|
engineGroup *EngineGroup
|
||||||
logger log.ContextLogger
|
logger log.ContextLogger
|
||||||
tagParser *tags.Parser
|
tagParser *tags.Parser
|
||||||
|
db *core.DB
|
||||||
|
|
||||||
driverName string
|
driverName string
|
||||||
dataSourceName string
|
dataSourceName string
|
||||||
|
@ -211,7 +212,7 @@ func (engine *Engine) NewDB() (*core.DB, error) {
|
||||||
|
|
||||||
// DB return the wrapper of sql.DB
|
// DB return the wrapper of sql.DB
|
||||||
func (engine *Engine) DB() *core.DB {
|
func (engine *Engine) DB() *core.DB {
|
||||||
return engine.dialect.DB()
|
return engine.db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dialect return database dialect
|
// Dialect return database dialect
|
||||||
|
@ -267,14 +268,14 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (engine *Engine) loadTableInfo(table *schemas.Table) error {
|
func (engine *Engine) loadTableInfo(table *schemas.Table) error {
|
||||||
colSeq, cols, err := engine.dialect.GetColumns(engine.defaultContext, table.Name)
|
colSeq, cols, err := engine.dialect.GetColumns(engine.db, engine.defaultContext, table.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, name := range colSeq {
|
for _, name := range colSeq {
|
||||||
table.AddColumn(cols[name])
|
table.AddColumn(cols[name])
|
||||||
}
|
}
|
||||||
indexes, err := engine.dialect.GetIndexes(engine.defaultContext, table.Name)
|
indexes, err := engine.dialect.GetIndexes(engine.db, engine.defaultContext, table.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -301,7 +302,7 @@ func (engine *Engine) loadTableInfo(table *schemas.Table) error {
|
||||||
|
|
||||||
// DBMetas Retrieve all tables, columns, indexes' informations from database.
|
// DBMetas Retrieve all tables, columns, indexes' informations from database.
|
||||||
func (engine *Engine) DBMetas() ([]*schemas.Table, error) {
|
func (engine *Engine) DBMetas() ([]*schemas.Table, error) {
|
||||||
tables, err := engine.dialect.GetTables(engine.defaultContext)
|
tables, err := engine.dialect.GetTables(engine.db, engine.defaultContext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -361,7 +362,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
||||||
|
|
||||||
uri := engine.dialect.URI()
|
uri := engine.dialect.URI()
|
||||||
destURI := *uri
|
destURI := *uri
|
||||||
dstDialect.Init(nil, &destURI)
|
dstDialect.Init(&destURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n",
|
_, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n",
|
||||||
|
@ -911,7 +912,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, col := range table.Columns() {
|
for _, col := range table.Columns() {
|
||||||
isExist, err := engine.dialect.IsColumnExist(session.ctx, tableNameNoSchema, col.Name)
|
isExist, err := engine.dialect.IsColumnExist(engine.db, session.ctx, tableNameNoSchema, col.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -161,17 +161,17 @@ func (eg *EngineGroup) SetMapper(mapper names.Mapper) {
|
||||||
|
|
||||||
// SetMaxIdleConns set the max idle connections on pool, default is 2
|
// SetMaxIdleConns set the max idle connections on pool, default is 2
|
||||||
func (eg *EngineGroup) SetMaxIdleConns(conns int) {
|
func (eg *EngineGroup) SetMaxIdleConns(conns int) {
|
||||||
eg.Engine.dialect.DB().SetMaxIdleConns(conns)
|
eg.Engine.DB().SetMaxIdleConns(conns)
|
||||||
for i := 0; i < len(eg.slaves); i++ {
|
for i := 0; i < len(eg.slaves); i++ {
|
||||||
eg.slaves[i].dialect.DB().SetMaxIdleConns(conns)
|
eg.slaves[i].DB().SetMaxIdleConns(conns)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMaxOpenConns is only available for go 1.2+
|
// SetMaxOpenConns is only available for go 1.2+
|
||||||
func (eg *EngineGroup) SetMaxOpenConns(conns int) {
|
func (eg *EngineGroup) SetMaxOpenConns(conns int) {
|
||||||
eg.Engine.dialect.DB().SetMaxOpenConns(conns)
|
eg.Engine.DB().SetMaxOpenConns(conns)
|
||||||
for i := 0; i < len(eg.slaves); i++ {
|
for i := 0; i < len(eg.slaves); i++ {
|
||||||
eg.slaves[i].dialect.DB().SetMaxOpenConns(conns)
|
eg.slaves[i].DB().SetMaxOpenConns(conns)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -99,7 +99,7 @@ func (session *Session) Init() {
|
||||||
session.engine.tagParser,
|
session.engine.tagParser,
|
||||||
session.engine.DatabaseTZ,
|
session.engine.DatabaseTZ,
|
||||||
)
|
)
|
||||||
|
session.db = session.engine.db
|
||||||
session.isAutoCommit = true
|
session.isAutoCommit = true
|
||||||
session.isCommitedOrRollbacked = false
|
session.isCommitedOrRollbacked = false
|
||||||
session.isAutoClose = false
|
session.isAutoClose = false
|
||||||
|
@ -140,6 +140,13 @@ func (session *Session) Close() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (session *Session) getQueryer() core.Queryer {
|
||||||
|
if session.tx != nil {
|
||||||
|
return session.tx
|
||||||
|
}
|
||||||
|
return session.db
|
||||||
|
}
|
||||||
|
|
||||||
// ContextCache enable context cache or not
|
// ContextCache enable context cache or not
|
||||||
func (session *Session) ContextCache(context contexts.ContextCache) *Session {
|
func (session *Session) ContextCache(context contexts.ContextCache) *Session {
|
||||||
session.statement.SetContextCache(context)
|
session.statement.SetContextCache(context)
|
||||||
|
|
|
@ -134,7 +134,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
|
||||||
tableName := session.engine.TableName(beanOrTableName)
|
tableName := session.engine.TableName(beanOrTableName)
|
||||||
sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true))
|
sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true))
|
||||||
if !checkIfExist {
|
if !checkIfExist {
|
||||||
exist, err := session.engine.dialect.IsTableExist(session.ctx, tableName)
|
exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -160,7 +160,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *Session) isTableExist(tableName string) (bool, error) {
|
func (session *Session) isTableExist(tableName string) (bool, error) {
|
||||||
return session.engine.dialect.IsTableExist(session.ctx, tableName)
|
return session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsTableEmpty if table have any records
|
// IsTableEmpty if table have any records
|
||||||
|
@ -187,7 +187,7 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) {
|
||||||
|
|
||||||
// find if index is exist according cols
|
// find if index is exist according cols
|
||||||
func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) {
|
func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) {
|
||||||
indexes, err := session.engine.dialect.GetIndexes(session.ctx, tableName)
|
indexes, err := session.engine.dialect.GetIndexes(session.getQueryer(), session.ctx, tableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -233,7 +233,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
tables, err := engine.dialect.GetTables(session.ctx)
|
tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
7
xorm.go
7
xorm.go
|
@ -13,6 +13,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"xorm.io/xorm/caches"
|
"xorm.io/xorm/caches"
|
||||||
|
"xorm.io/xorm/core"
|
||||||
"xorm.io/xorm/dialects"
|
"xorm.io/xorm/dialects"
|
||||||
"xorm.io/xorm/log"
|
"xorm.io/xorm/log"
|
||||||
"xorm.io/xorm/names"
|
"xorm.io/xorm/names"
|
||||||
|
@ -32,6 +33,11 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db, err := core.Open(driverName, dataSourceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
cacherMgr := caches.NewManager()
|
cacherMgr := caches.NewManager()
|
||||||
mapper := names.NewCacheMapper(new(names.SnakeMapper))
|
mapper := names.NewCacheMapper(new(names.SnakeMapper))
|
||||||
tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr)
|
tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr)
|
||||||
|
@ -44,6 +50,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
|
||||||
tagParser: tagParser,
|
tagParser: tagParser,
|
||||||
driverName: driverName,
|
driverName: driverName,
|
||||||
dataSourceName: dataSourceName,
|
dataSourceName: dataSourceName,
|
||||||
|
db: db,
|
||||||
}
|
}
|
||||||
|
|
||||||
if dialect.URI().DBType == schemas.SQLITE {
|
if dialect.URI().DBType == schemas.SQLITE {
|
||||||
|
|
Loading…
Reference in New Issue