Fix more bugs

This commit is contained in:
Lunny Xiao 2021-08-04 21:59:08 +08:00
parent b3bf20a83e
commit dc980514bd
7 changed files with 126 additions and 56 deletions

View File

@ -18,7 +18,6 @@ import (
"gitee.com/travelliu/dm" "gitee.com/travelliu/dm"
"xorm.io/xorm/core" "xorm.io/xorm/core"
"xorm.io/xorm/internal/convert" "xorm.io/xorm/internal/convert"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -529,7 +528,7 @@ func (db *dameng) Init(uri *URI) error {
} }
func (db *dameng) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { func (db *dameng) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) {
rows, err := queryer.QueryContext(ctx, "select * from v$version where banner like 'Oracle%'") rows, err := queryer.QueryContext(ctx, "SELECT * FROM V$VERSION") // select id_code
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -553,7 +552,8 @@ func (db *dameng) Version(ctx context.Context, queryer core.Queryer) (*schemas.V
func (db *dameng) Features() *DialectFeatures { func (db *dameng) Features() *DialectFeatures {
return &DialectFeatures{ return &DialectFeatures{
AutoincrMode: SequenceAutoincrMode, AutoincrMode: SequenceAutoincrMode,
SupportSequence: true,
} }
} }
@ -570,8 +570,12 @@ func (db *dameng) SQLType(c *schemas.Column) string {
return "BIGINT" return "BIGINT"
case schemas.Bit, schemas.Bool: case schemas.Bit, schemas.Bool:
return schemas.Bit return schemas.Bit
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: case schemas.Binary:
return schemas.Binary if c.Length == 0 {
return schemas.Binary + "(MAX)"
}
case schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
return schemas.VarBinary
case schemas.Date: case schemas.Date:
return schemas.Date return schemas.Date
case schemas.Time: case schemas.Time:
@ -635,7 +639,7 @@ func (db *dameng) DropTableSQL(tableName string) (string, bool) {
return fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)), false return fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)), false
} }
func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) ([]string, bool, error) { func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
} }
@ -667,38 +671,7 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
} }
b.WriteString(")") b.WriteString(")")
var seqName = utils.SeqName(tableName) return b.String(), false, nil
if table.AutoIncrColumn() != nil {
var cnt int
rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM user_sequences WHERE sequence_name = ?", seqName)
if err != nil {
return nil, false, err
}
defer rows.Close()
if !rows.Next() {
if rows.Err() != nil {
return nil, false, rows.Err()
}
return nil, false, errors.New("query sequence failed")
}
if err := rows.Scan(&cnt); err != nil {
return nil, false, err
}
if cnt == 0 {
var sql2 = fmt.Sprintf(`CREATE sequence %s
minvalue 1
nomaxvalue
start with 1
increment by 1
nocycle
nocache`, seqName)
return []string{b.String(), sql2}, false, nil
}
}
return []string{b.String()}, false, nil
} }
func (db *dameng) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *dameng) SetQuotePolicy(quotePolicy QuotePolicy) {
@ -728,6 +701,26 @@ func (db *dameng) IsTableExist(queryer core.Queryer, ctx context.Context, tableN
return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = ?`, tableName) return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = ?`, tableName)
} }
func (db *dameng) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) {
var cnt int
rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM user_sequences WHERE sequence_name = ?", seqName)
if err != nil {
return false, err
}
defer rows.Close()
if !rows.Next() {
if rows.Err() != nil {
return false, rows.Err()
}
return false, errors.New("query sequence failed")
}
if err := rows.Scan(&cnt); err != nil {
return false, err
}
return cnt > 0, nil
}
func (db *dameng) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { func (db *dameng) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
args := []interface{}{tableName, colName} args := []interface{}{tableName, colName}
query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" +
@ -839,7 +832,8 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
col.Name = strings.Trim(colName.String, `" `) col.Name = strings.Trim(colName.String, `" `)
if colDefault.valid { if colDefault.valid {
col.Default = colDefault.data col.Default = colDefault.data
col.DefaultIsEmpty = false } else {
col.DefaultIsEmpty = true
} }
if nullable.String == "Y" { if nullable.String == "Y" {
@ -1052,12 +1046,21 @@ func (d *damengDriver) GenScanResult(colType string) (interface{}, error) {
case "NUMBER": case "NUMBER":
var s sql.NullString var s sql.NullString
return &s, nil return &s, nil
case "DATE": case "BIGINT":
var s sql.NullTime var s sql.NullInt64
return &s, nil
case "INTEGER":
var s sql.NullInt32
return &s, nil
case "DATE", "TIMESTAMP":
var s sql.NullString
return &s, nil return &s, nil
case "BLOB": case "BLOB":
var r sql.RawBytes var r sql.RawBytes
return &r, nil return &r, nil
case "FLOAT":
var s sql.NullFloat64
return &s, nil
default: default:
var r sql.RawBytes var r sql.RawBytes
return &r, nil return &r, nil

View File

@ -44,7 +44,8 @@ const (
) )
type DialectFeatures struct { type DialectFeatures struct {
AutoincrMode int // 0 autoincrement column, 1 sequence AutoincrMode int // 0 autoincrement column, 1 sequence
SupportSequence bool
} }
// Dialect represents a kind of database // Dialect represents a kind of database
@ -71,9 +72,13 @@ type Dialect interface {
GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error)
IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error)
CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) ([]string, bool, error) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error)
DropTableSQL(tableName string) (string, bool) DropTableSQL(tableName string) (string, bool)
CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error)
IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error)
DropSequenceSQL(seqName string) (string, error)
GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error)
AddColumnSQL(tableName string, col *schemas.Column) string AddColumnSQL(tableName string, col *schemas.Column) string
@ -146,6 +151,24 @@ func (db *Base) CreateTableSQL(table *schemas.Table, tableName string) ([]string
return []string{b.String()}, false return []string{b.String()}, false
} }
func (db *Base) CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) {
return fmt.Sprintf(`CREATE SEQUENCE %s
minvalue 1
nomaxvalue
start with 1
increment by 1
nocycle
nocache`, seqName), nil
}
func (db *Base) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) {
return false, fmt.Errorf("unsupported sequence feature")
}
func (db *Base) DropSequenceSQL(seqName string) (string, error) {
return fmt.Sprintf("DROP SEQUENCE %s", seqName), nil
}
// DropTableSQL returns drop table SQL // DropTableSQL returns drop table SQL
func (db *Base) DropTableSQL(tableName string) (string, bool) { func (db *Base) DropTableSQL(tableName string) (string, bool) {
quote := db.dialect.Quoter().Quote quote := db.dialect.Quoter().Quote
@ -309,7 +332,7 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool)
} }
} }
if col.Default != "" { if !col.DefaultIsEmpty {
if _, err := bd.WriteString(" DEFAULT "); err != nil { if _, err := bd.WriteString(" DEFAULT "); err != nil {
return "", err return "", err
} }

View File

@ -684,11 +684,15 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin
b.WriteString(" ROW_FORMAT=") b.WriteString(" ROW_FORMAT=")
b.WriteString(db.rowFormat) b.WriteString(db.rowFormat)
} }
<<<<<<< HEAD
<<<<<<< HEAD <<<<<<< HEAD
return []string{b.String()}, true return []string{b.String()}, true
======= =======
return []string{sql}, true, nil return []string{sql}, true, nil
>>>>>>> 4dbe145 (fix insert) >>>>>>> 4dbe145 (fix insert)
=======
return sql, true, nil
>>>>>>> 21b6352 (Fix more bugs)
} }
func (db *mysql) Filters() []Filter { func (db *mysql) Filters() []Filter {

View File

@ -605,7 +605,7 @@ func (db *oracle) DropTableSQL(tableName string) (string, bool) {
return fmt.Sprintf("DROP TABLE `%s`", tableName), false return fmt.Sprintf("DROP TABLE `%s`", tableName), false
} }
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) ([]string, bool, error) { func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
var sql = "CREATE TABLE " var sql = "CREATE TABLE "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
@ -635,7 +635,7 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
} }
sql = sql[:len(sql)-2] + ")" sql = sql[:len(sql)-2] + ")"
return []string{sql}, false, nil return sql, false, nil
} }
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {

View File

@ -504,16 +504,26 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
} }
} }
sqls, _, err := dstDialect.CreateTableSQL(ctx, engine.db, dstTable, dstTableName) sqlstr, _, err := dstDialect.CreateTableSQL(ctx, engine.db, dstTable, dstTableName)
if err != nil { if err != nil {
return err return err
} }
for _, s := range sqls { _, err = io.WriteString(w, sqlstr+";\n")
_, err = io.WriteString(w, s+";\n") if err != nil {
return err
}
if dstTable.AutoIncrement != "" && dstDialect.Features().SupportSequence {
sqlstr, err = dstDialect.CreateSequenceSQL(ctx, engine.db, utils.SeqName(dstTableName))
if err != nil {
return err
}
_, err = io.WriteString(w, sqlstr+";\n")
if err != nil { if err != nil {
return err return err
} }
} }
if len(dstTable.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL { if len(dstTable.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL {
fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", dstTable.Name) fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", dstTable.Name)
} }

View File

@ -991,7 +991,7 @@ func TestMoreExtends(t *testing.T) {
books = make([]MoreExtendsBooksExtend, 0, len(books)) books = make([]MoreExtendsBooksExtend, 0, len(books))
err = testEngine.Table("more_extends_books"). err = testEngine.Table("more_extends_books").
Alias("m"). Alias("m").
Select("m.*, `more_extends_users`.*"). Select("`m`.*, `more_extends_users`.*").
Join("INNER", "more_extends_users", "`m`.`user_id` = `more_extends_users`.`id`"). Join("INNER", "more_extends_users", "`m`.`user_id` = `more_extends_users`.`id`").
Where("`m`.`name` LIKE ?", "abc"). Where("`m`.`name` LIKE ?", "abc").
Limit(10, 10). Limit(10, 10).

View File

@ -43,17 +43,26 @@ func (session *Session) createTable(bean interface{}) error {
session.statement.RefTable.StoreEngine = session.statement.StoreEngine session.statement.RefTable.StoreEngine = session.statement.StoreEngine
session.statement.RefTable.Charset = session.statement.Charset session.statement.RefTable.Charset = session.statement.Charset
sqlStrs, _, err := session.engine.dialect.CreateTableSQL(context.Background(), session.engine.db, session.statement.RefTable, session.statement.TableName()) tableName := session.statement.TableName()
refTable := session.statement.RefTable
sqlStr, _, err := session.engine.dialect.CreateTableSQL(context.Background(), session.engine.db, refTable, tableName)
if err != nil { if err != nil {
return err return err
} }
if _, err := session.exec(sqlStr); err != nil {
return err
}
for _, s := range sqlStrs { if refTable.AutoIncrement != "" && session.engine.dialect.Features().SupportSequence {
_, err := session.exec(s) sqlStr, err = session.engine.dialect.CreateSequenceSQL(context.Background(), session.engine.db, utils.SeqName(tableName))
if err != nil { if err != nil {
return err return err
} }
if _, err := session.exec(sqlStr); err != nil {
return err
}
} }
return nil return nil
} }
@ -148,11 +157,32 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
checkIfExist = exist checkIfExist = exist
} }
if checkIfExist { if !checkIfExist {
_, err := session.exec(sqlStr) return nil
}
if _, err := session.exec(sqlStr); err != nil {
return err return err
} }
return nil
if !session.engine.dialect.Features().SupportSequence {
return nil
}
var seqName = utils.SeqName(tableName)
exist, err := session.engine.dialect.IsSequenceExist(session.ctx, session.getQueryer(), seqName)
if err != nil {
return err
}
if !exist {
return nil
}
sqlStr, err = session.engine.dialect.DropSequenceSQL(seqName)
if err != nil {
return err
}
_, err = session.exec(sqlStr)
return err
} }
// IsTableExist if a table is exist // IsTableExist if a table is exist