From a4f9c21c1782e3677444876ce8b3a9a108cdd56a Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 28 Oct 2023 11:59:16 +0000 Subject: [PATCH] Some refactors (#2361) Reviewed-on: https://gitea.com/xorm/xorm/pulls/2361 --- Makefile | 2 +- dialects/dameng.go | 16 +++---- dialects/dialect.go | 16 +++---- dialects/mssql.go | 14 +++--- dialects/mysql.go | 10 ++-- dialects/oracle.go | 16 +++---- dialects/postgres.go | 14 +++--- dialects/sqlite3.go | 12 ++--- engine.go | 6 +-- internal/statements/cache.go | 87 ----------------------------------- internal/statements/update.go | 22 ++++----- schema.go | 4 +- sync.go | 2 +- 13 files changed, 67 insertions(+), 154 deletions(-) delete mode 100644 internal/statements/cache.go diff --git a/Makefile b/Makefile index 4cf77258..ff93f181 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ GOFMT ?= gofmt -s TAGS ?= SED_INPLACE := sed -i -GO_DIRS := contexts tests core dialects internal log migrate names schemas tags +GO_DIRS := contexts tests dialects internal log migrate names schemas tags GOFILES := $(wildcard *.go) GOFILES += $(shell find $(GO_DIRS) -name "*.go" -type f) INTEGRATION_PACKAGES := xorm.io/xorm/v2/tests diff --git a/dialects/dameng.go b/dialects/dameng.go index ca03062c..3ec7fa51 100644 --- a/dialects/dameng.go +++ b/dialects/dameng.go @@ -755,8 +755,8 @@ func (db *dameng) IndexCheckSQL(tableName, idxName string) (string, []any) { `WHERE TABLE_NAME = ? AND INDEX_NAME = ?`, args } -func (db *dameng) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { - return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = ?`, tableName) +func (db *dameng) IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) { + return db.HasRecords(ctx, queryer, `SELECT table_name FROM user_tables WHERE table_name = ?`, tableName) } func (db *dameng) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) { @@ -779,11 +779,11 @@ func (db *dameng) IsSequenceExist(ctx context.Context, queryer core.Queryer, seq return cnt > 0, nil } -func (db *dameng) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { +func (db *dameng) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName, colName string) (bool, error) { args := []any{tableName, colName} query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + " AND column_name = ?" - return db.HasRecords(queryer, ctx, query, args...) + return db.HasRecords(ctx, queryer, query, args...) } var _ sql.Scanner = &dmClobScanner{} @@ -850,7 +850,7 @@ func addSingleQuote(name string) string { return b.String() } -func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *dameng) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) { s := `select column_name from user_cons_columns where constraint_name = (select constraint_name from user_constraints where table_name = ? and constraint_type ='P')` @@ -925,7 +925,7 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam } if utils.IndexSlice(pkNames, col.Name) > -1 { col.IsPrimaryKey = true - has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = ?", utils.SeqName(tableName)) + has, err := db.HasRecords(ctx, queryer, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = ?", utils.SeqName(tableName)) if err != nil { return nil, nil, err } @@ -1002,7 +1002,7 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam return colSeq, cols, nil } -func (db *dameng) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { +func (db *dameng) GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) { s := "SELECT table_name FROM user_tables WHERE temporary = 'N' AND table_name NOT LIKE ?" args := []any{strings.ToUpper(db.uri.User), "%$%"} @@ -1028,7 +1028,7 @@ func (db *dameng) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem return tables, nil } -func (db *dameng) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { +func (db *dameng) GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) { args := []any{tableName, tableName} s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =?" + diff --git a/dialects/dialect.go b/dialects/dialect.go index f1778041..b1d26c63 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -66,13 +66,13 @@ type Dialect interface { AutoIncrStr() string - GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) + GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) IndexCheckSQL(tableName, idxName string) (string, []any) CreateIndexSQL(tableName string, index *schemas.Index) string DropIndexSQL(tableName string, index *schemas.Index) string - GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) - IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) + GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) + IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) DropTableSQL(tableName string) (string, bool) @@ -80,8 +80,8 @@ type Dialect interface { IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) DropSequenceSQL(seqName string) (string, error) - GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) - IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error) + GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) + IsColumnExist(ctx context.Context, queryer core.Queryer, tableName string, colName string) (bool, error) AddColumnSQL(tableName string, col *schemas.Column) string ModifyColumnSQL(tableName string, col *schemas.Column) string @@ -177,7 +177,7 @@ func (db *Base) DropTableSQL(tableName string) (string, bool) { } // HasRecords returns true if the SQL has records returned -func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...any) (bool, error) { +func (db *Base) HasRecords(ctx context.Context, queryer core.Queryer, query string, args ...any) (bool, error) { rows, err := queryer.QueryContext(ctx, query, args...) if err != nil { return false, err @@ -191,7 +191,7 @@ func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query stri } // IsColumnExist returns true if the column of the table exist -func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { +func (db *Base) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName, colName string) (bool, error) { quote := db.dialect.Quoter().Quote query := fmt.Sprintf( "SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?", @@ -202,7 +202,7 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa quote("TABLE_NAME"), quote("COLUMN_NAME"), ) - return db.HasRecords(queryer, ctx, query, db.uri.DBName, tableName, colName) + return db.HasRecords(ctx, queryer, query, db.uri.DBName, tableName, colName) } // AddColumnSQL returns a SQL to add a column diff --git a/dialects/mssql.go b/dialects/mssql.go index c24c2728..d936fd80 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -448,18 +448,18 @@ func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []any) { return sql, args } -func (db *mssql) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { +func (db *mssql) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName, colName string) (bool, error) { query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?` - return db.HasRecords(queryer, ctx, query, tableName, colName) + return db.HasRecords(ctx, queryer, query, tableName, colName) } -func (db *mssql) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { +func (db *mssql) IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) { sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1" - return db.HasRecords(queryer, ctx, sql) + return db.HasRecords(ctx, queryer, sql) } -func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *mssql) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) { args := []any{} s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), @@ -553,7 +553,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName return colSeq, cols, nil } -func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { +func (db *mssql) GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) { args := []any{} s := `select name from sysobjects where xtype ='U'` @@ -580,7 +580,7 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema return tables, nil } -func (db *mssql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { +func (db *mssql) GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) { args := []any{tableName} s := `SELECT IXS.NAME AS [INDEX_NAME], diff --git a/dialects/mysql.go b/dialects/mysql.go index 7445ee07..9cf2350b 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -377,9 +377,9 @@ func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []any) { return sql, args } -func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { +func (db *mysql) IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) { sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" - return db.HasRecords(queryer, ctx, sql, db.uri.DBName, tableName) + return db.HasRecords(ctx, queryer, sql, db.uri.DBName, tableName) } func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { @@ -407,7 +407,7 @@ func (db *mysql) ModifyColumnSQL(tableName string, col *schemas.Column) string { return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s) } -func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *mysql) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) { args := []any{db.uri.DBName, tableName} alreadyQuoted := "(INSTR(VERSION(), 'maria') > 0 && " + "(SUBSTRING_INDEX(VERSION(), '.', 1) > 10 || " + @@ -544,7 +544,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName return colSeq, cols, nil } -func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { +func (db *mysql) GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) { args := []any{db.uri.DBName} s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT`, `TABLE_COLLATION` from " + "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" @@ -596,7 +596,7 @@ func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) { } } -func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { +func (db *mysql) GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) { args := []any{db.uri.DBName, tableName} s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `SEQ_IN_INDEX`" diff --git a/dialects/oracle.go b/dialects/oracle.go index 778c6689..d0d5fab5 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -658,7 +658,7 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl } func (db *oracle) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) { - return db.HasRecords(queryer, ctx, `SELECT sequence_name FROM user_sequences WHERE sequence_name = :1`, seqName) + return db.HasRecords(ctx, queryer, `SELECT sequence_name FROM user_sequences WHERE sequence_name = :1`, seqName) } func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { @@ -684,18 +684,18 @@ func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []any) { `WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args } -func (db *oracle) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { - return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = :1`, tableName) +func (db *oracle) IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) { + return db.HasRecords(ctx, queryer, `SELECT table_name FROM user_tables WHERE table_name = :1`, tableName) } -func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { +func (db *oracle) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName, colName string) (bool, error) { args := []any{tableName, colName} query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" + " AND column_name = :2" - return db.HasRecords(queryer, ctx, query, args...) + return db.HasRecords(ctx, queryer, query, args...) } -func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *oracle) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) { args := []any{tableName} s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" @@ -795,7 +795,7 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam return colSeq, cols, nil } -func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { +func (db *oracle) GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) { args := []any{} s := "SELECT table_name FROM user_tables" @@ -821,7 +821,7 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem return tables, nil } -func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { +func (db *oracle) GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) { args := []any{tableName} s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1" diff --git a/dialects/postgres.go b/dialects/postgres.go index 0bc33edc..e4ae0907 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1009,12 +1009,12 @@ func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []any) { `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } -func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { +func (db *postgres) IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) { if len(db.getSchema()) == 0 { - return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName) + return db.HasRecords(ctx, queryer, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName) } - return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`, + return db.HasRecords(ctx, queryer, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`, db.getSchema(), tableName) } @@ -1070,7 +1070,7 @@ func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { +func (db *postgres) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName, colName string) (bool, error) { args := []any{db.getSchema(), tableName, colName} query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + " AND column_name = $3" @@ -1092,7 +1092,7 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab return false, rows.Err() } -func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *postgres) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) { args := []any{tableName} s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, description, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, @@ -1245,7 +1245,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r' AND c.relname = $1%s AND f.a return colSeq, cols, nil } -func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { +func (db *postgres) GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) { args := []any{} s := "SELECT tablename FROM pg_tables" schema := db.getSchema() @@ -1288,7 +1288,7 @@ func getIndexColName(indexdef string) []string { return colNames } -func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { +func (db *postgres) GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) { args := []any{tableName} s := "SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1" if len(db.getSchema()) != 0 { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 15bb00ee..d6c341f1 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -272,8 +272,8 @@ func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []any) { return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args } -func (db *sqlite3) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { - return db.HasRecords(queryer, ctx, "SELECT name FROM sqlite_master WHERE type='table' and name = ?", tableName) +func (db *sqlite3) IsTableExist(ctx context.Context, queryer core.Queryer, tableName string) (bool, error) { + return db.HasRecords(ctx, queryer, "SELECT name FROM sqlite_master WHERE type='table' and name = ?", tableName) } func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { @@ -291,7 +291,7 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { +func (db *sqlite3) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName, colName string) (bool, error) { query := "SELECT * FROM " + tableName + " LIMIT 0" rows, err := queryer.QueryContext(ctx, query) if err != nil { @@ -375,7 +375,7 @@ func parseString(colStr string) (*schemas.Column, error) { return col, nil } -func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { +func (db *sqlite3) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) { args := []any{tableName} s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" @@ -434,7 +434,7 @@ func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableNa return colSeq, cols, nil } -func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { +func (db *sqlite3) GetTables(ctx context.Context, queryer core.Queryer) ([]*schemas.Table, error) { args := []any{} s := "SELECT name FROM sqlite_master WHERE type='table'" @@ -462,7 +462,7 @@ func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*sche return tables, nil } -func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { +func (db *sqlite3) GetIndexes(ctx context.Context, queryer core.Queryer, tableName string) (map[string]*schemas.Index, error) { args := []any{tableName} s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" diff --git a/engine.go b/engine.go index b678e3fb..b6af4b20 100644 --- a/engine.go +++ b/engine.go @@ -308,14 +308,14 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session { } func (engine *Engine) loadTableInfo(ctx context.Context, table *schemas.Table) error { - colSeq, cols, err := engine.dialect.GetColumns(engine.db, ctx, table.Name) + colSeq, cols, err := engine.dialect.GetColumns(ctx, engine.db, table.Name) if err != nil { return err } for _, name := range colSeq { table.AddColumn(cols[name]) } - indexes, err := engine.dialect.GetIndexes(engine.db, ctx, table.Name) + indexes, err := engine.dialect.GetIndexes(ctx, engine.db, table.Name) if err != nil { return err } @@ -345,7 +345,7 @@ func (engine *Engine) loadTableInfo(ctx context.Context, table *schemas.Table) e // DBMetas Retrieve all tables, columns, indexes' informations from database. func (engine *Engine) DBMetas() ([]*schemas.Table, error) { - tables, err := engine.dialect.GetTables(engine.db, engine.defaultContext) + tables, err := engine.dialect.GetTables(engine.defaultContext, engine.db) if err != nil { return nil, err } diff --git a/internal/statements/cache.go b/internal/statements/cache.go deleted file mode 100644 index e6277fd3..00000000 --- a/internal/statements/cache.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package statements - -import ( - "fmt" - "strconv" - "strings" - - "xorm.io/xorm/v2/internal/utils" - "xorm.io/xorm/v2/schemas" -) - -// ConvertIDSQL converts SQL with id -func (statement *Statement) ConvertIDSQL(sqlStr string) string { - if statement.RefTable != nil { - cols := statement.RefTable.PKColumns() - if len(cols) == 0 { - return "" - } - - colstrs := statement.joinColumns(cols, false) - sqls := utils.SplitNNoCase(sqlStr, " from ", 2) - if len(sqls) != 2 { - return "" - } - - var b strings.Builder - b.WriteString("SELECT ") - pLimitN := statement.LimitN - if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL { - b.WriteString("TOP ") - b.WriteString(strconv.Itoa(*pLimitN)) - b.WriteString(" ") - } - b.WriteString(colstrs) - b.WriteString(" FROM ") - b.WriteString(sqls[1]) - - return b.String() - } - return "" -} - -// ConvertUpdateSQL converts update SQL -func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) { - if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { - return "", "" - } - - colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true) - sqls := utils.SplitNNoCase(sqlStr, "where", 2) - if len(sqls) != 2 { - if len(sqls) == 1 { - return sqls[0], fmt.Sprintf("SELECT %v FROM %v", - colstrs, statement.quote(statement.TableName())) - } - return "", "" - } - - whereStr := sqls[1] - - // TODO: for postgres only, if any other database? - var paraStr string - if statement.dialect.URI().DBType == schemas.POSTGRES { - paraStr = "$" - } else if statement.dialect.URI().DBType == schemas.MSSQL { - paraStr = ":" - } - - if paraStr != "" { - if strings.Contains(sqls[1], paraStr) { - dollers := strings.Split(sqls[1], paraStr) - whereStr = dollers[0] - for i, c := range dollers[1:] { - ccs := strings.SplitN(c, " ", 2) - whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1]) - } - } - } - - return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", - colstrs, statement.quote(statement.TableName()), - whereStr) -} diff --git a/internal/statements/update.go b/internal/statements/update.go index 886e834c..cd09e7a9 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -565,6 +565,12 @@ func (statement *Statement) writeSetColumns(colNames []string, args []any) func( } func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []any) error { + // write set + if _, err := fmt.Fprint(w, " SET "); err != nil { + return err + } + previousLen := w.Len() + if err := statement.writeSetColumns(colNames, args)(w); err != nil { return err } @@ -588,6 +594,11 @@ func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Va if err := statement.writeVersionIncrSet(w, v, setNumber > 0); err != nil { return err } + + // if no columns to be updated, return error + if previousLen == w.Len() { + return ErrNoColumnsTobeUpdated + } return nil } @@ -606,21 +617,10 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond return err } - // write set - if _, err := fmt.Fprint(updateWriter, " SET "); err != nil { - return err - } - previousLen := updateWriter.Len() - if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil { return err } - // if no columns to be updated, return error - if previousLen == updateWriter.Len() { - return ErrNoColumnsTobeUpdated - } - // write from if err := statement.writeUpdateFrom(updateWriter); err != nil { return err diff --git a/schema.go b/schema.go index 08b2868c..b3ede83a 100644 --- a/schema.go +++ b/schema.go @@ -150,7 +150,7 @@ func (session *Session) dropTable(beanOrTableName any) error { tableName := session.engine.TableName(beanOrTableName) sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true)) if !checkIfExist { - exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) + exist, err := session.engine.dialect.IsTableExist(session.ctx, session.getQueryer(), tableName) if err != nil { return err } @@ -197,7 +197,7 @@ func (session *Session) IsTableExist(beanOrTableName any) (bool, error) { } func (session *Session) isTableExist(tableName string) (bool, error) { - return session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) + return session.engine.dialect.IsTableExist(session.ctx, session.getQueryer(), tableName) } // IsTableEmpty if table have any records diff --git a/sync.go b/sync.go index e92d2084..20e12971 100644 --- a/sync.go +++ b/sync.go @@ -67,7 +67,7 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...any) (*SyncRe defer session.Close() } - tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx) + tables, err := engine.dialect.GetTables(session.ctx, session.getQueryer()) if err != nil { return nil, err }