Improve dialect interface (#1579)

Fix bug

Improve dialect interface

Reviewed-on: https://gitea.com/xorm/xorm/pulls/1579
This commit is contained in:
Lunny Xiao 2020-03-07 10:00:05 +00:00
parent 7f22948be9
commit ccf65397e8
10 changed files with 46 additions and 111 deletions

View File

@ -43,8 +43,6 @@ type Dialect interface {
SetQuotePolicy(quotePolicy QuotePolicy) SetQuotePolicy(quotePolicy QuotePolicy)
AutoIncrStr() string AutoIncrStr() string
SupportInsertMany() bool
SupportDropIfExists() bool
GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error)
IndexCheckSQL(tableName, idxName string) (string, []interface{}) IndexCheckSQL(tableName, idxName string) (string, []interface{})
@ -52,9 +50,9 @@ type Dialect interface {
DropIndexSQL(tableName string, index *schemas.Index) string DropIndexSQL(tableName string, index *schemas.Index) string
GetTables(ctx context.Context) ([]*schemas.Table, error) GetTables(ctx context.Context) ([]*schemas.Table, error)
TableCheckSQL(tableName string) (string, []interface{}) IsTableExist(ctx context.Context, tableName string) (bool, error)
CreateTableSQL(table *schemas.Table, tableName string) string CreateTableSQL(table *schemas.Table, tableName string) (string, bool)
DropTableSQL(tableName string) string DropTableSQL(tableName string) (string, bool)
GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error) IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error)
@ -149,13 +147,9 @@ func (b *Base) FormatBytes(bs []byte) string {
return fmt.Sprintf("0x%x", bs) return fmt.Sprintf("0x%x", bs)
} }
func (db *Base) SupportDropIfExists() bool { func (db *Base) DropTableSQL(tableName string) (string, bool) {
return true
}
func (db *Base) DropTableSQL(tableName string) string {
quote := db.dialect.Quoter().Quote quote := db.dialect.Quoter().Quote
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)) return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true
} }
func (db *Base) HasRecords(ctx context.Context, query string, args ...interface{}) (bool, error) { func (db *Base) HasRecords(ctx context.Context, query string, args ...interface{}) (bool, error) {

View File

@ -281,10 +281,6 @@ func (db *mssql) SQLType(c *schemas.Column) string {
return res return res
} }
func (db *mssql) SupportInsertMany() bool {
return true
}
func (db *mssql) IsReserved(name string) bool { func (db *mssql) IsReserved(name string) bool {
_, ok := mssqlReservedWords[strings.ToUpper(name)] _, ok := mssqlReservedWords[strings.ToUpper(name)]
return ok return ok
@ -311,10 +307,10 @@ func (db *mssql) AutoIncrStr() string {
return "IDENTITY" return "IDENTITY"
} }
func (db *mssql) DropTableSQL(tableName string) string { func (db *mssql) DropTableSQL(tableName string) (string, bool) {
return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+
"object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+ "object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+
"DROP TABLE \"%s\"", tableName, tableName) "DROP TABLE \"%s\"", tableName, tableName), true
} }
func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
@ -329,10 +325,9 @@ func (db *mssql) IsColumnExist(ctx context.Context, tableName, colName string) (
return db.HasRecords(ctx, query, tableName, colName) return db.HasRecords(ctx, query, tableName, colName)
} }
func (db *mssql) TableCheckSQL(tableName string) (string, []interface{}) { func (db *mssql) IsTableExist(ctx context.Context, tableName string) (bool, error) {
args := []interface{}{}
sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1" sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1"
return sql, args return db.HasRecords(ctx, sql)
} }
func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
@ -491,7 +486,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
return indexes, nil return indexes, nil
} }
func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) string { func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) {
var sql string var sql string
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
@ -522,7 +517,7 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) string {
sql = sql[:len(sql)-2] + ")" sql = sql[:len(sql)-2] + ")"
sql += ";" sql += ";"
return sql return sql, true
} }
func (db *mssql) ForUpdateSQL(query string) string { func (db *mssql) ForUpdateSQL(query string) string {

View File

@ -270,10 +270,6 @@ func (db *mysql) SQLType(c *schemas.Column) string {
return res return res
} }
func (db *mysql) SupportInsertMany() bool {
return true
}
func (db *mysql) IsReserved(name string) bool { func (db *mysql) IsReserved(name string) bool {
_, ok := mysqlReservedWords[strings.ToUpper(name)] _, ok := mysqlReservedWords[strings.ToUpper(name)]
return ok return ok
@ -290,10 +286,9 @@ func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{}
return sql, args return sql, args
} }
func (db *mysql) TableCheckSQL(tableName string) (string, []interface{}) { func (db *mysql) IsTableExist(ctx context.Context, tableName string) (bool, error) {
args := []interface{}{db.uri.DBName, tableName}
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
return sql, args return db.HasRecords(ctx, sql, db.uri.DBName, tableName)
} }
func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
@ -512,7 +507,7 @@ func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*
return indexes, nil return indexes, nil
} }
func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) string { func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) {
var sql = "CREATE TABLE IF NOT EXISTS " var sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
@ -565,7 +560,7 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) string {
if db.rowFormat != "" { if db.rowFormat != "" {
sql += " ROW_FORMAT=" + db.rowFormat sql += " ROW_FORMAT=" + db.rowFormat
} }
return sql return sql, true
} }
func (db *mysql) Filters() []Filter { func (db *mysql) Filters() []Filter {

View File

@ -547,24 +547,16 @@ func (db *oracle) AutoIncrStr() string {
return "AUTO_INCREMENT" return "AUTO_INCREMENT"
} }
func (db *oracle) SupportInsertMany() bool {
return true
}
func (db *oracle) IsReserved(name string) bool { func (db *oracle) IsReserved(name string) bool {
_, ok := oracleReservedWords[strings.ToUpper(name)] _, ok := oracleReservedWords[strings.ToUpper(name)]
return ok return ok
} }
func (db *oracle) SupportDropIfExists() bool { func (db *oracle) DropTableSQL(tableName string) (string, bool) {
return false return fmt.Sprintf("DROP TABLE `%s`", tableName), false
} }
func (db *oracle) DropTableSQL(tableName string) string { func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) {
return fmt.Sprintf("DROP TABLE `%s`", tableName)
}
func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) string {
var sql = "CREATE TABLE " var sql = "CREATE TABLE "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
@ -593,7 +585,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) string
} }
sql = sql[:len(sql)-2] + ")" sql = sql[:len(sql)-2] + ")"
return sql return sql, false
} }
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
@ -619,26 +611,15 @@ func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []interface{
`WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args `WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args
} }
func (db *oracle) TableCheckSQL(tableName string) (string, []interface{}) { func (db *oracle) IsTableExist(ctx context.Context, tableName string) (bool, error) {
args := []interface{}{tableName} return db.HasRecords(ctx, `SELECT table_name FROM user_tables WHERE table_name = :1`, tableName)
return `SELECT table_name FROM user_tables WHERE table_name = :1`, args
} }
func (db *oracle) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { func (db *oracle) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) {
args := []interface{}{tableName, colName} args := []interface{}{tableName, colName}
query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" + query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" +
" AND column_name = :2" " AND column_name = :2"
return db.HasRecords(ctx, query, args...)
rows, err := db.DB().QueryContext(ctx, query, args...)
if err != nil {
return false, err
}
defer rows.Close()
if rows.Next() {
return true, nil
}
return false, nil
} }
func (db *oracle) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func (db *oracle) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {

View File

@ -884,10 +884,6 @@ func (db *postgres) SQLType(c *schemas.Column) string {
return res return res
} }
func (db *postgres) SupportInsertMany() bool {
return true
}
func (db *postgres) IsReserved(name string) bool { func (db *postgres) IsReserved(name string) bool {
_, ok := postgresReservedWords[strings.ToUpper(name)] _, ok := postgresReservedWords[strings.ToUpper(name)]
return ok return ok
@ -897,7 +893,7 @@ func (db *postgres) AutoIncrStr() string {
return "" return ""
} }
func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) string { func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) {
var sql string var sql string
sql = "CREATE TABLE IF NOT EXISTS " sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" { if tableName == "" {
@ -932,7 +928,7 @@ func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) strin
} }
sql += ")" sql += ")"
return sql return sql, true
} }
func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
@ -946,14 +942,13 @@ func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interfac
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
} }
func (db *postgres) TableCheckSQL(tableName string) (string, []interface{}) { func (db *postgres) IsTableExist(ctx context.Context, tableName string) (bool, error) {
if len(db.uri.Schema) == 0 { if len(db.uri.Schema) == 0 {
args := []interface{}{tableName} return db.HasRecords(ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName)
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
} }
args := []interface{}{db.uri.Schema, tableName} return db.HasRecords(ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`,
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args db.uri.Schema, tableName)
} }
func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string {

View File

@ -211,10 +211,6 @@ func (db *sqlite3) FormatBytes(bs []byte) string {
return fmt.Sprintf("X'%x'", bs) return fmt.Sprintf("X'%x'", bs)
} }
func (db *sqlite3) SupportInsertMany() bool {
return true
}
func (db *sqlite3) IsReserved(name string) bool { func (db *sqlite3) IsReserved(name string) bool {
_, ok := sqlite3ReservedWords[strings.ToUpper(name)] _, ok := sqlite3ReservedWords[strings.ToUpper(name)]
return ok return ok
@ -229,9 +225,8 @@ func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []interface
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
} }
func (db *sqlite3) TableCheckSQL(tableName string) (string, []interface{}) { func (db *sqlite3) IsTableExist(ctx context.Context, tableName string) (bool, error) {
args := []interface{}{tableName} return db.HasRecords(ctx, "SELECT name FROM sqlite_master WHERE type='table' and name = ?", tableName)
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
} }
func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
@ -249,7 +244,7 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName))
} }
func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) string { func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) (string, bool) {
var sql string var sql string
sql = "CREATE TABLE IF NOT EXISTS " sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" { if tableName == "" {
@ -284,7 +279,7 @@ func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) string
} }
sql += ")" sql += ")"
return sql return sql, true
} }
func (db *sqlite3) ForUpdateSQL(query string) string { func (db *sqlite3) ForUpdateSQL(query string) string {

View File

@ -125,14 +125,6 @@ func (engine *Engine) SetColumnMapper(mapper names.Mapper) {
engine.tagParser.SetColumnMapper(mapper) engine.tagParser.SetColumnMapper(mapper)
} }
// SupportInsertMany If engine's database support batch insert records like
// "insert into user values (name, age), (name, age)".
// When the return is ture, then engine.Insert(&users) will
// generate batch sql and exeute.
func (engine *Engine) SupportInsertMany() bool {
return engine.dialect.SupportInsertMany()
}
// Quote Use QuoteStr quote the string sql // Quote Use QuoteStr quote the string sql
func (engine *Engine) Quote(value string) string { func (engine *Engine) Quote(value string) string {
value = strings.TrimSpace(value) value = strings.TrimSpace(value)
@ -388,7 +380,8 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return err return err
} }
} }
_, err = io.WriteString(w, dialect.CreateTableSQL(table, "")+";\n") s, _ := dialect.CreateTableSQL(table, "")
_, err = io.WriteString(w, s+";\n")
if err != nil { if err != nil {
return err return err
} }

View File

@ -643,7 +643,8 @@ func (statement *Statement) genColumnStr() string {
func (statement *Statement) GenCreateTableSQL() string { func (statement *Statement) GenCreateTableSQL() string {
statement.RefTable.StoreEngine = statement.StoreEngine statement.RefTable.StoreEngine = statement.StoreEngine
statement.RefTable.Charset = statement.Charset statement.RefTable.Charset = statement.Charset
return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName()) s, _ := statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName())
return s
} }
func (statement *Statement) GenIndexSQL() []string { func (statement *Statement) GenIndexSQL() []string {

View File

@ -75,21 +75,11 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
return 0, ErrNoElementsOnSlice return 0, ErrNoElementsOnSlice
} }
if session.engine.SupportInsertMany() { cnt, err := session.innerInsertMulti(bean)
cnt, err := session.innerInsertMulti(bean) if err != nil {
if err != nil { return affected, err
return affected, err
}
affected += cnt
} else {
for i := 0; i < size; i++ {
cnt, err := session.innerInsert(sliceValue.Index(i).Interface())
if err != nil {
return affected, err
}
affected += cnt
}
} }
affected += cnt
} else { } else {
cnt, err := session.innerInsert(bean) cnt, err := session.innerInsert(bean)
if err != nil { if err != nil {

View File

@ -124,18 +124,16 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
func (session *Session) dropTable(beanOrTableName interface{}) error { func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName := session.engine.TableName(beanOrTableName) tableName := session.engine.TableName(beanOrTableName)
var needDrop = true sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true))
if !session.engine.dialect.SupportDropIfExists() { if !checkIfExist {
sqlStr, args := session.engine.dialect.TableCheckSQL(tableName) exist, err := session.engine.dialect.IsTableExist(session.ctx, tableName)
results, err := session.queryBytes(sqlStr, args...)
if err != nil { if err != nil {
return err return err
} }
needDrop = len(results) > 0 checkIfExist = exist
} }
if needDrop { if checkIfExist {
sqlStr := session.engine.Dialect().DropTableSQL(session.engine.TableName(tableName, true))
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
return err return err
} }
@ -154,9 +152,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
} }
func (session *Session) isTableExist(tableName string) (bool, error) { func (session *Session) isTableExist(tableName string) (bool, error) {
sqlStr, args := session.engine.dialect.TableCheckSQL(tableName) return session.engine.dialect.IsTableExist(session.ctx, tableName)
results, err := session.queryBytes(sqlStr, args...)
return len(results) > 0, err
} }
// IsTableEmpty if table have any records // IsTableEmpty if table have any records