Some refactors
This commit is contained in:
parent
4cde28ca21
commit
c8611087d8
2
Makefile
2
Makefile
|
@ -6,7 +6,7 @@ GOFMT ?= gofmt -s
|
|||
TAGS ?=
|
||||
SED_INPLACE := sed -i
|
||||
|
||||
GO_DIRS := contexts tests core dialects internal log migrate names schemas tags
|
||||
GO_DIRS := contexts tests dialects internal log migrate names schemas tags
|
||||
GOFILES := $(wildcard *.go)
|
||||
GOFILES += $(shell find $(GO_DIRS) -name "*.go" -type f)
|
||||
INTEGRATION_PACKAGES := xorm.io/xorm/v2/tests
|
||||
|
|
|
@ -755,8 +755,8 @@ func (db *dameng) IndexCheckSQL(tableName, idxName string) (string, []any) {
|
|||
`WHERE TABLE_NAME = ? AND INDEX_NAME = ?`, args
|
||||
}
|
||||
|
||||
func (db *dameng) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
|
||||
return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = ?`, tableName)
|
||||
func (db *dameng) IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) {
|
||||
return db.HasRecords(ctx, queryer, `SELECT table_name FROM user_tables WHERE table_name = ?`, tableName)
|
||||
}
|
||||
|
||||
func (db *dameng) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) {
|
||||
|
@ -779,11 +779,11 @@ func (db *dameng) IsSequenceExist(ctx context.Context, queryer core.Queryer, seq
|
|||
return cnt > 0, nil
|
||||
}
|
||||
|
||||
func (db *dameng) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
|
||||
func (db *dameng) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName, colName string) (bool, error) {
|
||||
args := []any{tableName, colName}
|
||||
query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" +
|
||||
" AND column_name = ?"
|
||||
return db.HasRecords(queryer, ctx, query, args...)
|
||||
return db.HasRecords(ctx, queryer, query, args...)
|
||||
}
|
||||
|
||||
var _ sql.Scanner = &dmClobScanner{}
|
||||
|
@ -850,7 +850,7 @@ func addSingleQuote(name string) string {
|
|||
return b.String()
|
||||
}
|
||||
|
||||
func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
func (db *dameng) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
s := `select column_name from user_cons_columns
|
||||
where constraint_name = (select constraint_name from user_constraints
|
||||
where table_name = ? and constraint_type ='P')`
|
||||
|
@ -925,7 +925,7 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
|
|||
}
|
||||
if utils.IndexSlice(pkNames, col.Name) > -1 {
|
||||
col.IsPrimaryKey = true
|
||||
has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = ?", utils.SeqName(tableName))
|
||||
has, err := db.HasRecords(ctx, queryer, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = ?", utils.SeqName(tableName))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -1002,7 +1002,7 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
|
|||
return colSeq, cols, nil
|
||||
}
|
||||
|
||||
func (db *dameng) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
||||
func (db *dameng) GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) {
|
||||
s := "SELECT table_name FROM user_tables WHERE temporary = 'N' AND table_name NOT LIKE ?"
|
||||
args := []any{strings.ToUpper(db.uri.User), "%$%"}
|
||||
|
||||
|
@ -1028,7 +1028,7 @@ func (db *dameng) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem
|
|||
return tables, nil
|
||||
}
|
||||
|
||||
func (db *dameng) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
||||
func (db *dameng) GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) {
|
||||
args := []any{tableName, tableName}
|
||||
s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " +
|
||||
"WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =?" +
|
||||
|
|
|
@ -66,13 +66,13 @@ type Dialect interface {
|
|||
|
||||
AutoIncrStr() string
|
||||
|
||||
GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error)
|
||||
GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error)
|
||||
IndexCheckSQL(tableName, idxName string) (string, []any)
|
||||
CreateIndexSQL(tableName string, index *schemas.Index) string
|
||||
DropIndexSQL(tableName string, index *schemas.Index) string
|
||||
|
||||
GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error)
|
||||
IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error)
|
||||
GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error)
|
||||
IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error)
|
||||
CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error)
|
||||
DropTableSQL(tableName string) (string, bool)
|
||||
|
||||
|
@ -80,8 +80,8 @@ type Dialect interface {
|
|||
IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error)
|
||||
DropSequenceSQL(seqName string) (string, error)
|
||||
|
||||
GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
|
||||
IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error)
|
||||
GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error)
|
||||
IsColumnExist(ctx context.Context, queryer core.Queryer, tableName string, colName string) (bool, error)
|
||||
AddColumnSQL(tableName string, col *schemas.Column) string
|
||||
ModifyColumnSQL(tableName string, col *schemas.Column) string
|
||||
|
||||
|
@ -177,7 +177,7 @@ func (db *Base) DropTableSQL(tableName string) (string, bool) {
|
|||
}
|
||||
|
||||
// HasRecords returns true if the SQL has records returned
|
||||
func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...any) (bool, error) {
|
||||
func (db *Base) HasRecords(ctx context.Context, queryer core.Queryer, query string, args ...any) (bool, error) {
|
||||
rows, err := queryer.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
@ -191,7 +191,7 @@ func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query stri
|
|||
}
|
||||
|
||||
// IsColumnExist returns true if the column of the table exist
|
||||
func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
|
||||
func (db *Base) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName, colName string) (bool, error) {
|
||||
quote := db.dialect.Quoter().Quote
|
||||
query := fmt.Sprintf(
|
||||
"SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?",
|
||||
|
@ -202,7 +202,7 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa
|
|||
quote("TABLE_NAME"),
|
||||
quote("COLUMN_NAME"),
|
||||
)
|
||||
return db.HasRecords(queryer, ctx, query, db.uri.DBName, tableName, colName)
|
||||
return db.HasRecords(ctx, queryer, query, db.uri.DBName, tableName, colName)
|
||||
}
|
||||
|
||||
// AddColumnSQL returns a SQL to add a column
|
||||
|
|
|
@ -448,18 +448,18 @@ func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []any) {
|
|||
return sql, args
|
||||
}
|
||||
|
||||
func (db *mssql) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
|
||||
func (db *mssql) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName, colName string) (bool, error) {
|
||||
query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?`
|
||||
|
||||
return db.HasRecords(queryer, ctx, query, tableName, colName)
|
||||
return db.HasRecords(ctx, queryer, query, tableName, colName)
|
||||
}
|
||||
|
||||
func (db *mssql) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
|
||||
func (db *mssql) IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) {
|
||||
sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1"
|
||||
return db.HasRecords(queryer, ctx, sql)
|
||||
return db.HasRecords(ctx, queryer, sql)
|
||||
}
|
||||
|
||||
func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
func (db *mssql) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
args := []any{}
|
||||
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable,
|
||||
"default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END),
|
||||
|
@ -553,7 +553,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
|
|||
return colSeq, cols, nil
|
||||
}
|
||||
|
||||
func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
||||
func (db *mssql) GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) {
|
||||
args := []any{}
|
||||
s := `select name from sysobjects where xtype ='U'`
|
||||
|
||||
|
@ -580,7 +580,7 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
|
|||
return tables, nil
|
||||
}
|
||||
|
||||
func (db *mssql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
||||
func (db *mssql) GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) {
|
||||
args := []any{tableName}
|
||||
s := `SELECT
|
||||
IXS.NAME AS [INDEX_NAME],
|
||||
|
|
|
@ -377,9 +377,9 @@ func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []any) {
|
|||
return sql, args
|
||||
}
|
||||
|
||||
func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
|
||||
func (db *mysql) IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) {
|
||||
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
|
||||
return db.HasRecords(queryer, ctx, sql, db.uri.DBName, tableName)
|
||||
return db.HasRecords(ctx, queryer, sql, db.uri.DBName, tableName)
|
||||
}
|
||||
|
||||
func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
|
||||
|
@ -407,7 +407,7 @@ func (db *mysql) ModifyColumnSQL(tableName string, col *schemas.Column) string {
|
|||
return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s)
|
||||
}
|
||||
|
||||
func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
func (db *mysql) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
args := []any{db.uri.DBName, tableName}
|
||||
alreadyQuoted := "(INSTR(VERSION(), 'maria') > 0 && " +
|
||||
"(SUBSTRING_INDEX(VERSION(), '.', 1) > 10 || " +
|
||||
|
@ -544,7 +544,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
|
|||
return colSeq, cols, nil
|
||||
}
|
||||
|
||||
func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
||||
func (db *mysql) GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) {
|
||||
args := []any{db.uri.DBName}
|
||||
s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT`, `TABLE_COLLATION` from " +
|
||||
"`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')"
|
||||
|
@ -596,7 +596,7 @@ func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) {
|
|||
}
|
||||
}
|
||||
|
||||
func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
||||
func (db *mysql) GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) {
|
||||
args := []any{db.uri.DBName, tableName}
|
||||
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `SEQ_IN_INDEX`"
|
||||
|
||||
|
|
|
@ -658,7 +658,7 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
|
|||
}
|
||||
|
||||
func (db *oracle) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) {
|
||||
return db.HasRecords(queryer, ctx, `SELECT sequence_name FROM user_sequences WHERE sequence_name = :1`, seqName)
|
||||
return db.HasRecords(ctx, queryer, `SELECT sequence_name FROM user_sequences WHERE sequence_name = :1`, seqName)
|
||||
}
|
||||
|
||||
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
|
||||
|
@ -684,18 +684,18 @@ func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []any) {
|
|||
`WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args
|
||||
}
|
||||
|
||||
func (db *oracle) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
|
||||
return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = :1`, tableName)
|
||||
func (db *oracle) IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) {
|
||||
return db.HasRecords(ctx, queryer, `SELECT table_name FROM user_tables WHERE table_name = :1`, tableName)
|
||||
}
|
||||
|
||||
func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
|
||||
func (db *oracle) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName, colName string) (bool, error) {
|
||||
args := []any{tableName, colName}
|
||||
query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" +
|
||||
" AND column_name = :2"
|
||||
return db.HasRecords(queryer, ctx, query, args...)
|
||||
return db.HasRecords(ctx, queryer, query, args...)
|
||||
}
|
||||
|
||||
func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
func (db *oracle) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
args := []any{tableName}
|
||||
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
|
||||
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
|
||||
|
@ -795,7 +795,7 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
|
|||
return colSeq, cols, nil
|
||||
}
|
||||
|
||||
func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
||||
func (db *oracle) GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) {
|
||||
args := []any{}
|
||||
s := "SELECT table_name FROM user_tables"
|
||||
|
||||
|
@ -821,7 +821,7 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem
|
|||
return tables, nil
|
||||
}
|
||||
|
||||
func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
||||
func (db *oracle) GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) {
|
||||
args := []any{tableName}
|
||||
s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " +
|
||||
"WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1"
|
||||
|
|
|
@ -1009,12 +1009,12 @@ func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []any) {
|
|||
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
|
||||
}
|
||||
|
||||
func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
|
||||
func (db *postgres) IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) {
|
||||
if len(db.getSchema()) == 0 {
|
||||
return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName)
|
||||
return db.HasRecords(ctx, queryer, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName)
|
||||
}
|
||||
|
||||
return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`,
|
||||
return db.HasRecords(ctx, queryer, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`,
|
||||
db.getSchema(), tableName)
|
||||
}
|
||||
|
||||
|
@ -1070,7 +1070,7 @@ func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string
|
|||
return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName))
|
||||
}
|
||||
|
||||
func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
|
||||
func (db *postgres) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName, colName string) (bool, error) {
|
||||
args := []any{db.getSchema(), tableName, colName}
|
||||
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
|
||||
" AND column_name = $3"
|
||||
|
@ -1092,7 +1092,7 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab
|
|||
return false, rows.Err()
|
||||
}
|
||||
|
||||
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
func (db *postgres) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
args := []any{tableName}
|
||||
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, description,
|
||||
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
|
||||
|
@ -1245,7 +1245,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r' AND c.relname = $1%s AND f.a
|
|||
return colSeq, cols, nil
|
||||
}
|
||||
|
||||
func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
||||
func (db *postgres) GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) {
|
||||
args := []any{}
|
||||
s := "SELECT tablename FROM pg_tables"
|
||||
schema := db.getSchema()
|
||||
|
@ -1288,7 +1288,7 @@ func getIndexColName(indexdef string) []string {
|
|||
return colNames
|
||||
}
|
||||
|
||||
func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
||||
func (db *postgres) GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) {
|
||||
args := []any{tableName}
|
||||
s := "SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1"
|
||||
if len(db.getSchema()) != 0 {
|
||||
|
|
|
@ -272,8 +272,8 @@ func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []any) {
|
|||
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
|
||||
}
|
||||
|
||||
func (db *sqlite3) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) {
|
||||
return db.HasRecords(queryer, ctx, "SELECT name FROM sqlite_master WHERE type='table' and name = ?", tableName)
|
||||
func (db *sqlite3) IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) {
|
||||
return db.HasRecords(ctx, queryer, "SELECT name FROM sqlite_master WHERE type='table' and name = ?", tableName)
|
||||
}
|
||||
|
||||
func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
|
||||
|
@ -291,7 +291,7 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
|
|||
return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName))
|
||||
}
|
||||
|
||||
func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
|
||||
func (db *sqlite3) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName, colName string) (bool, error) {
|
||||
query := "SELECT * FROM " + tableName + " LIMIT 0"
|
||||
rows, err := queryer.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
|
@ -375,7 +375,7 @@ func parseString(colStr string) (*schemas.Column, error) {
|
|||
return col, nil
|
||||
}
|
||||
|
||||
func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
func (db *sqlite3) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
args := []any{tableName}
|
||||
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
|
||||
|
||||
|
@ -434,7 +434,7 @@ func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableNa
|
|||
return colSeq, cols, nil
|
||||
}
|
||||
|
||||
func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
|
||||
func (db *sqlite3) GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) {
|
||||
args := []any{}
|
||||
s := "SELECT name FROM sqlite_master WHERE type='table'"
|
||||
|
||||
|
@ -462,7 +462,7 @@ func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*sche
|
|||
return tables, nil
|
||||
}
|
||||
|
||||
func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
|
||||
func (db *sqlite3) GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) {
|
||||
args := []any{tableName}
|
||||
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
|
||||
|
||||
|
|
|
@ -308,14 +308,14 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session {
|
|||
}
|
||||
|
||||
func (engine *Engine) loadTableInfo(ctx context.Context, table *schemas.Table) error {
|
||||
colSeq, cols, err := engine.dialect.GetColumns(engine.db, ctx, table.Name)
|
||||
colSeq, cols, err := engine.dialect.GetColumns(ctx, engine.db, table.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, name := range colSeq {
|
||||
table.AddColumn(cols[name])
|
||||
}
|
||||
indexes, err := engine.dialect.GetIndexes(engine.db, ctx, table.Name)
|
||||
indexes, err := engine.dialect.GetIndexes(ctx, engine.db, table.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -345,7 +345,7 @@ func (engine *Engine) loadTableInfo(ctx context.Context, table *schemas.Table) e
|
|||
|
||||
// DBMetas Retrieve all tables, columns, indexes' informations from database.
|
||||
func (engine *Engine) DBMetas() ([]*schemas.Table, error) {
|
||||
tables, err := engine.dialect.GetTables(engine.db, engine.defaultContext)
|
||||
tables, err := engine.dialect.GetTables(engine.defaultContext, engine.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,87 +0,0 @@
|
|||
// Copyright 2019 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package statements
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"xorm.io/xorm/v2/internal/utils"
|
||||
"xorm.io/xorm/v2/schemas"
|
||||
)
|
||||
|
||||
// ConvertIDSQL converts SQL with id
|
||||
func (statement *Statement) ConvertIDSQL(sqlStr string) string {
|
||||
if statement.RefTable != nil {
|
||||
cols := statement.RefTable.PKColumns()
|
||||
if len(cols) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
colstrs := statement.joinColumns(cols, false)
|
||||
sqls := utils.SplitNNoCase(sqlStr, " from ", 2)
|
||||
if len(sqls) != 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("SELECT ")
|
||||
pLimitN := statement.LimitN
|
||||
if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL {
|
||||
b.WriteString("TOP ")
|
||||
b.WriteString(strconv.Itoa(*pLimitN))
|
||||
b.WriteString(" ")
|
||||
}
|
||||
b.WriteString(colstrs)
|
||||
b.WriteString(" FROM ")
|
||||
b.WriteString(sqls[1])
|
||||
|
||||
return b.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ConvertUpdateSQL converts update SQL
|
||||
func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) {
|
||||
if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
|
||||
sqls := utils.SplitNNoCase(sqlStr, "where", 2)
|
||||
if len(sqls) != 2 {
|
||||
if len(sqls) == 1 {
|
||||
return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
|
||||
colstrs, statement.quote(statement.TableName()))
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
whereStr := sqls[1]
|
||||
|
||||
// TODO: for postgres only, if any other database?
|
||||
var paraStr string
|
||||
if statement.dialect.URI().DBType == schemas.POSTGRES {
|
||||
paraStr = "$"
|
||||
} else if statement.dialect.URI().DBType == schemas.MSSQL {
|
||||
paraStr = ":"
|
||||
}
|
||||
|
||||
if paraStr != "" {
|
||||
if strings.Contains(sqls[1], paraStr) {
|
||||
dollers := strings.Split(sqls[1], paraStr)
|
||||
whereStr = dollers[0]
|
||||
for i, c := range dollers[1:] {
|
||||
ccs := strings.SplitN(c, " ", 2)
|
||||
whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
|
||||
colstrs, statement.quote(statement.TableName()),
|
||||
whereStr)
|
||||
}
|
|
@ -565,6 +565,12 @@ func (statement *Statement) writeSetColumns(colNames []string, args []any) func(
|
|||
}
|
||||
|
||||
func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []any) error {
|
||||
// write set
|
||||
if _, err := fmt.Fprint(w, " SET "); err != nil {
|
||||
return err
|
||||
}
|
||||
previousLen := w.Len()
|
||||
|
||||
if err := statement.writeSetColumns(colNames, args)(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -588,6 +594,11 @@ func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Va
|
|||
if err := statement.writeVersionIncrSet(w, v, setNumber > 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if no columns to be updated, return error
|
||||
if previousLen == w.Len() {
|
||||
return ErrNoColumnsTobeUpdated
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -606,21 +617,10 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond
|
|||
return err
|
||||
}
|
||||
|
||||
// write set
|
||||
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
|
||||
return err
|
||||
}
|
||||
previousLen := updateWriter.Len()
|
||||
|
||||
if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if no columns to be updated, return error
|
||||
if previousLen == updateWriter.Len() {
|
||||
return ErrNoColumnsTobeUpdated
|
||||
}
|
||||
|
||||
// write from
|
||||
if err := statement.writeUpdateFrom(updateWriter); err != nil {
|
||||
return err
|
||||
|
|
|
@ -150,7 +150,7 @@ func (session *Session) dropTable(beanOrTableName any) error {
|
|||
tableName := session.engine.TableName(beanOrTableName)
|
||||
sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true))
|
||||
if !checkIfExist {
|
||||
exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
|
||||
exist, err := session.engine.dialect.IsTableExist(session.ctx, session.getQueryer(), tableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -197,7 +197,7 @@ func (session *Session) IsTableExist(beanOrTableName any) (bool, error) {
|
|||
}
|
||||
|
||||
func (session *Session) isTableExist(tableName string) (bool, error) {
|
||||
return session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
|
||||
return session.engine.dialect.IsTableExist(session.ctx, session.getQueryer(), tableName)
|
||||
}
|
||||
|
||||
// IsTableEmpty if table have any records
|
||||
|
|
2
sync.go
2
sync.go
|
@ -67,7 +67,7 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...any) (*SyncRe
|
|||
defer session.Close()
|
||||
}
|
||||
|
||||
tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx)
|
||||
tables, err := engine.dialect.GetTables(session.ctx, session.getQueryer())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue