From 4c2b0e0f551384083b419cc0d7f797462a12fdff Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 27 Feb 2020 15:31:05 +0000 Subject: [PATCH] Add context for dialects (#1558) More improvements Add context for dialects Reviewed-on: https://gitea.com/xorm/xorm/pulls/1558 --- dialects/dialect.go | 157 +++++++++++++++++++------------------------ dialects/mssql.go | 21 +++--- dialects/mysql.go | 17 ++--- dialects/oracle.go | 46 +++---------- dialects/postgres.go | 17 ++--- dialects/sqlite3.go | 17 ++--- engine.go | 10 +-- session_schema.go | 8 +-- statement.go | 11 --- 9 files changed, 128 insertions(+), 176 deletions(-) diff --git a/dialects/dialect.go b/dialects/dialect.go index 3ed867f4..26d6521a 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -5,6 +5,7 @@ package dialects import ( + "context" "fmt" "strings" "time" @@ -60,23 +61,21 @@ type Dialect interface { IndexCheckSQL(tableName, idxName string) (string, []interface{}) TableCheckSQL(tableName string) (string, []interface{}) - IsColumnExist(tableName string, colName string) (bool, error) + IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string DropTableSQL(tableName string) string CreateIndexSQL(tableName string, index *schemas.Index) string DropIndexSQL(tableName string, index *schemas.Index) string + AddColumnSQL(tableName string, col *schemas.Column) string ModifyColumnSQL(tableName string, col *schemas.Column) string ForUpdateSQL(query string) string - // CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error - // MustDropTable(tableName string) error - - GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) - GetTables() ([]*schemas.Table, error) - GetIndexes(tableName string) (map[string]*schemas.Index, error) + GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) + GetTables(ctx context.Context) ([]*schemas.Table, error) + GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) Filters() []Filter SetParams(params map[string]string) @@ -96,55 +95,6 @@ type Base struct { uri *URI } -// String generate column description string according dialect -func String(d Dialect, col *schemas.Column) string { - sql := d.Quoter().Quote(col.Name) + " " - - sql += d.SQLType(col) + " " - - if col.IsPrimaryKey { - sql += "PRIMARY KEY " - if col.IsAutoIncrement { - sql += d.AutoIncrStr() + " " - } - } - - if col.Default != "" { - sql += "DEFAULT " + col.Default + " " - } - - if d.ShowCreateNull() { - if col.Nullable { - sql += "NULL " - } else { - sql += "NOT NULL " - } - } - - return sql -} - -// StringNoPk generate column description string according dialect without primary keys -func StringNoPk(d Dialect, col *schemas.Column) string { - sql := d.Quoter().Quote(col.Name) + " " - - sql += d.SQLType(col) + " " - - if col.Default != "" { - sql += "DEFAULT " + col.Default + " " - } - - if d.ShowCreateNull() { - if col.Nullable { - sql += "NULL " - } else { - sql += "NOT NULL " - } - } - - return sql -} - func (b *Base) DB() *core.DB { return b.db } @@ -167,6 +117,55 @@ func (b *Base) DBType() DBType { return b.uri.DBType } +// String generate column description string according dialect +func (b *Base) String(col *schemas.Column) string { + sql := b.dialect.Quoter().Quote(col.Name) + " " + + sql += b.dialect.SQLType(col) + " " + + if col.IsPrimaryKey { + sql += "PRIMARY KEY " + if col.IsAutoIncrement { + sql += b.dialect.AutoIncrStr() + " " + } + } + + if col.Default != "" { + sql += "DEFAULT " + col.Default + " " + } + + if b.dialect.ShowCreateNull() { + if col.Nullable { + sql += "NULL " + } else { + sql += "NOT NULL " + } + } + + return sql +} + +// StringNoPk generate column description string according dialect without primary keys +func (b *Base) StringNoPk(col *schemas.Column) string { + sql := b.dialect.Quoter().Quote(col.Name) + " " + + sql += b.dialect.SQLType(col) + " " + + if col.Default != "" { + sql += "DEFAULT " + col.Default + " " + } + + if b.dialect.ShowCreateNull() { + if col.Nullable { + sql += "NULL " + } else { + sql += "NOT NULL " + } + } + + return sql +} + func (b *Base) FormatBytes(bs []byte) string { return fmt.Sprintf("0x%x", bs) } @@ -196,9 +195,9 @@ func (db *Base) DropTableSQL(tableName string) string { return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)) } -func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) { +func (db *Base) HasRecords(ctx context.Context, query string, args ...interface{}) (bool, error) { db.LogSQL(query, args) - rows, err := db.DB().Query(query, args...) + rows, err := db.DB().QueryContext(ctx, query, args...) if err != nil { return false, err } @@ -210,7 +209,7 @@ func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) { return false, nil } -func (db *Base) IsColumnExist(tableName, colName string) (bool, error) { +func (db *Base) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { quote := db.dialect.Quoter().Quote query := fmt.Sprintf( "SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?", @@ -221,32 +220,18 @@ func (db *Base) IsColumnExist(tableName, colName string) (bool, error) { quote("TABLE_NAME"), quote("COLUMN_NAME"), ) - return db.HasRecords(query, db.uri.DBName, tableName, colName) + return db.HasRecords(ctx, query, db.uri.DBName, tableName, colName) } -/* -func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error { - sql, args := db.dialect.TableCheckSQL(tableName) - rows, err := db.DB().Query(sql, args...) - if db.Logger != nil { - db.Logger.Info("[sql]", sql, args) +func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { + quoter := db.dialect.Quoter() + sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), + db.String(col)) + if db.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 { + sql += " COMMENT '" + col.Comment + "'" } - if err != nil { - return err - } - defer rows.Close() - - if rows.Next() { - return nil - } - - sql = db.dialect.CreateTableSQL(table, tableName, storeEngine, charset) - _, err = db.DB().Exec(sql) - if db.Logger != nil { - db.Logger.Info("[sql]", sql) - } - return err -}*/ + return sql +} func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { quoter := db.dialect.Quoter() @@ -273,7 +258,7 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { } func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { - return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, StringNoPk(db.dialect, col)) + return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, db.StringNoPk(col)) } func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string { @@ -293,12 +278,12 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) if col.IsPrimaryKey && len(pkList) == 1 { - sql += String(b.dialect, col) + sql += b.String(col) } else { - sql += StringNoPk(b.dialect, col) + sql += b.StringNoPk(col) } sql = strings.TrimSpace(sql) - if b.DriverName() == schemas.MYSQL && len(col.Comment) > 0 { + if b.DBType() == schemas.MYSQL && len(col.Comment) > 0 { sql += " COMMENT '" + col.Comment + "'" } sql += ", " diff --git a/dialects/mssql.go b/dialects/mssql.go index 74a3bb63..83844f4e 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -5,6 +5,7 @@ package dialects import ( + "context" "errors" "fmt" "net/url" @@ -324,10 +325,10 @@ func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{} return sql, args }*/ -func (db *mssql) IsColumnExist(tableName, colName string) (bool, error) { +func (db *mssql) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?` - return db.HasRecords(query, tableName, colName) + return db.HasRecords(ctx, query, tableName, colName) } func (db *mssql) TableCheckSQL(tableName string) (string, []interface{}) { @@ -336,7 +337,7 @@ func (db *mssql) TableCheckSQL(tableName string) (string, []interface{}) { return sql, args } -func (db *mssql) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{} s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), @@ -352,7 +353,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*schemas.Col where a.object_id=object_id('` + tableName + `')` db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -407,12 +408,12 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*schemas.Col return colSeq, cols, nil } -func (db *mssql) GetTables() ([]*schemas.Table, error) { +func (db *mssql) GetTables(ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := `select name from sysobjects where xtype ='U'` db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -432,7 +433,7 @@ func (db *mssql) GetTables() ([]*schemas.Table, error) { return tables, nil } -func (db *mssql) GetIndexes(tableName string) (map[string]*schemas.Index, error) { +func (db *mssql) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := `SELECT IXS.NAME AS [INDEX_NAME], @@ -447,7 +448,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? ` db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -510,9 +511,9 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) if col.IsPrimaryKey && len(pkList) == 1 { - sql += String(db, col) + sql += db.String(col) } else { - sql += StringNoPk(db, col) + sql += db.StringNoPk(col) } sql = strings.TrimSpace(sql) sql += ", " diff --git a/dialects/mysql.go b/dialects/mysql.go index 32dc25b7..62fc6eb1 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -5,6 +5,7 @@ package dialects import ( + "context" "crypto/tls" "errors" "fmt" @@ -314,13 +315,13 @@ func (db *mysql) TableCheckSQL(tableName string) (string, []interface{}) { return sql, args } -func (db *mysql) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *mysql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{db.uri.DBName, tableName} s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + " `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -425,13 +426,13 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*schemas.Col return colSeq, cols, nil } -func (db *mysql) GetTables() ([]*schemas.Table, error) { +func (db *mysql) GetTables(ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{db.uri.DBName} s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -455,12 +456,12 @@ func (db *mysql) GetTables() ([]*schemas.Table, error) { return tables, nil } -func (db *mysql) GetIndexes(tableName string) (map[string]*schemas.Index, error) { +func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{db.uri.DBName, tableName} s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -523,9 +524,9 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) if col.IsPrimaryKey && len(pkList) == 1 { - sql += String(db, col) + sql += db.String(col) } else { - sql += StringNoPk(db, col) + sql += db.StringNoPk(col) } sql = strings.TrimSpace(sql) if len(col.Comment) > 0 { diff --git a/dialects/oracle.go b/dialects/oracle.go index 46f7aca2..1247d7a4 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -5,6 +5,7 @@ package dialects import ( + "context" "errors" "fmt" "regexp" @@ -592,7 +593,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c /*if col.IsPrimaryKey && len(pkList) == 1 { sql += col.String(b.dialect) } else {*/ - sql += StringNoPk(db, col) + sql += db.StringNoPk(col) // } sql = strings.TrimSpace(sql) sql += ", " @@ -630,40 +631,13 @@ func (db *oracle) TableCheckSQL(tableName string) (string, []interface{}) { return `SELECT table_name FROM user_tables WHERE table_name = :1`, args } -func (db *oracle) MustDropTable(tableName string) error { - sql, args := db.TableCheckSQL(tableName) - db.LogSQL(sql, args) - - rows, err := db.DB().Query(sql, args...) - if err != nil { - return err - } - defer rows.Close() - - if !rows.Next() { - return nil - } - - sql = "Drop Table \"" + tableName + "\"" - db.LogSQL(sql, args) - - _, err = db.DB().Exec(sql) - return err -} - -/*func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(colName)} - return "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + - " AND column_name = ?", args -}*/ - -func (db *oracle) IsColumnExist(tableName, colName string) (bool, error) { +func (db *oracle) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { args := []interface{}{tableName, colName} query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" + " AND column_name = :2" db.LogSQL(query, args) - rows, err := db.DB().Query(query, args...) + rows, err := db.DB().QueryContext(ctx, query, args...) if err != nil { return false, err } @@ -675,13 +649,13 @@ func (db *oracle) IsColumnExist(tableName, colName string) (bool, error) { return false, nil } -func (db *oracle) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *oracle) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -773,12 +747,12 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*schemas.Co return colSeq, cols, nil } -func (db *oracle) GetTables() ([]*schemas.Table, error) { +func (db *oracle) GetTables(ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT table_name FROM user_tables" db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -797,13 +771,13 @@ func (db *oracle) GetTables() ([]*schemas.Table, error) { return tables, nil } -func (db *oracle) GetIndexes(tableName string) (map[string]*schemas.Index, error) { +func (db *oracle) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1" db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, err } diff --git a/dialects/postgres.go b/dialects/postgres.go index cab7eaef..d6847b02 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -5,6 +5,7 @@ package dialects import ( + "context" "errors" "fmt" "net/url" @@ -929,7 +930,7 @@ func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { +func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { args := []interface{}{db.uri.Schema, tableName, colName} query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + " AND column_name = $3" @@ -940,7 +941,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { } db.LogSQL(query, args) - rows, err := db.DB().Query(query, args...) + rows, err := db.DB().QueryContext(ctx, query, args...) if err != nil { return false, err } @@ -949,7 +950,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { return rows.Next(), nil } -func (db *postgres) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *postgres) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, @@ -972,7 +973,7 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -1064,7 +1065,7 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att return colSeq, cols, nil } -func (db *postgres) GetTables() ([]*schemas.Table, error) { +func (db *postgres) GetTables(ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT tablename FROM pg_tables" if len(db.uri.Schema) != 0 { @@ -1074,7 +1075,7 @@ func (db *postgres) GetTables() ([]*schemas.Table, error) { db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -1105,7 +1106,7 @@ func getIndexColName(indexdef string) []string { return colNames } -func (db *postgres) GetIndexes(tableName string) (map[string]*schemas.Index, error) { +func (db *postgres) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") if len(db.uri.Schema) != 0 { @@ -1114,7 +1115,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*schemas.Index, err } db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, err } diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 0fd80b73..5511468f 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -5,6 +5,7 @@ package dialects import ( + "context" "database/sql" "errors" "fmt" @@ -254,11 +255,11 @@ func (db *sqlite3) ForUpdateSQL(query string) string { return sql, args }*/ -func (db *sqlite3) IsColumnExist(tableName, colName string) (bool, error) { +func (db *sqlite3) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { args := []interface{}{tableName} query := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" db.LogSQL(query, args) - rows, err := db.DB().Query(query, args...) + rows, err := db.DB().QueryContext(ctx, query, args...) if err != nil { return false, err } @@ -332,11 +333,11 @@ func parseString(colStr string) (*schemas.Column, error) { return col, nil } -func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *sqlite3) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -389,12 +390,12 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*schemas.C return colSeq, cols, nil } -func (db *sqlite3) GetTables() ([]*schemas.Table, error) { +func (db *sqlite3) GetTables(ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT name FROM sqlite_master WHERE type='table'" db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, err } @@ -415,12 +416,12 @@ func (db *sqlite3) GetTables() ([]*schemas.Table, error) { return tables, nil } -func (db *sqlite3) GetIndexes(tableName string) (map[string]*schemas.Index, error) { +func (db *sqlite3) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, err } diff --git a/engine.go b/engine.go index b97d1c06..50b0958c 100644 --- a/engine.go +++ b/engine.go @@ -321,14 +321,14 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session { } func (engine *Engine) loadTableInfo(table *schemas.Table) error { - colSeq, cols, err := engine.dialect.GetColumns(table.Name) + colSeq, cols, err := engine.dialect.GetColumns(engine.defaultContext, table.Name) if err != nil { return err } for _, name := range colSeq { table.AddColumn(cols[name]) } - indexes, err := engine.dialect.GetIndexes(table.Name) + indexes, err := engine.dialect.GetIndexes(engine.defaultContext, table.Name) if err != nil { return err } @@ -348,7 +348,7 @@ func (engine *Engine) loadTableInfo(table *schemas.Table) error { // DBMetas Retrieve all tables, columns, indexes' informations from database. func (engine *Engine) DBMetas() ([]*schemas.Table, error) { - tables, err := engine.dialect.GetTables() + tables, err := engine.dialect.GetTables(engine.defaultContext) if err != nil { return nil, err } @@ -439,7 +439,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dia colNames := engine.dialect.Quoter().Join(cols, ", ") destColNames := dialect.Quoter().Join(cols, ", ") - rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.Quote(table.Name)) + rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(table.Name)) if err != nil { return err } @@ -979,7 +979,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { } } else { for _, col := range table.Columns() { - isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name) + isExist, err := engine.dialect.IsColumnExist(session.ctx, tableNameNoSchema, col.Name) if err != nil { return err } diff --git a/session_schema.go b/session_schema.go index 809f158f..05b24c91 100644 --- a/session_schema.go +++ b/session_schema.go @@ -183,7 +183,7 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) { // find if index is exist according cols func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) { - indexes, err := session.engine.dialect.GetIndexes(tableName) + indexes, err := session.engine.dialect.GetIndexes(session.ctx, tableName) if err != nil { return false, err } @@ -201,8 +201,8 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo func (session *Session) addColumn(colName string) error { col := session.statement.RefTable.GetColumn(colName) - sql, args := session.statement.genAddColumnStr(col) - _, err := session.exec(sql, args...) + sql := session.statement.dialect.AddColumnSQL(session.statement.TableName(), col) + _, err := session.exec(sql) return err } @@ -229,7 +229,7 @@ func (session *Session) Sync2(beans ...interface{}) error { defer session.Close() } - tables, err := engine.dialect.GetTables() + tables, err := engine.dialect.GetTables(session.ctx) if err != nil { return err } diff --git a/statement.go b/statement.go index c07ddfe9..b1593621 100644 --- a/statement.go +++ b/statement.go @@ -902,17 +902,6 @@ func (statement *Statement) genDelIndexSQL() []string { return sqls } -func (statement *Statement) genAddColumnStr(col *schemas.Column) (string, []interface{}) { - quote := statement.quote - sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()), - dialects.String(statement.dialect, col)) - if statement.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 { - sql += " COMMENT '" + col.Comment + "'" - } - sql += ";" - return sql, []interface{}{} -} - func (statement *Statement) buildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)