Fix more bugs

This commit is contained in:
Lunny Xiao 2021-08-04 21:59:08 +08:00
parent b3bf20a83e
commit dc980514bd
7 changed files with 126 additions and 56 deletions

View File

@ -18,7 +18,6 @@ import (
"gitee.com/travelliu/dm"
"xorm.io/xorm/core"
"xorm.io/xorm/internal/convert"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
@ -529,7 +528,7 @@ func (db *dameng) Init(uri *URI) error {
}
func (db *dameng) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) {
rows, err := queryer.QueryContext(ctx, "select * from v$version where banner like 'Oracle%'")
rows, err := queryer.QueryContext(ctx, "SELECT * FROM V$VERSION") // select id_code
if err != nil {
return nil, err
}
@ -554,6 +553,7 @@ func (db *dameng) Version(ctx context.Context, queryer core.Queryer) (*schemas.V
func (db *dameng) Features() *DialectFeatures {
return &DialectFeatures{
AutoincrMode: SequenceAutoincrMode,
SupportSequence: true,
}
}
@ -570,8 +570,12 @@ func (db *dameng) SQLType(c *schemas.Column) string {
return "BIGINT"
case schemas.Bit, schemas.Bool:
return schemas.Bit
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
return schemas.Binary
case schemas.Binary:
if c.Length == 0 {
return schemas.Binary + "(MAX)"
}
case schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
return schemas.VarBinary
case schemas.Date:
return schemas.Date
case schemas.Time:
@ -635,7 +639,7 @@ func (db *dameng) DropTableSQL(tableName string) (string, bool) {
return fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)), false
}
func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) ([]string, bool, error) {
func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
if tableName == "" {
tableName = table.Name
}
@ -667,38 +671,7 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
}
b.WriteString(")")
var seqName = utils.SeqName(tableName)
if table.AutoIncrColumn() != nil {
var cnt int
rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM user_sequences WHERE sequence_name = ?", seqName)
if err != nil {
return nil, false, err
}
defer rows.Close()
if !rows.Next() {
if rows.Err() != nil {
return nil, false, rows.Err()
}
return nil, false, errors.New("query sequence failed")
}
if err := rows.Scan(&cnt); err != nil {
return nil, false, err
}
if cnt == 0 {
var sql2 = fmt.Sprintf(`CREATE sequence %s
minvalue 1
nomaxvalue
start with 1
increment by 1
nocycle
nocache`, seqName)
return []string{b.String(), sql2}, false, nil
}
}
return []string{b.String()}, false, nil
return b.String(), false, nil
}
func (db *dameng) SetQuotePolicy(quotePolicy QuotePolicy) {
@ -728,6 +701,26 @@ func (db *dameng) IsTableExist(queryer core.Queryer, ctx context.Context, tableN
return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = ?`, tableName)
}
func (db *dameng) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) {
var cnt int
rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM user_sequences WHERE sequence_name = ?", seqName)
if err != nil {
return false, err
}
defer rows.Close()
if !rows.Next() {
if rows.Err() != nil {
return false, rows.Err()
}
return false, errors.New("query sequence failed")
}
if err := rows.Scan(&cnt); err != nil {
return false, err
}
return cnt > 0, nil
}
func (db *dameng) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
args := []interface{}{tableName, colName}
query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" +
@ -839,7 +832,8 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
col.Name = strings.Trim(colName.String, `" `)
if colDefault.valid {
col.Default = colDefault.data
col.DefaultIsEmpty = false
} else {
col.DefaultIsEmpty = true
}
if nullable.String == "Y" {
@ -1052,12 +1046,21 @@ func (d *damengDriver) GenScanResult(colType string) (interface{}, error) {
case "NUMBER":
var s sql.NullString
return &s, nil
case "DATE":
var s sql.NullTime
case "BIGINT":
var s sql.NullInt64
return &s, nil
case "INTEGER":
var s sql.NullInt32
return &s, nil
case "DATE", "TIMESTAMP":
var s sql.NullString
return &s, nil
case "BLOB":
var r sql.RawBytes
return &r, nil
case "FLOAT":
var s sql.NullFloat64
return &s, nil
default:
var r sql.RawBytes
return &r, nil

View File

@ -45,6 +45,7 @@ const (
type DialectFeatures struct {
AutoincrMode int // 0 autoincrement column, 1 sequence
SupportSequence bool
}
// Dialect represents a kind of database
@ -71,9 +72,13 @@ type Dialect interface {
GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error)
IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error)
CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) ([]string, bool, error)
CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error)
DropTableSQL(tableName string) (string, bool)
CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error)
IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error)
DropSequenceSQL(seqName string) (string, error)
GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error)
AddColumnSQL(tableName string, col *schemas.Column) string
@ -146,6 +151,24 @@ func (db *Base) CreateTableSQL(table *schemas.Table, tableName string) ([]string
return []string{b.String()}, false
}
func (db *Base) CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) {
return fmt.Sprintf(`CREATE SEQUENCE %s
minvalue 1
nomaxvalue
start with 1
increment by 1
nocycle
nocache`, seqName), nil
}
func (db *Base) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) {
return false, fmt.Errorf("unsupported sequence feature")
}
func (db *Base) DropSequenceSQL(seqName string) (string, error) {
return fmt.Sprintf("DROP SEQUENCE %s", seqName), nil
}
// DropTableSQL returns drop table SQL
func (db *Base) DropTableSQL(tableName string) (string, bool) {
quote := db.dialect.Quoter().Quote
@ -309,7 +332,7 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool)
}
}
if col.Default != "" {
if !col.DefaultIsEmpty {
if _, err := bd.WriteString(" DEFAULT "); err != nil {
return "", err
}

View File

@ -684,11 +684,15 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin
b.WriteString(" ROW_FORMAT=")
b.WriteString(db.rowFormat)
}
<<<<<<< HEAD
<<<<<<< HEAD
return []string{b.String()}, true
=======
return []string{sql}, true, nil
>>>>>>> 4dbe145 (fix insert)
=======
return sql, true, nil
>>>>>>> 21b6352 (Fix more bugs)
}
func (db *mysql) Filters() []Filter {

View File

@ -605,7 +605,7 @@ func (db *oracle) DropTableSQL(tableName string) (string, bool) {
return fmt.Sprintf("DROP TABLE `%s`", tableName), false
}
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) ([]string, bool, error) {
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
var sql = "CREATE TABLE "
if tableName == "" {
tableName = table.Name
@ -635,7 +635,7 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
}
sql = sql[:len(sql)-2] + ")"
return []string{sql}, false, nil
return sql, false, nil
}
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {

View File

@ -504,16 +504,26 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
}
}
sqls, _, err := dstDialect.CreateTableSQL(ctx, engine.db, dstTable, dstTableName)
sqlstr, _, err := dstDialect.CreateTableSQL(ctx, engine.db, dstTable, dstTableName)
if err != nil {
return err
}
for _, s := range sqls {
_, err = io.WriteString(w, s+";\n")
_, err = io.WriteString(w, sqlstr+";\n")
if err != nil {
return err
}
if dstTable.AutoIncrement != "" && dstDialect.Features().SupportSequence {
sqlstr, err = dstDialect.CreateSequenceSQL(ctx, engine.db, utils.SeqName(dstTableName))
if err != nil {
return err
}
_, err = io.WriteString(w, sqlstr+";\n")
if err != nil {
return err
}
}
if len(dstTable.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL {
fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", dstTable.Name)
}

View File

@ -991,7 +991,7 @@ func TestMoreExtends(t *testing.T) {
books = make([]MoreExtendsBooksExtend, 0, len(books))
err = testEngine.Table("more_extends_books").
Alias("m").
Select("m.*, `more_extends_users`.*").
Select("`m`.*, `more_extends_users`.*").
Join("INNER", "more_extends_users", "`m`.`user_id` = `more_extends_users`.`id`").
Where("`m`.`name` LIKE ?", "abc").
Limit(10, 10).

View File

@ -43,17 +43,26 @@ func (session *Session) createTable(bean interface{}) error {
session.statement.RefTable.StoreEngine = session.statement.StoreEngine
session.statement.RefTable.Charset = session.statement.Charset
sqlStrs, _, err := session.engine.dialect.CreateTableSQL(context.Background(), session.engine.db, session.statement.RefTable, session.statement.TableName())
tableName := session.statement.TableName()
refTable := session.statement.RefTable
sqlStr, _, err := session.engine.dialect.CreateTableSQL(context.Background(), session.engine.db, refTable, tableName)
if err != nil {
return err
}
if _, err := session.exec(sqlStr); err != nil {
return err
}
for _, s := range sqlStrs {
_, err := session.exec(s)
if refTable.AutoIncrement != "" && session.engine.dialect.Features().SupportSequence {
sqlStr, err = session.engine.dialect.CreateSequenceSQL(context.Background(), session.engine.db, utils.SeqName(tableName))
if err != nil {
return err
}
if _, err := session.exec(sqlStr); err != nil {
return err
}
}
return nil
}
@ -148,13 +157,34 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
checkIfExist = exist
}
if checkIfExist {
_, err := session.exec(sqlStr)
if !checkIfExist {
return nil
}
if _, err := session.exec(sqlStr); err != nil {
return err
}
if !session.engine.dialect.Features().SupportSequence {
return nil
}
var seqName = utils.SeqName(tableName)
exist, err := session.engine.dialect.IsSequenceExist(session.ctx, session.getQueryer(), seqName)
if err != nil {
return err
}
if !exist {
return nil
}
sqlStr, err = session.engine.dialect.DropSequenceSQL(seqName)
if err != nil {
return err
}
_, err = session.exec(sqlStr)
return err
}
// IsTableExist if a table is exist
func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) {
if session.isAutoClose {