diff --git a/core/db.go b/core/db.go index 9aa771ba..6845565d 100644 --- a/core/db.go +++ b/core/db.go @@ -77,6 +77,10 @@ type cacheStruct struct { idx int } +var ( + _ QueryExecuter = &DB{} +) + // DB is a wrap of sql.DB with extra contents type DB struct { *sql.DB diff --git a/core/interface.go b/core/interface.go new file mode 100644 index 00000000..a5c8e4e2 --- /dev/null +++ b/core/interface.go @@ -0,0 +1,22 @@ +package core + +import ( + "context" + "database/sql" +) + +// Queryer represents an interface to query a SQL to get data from database +type Queryer interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) +} + +// Executer represents an interface to execute a SQL +type Executer interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +// QueryExecuter combines the Queryer and Executer +type QueryExecuter interface { + Queryer + Executer +} diff --git a/core/stmt.go b/core/stmt.go index 9d5954bd..4b1c7605 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -27,7 +27,7 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { var i int query = re.ReplaceAllStringFunc(query, func(src string) string { names[src[1:]] = i - i += 1 + i++ return "?" }) diff --git a/core/tx.go b/core/tx.go index 07713267..99a8097d 100644 --- a/core/tx.go +++ b/core/tx.go @@ -12,6 +12,11 @@ import ( "xorm.io/xorm/log" ) +var ( + _ QueryExecuter = &Tx{} +) + +// Tx represents a transaction type Tx struct { *sql.Tx db *DB @@ -50,7 +55,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { var i int query = re.ReplaceAllStringFunc(query, func(src string) string { names[src[1:]] = i - i += 1 + i++ return "?" }) diff --git a/dialects/dialect.go b/dialects/dialect.go index 4fdf35e9..a3328e05 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -39,9 +39,8 @@ func (uri *URI) SetSchema(schema string) { // Dialect represents a kind of database type Dialect interface { - Init(*core.DB, *URI) error + Init(*URI) error URI() *URI - DB() *core.DB SQLType(*schemas.Column) string FormatBytes(b []byte) string DefaultSchema() string @@ -52,18 +51,18 @@ type Dialect interface { AutoIncrStr() string - GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) + GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) IndexCheckSQL(tableName, idxName string) (string, []interface{}) CreateIndexSQL(tableName string, index *schemas.Index) string DropIndexSQL(tableName string, index *schemas.Index) string - GetTables(ctx context.Context) ([]*schemas.Table, error) - IsTableExist(ctx context.Context, tableName string) (bool, error) + GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) + IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) DropTableSQL(tableName string) (string, bool) - GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) - IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error) + GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) + IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error) AddColumnSQL(tableName string, col *schemas.Column) string ModifyColumnSQL(tableName string, col *schemas.Column) string @@ -75,7 +74,6 @@ type Dialect interface { // Base represents a basic dialect and all real dialects could embed this struct type Base struct { - db *core.DB dialect Dialect uri *URI quoter schemas.Quoter @@ -85,16 +83,12 @@ func (b *Base) Quoter() schemas.Quoter { return b.quoter } -func (b *Base) DB() *core.DB { - return b.db -} - func (b *Base) DefaultSchema() string { return "" } -func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI) error { - b.db, b.dialect, b.uri = db, dialect, uri +func (b *Base) Init(dialect Dialect, uri *URI) error { + b.dialect, b.uri = dialect, uri return nil } @@ -160,8 +154,8 @@ func (db *Base) DropTableSQL(tableName string) (string, bool) { return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true } -func (db *Base) HasRecords(ctx context.Context, query string, args ...interface{}) (bool, error) { - rows, err := db.DB().QueryContext(ctx, query, args...) +func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...interface{}) (bool, error) { + rows, err := queryer.QueryContext(ctx, query, args...) if err != nil { return false, err } @@ -173,7 +167,7 @@ func (db *Base) HasRecords(ctx context.Context, query string, args ...interface{ return false, nil } -func (db *Base) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { +func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { quote := db.dialect.Quoter().Quote query := fmt.Sprintf( "SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?", @@ -184,7 +178,7 @@ func (db *Base) IsColumnExist(ctx context.Context, tableName, colName string) (b quote("TABLE_NAME"), quote("COLUMN_NAME"), ) - return db.HasRecords(ctx, query, db.uri.DBName, tableName, colName) + return db.HasRecords(queryer, ctx, query, db.uri.DBName, tableName, colName) } func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { diff --git a/dialects/driver.go b/dialects/driver.go index 89d21bfc..ae3afe42 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -6,8 +6,6 @@ package dialects import ( "fmt" - - "xorm.io/xorm/core" ) type Driver interface { @@ -53,11 +51,7 @@ func OpenDialect(driverName, connstr string) (Dialect, error) { return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType) } - db, err := core.Open(driverName, connstr) - if err != nil { - return nil, err - } - dialect.Init(db, uri) + dialect.Init(uri) return dialect, nil } diff --git a/dialects/mssql.go b/dialects/mssql.go index dd3f4247..92457ff9 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -212,9 +212,9 @@ type mssql struct { Base } -func (db *mssql) Init(d *core.DB, uri *URI) error { +func (db *mssql) Init(uri *URI) error { db.quoter = mssqlQuoter - return db.Base.Init(d, db, uri) + return db.Base.Init(db, uri) } func (db *mssql) SQLType(c *schemas.Column) string { @@ -319,18 +319,18 @@ func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{} return sql, args } -func (db *mssql) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { +func (db *mssql) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?` - return db.HasRecords(ctx, query, tableName, colName) + return db.HasRecords(queryer, ctx, query, tableName, colName) } -func (db *mssql) IsTableExist(ctx context.Context, tableName string) (bool, error) { +func (db *mssql) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1" - return db.HasRecords(ctx, sql) + return db.HasRecords(queryer, ctx, sql) } -func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{} s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), @@ -346,7 +346,7 @@ func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, ma ) as p on p.object_id = a.object_id AND p.column_id = a.column_id where a.object_id=object_id('` + tableName + `')` - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -401,11 +401,11 @@ func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, ma return colSeq, cols, nil } -func (db *mssql) GetTables(ctx context.Context) ([]*schemas.Table, error) { +func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := `select name from sysobjects where xtype ='U'` - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -425,7 +425,7 @@ func (db *mssql) GetTables(ctx context.Context) ([]*schemas.Table, error) { return tables, nil } -func (db *mssql) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { +func (db *mssql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := `SELECT IXS.NAME AS [INDEX_NAME], @@ -439,7 +439,7 @@ AND IXCS.COLUMN_ID=C.COLUMN_ID WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? ` - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } diff --git a/dialects/mysql.go b/dialects/mysql.go index b7598680..289e0ec7 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -179,9 +179,9 @@ type mysql struct { rowFormat string } -func (db *mysql) Init(d *core.DB, uri *URI) error { +func (db *mysql) Init(uri *URI) error { db.quoter = mysqlQuoter - return db.Base.Init(d, db, uri) + return db.Base.Init(db, uri) } func (db *mysql) SetParams(params map[string]string) { @@ -286,9 +286,9 @@ func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{} return sql, args } -func (db *mysql) IsTableExist(ctx context.Context, tableName string) (bool, error) { +func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" - return db.HasRecords(ctx, sql, db.uri.DBName, tableName) + return db.HasRecords(queryer, ctx, sql, db.uri.DBName, tableName) } func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { @@ -301,12 +301,12 @@ func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { return sql } -func (db *mysql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{db.uri.DBName, tableName} s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + " `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -411,12 +411,12 @@ func (db *mysql) GetColumns(ctx context.Context, tableName string) ([]string, ma return colSeq, cols, nil } -func (db *mysql) GetTables(ctx context.Context) ([]*schemas.Table, error) { +func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{db.uri.DBName} s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -459,11 +459,11 @@ func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) { } } -func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { +func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{db.uri.DBName, tableName} s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } diff --git a/dialects/oracle.go b/dialects/oracle.go index 466b6a45..db823e95 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -506,9 +506,9 @@ type oracle struct { Base } -func (db *oracle) Init(d *core.DB, uri *URI) error { +func (db *oracle) Init(uri *URI) error { db.quoter = oracleQuoter - return db.Base.Init(d, db, uri) + return db.Base.Init(db, uri) } func (db *oracle) SQLType(c *schemas.Column) string { @@ -611,23 +611,23 @@ func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []interface{ `WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args } -func (db *oracle) IsTableExist(ctx context.Context, tableName string) (bool, error) { - return db.HasRecords(ctx, `SELECT table_name FROM user_tables WHERE table_name = :1`, tableName) +func (db *oracle) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { + return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = :1`, tableName) } -func (db *oracle) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { +func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { args := []interface{}{tableName, colName} query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" + " AND column_name = :2" - return db.HasRecords(ctx, query, args...) + return db.HasRecords(queryer, ctx, query, args...) } -func (db *oracle) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -719,11 +719,11 @@ func (db *oracle) GetColumns(ctx context.Context, tableName string) ([]string, m return colSeq, cols, nil } -func (db *oracle) GetTables(ctx context.Context) ([]*schemas.Table, error) { +func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT table_name FROM user_tables" - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -742,12 +742,12 @@ func (db *oracle) GetTables(ctx context.Context) ([]*schemas.Table, error) { return tables, nil } -func (db *oracle) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { +func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1" - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } diff --git a/dialects/postgres.go b/dialects/postgres.go index 0a851fe2..a83c3a5c 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -776,9 +776,9 @@ type postgres struct { Base } -func (db *postgres) Init(d *core.DB, uri *URI) error { +func (db *postgres) Init(uri *URI) error { db.quoter = postgresQuoter - err := db.Base.Init(d, db, uri) + err := db.Base.Init(db, uri) if err != nil { return err } @@ -942,12 +942,12 @@ func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interfac `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } -func (db *postgres) IsTableExist(ctx context.Context, tableName string) (bool, error) { +func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { if len(db.uri.Schema) == 0 { - return db.HasRecords(ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName) + return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName) } - return db.HasRecords(ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`, + return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`, db.uri.Schema, tableName) } @@ -980,7 +980,7 @@ func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { +func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { args := []interface{}{db.uri.Schema, tableName, colName} query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + " AND column_name = $3" @@ -990,7 +990,7 @@ func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string " AND column_name = $2" } - rows, err := db.DB().QueryContext(ctx, query, args...) + rows, err := queryer.QueryContext(ctx, query, args...) if err != nil { return false, err } @@ -999,7 +999,7 @@ func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string return rows.Next(), nil } -func (db *postgres) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{db.uri.Schema, tableName, db.uri.Schema} s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, @@ -1013,7 +1013,7 @@ FROM pg_attribute f LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_schema = $3 AND f.attnum > 0 ORDER BY f.attnum;` - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -1132,7 +1132,7 @@ WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_sch return colSeq, cols, nil } -func (db *postgres) GetTables(ctx context.Context) ([]*schemas.Table, error) { +func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT tablename FROM pg_tables" if len(db.uri.Schema) != 0 { @@ -1140,7 +1140,7 @@ func (db *postgres) GetTables(ctx context.Context) ([]*schemas.Table, error) { s = s + " WHERE schemaname = $1" } - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -1171,7 +1171,7 @@ func getIndexColName(indexdef string) []string { return colNames } -func (db *postgres) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { +func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") if len(db.uri.Schema) != 0 { @@ -1179,7 +1179,7 @@ func (db *postgres) GetIndexes(ctx context.Context, tableName string) (map[strin s = s + " AND schemaname=$2" } - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 710babe6..0e95ebc7 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -151,9 +151,9 @@ type sqlite3 struct { Base } -func (db *sqlite3) Init(d *core.DB, uri *URI) error { +func (db *sqlite3) Init(uri *URI) error { db.quoter = sqlite3Quoter - return db.Base.Init(d, db, uri) + return db.Base.Init(db, uri) } func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { @@ -225,8 +225,8 @@ func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []interface return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args } -func (db *sqlite3) IsTableExist(ctx context.Context, tableName string) (bool, error) { - return db.HasRecords(ctx, "SELECT name FROM sqlite_master WHERE type='table' and name = ?", tableName) +func (db *sqlite3) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { + return db.HasRecords(queryer, ctx, "SELECT name FROM sqlite_master WHERE type='table' and name = ?", tableName) } func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { @@ -286,9 +286,9 @@ func (db *sqlite3) ForUpdateSQL(query string) string { return query } -func (db *sqlite3) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { +func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { query := "SELECT * FROM " + tableName + " LIMIT 0" - rows, err := db.DB().QueryContext(ctx, query) + rows, err := queryer.QueryContext(ctx, query) if err != nil { return false, err } @@ -370,11 +370,11 @@ func parseString(colStr string) (*schemas.Column, error) { return col, nil } -func (db *sqlite3) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -427,11 +427,11 @@ func (db *sqlite3) GetColumns(ctx context.Context, tableName string) ([]string, return colSeq, cols, nil } -func (db *sqlite3) GetTables(ctx context.Context) ([]*schemas.Table, error) { +func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT name FROM sqlite_master WHERE type='table'" - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -452,11 +452,11 @@ func (db *sqlite3) GetTables(ctx context.Context) ([]*schemas.Table, error) { return tables, nil } -func (db *sqlite3) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { +func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" - rows, err := db.DB().QueryContext(ctx, s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } diff --git a/engine.go b/engine.go index 4694e1c0..dd34fe1b 100644 --- a/engine.go +++ b/engine.go @@ -35,6 +35,7 @@ type Engine struct { engineGroup *EngineGroup logger log.ContextLogger tagParser *tags.Parser + db *core.DB driverName string dataSourceName string @@ -211,7 +212,7 @@ func (engine *Engine) NewDB() (*core.DB, error) { // DB return the wrapper of sql.DB func (engine *Engine) DB() *core.DB { - return engine.dialect.DB() + return engine.db } // Dialect return database dialect @@ -267,14 +268,14 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session { } func (engine *Engine) loadTableInfo(table *schemas.Table) error { - colSeq, cols, err := engine.dialect.GetColumns(engine.defaultContext, table.Name) + colSeq, cols, err := engine.dialect.GetColumns(engine.db, engine.defaultContext, table.Name) if err != nil { return err } for _, name := range colSeq { table.AddColumn(cols[name]) } - indexes, err := engine.dialect.GetIndexes(engine.defaultContext, table.Name) + indexes, err := engine.dialect.GetIndexes(engine.db, engine.defaultContext, table.Name) if err != nil { return err } @@ -301,7 +302,7 @@ func (engine *Engine) loadTableInfo(table *schemas.Table) error { // DBMetas Retrieve all tables, columns, indexes' informations from database. func (engine *Engine) DBMetas() ([]*schemas.Table, error) { - tables, err := engine.dialect.GetTables(engine.defaultContext) + tables, err := engine.dialect.GetTables(engine.db, engine.defaultContext) if err != nil { return nil, err } @@ -361,7 +362,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch uri := engine.dialect.URI() destURI := *uri - dstDialect.Init(nil, &destURI) + dstDialect.Init(&destURI) } _, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n", @@ -911,7 +912,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { } } else { for _, col := range table.Columns() { - isExist, err := engine.dialect.IsColumnExist(session.ctx, tableNameNoSchema, col.Name) + isExist, err := engine.dialect.IsColumnExist(engine.db, session.ctx, tableNameNoSchema, col.Name) if err != nil { return err } diff --git a/engine_group.go b/engine_group.go index 868d4dc9..d557645e 100644 --- a/engine_group.go +++ b/engine_group.go @@ -161,17 +161,17 @@ func (eg *EngineGroup) SetMapper(mapper names.Mapper) { // SetMaxIdleConns set the max idle connections on pool, default is 2 func (eg *EngineGroup) SetMaxIdleConns(conns int) { - eg.Engine.dialect.DB().SetMaxIdleConns(conns) + eg.Engine.DB().SetMaxIdleConns(conns) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].dialect.DB().SetMaxIdleConns(conns) + eg.slaves[i].DB().SetMaxIdleConns(conns) } } // SetMaxOpenConns is only available for go 1.2+ func (eg *EngineGroup) SetMaxOpenConns(conns int) { - eg.Engine.dialect.DB().SetMaxOpenConns(conns) + eg.Engine.DB().SetMaxOpenConns(conns) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].dialect.DB().SetMaxOpenConns(conns) + eg.slaves[i].DB().SetMaxOpenConns(conns) } } diff --git a/session.go b/session.go index 4842883b..6b8bfbaf 100644 --- a/session.go +++ b/session.go @@ -99,7 +99,7 @@ func (session *Session) Init() { session.engine.tagParser, session.engine.DatabaseTZ, ) - + session.db = session.engine.db session.isAutoCommit = true session.isCommitedOrRollbacked = false session.isAutoClose = false @@ -140,6 +140,13 @@ func (session *Session) Close() { } } +func (session *Session) getQueryer() core.Queryer { + if session.tx != nil { + return session.tx + } + return session.db +} + // ContextCache enable context cache or not func (session *Session) ContextCache(context contexts.ContextCache) *Session { session.statement.SetContextCache(context) diff --git a/session_schema.go b/session_schema.go index 84eb586e..9ccf8abe 100644 --- a/session_schema.go +++ b/session_schema.go @@ -134,7 +134,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { tableName := session.engine.TableName(beanOrTableName) sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true)) if !checkIfExist { - exist, err := session.engine.dialect.IsTableExist(session.ctx, tableName) + exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) if err != nil { return err } @@ -160,7 +160,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) } func (session *Session) isTableExist(tableName string) (bool, error) { - return session.engine.dialect.IsTableExist(session.ctx, tableName) + return session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) } // IsTableEmpty if table have any records @@ -187,7 +187,7 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) { // find if index is exist according cols func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) { - indexes, err := session.engine.dialect.GetIndexes(session.ctx, tableName) + indexes, err := session.engine.dialect.GetIndexes(session.getQueryer(), session.ctx, tableName) if err != nil { return false, err } @@ -233,7 +233,7 @@ func (session *Session) Sync2(beans ...interface{}) error { defer session.Close() } - tables, err := engine.dialect.GetTables(session.ctx) + tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx) if err != nil { return err } diff --git a/xorm.go b/xorm.go index 2025522f..e9cd7415 100644 --- a/xorm.go +++ b/xorm.go @@ -13,6 +13,7 @@ import ( "time" "xorm.io/xorm/caches" + "xorm.io/xorm/core" "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -32,6 +33,11 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { return nil, err } + db, err := core.Open(driverName, dataSourceName) + if err != nil { + return nil, err + } + cacherMgr := caches.NewManager() mapper := names.NewCacheMapper(new(names.SnakeMapper)) tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr) @@ -44,6 +50,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { tagParser: tagParser, driverName: driverName, dataSourceName: dataSourceName, + db: db, } if dialect.URI().DBType == schemas.SQLITE {