From 7f22948be93e8fb6428b0cefeb30677ce083905b Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 7 Mar 2020 08:51:30 +0000 Subject: [PATCH] Improve dialect interface (#1578) Improve dialect interface Reviewed-on: https://gitea.com/xorm/xorm/pulls/1578 --- dialects/dialect.go | 101 ++++---------------------- dialects/mssql.go | 20 +---- dialects/mysql.go | 19 +---- dialects/oracle.go | 25 +------ dialects/postgres.go | 42 +++++++++-- dialects/sqlite3.go | 50 ++++++++++--- dialects/time.go | 2 +- engine.go | 8 +- internal/statements/cache.go | 6 +- internal/statements/query.go | 18 ++--- internal/statements/statement.go | 26 ++----- internal/statements/statement_args.go | 4 +- session_cols_test.go | 2 +- session_convert.go | 6 +- session_delete.go | 4 +- session_delete_test.go | 4 +- session_get_test.go | 6 +- session_insert.go | 13 ++-- session_query_test.go | 8 +- session_schema.go | 8 +- session_update.go | 12 +-- tags_test.go | 14 ++-- types_test.go | 4 +- 23 files changed, 162 insertions(+), 240 deletions(-) diff --git a/dialects/dialect.go b/dialects/dialect.go index d89f1ebe..b074d485 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -34,7 +34,6 @@ type Dialect interface { Init(*core.DB, *URI) error URI() *URI DB() *core.DB - DBType() schemas.DBType SQLType(*schemas.Column) string FormatBytes(b []byte) string DefaultSchema() string @@ -44,33 +43,26 @@ type Dialect interface { SetQuotePolicy(quotePolicy QuotePolicy) AutoIncrStr() string - SupportInsertMany() bool - SupportEngine() bool - SupportCharset() bool SupportDropIfExists() bool - IndexOnTable() bool - ShowCreateNull() bool + GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) IndexCheckSQL(tableName, idxName string) (string, []interface{}) - TableCheckSQL(tableName string) (string, []interface{}) - - IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error) - - CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string - DropTableSQL(tableName string) string CreateIndexSQL(tableName string, index *schemas.Index) string DropIndexSQL(tableName string, index *schemas.Index) string + GetTables(ctx context.Context) ([]*schemas.Table, error) + TableCheckSQL(tableName string) (string, []interface{}) + CreateTableSQL(table *schemas.Table, tableName string) string + DropTableSQL(tableName string) string + + GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) + IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error) AddColumnSQL(tableName string, col *schemas.Column) string ModifyColumnSQL(tableName string, col *schemas.Column) string ForUpdateSQL(query string) string - GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) - GetTables(ctx context.Context) ([]*schemas.Table, error) - GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) - Filters() []Filter SetParams(params map[string]string) } @@ -125,12 +117,10 @@ func (b *Base) String(col *schemas.Column) string { sql += "DEFAULT " + col.Default + " " } - if b.dialect.ShowCreateNull() { - if col.Nullable { - sql += "NULL " - } else { - sql += "NOT NULL " - } + if col.Nullable { + sql += "NULL " + } else { + sql += "NOT NULL " } return sql @@ -146,12 +136,10 @@ func (b *Base) StringNoPk(col *schemas.Column) string { sql += "DEFAULT " + col.Default + " " } - if b.dialect.ShowCreateNull() { - if col.Nullable { - sql += "NULL " - } else { - sql += "NOT NULL " - } + if col.Nullable { + sql += "NULL " + } else { + sql += "NOT NULL " } return sql @@ -161,10 +149,6 @@ func (b *Base) FormatBytes(bs []byte) string { return fmt.Sprintf("0x%x", bs) } -func (b *Base) ShowCreateNull() bool { - return true -} - func (db *Base) SupportDropIfExists() bool { return true } @@ -234,59 +218,6 @@ func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, db.StringNoPk(col)) } -func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string { - var sql string - sql = "CREATE TABLE IF NOT EXISTS " - if tableName == "" { - tableName = table.Name - } - - quoter := b.dialect.Quoter() - sql += quoter.Quote(tableName) - sql += " (" - - if len(table.ColumnsSeq()) > 0 { - pkList := table.PrimaryKeys - - for _, colName := range table.ColumnsSeq() { - col := table.GetColumn(colName) - if col.IsPrimaryKey && len(pkList) == 1 { - sql += b.String(col) - } else { - sql += b.StringNoPk(col) - } - sql = strings.TrimSpace(sql) - if b.DBType() == schemas.MYSQL && len(col.Comment) > 0 { - sql += " COMMENT '" + col.Comment + "'" - } - sql += ", " - } - - if len(pkList) > 1 { - sql += "PRIMARY KEY ( " - sql += quoter.Join(pkList, ",") - sql += " ), " - } - - sql = sql[:len(sql)-2] - } - sql += ")" - - if b.dialect.SupportEngine() && storeEngine != "" { - sql += " ENGINE=" + storeEngine - } - if b.dialect.SupportCharset() { - if len(charset) == 0 { - charset = b.dialect.URI().Charset - } - if len(charset) > 0 { - sql += " DEFAULT CHARSET " + charset - } - } - - return sql -} - func (b *Base) ForUpdateSQL(query string) string { return query + " FOR UPDATE" } diff --git a/dialects/mssql.go b/dialects/mssql.go index a2cbb361..06ab0b78 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -307,10 +307,6 @@ func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) { } } -func (db *mssql) SupportEngine() bool { - return false -} - func (db *mssql) AutoIncrStr() string { return "IDENTITY" } @@ -321,26 +317,12 @@ func (db *mssql) DropTableSQL(tableName string) string { "DROP TABLE \"%s\"", tableName, tableName) } -func (db *mssql) SupportCharset() bool { - return false -} - -func (db *mssql) IndexOnTable() bool { - return true -} - func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { args := []interface{}{idxName} sql := "select name from sysindexes where id=object_id('" + tableName + "') and name=?" return sql, args } -/*func (db *mssql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName, colName} - sql := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?` - return sql, args -}*/ - func (db *mssql) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?` @@ -509,7 +491,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? return indexes, nil } -func (db *mssql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string { +func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) string { var sql string if tableName == "" { tableName = table.Name diff --git a/dialects/mysql.go b/dialects/mysql.go index 5f36ed31..364f22b6 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -279,22 +279,10 @@ func (db *mysql) IsReserved(name string) bool { return ok } -func (db *mysql) SupportEngine() bool { - return true -} - func (db *mysql) AutoIncrStr() string { return "AUTO_INCREMENT" } -func (db *mysql) SupportCharset() bool { - return true -} - -func (db *mysql) IndexOnTable() bool { - return true -} - func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { args := []interface{}{db.uri.DBName, tableName, idxName} sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" @@ -524,7 +512,7 @@ func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]* return indexes, nil } -func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string { +func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) string { var sql = "CREATE TABLE IF NOT EXISTS " if tableName == "" { tableName = table.Name @@ -562,10 +550,11 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch } sql += ")" - if storeEngine != "" { - sql += " ENGINE=" + storeEngine + if table.StoreEngine != "" { + sql += " ENGINE=" + table.StoreEngine } + var charset = table.Charset if len(charset) == 0 { charset = db.URI().Charset } diff --git a/dialects/oracle.go b/dialects/oracle.go index d54ca80c..e0d83115 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -556,27 +556,15 @@ func (db *oracle) IsReserved(name string) bool { return ok } -func (db *oracle) SupportEngine() bool { - return false -} - -func (db *oracle) SupportCharset() bool { - return false -} - func (db *oracle) SupportDropIfExists() bool { return false } -func (db *oracle) IndexOnTable() bool { - return false -} - func (db *oracle) DropTableSQL(tableName string) string { return fmt.Sprintf("DROP TABLE `%s`", tableName) } -func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string { +func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) string { var sql = "CREATE TABLE " if tableName == "" { tableName = table.Name @@ -605,17 +593,6 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c } sql = sql[:len(sql)-2] + ")" - if db.SupportEngine() && storeEngine != "" { - sql += " ENGINE=" + storeEngine - } - if db.SupportCharset() { - if len(charset) == 0 { - charset = db.URI().Charset - } - if len(charset) > 0 { - sql += " DEFAULT CHARSET " + charset - } - } return sql } diff --git a/dialects/postgres.go b/dialects/postgres.go index 0049cee6..31cd49b6 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -897,16 +897,42 @@ func (db *postgres) AutoIncrStr() string { return "" } -func (db *postgres) SupportEngine() bool { - return false -} +func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) string { + var sql string + sql = "CREATE TABLE IF NOT EXISTS " + if tableName == "" { + tableName = table.Name + } -func (db *postgres) SupportCharset() bool { - return false -} + quoter := db.Quoter() + sql += quoter.Quote(tableName) + sql += " (" -func (db *postgres) IndexOnTable() bool { - return false + if len(table.ColumnsSeq()) > 0 { + pkList := table.PrimaryKeys + + for _, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + if col.IsPrimaryKey && len(pkList) == 1 { + sql += db.String(col) + } else { + sql += db.StringNoPk(col) + } + sql = strings.TrimSpace(sql) + sql += ", " + } + + if len(pkList) > 1 { + sql += "PRIMARY KEY ( " + sql += quoter.Join(pkList, ",") + sql += " ), " + } + + sql = sql[:len(sql)-2] + } + sql += ")" + + return sql } func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 4af9b27e..212c5a8e 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -224,18 +224,6 @@ func (db *sqlite3) AutoIncrStr() string { return "AUTOINCREMENT" } -func (db *sqlite3) SupportEngine() bool { - return false -} - -func (db *sqlite3) SupportCharset() bool { - return false -} - -func (db *sqlite3) IndexOnTable() bool { - return false -} - func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { args := []interface{}{idxName} return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args @@ -261,6 +249,44 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } +func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) string { + var sql string + sql = "CREATE TABLE IF NOT EXISTS " + if tableName == "" { + tableName = table.Name + } + + quoter := db.Quoter() + sql += quoter.Quote(tableName) + sql += " (" + + if len(table.ColumnsSeq()) > 0 { + pkList := table.PrimaryKeys + + for _, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + if col.IsPrimaryKey && len(pkList) == 1 { + sql += db.String(col) + } else { + sql += db.StringNoPk(col) + } + sql = strings.TrimSpace(sql) + sql += ", " + } + + if len(pkList) > 1 { + sql += "PRIMARY KEY ( " + sql += quoter.Join(pkList, ",") + sql += " ), " + } + + sql = sql[:len(sql)-2] + } + sql += ")" + + return sql +} + func (db *sqlite3) ForUpdateSQL(query string) string { return query } diff --git a/dialects/time.go b/dialects/time.go index 022dc960..b0394745 100644 --- a/dialects/time.go +++ b/dialects/time.go @@ -21,7 +21,7 @@ func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{} case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar. v = t.Format("2006-01-02 15:04:05") case schemas.TimeStampz: - if dialect.DBType() == schemas.MSSQL { + if dialect.URI().DBType == schemas.MSSQL { v = t.Format("2006-01-02T15:04:05.9999999Z07:00") } else { v = t.Format(time.RFC3339Nano) diff --git a/engine.go b/engine.go index c657cd1f..c330e9f5 100644 --- a/engine.go +++ b/engine.go @@ -365,7 +365,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch var distDBName string if len(tp) == 0 { dialect = engine.dialect - distDBName = string(engine.dialect.DBType()) + distDBName = string(engine.dialect.URI().DBType) } else { dialect = dialects.QueryDialect(tp[0]) if dialect == nil { @@ -376,7 +376,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } _, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm v%s %s, from %s to %s*/\n\n", - Version, time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.DBType(), strings.ToUpper(distDBName))) + Version, time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, strings.ToUpper(distDBName))) if err != nil { return err } @@ -388,7 +388,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } } - _, err = io.WriteString(w, dialect.CreateTableSQL(table, "", table.StoreEngine, "")+";\n") + _, err = io.WriteString(w, dialect.CreateTableSQL(table, "")+";\n") if err != nil { return err } @@ -486,7 +486,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } // FIXME: Hack for postgres - if dialect.DBType() == schemas.POSTGRES && table.AutoIncrColumn() != nil { + if dialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n") if err != nil { return err diff --git a/internal/statements/cache.go b/internal/statements/cache.go index d7f72318..cb33df08 100644 --- a/internal/statements/cache.go +++ b/internal/statements/cache.go @@ -27,7 +27,7 @@ func (statement *Statement) ConvertIDSQL(sqlStr string) string { var top string pLimitN := statement.LimitN - if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL { + if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL { top = fmt.Sprintf("TOP %d ", *pLimitN) } @@ -56,9 +56,9 @@ func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) { // TODO: for postgres only, if any other database? var paraStr string - if statement.dialect.DBType() == schemas.POSTGRES { + if statement.dialect.URI().DBType == schemas.POSTGRES { paraStr = "$" - } else if statement.dialect.DBType() == schemas.MSSQL { + } else if statement.dialect.URI().DBType == schemas.MSSQL { paraStr = ":" } diff --git a/internal/statements/query.go b/internal/statements/query.go index a058f752..8f7aeebb 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -201,14 +201,14 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB whereStr = " WHERE " + condSQL } - if dialect.DBType() == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { + if dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { fromStr += statement.TableName() } else { fromStr += quote(statement.TableName()) } if statement.TableAlias != "" { - if dialect.DBType() == schemas.ORACLE { + if dialect.URI().DBType == schemas.ORACLE { fromStr += " " + quote(statement.TableAlias) } else { fromStr += " AS " + quote(statement.TableAlias) @@ -219,7 +219,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB } pLimitN := statement.LimitN - if dialect.DBType() == schemas.MSSQL { + if dialect.URI().DBType == schemas.MSSQL { if pLimitN != nil { LimitNValue := *pLimitN top = fmt.Sprintf("TOP %d ", LimitNValue) @@ -281,7 +281,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) } if needLimit { - if dialect.DBType() != schemas.MSSQL && dialect.DBType() != schemas.ORACLE { + if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { if statement.Start > 0 { if pLimitN != nil { fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start) @@ -291,7 +291,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB } else if pLimitN != nil { fmt.Fprint(&buf, " LIMIT ", *pLimitN) } - } else if dialect.DBType() == schemas.ORACLE { + } else if dialect.URI().DBType == schemas.ORACLE { if statement.Start != 0 || pLimitN != nil { oldString := buf.String() buf.Reset() @@ -337,18 +337,18 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac return "", nil, err } - if statement.dialect.DBType() == schemas.MSSQL { + if statement.dialect.URI().DBType == schemas.MSSQL { sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL) - } else if statement.dialect.DBType() == schemas.ORACLE { + } else if statement.dialect.URI().DBType == schemas.ORACLE { sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL) } else { sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL) } args = condArgs } else { - if statement.dialect.DBType() == schemas.MSSQL { + if statement.dialect.URI().DBType == schemas.MSSQL { sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) - } else if statement.dialect.DBType() == schemas.ORACLE { + } else if statement.dialect.URI().DBType == schemas.ORACLE { sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) } else { sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr) diff --git a/internal/statements/statement.go b/internal/statements/statement.go index a2a356ff..d6dd58b1 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -641,8 +641,9 @@ func (statement *Statement) genColumnStr() string { } func (statement *Statement) GenCreateTableSQL() string { - return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName(), - statement.StoreEngine, statement.Charset) + statement.RefTable.StoreEngine = statement.StoreEngine + statement.RefTable.Charset = statement.Charset + return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName()) } func (statement *Statement) GenIndexSQL() []string { @@ -680,20 +681,8 @@ func (statement *Statement) GenDelIndexSQL() []string { if idx > -1 { tbName = tbName[idx+1:] } - idxPrefixName := strings.Replace(tbName, `"`, "", -1) - idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1) - for idxName, index := range statement.RefTable.Indexes { - var rIdxName string - if index.Type == schemas.UniqueType { - rIdxName = uniqueName(idxPrefixName, idxName) - } else if index.Type == schemas.IndexType { - rIdxName = utils.IndexName(idxPrefixName, idxName) - } - sql := fmt.Sprintf("DROP INDEX %v", statement.quote(dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), rIdxName, true))) - if statement.dialect.IndexOnTable() { - sql += fmt.Sprintf(" ON %v", statement.quote(tbName)) - } - sqls = append(sqls, sql) + for _, index := range statement.RefTable.Indexes { + sqls = append(sqls, statement.dialect.DropIndexSQL(tbName, index)) } return sqls } @@ -714,7 +703,8 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, continue } - if statement.dialect.DBType() == schemas.MSSQL && (col.SQLType.Name == schemas.Text || col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) { + if statement.dialect.URI().DBType == schemas.MSSQL && (col.SQLType.Name == schemas.Text || + col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) { continue } if col.SQLType.IsJson() { @@ -1002,7 +992,7 @@ func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { cond = builder.Eq{colName: 0} } else { // FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value. - if statement.dialect.DBType() != schemas.MSSQL { + if statement.dialect.URI().DBType != schemas.MSSQL { cond = builder.Eq{colName: utils.ZeroTime1} } } diff --git a/internal/statements/statement_args.go b/internal/statements/statement_args.go index 8eee246e..7d1ef9eb 100644 --- a/internal/statements/statement_args.go +++ b/internal/statements/statement_args.go @@ -80,7 +80,7 @@ const insertSelectPlaceHolder = true func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error { switch argv := arg.(type) { case bool: - if statement.dialect.DBType() == schemas.MSSQL { + if statement.dialect.URI().DBType == schemas.MSSQL { if argv { if _, err := w.WriteString("1"); err != nil { return err @@ -119,7 +119,7 @@ func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) er w.Append(arg) } else { var convertFunc = convertStringSingleQuote - if statement.dialect.DBType() == schemas.MYSQL { + if statement.dialect.URI().DBType == schemas.MYSQL { convertFunc = convertString } if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil { diff --git a/session_cols_test.go b/session_cols_test.go index 58b4e841..2847cf35 100644 --- a/session_cols_test.go +++ b/session_cols_test.go @@ -45,7 +45,7 @@ func TestSetExpr(t *testing.T) { assert.EqualValues(t, 1, cnt) var not = "NOT" - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { not = "~" } cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr)) diff --git a/session_convert.go b/session_convert.go index 0776bc45..28866d4d 100644 --- a/session_convert.go +++ b/session_convert.go @@ -65,7 +65,7 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time } sdata = strings.TrimSpace(sdata) - if session.engine.dialect.DBType() == schemas.MYSQL && len(sdata) > 8 { + if session.engine.dialect.URI().DBType == schemas.MYSQL && len(sdata) > 8 { sdata = sdata[len(sdata)-8:] } @@ -159,7 +159,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val var err error // for mysql, when use bit, it returned \x01 if col.SQLType.Name == schemas.Bit && - session.engine.dialect.DBType() == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API + session.engine.dialect.URI().DBType == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API if len(data) == 1 { x = int64(data[0]) } else { @@ -399,7 +399,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val var err error // for mysql, when use bit, it returned \x01 if col.SQLType.Name == schemas.Bit && - session.engine.dialect.DBType() == schemas.MYSQL { + session.engine.dialect.URI().DBType == schemas.MYSQL { if len(data) == 1 { x = int32(data[0]) } else { diff --git a/session_delete.go b/session_delete.go index eb5e2aea..ff28867a 100644 --- a/session_delete.go +++ b/session_delete.go @@ -135,7 +135,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { } if len(orderSQL) > 0 { - switch session.engine.dialect.DBType() { + switch session.engine.dialect.URI().DBType { case schemas.POSTGRES: inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) if len(condSQL) > 0 { @@ -176,7 +176,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { condSQL) if len(orderSQL) > 0 { - switch session.engine.dialect.DBType() { + switch session.engine.dialect.URI().DBType { case schemas.POSTGRES: inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) if len(condSQL) > 0 { diff --git a/session_delete_test.go b/session_delete_test.go index d7fb3110..6fba860b 100644 --- a/session_delete_test.go +++ b/session_delete_test.go @@ -28,7 +28,7 @@ func TestDelete(t *testing.T) { defer session.Close() var err error - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) _, err = session.Exec("SET IDENTITY_INSERT userinfo_delete ON") @@ -40,7 +40,7 @@ func TestDelete(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Commit() assert.NoError(t, err) } diff --git a/session_get_test.go b/session_get_test.go index 7e10bf54..b83a118b 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -154,7 +154,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) var money2 float64 - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2) } else { has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2) @@ -234,7 +234,7 @@ func TestGetStruct(t *testing.T) { defer session.Close() var err error - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) _, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON") @@ -243,7 +243,7 @@ func TestGetStruct(t *testing.T) { cnt, err := session.Insert(&UserinfoGet{Uid: 2}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Commit() assert.NoError(t, err) } diff --git a/session_insert.go b/session_insert.go index 4662e25a..e5368571 100644 --- a/session_insert.go +++ b/session_insert.go @@ -254,7 +254,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error quoter := session.engine.dialect.Quoter() var sql string colStr := quoter.Join(colNames, ",") - if session.engine.dialect.DBType() == schemas.ORACLE { + if session.engine.dialect.URI().DBType == schemas.ORACLE { temp := fmt.Sprintf(") INTO %s (%v) VALUES (", quoter.Quote(tableName), colStr) @@ -361,7 +361,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { var tableName = session.statement.TableName() var output string - if session.engine.dialect.DBType() == schemas.MSSQL && len(table.AutoIncrement) > 0 { + if session.engine.dialect.URI().DBType == schemas.MSSQL && len(table.AutoIncrement) > 0 { output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) } @@ -371,7 +371,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } if len(colPlaces) <= 0 { - if session.engine.dialect.DBType() == schemas.MYSQL { + if session.engine.dialect.URI().DBType == schemas.MYSQL { if _, err := buf.WriteString(" VALUES ()"); err != nil { return 0, err } @@ -433,7 +433,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } } - if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == schemas.POSTGRES { + if len(table.AutoIncrement) > 0 && session.engine.dialect.URI().DBType == schemas.POSTGRES { if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil { return 0, err } @@ -472,7 +472,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. - if session.engine.dialect.DBType() == schemas.ORACLE && len(table.AutoIncrement) > 0 { + if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 { res, err := session.queryBytes("select seq_atable.currval from dual", args...) if err != nil { return 0, err @@ -513,7 +513,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue.Set(int64ToIntValue(id, aiValue.Type())) return 1, nil - } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == schemas.POSTGRES || session.engine.dialect.DBType() == schemas.MSSQL) { + } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || + session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) if err != nil { diff --git a/session_query_test.go b/session_query_test.go index e4635d64..bed62be0 100644 --- a/session_query_test.go +++ b/session_query_test.go @@ -207,7 +207,7 @@ func TestQueryStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0]["id"]) - if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.EqualValues(t, "false", records[0]["msg"]) } else { assert.EqualValues(t, "0", records[0]["msg"]) @@ -217,7 +217,7 @@ func TestQueryStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0]["id"]) - if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.EqualValues(t, "false", records[0]["msg"]) } else { assert.EqualValues(t, "0", records[0]["msg"]) @@ -244,7 +244,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0][0]) - if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.EqualValues(t, "false", records[0][1]) } else { assert.EqualValues(t, "0", records[0][1]) @@ -254,7 +254,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0][0]) - if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.EqualValues(t, "false", records[0][1]) } else { assert.EqualValues(t, "0", records[0][1]) diff --git a/session_schema.go b/session_schema.go index 3617a6b8..6d363521 100644 --- a/session_schema.go +++ b/session_schema.go @@ -313,8 +313,8 @@ func (session *Session) Sync2(beans ...interface{}) error { if expectedType == schemas.Text && strings.HasPrefix(curType, schemas.Varchar) { // currently only support mysql & postgres - if engine.dialect.DBType() == schemas.MYSQL || - engine.dialect.DBType() == schemas.POSTGRES { + if engine.dialect.URI().DBType == schemas.MYSQL || + engine.dialect.URI().DBType == schemas.POSTGRES { engine.logger.Infof("Table %s column %s change type from %s to %s\n", tbNameWithSchema, col.Name, curType, expectedType) _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) @@ -323,7 +323,7 @@ func (session *Session) Sync2(beans ...interface{}) error { tbNameWithSchema, col.Name, curType, expectedType) } } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) { - if engine.dialect.DBType() == schemas.MYSQL { + if engine.dialect.URI().DBType == schemas.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", tbNameWithSchema, col.Name, oriCol.Length, col.Length) @@ -337,7 +337,7 @@ func (session *Session) Sync2(beans ...interface{}) error { } } } else if expectedType == schemas.Varchar { - if engine.dialect.DBType() == schemas.MYSQL { + if engine.dialect.URI().DBType == schemas.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", tbNameWithSchema, col.Name, oriCol.Length, col.Length) diff --git a/session_update.go b/session_update.go index 551b8167..aa4718b6 100644 --- a/session_update.go +++ b/session_update.go @@ -335,9 +335,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var top string if st.LimitN != nil { limitValue := *st.LimitN - if session.engine.dialect.DBType() == schemas.MYSQL { + if session.engine.dialect.URI().DBType == schemas.MYSQL { condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) - } else if session.engine.dialect.DBType() == schemas.SQLITE { + } else if session.engine.dialect.URI().DBType == schemas.SQLITE { tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) @@ -348,7 +348,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if session.engine.dialect.DBType() == schemas.POSTGRES { + } else if session.engine.dialect.URI().DBType == schemas.POSTGRES { tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) @@ -360,8 +360,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if session.engine.dialect.DBType() == schemas.MSSQL { - if st.OrderStr != "" && session.engine.dialect.DBType() == schemas.MSSQL && + } else if session.engine.dialect.URI().DBType == schemas.MSSQL { + if st.OrderStr != "" && session.engine.dialect.URI().DBType == schemas.MSSQL && table != nil && len(table.PrimaryKeys) == 1 { cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], @@ -387,7 +387,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var tableAlias = session.engine.Quote(tableName) var fromSQL string if session.statement.TableAlias != "" { - switch session.engine.dialect.DBType() { + switch session.engine.dialect.URI().DBType { case schemas.MSSQL: fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias) tableAlias = session.statement.TableAlias diff --git a/tags_test.go b/tags_test.go index 4473a12f..ff578def 100644 --- a/tags_test.go +++ b/tags_test.go @@ -238,7 +238,7 @@ func TestExtends2(t *testing.T) { defer session.Close() // MSSQL deny insert identity column excep declare as below - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) _, err = session.Exec("SET IDENTITY_INSERT message ON") @@ -248,7 +248,7 @@ func TestExtends2(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Commit() assert.NoError(t, err) } @@ -299,7 +299,7 @@ func TestExtends3(t *testing.T) { defer session.Close() // MSSQL deny insert identity column excep declare as below - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) _, err = session.Exec("SET IDENTITY_INSERT message ON") @@ -308,7 +308,7 @@ func TestExtends3(t *testing.T) { _, err = session.Insert(&msg) assert.NoError(t, err) - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Commit() assert.NoError(t, err) } @@ -362,7 +362,7 @@ func TestExtends4(t *testing.T) { defer session.Close() // MSSQL deny insert identity column excep declare as below - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) _, err = session.Exec("SET IDENTITY_INSERT message ON") @@ -371,7 +371,7 @@ func TestExtends4(t *testing.T) { _, err = session.Insert(&msg) assert.NoError(t, err) - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Commit() assert.NoError(t, err) } @@ -800,7 +800,7 @@ func TestAutoIncrTag(t *testing.T) { func TestTagComment(t *testing.T) { assert.NoError(t, prepareEngine()) // FIXME: only support mysql - if testEngine.Dialect().DBType() != schemas.MYSQL { + if testEngine.Dialect().URI().DBType != schemas.MYSQL { return } diff --git a/types_test.go b/types_test.go index d8fd8309..77407e98 100644 --- a/types_test.go +++ b/types_test.go @@ -314,7 +314,7 @@ func TestCustomType2(t *testing.T) { session := testEngine.NewSession() defer session.Close() - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) _, err = session.Exec("set IDENTITY_INSERT " + tableName + " on") @@ -325,7 +325,7 @@ func TestCustomType2(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - if testEngine.Dialect().DBType() == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Commit() assert.NoError(t, err) }