Improve dialect interface (#1578)

Improve dialect interface

Reviewed-on: https://gitea.com/xorm/xorm/pulls/1578
This commit is contained in:
Lunny Xiao 2020-03-07 08:51:30 +00:00
parent 0f166d82da
commit 7f22948be9
23 changed files with 162 additions and 240 deletions

View File

@ -34,7 +34,6 @@ type Dialect interface {
Init(*core.DB, *URI) error Init(*core.DB, *URI) error
URI() *URI URI() *URI
DB() *core.DB DB() *core.DB
DBType() schemas.DBType
SQLType(*schemas.Column) string SQLType(*schemas.Column) string
FormatBytes(b []byte) string FormatBytes(b []byte) string
DefaultSchema() string DefaultSchema() string
@ -44,33 +43,26 @@ type Dialect interface {
SetQuotePolicy(quotePolicy QuotePolicy) SetQuotePolicy(quotePolicy QuotePolicy)
AutoIncrStr() string AutoIncrStr() string
SupportInsertMany() bool SupportInsertMany() bool
SupportEngine() bool
SupportCharset() bool
SupportDropIfExists() bool SupportDropIfExists() bool
IndexOnTable() bool
ShowCreateNull() bool
GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error)
IndexCheckSQL(tableName, idxName string) (string, []interface{}) IndexCheckSQL(tableName, idxName string) (string, []interface{})
TableCheckSQL(tableName string) (string, []interface{})
IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error)
CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string
DropTableSQL(tableName string) string
CreateIndexSQL(tableName string, index *schemas.Index) string CreateIndexSQL(tableName string, index *schemas.Index) string
DropIndexSQL(tableName string, index *schemas.Index) string DropIndexSQL(tableName string, index *schemas.Index) string
GetTables(ctx context.Context) ([]*schemas.Table, error)
TableCheckSQL(tableName string) (string, []interface{})
CreateTableSQL(table *schemas.Table, tableName string) string
DropTableSQL(tableName string) string
GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error)
AddColumnSQL(tableName string, col *schemas.Column) string AddColumnSQL(tableName string, col *schemas.Column) string
ModifyColumnSQL(tableName string, col *schemas.Column) string ModifyColumnSQL(tableName string, col *schemas.Column) string
ForUpdateSQL(query string) string ForUpdateSQL(query string) string
GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
GetTables(ctx context.Context) ([]*schemas.Table, error)
GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error)
Filters() []Filter Filters() []Filter
SetParams(params map[string]string) SetParams(params map[string]string)
} }
@ -125,13 +117,11 @@ func (b *Base) String(col *schemas.Column) string {
sql += "DEFAULT " + col.Default + " " sql += "DEFAULT " + col.Default + " "
} }
if b.dialect.ShowCreateNull() {
if col.Nullable { if col.Nullable {
sql += "NULL " sql += "NULL "
} else { } else {
sql += "NOT NULL " sql += "NOT NULL "
} }
}
return sql return sql
} }
@ -146,13 +136,11 @@ func (b *Base) StringNoPk(col *schemas.Column) string {
sql += "DEFAULT " + col.Default + " " sql += "DEFAULT " + col.Default + " "
} }
if b.dialect.ShowCreateNull() {
if col.Nullable { if col.Nullable {
sql += "NULL " sql += "NULL "
} else { } else {
sql += "NOT NULL " sql += "NOT NULL "
} }
}
return sql return sql
} }
@ -161,10 +149,6 @@ func (b *Base) FormatBytes(bs []byte) string {
return fmt.Sprintf("0x%x", bs) return fmt.Sprintf("0x%x", bs)
} }
func (b *Base) ShowCreateNull() bool {
return true
}
func (db *Base) SupportDropIfExists() bool { func (db *Base) SupportDropIfExists() bool {
return true return true
} }
@ -234,59 +218,6 @@ func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, db.StringNoPk(col)) return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, db.StringNoPk(col))
} }
func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}
quoter := b.dialect.Quoter()
sql += quoter.Quote(tableName)
sql += " ("
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += b.String(col)
} else {
sql += b.StringNoPk(col)
}
sql = strings.TrimSpace(sql)
if b.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ", "
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += quoter.Join(pkList, ",")
sql += " ), "
}
sql = sql[:len(sql)-2]
}
sql += ")"
if b.dialect.SupportEngine() && storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if b.dialect.SupportCharset() {
if len(charset) == 0 {
charset = b.dialect.URI().Charset
}
if len(charset) > 0 {
sql += " DEFAULT CHARSET " + charset
}
}
return sql
}
func (b *Base) ForUpdateSQL(query string) string { func (b *Base) ForUpdateSQL(query string) string {
return query + " FOR UPDATE" return query + " FOR UPDATE"
} }

View File

@ -307,10 +307,6 @@ func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) {
} }
} }
func (db *mssql) SupportEngine() bool {
return false
}
func (db *mssql) AutoIncrStr() string { func (db *mssql) AutoIncrStr() string {
return "IDENTITY" return "IDENTITY"
} }
@ -321,26 +317,12 @@ func (db *mssql) DropTableSQL(tableName string) string {
"DROP TABLE \"%s\"", tableName, tableName) "DROP TABLE \"%s\"", tableName, tableName)
} }
func (db *mssql) SupportCharset() bool {
return false
}
func (db *mssql) IndexOnTable() bool {
return true
}
func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{idxName} args := []interface{}{idxName}
sql := "select name from sysindexes where id=object_id('" + tableName + "') and name=?" sql := "select name from sysindexes where id=object_id('" + tableName + "') and name=?"
return sql, args return sql, args
} }
/*func (db *mssql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName, colName}
sql := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?`
return sql, args
}*/
func (db *mssql) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { func (db *mssql) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) {
query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?` query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?`
@ -509,7 +491,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
return indexes, nil return indexes, nil
} }
func (db *mssql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string { func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) string {
var sql string var sql string
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name

View File

@ -279,22 +279,10 @@ func (db *mysql) IsReserved(name string) bool {
return ok return ok
} }
func (db *mysql) SupportEngine() bool {
return true
}
func (db *mysql) AutoIncrStr() string { func (db *mysql) AutoIncrStr() string {
return "AUTO_INCREMENT" return "AUTO_INCREMENT"
} }
func (db *mysql) SupportCharset() bool {
return true
}
func (db *mysql) IndexOnTable() bool {
return true
}
func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{db.uri.DBName, tableName, idxName} args := []interface{}{db.uri.DBName, tableName, idxName}
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
@ -524,7 +512,7 @@ func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*
return indexes, nil return indexes, nil
} }
func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string { func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) string {
var sql = "CREATE TABLE IF NOT EXISTS " var sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
@ -562,10 +550,11 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch
} }
sql += ")" sql += ")"
if storeEngine != "" { if table.StoreEngine != "" {
sql += " ENGINE=" + storeEngine sql += " ENGINE=" + table.StoreEngine
} }
var charset = table.Charset
if len(charset) == 0 { if len(charset) == 0 {
charset = db.URI().Charset charset = db.URI().Charset
} }

View File

@ -556,27 +556,15 @@ func (db *oracle) IsReserved(name string) bool {
return ok return ok
} }
func (db *oracle) SupportEngine() bool {
return false
}
func (db *oracle) SupportCharset() bool {
return false
}
func (db *oracle) SupportDropIfExists() bool { func (db *oracle) SupportDropIfExists() bool {
return false return false
} }
func (db *oracle) IndexOnTable() bool {
return false
}
func (db *oracle) DropTableSQL(tableName string) string { func (db *oracle) DropTableSQL(tableName string) string {
return fmt.Sprintf("DROP TABLE `%s`", tableName) return fmt.Sprintf("DROP TABLE `%s`", tableName)
} }
func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string { func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) string {
var sql = "CREATE TABLE " var sql = "CREATE TABLE "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
@ -605,17 +593,6 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c
} }
sql = sql[:len(sql)-2] + ")" sql = sql[:len(sql)-2] + ")"
if db.SupportEngine() && storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if db.SupportCharset() {
if len(charset) == 0 {
charset = db.URI().Charset
}
if len(charset) > 0 {
sql += " DEFAULT CHARSET " + charset
}
}
return sql return sql
} }

View File

@ -897,16 +897,42 @@ func (db *postgres) AutoIncrStr() string {
return "" return ""
} }
func (db *postgres) SupportEngine() bool { func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) string {
return false var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
} }
func (db *postgres) SupportCharset() bool { quoter := db.Quoter()
return false sql += quoter.Quote(tableName)
sql += " ("
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += db.String(col)
} else {
sql += db.StringNoPk(col)
}
sql = strings.TrimSpace(sql)
sql += ", "
} }
func (db *postgres) IndexOnTable() bool { if len(pkList) > 1 {
return false sql += "PRIMARY KEY ( "
sql += quoter.Join(pkList, ",")
sql += " ), "
}
sql = sql[:len(sql)-2]
}
sql += ")"
return sql
} }
func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {

View File

@ -224,18 +224,6 @@ func (db *sqlite3) AutoIncrStr() string {
return "AUTOINCREMENT" return "AUTOINCREMENT"
} }
func (db *sqlite3) SupportEngine() bool {
return false
}
func (db *sqlite3) SupportCharset() bool {
return false
}
func (db *sqlite3) IndexOnTable() bool {
return false
}
func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{idxName} args := []interface{}{idxName}
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
@ -261,6 +249,44 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName))
} }
func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}
quoter := db.Quoter()
sql += quoter.Quote(tableName)
sql += " ("
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += db.String(col)
} else {
sql += db.StringNoPk(col)
}
sql = strings.TrimSpace(sql)
sql += ", "
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += quoter.Join(pkList, ",")
sql += " ), "
}
sql = sql[:len(sql)-2]
}
sql += ")"
return sql
}
func (db *sqlite3) ForUpdateSQL(query string) string { func (db *sqlite3) ForUpdateSQL(query string) string {
return query return query
} }

View File

@ -21,7 +21,7 @@ func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}
case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar. case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar.
v = t.Format("2006-01-02 15:04:05") v = t.Format("2006-01-02 15:04:05")
case schemas.TimeStampz: case schemas.TimeStampz:
if dialect.DBType() == schemas.MSSQL { if dialect.URI().DBType == schemas.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00") v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
} else { } else {
v = t.Format(time.RFC3339Nano) v = t.Format(time.RFC3339Nano)

View File

@ -365,7 +365,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
var distDBName string var distDBName string
if len(tp) == 0 { if len(tp) == 0 {
dialect = engine.dialect dialect = engine.dialect
distDBName = string(engine.dialect.DBType()) distDBName = string(engine.dialect.URI().DBType)
} else { } else {
dialect = dialects.QueryDialect(tp[0]) dialect = dialects.QueryDialect(tp[0])
if dialect == nil { if dialect == nil {
@ -376,7 +376,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
} }
_, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm v%s %s, from %s to %s*/\n\n", _, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm v%s %s, from %s to %s*/\n\n",
Version, time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.DBType(), strings.ToUpper(distDBName))) Version, time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, strings.ToUpper(distDBName)))
if err != nil { if err != nil {
return err return err
} }
@ -388,7 +388,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return err return err
} }
} }
_, err = io.WriteString(w, dialect.CreateTableSQL(table, "", table.StoreEngine, "")+";\n") _, err = io.WriteString(w, dialect.CreateTableSQL(table, "")+";\n")
if err != nil { if err != nil {
return err return err
} }
@ -486,7 +486,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
} }
// FIXME: Hack for postgres // FIXME: Hack for postgres
if dialect.DBType() == schemas.POSTGRES && table.AutoIncrColumn() != nil { if dialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {
_, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n") _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n")
if err != nil { if err != nil {
return err return err

View File

@ -27,7 +27,7 @@ func (statement *Statement) ConvertIDSQL(sqlStr string) string {
var top string var top string
pLimitN := statement.LimitN pLimitN := statement.LimitN
if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL { if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN) top = fmt.Sprintf("TOP %d ", *pLimitN)
} }
@ -56,9 +56,9 @@ func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) {
// TODO: for postgres only, if any other database? // TODO: for postgres only, if any other database?
var paraStr string var paraStr string
if statement.dialect.DBType() == schemas.POSTGRES { if statement.dialect.URI().DBType == schemas.POSTGRES {
paraStr = "$" paraStr = "$"
} else if statement.dialect.DBType() == schemas.MSSQL { } else if statement.dialect.URI().DBType == schemas.MSSQL {
paraStr = ":" paraStr = ":"
} }

View File

@ -201,14 +201,14 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
whereStr = " WHERE " + condSQL whereStr = " WHERE " + condSQL
} }
if dialect.DBType() == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { if dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
fromStr += statement.TableName() fromStr += statement.TableName()
} else { } else {
fromStr += quote(statement.TableName()) fromStr += quote(statement.TableName())
} }
if statement.TableAlias != "" { if statement.TableAlias != "" {
if dialect.DBType() == schemas.ORACLE { if dialect.URI().DBType == schemas.ORACLE {
fromStr += " " + quote(statement.TableAlias) fromStr += " " + quote(statement.TableAlias)
} else { } else {
fromStr += " AS " + quote(statement.TableAlias) fromStr += " AS " + quote(statement.TableAlias)
@ -219,7 +219,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} }
pLimitN := statement.LimitN pLimitN := statement.LimitN
if dialect.DBType() == schemas.MSSQL { if dialect.URI().DBType == schemas.MSSQL {
if pLimitN != nil { if pLimitN != nil {
LimitNValue := *pLimitN LimitNValue := *pLimitN
top = fmt.Sprintf("TOP %d ", LimitNValue) top = fmt.Sprintf("TOP %d ", LimitNValue)
@ -281,7 +281,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
} }
if needLimit { if needLimit {
if dialect.DBType() != schemas.MSSQL && dialect.DBType() != schemas.ORACLE { if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
if statement.Start > 0 { if statement.Start > 0 {
if pLimitN != nil { if pLimitN != nil {
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start) fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
@ -291,7 +291,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} else if pLimitN != nil { } else if pLimitN != nil {
fmt.Fprint(&buf, " LIMIT ", *pLimitN) fmt.Fprint(&buf, " LIMIT ", *pLimitN)
} }
} else if dialect.DBType() == schemas.ORACLE { } else if dialect.URI().DBType == schemas.ORACLE {
if statement.Start != 0 || pLimitN != nil { if statement.Start != 0 || pLimitN != nil {
oldString := buf.String() oldString := buf.String()
buf.Reset() buf.Reset()
@ -337,18 +337,18 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
return "", nil, err return "", nil, err
} }
if statement.dialect.DBType() == schemas.MSSQL { if statement.dialect.URI().DBType == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL) sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL)
} else if statement.dialect.DBType() == schemas.ORACLE { } else if statement.dialect.URI().DBType == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL) sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL)
} else { } else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL) sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL)
} }
args = condArgs args = condArgs
} else { } else {
if statement.dialect.DBType() == schemas.MSSQL { if statement.dialect.URI().DBType == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr)
} else if statement.dialect.DBType() == schemas.ORACLE { } else if statement.dialect.URI().DBType == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr)
} else { } else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr) sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr)

View File

@ -641,8 +641,9 @@ func (statement *Statement) genColumnStr() string {
} }
func (statement *Statement) GenCreateTableSQL() string { func (statement *Statement) GenCreateTableSQL() string {
return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName(), statement.RefTable.StoreEngine = statement.StoreEngine
statement.StoreEngine, statement.Charset) statement.RefTable.Charset = statement.Charset
return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName())
} }
func (statement *Statement) GenIndexSQL() []string { func (statement *Statement) GenIndexSQL() []string {
@ -680,20 +681,8 @@ func (statement *Statement) GenDelIndexSQL() []string {
if idx > -1 { if idx > -1 {
tbName = tbName[idx+1:] tbName = tbName[idx+1:]
} }
idxPrefixName := strings.Replace(tbName, `"`, "", -1) for _, index := range statement.RefTable.Indexes {
idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1) sqls = append(sqls, statement.dialect.DropIndexSQL(tbName, index))
for idxName, index := range statement.RefTable.Indexes {
var rIdxName string
if index.Type == schemas.UniqueType {
rIdxName = uniqueName(idxPrefixName, idxName)
} else if index.Type == schemas.IndexType {
rIdxName = utils.IndexName(idxPrefixName, idxName)
}
sql := fmt.Sprintf("DROP INDEX %v", statement.quote(dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), rIdxName, true)))
if statement.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.quote(tbName))
}
sqls = append(sqls, sql)
} }
return sqls return sqls
} }
@ -714,7 +703,8 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
continue continue
} }
if statement.dialect.DBType() == schemas.MSSQL && (col.SQLType.Name == schemas.Text || col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) { if statement.dialect.URI().DBType == schemas.MSSQL && (col.SQLType.Name == schemas.Text ||
col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) {
continue continue
} }
if col.SQLType.IsJson() { if col.SQLType.IsJson() {
@ -1002,7 +992,7 @@ func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond {
cond = builder.Eq{colName: 0} cond = builder.Eq{colName: 0}
} else { } else {
// FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value. // FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value.
if statement.dialect.DBType() != schemas.MSSQL { if statement.dialect.URI().DBType != schemas.MSSQL {
cond = builder.Eq{colName: utils.ZeroTime1} cond = builder.Eq{colName: utils.ZeroTime1}
} }
} }

View File

@ -80,7 +80,7 @@ const insertSelectPlaceHolder = true
func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error { func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error {
switch argv := arg.(type) { switch argv := arg.(type) {
case bool: case bool:
if statement.dialect.DBType() == schemas.MSSQL { if statement.dialect.URI().DBType == schemas.MSSQL {
if argv { if argv {
if _, err := w.WriteString("1"); err != nil { if _, err := w.WriteString("1"); err != nil {
return err return err
@ -119,7 +119,7 @@ func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) er
w.Append(arg) w.Append(arg)
} else { } else {
var convertFunc = convertStringSingleQuote var convertFunc = convertStringSingleQuote
if statement.dialect.DBType() == schemas.MYSQL { if statement.dialect.URI().DBType == schemas.MYSQL {
convertFunc = convertString convertFunc = convertString
} }
if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil { if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil {

View File

@ -45,7 +45,7 @@ func TestSetExpr(t *testing.T) {
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var not = "NOT" var not = "NOT"
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
not = "~" not = "~"
} }
cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr)) cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr))

View File

@ -65,7 +65,7 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time
} }
sdata = strings.TrimSpace(sdata) sdata = strings.TrimSpace(sdata)
if session.engine.dialect.DBType() == schemas.MYSQL && len(sdata) > 8 { if session.engine.dialect.URI().DBType == schemas.MYSQL && len(sdata) > 8 {
sdata = sdata[len(sdata)-8:] sdata = sdata[len(sdata)-8:]
} }
@ -159,7 +159,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == schemas.Bit && if col.SQLType.Name == schemas.Bit &&
session.engine.dialect.DBType() == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API session.engine.dialect.URI().DBType == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API
if len(data) == 1 { if len(data) == 1 {
x = int64(data[0]) x = int64(data[0])
} else { } else {
@ -399,7 +399,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == schemas.Bit && if col.SQLType.Name == schemas.Bit &&
session.engine.dialect.DBType() == schemas.MYSQL { session.engine.dialect.URI().DBType == schemas.MYSQL {
if len(data) == 1 { if len(data) == 1 {
x = int32(data[0]) x = int32(data[0])
} else { } else {

View File

@ -135,7 +135,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
} }
if len(orderSQL) > 0 { if len(orderSQL) > 0 {
switch session.engine.dialect.DBType() { switch session.engine.dialect.URI().DBType {
case schemas.POSTGRES: case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 { if len(condSQL) > 0 {
@ -176,7 +176,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
condSQL) condSQL)
if len(orderSQL) > 0 { if len(orderSQL) > 0 {
switch session.engine.dialect.DBType() { switch session.engine.dialect.URI().DBType {
case schemas.POSTGRES: case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 { if len(condSQL) > 0 {

View File

@ -28,7 +28,7 @@ func TestDelete(t *testing.T) {
defer session.Close() defer session.Close()
var err error var err error
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Begin() err = session.Begin()
assert.NoError(t, err) assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT userinfo_delete ON") _, err = session.Exec("SET IDENTITY_INSERT userinfo_delete ON")
@ -40,7 +40,7 @@ func TestDelete(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Commit() err = session.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@ -154,7 +154,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))
var money2 float64 var money2 float64
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2) has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2)
} else { } else {
has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2) has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2)
@ -234,7 +234,7 @@ func TestGetStruct(t *testing.T) {
defer session.Close() defer session.Close()
var err error var err error
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Begin() err = session.Begin()
assert.NoError(t, err) assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON") _, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON")
@ -243,7 +243,7 @@ func TestGetStruct(t *testing.T) {
cnt, err := session.Insert(&UserinfoGet{Uid: 2}) cnt, err := session.Insert(&UserinfoGet{Uid: 2})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Commit() err = session.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@ -254,7 +254,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
quoter := session.engine.dialect.Quoter() quoter := session.engine.dialect.Quoter()
var sql string var sql string
colStr := quoter.Join(colNames, ",") colStr := quoter.Join(colNames, ",")
if session.engine.dialect.DBType() == schemas.ORACLE { if session.engine.dialect.URI().DBType == schemas.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v) VALUES (", temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
quoter.Quote(tableName), quoter.Quote(tableName),
colStr) colStr)
@ -361,7 +361,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
var tableName = session.statement.TableName() var tableName = session.statement.TableName()
var output string var output string
if session.engine.dialect.DBType() == schemas.MSSQL && len(table.AutoIncrement) > 0 { if session.engine.dialect.URI().DBType == schemas.MSSQL && len(table.AutoIncrement) > 0 {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
} }
@ -371,7 +371,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} }
if len(colPlaces) <= 0 { if len(colPlaces) <= 0 {
if session.engine.dialect.DBType() == schemas.MYSQL { if session.engine.dialect.URI().DBType == schemas.MYSQL {
if _, err := buf.WriteString(" VALUES ()"); err != nil { if _, err := buf.WriteString(" VALUES ()"); err != nil {
return 0, err return 0, err
} }
@ -433,7 +433,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} }
} }
if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == schemas.POSTGRES { if len(table.AutoIncrement) > 0 && session.engine.dialect.URI().DBType == schemas.POSTGRES {
if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil { if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil {
return 0, err return 0, err
} }
@ -472,7 +472,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// for postgres, many of them didn't implement lastInsertId, so we should // for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself. // implemented it ourself.
if session.engine.dialect.DBType() == schemas.ORACLE && len(table.AutoIncrement) > 0 { if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 {
res, err := session.queryBytes("select seq_atable.currval from dual", args...) res, err := session.queryBytes("select seq_atable.currval from dual", args...)
if err != nil { if err != nil {
return 0, err return 0, err
@ -513,7 +513,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue.Set(int64ToIntValue(id, aiValue.Type())) aiValue.Set(int64ToIntValue(id, aiValue.Type()))
return 1, nil return 1, nil
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == schemas.POSTGRES || session.engine.dialect.DBType() == schemas.MSSQL) { } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
session.engine.dialect.URI().DBType == schemas.MSSQL) {
res, err := session.queryBytes(sqlStr, args...) res, err := session.queryBytes(sqlStr, args...)
if err != nil { if err != nil {

View File

@ -207,7 +207,7 @@ func TestQueryStringNoParam(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0]["id"]) assert.EqualValues(t, "1", records[0]["id"])
if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL {
assert.EqualValues(t, "false", records[0]["msg"]) assert.EqualValues(t, "false", records[0]["msg"])
} else { } else {
assert.EqualValues(t, "0", records[0]["msg"]) assert.EqualValues(t, "0", records[0]["msg"])
@ -217,7 +217,7 @@ func TestQueryStringNoParam(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0]["id"]) assert.EqualValues(t, "1", records[0]["id"])
if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL {
assert.EqualValues(t, "false", records[0]["msg"]) assert.EqualValues(t, "false", records[0]["msg"])
} else { } else {
assert.EqualValues(t, "0", records[0]["msg"]) assert.EqualValues(t, "0", records[0]["msg"])
@ -244,7 +244,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0][0]) assert.EqualValues(t, "1", records[0][0])
if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL {
assert.EqualValues(t, "false", records[0][1]) assert.EqualValues(t, "false", records[0][1])
} else { } else {
assert.EqualValues(t, "0", records[0][1]) assert.EqualValues(t, "0", records[0][1])
@ -254,7 +254,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0][0]) assert.EqualValues(t, "1", records[0][0])
if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL {
assert.EqualValues(t, "false", records[0][1]) assert.EqualValues(t, "false", records[0][1])
} else { } else {
assert.EqualValues(t, "0", records[0][1]) assert.EqualValues(t, "0", records[0][1])

View File

@ -313,8 +313,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
if expectedType == schemas.Text && if expectedType == schemas.Text &&
strings.HasPrefix(curType, schemas.Varchar) { strings.HasPrefix(curType, schemas.Varchar) {
// currently only support mysql & postgres // currently only support mysql & postgres
if engine.dialect.DBType() == schemas.MYSQL || if engine.dialect.URI().DBType == schemas.MYSQL ||
engine.dialect.DBType() == schemas.POSTGRES { engine.dialect.URI().DBType == schemas.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n", engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbNameWithSchema, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
@ -323,7 +323,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
tbNameWithSchema, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
} }
} else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) { } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
if engine.dialect.DBType() == schemas.MYSQL { if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length) tbNameWithSchema, col.Name, oriCol.Length, col.Length)
@ -337,7 +337,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
} }
} }
} else if expectedType == schemas.Varchar { } else if expectedType == schemas.Varchar {
if engine.dialect.DBType() == schemas.MYSQL { if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length) tbNameWithSchema, col.Name, oriCol.Length, col.Length)

View File

@ -335,9 +335,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var top string var top string
if st.LimitN != nil { if st.LimitN != nil {
limitValue := *st.LimitN limitValue := *st.LimitN
if session.engine.dialect.DBType() == schemas.MYSQL { if session.engine.dialect.URI().DBType == schemas.MYSQL {
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
} else if session.engine.dialect.DBType() == schemas.SQLITE { } else if session.engine.dialect.URI().DBType == schemas.SQLITE {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), tempCondSQL), condArgs...))
@ -348,7 +348,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 { if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL condSQL = "WHERE " + condSQL
} }
} else if session.engine.dialect.DBType() == schemas.POSTGRES { } else if session.engine.dialect.URI().DBType == schemas.POSTGRES {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), tempCondSQL), condArgs...))
@ -360,8 +360,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 { if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL condSQL = "WHERE " + condSQL
} }
} else if session.engine.dialect.DBType() == schemas.MSSQL { } else if session.engine.dialect.URI().DBType == schemas.MSSQL {
if st.OrderStr != "" && session.engine.dialect.DBType() == schemas.MSSQL && if st.OrderStr != "" && session.engine.dialect.URI().DBType == schemas.MSSQL &&
table != nil && len(table.PrimaryKeys) == 1 { table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
@ -387,7 +387,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var tableAlias = session.engine.Quote(tableName) var tableAlias = session.engine.Quote(tableName)
var fromSQL string var fromSQL string
if session.statement.TableAlias != "" { if session.statement.TableAlias != "" {
switch session.engine.dialect.DBType() { switch session.engine.dialect.URI().DBType {
case schemas.MSSQL: case schemas.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias) fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias)
tableAlias = session.statement.TableAlias tableAlias = session.statement.TableAlias

View File

@ -238,7 +238,7 @@ func TestExtends2(t *testing.T) {
defer session.Close() defer session.Close()
// MSSQL deny insert identity column excep declare as below // MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Begin() err = session.Begin()
assert.NoError(t, err) assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON") _, err = session.Exec("SET IDENTITY_INSERT message ON")
@ -248,7 +248,7 @@ func TestExtends2(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Commit() err = session.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }
@ -299,7 +299,7 @@ func TestExtends3(t *testing.T) {
defer session.Close() defer session.Close()
// MSSQL deny insert identity column excep declare as below // MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Begin() err = session.Begin()
assert.NoError(t, err) assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON") _, err = session.Exec("SET IDENTITY_INSERT message ON")
@ -308,7 +308,7 @@ func TestExtends3(t *testing.T) {
_, err = session.Insert(&msg) _, err = session.Insert(&msg)
assert.NoError(t, err) assert.NoError(t, err)
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Commit() err = session.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }
@ -362,7 +362,7 @@ func TestExtends4(t *testing.T) {
defer session.Close() defer session.Close()
// MSSQL deny insert identity column excep declare as below // MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Begin() err = session.Begin()
assert.NoError(t, err) assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON") _, err = session.Exec("SET IDENTITY_INSERT message ON")
@ -371,7 +371,7 @@ func TestExtends4(t *testing.T) {
_, err = session.Insert(&msg) _, err = session.Insert(&msg)
assert.NoError(t, err) assert.NoError(t, err)
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Commit() err = session.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }
@ -800,7 +800,7 @@ func TestAutoIncrTag(t *testing.T) {
func TestTagComment(t *testing.T) { func TestTagComment(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
// FIXME: only support mysql // FIXME: only support mysql
if testEngine.Dialect().DBType() != schemas.MYSQL { if testEngine.Dialect().URI().DBType != schemas.MYSQL {
return return
} }

View File

@ -314,7 +314,7 @@ func TestCustomType2(t *testing.T) {
session := testEngine.NewSession() session := testEngine.NewSession()
defer session.Close() defer session.Close()
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Begin() err = session.Begin()
assert.NoError(t, err) assert.NoError(t, err)
_, err = session.Exec("set IDENTITY_INSERT " + tableName + " on") _, err = session.Exec("set IDENTITY_INSERT " + tableName + " on")
@ -325,7 +325,7 @@ func TestCustomType2(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == schemas.MSSQL { if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Commit() err = session.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }