Use provided type to create the dstTable rather than inferring from the SQL types

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
Andrew Thornton 2021-02-28 16:08:57 +00:00
parent eed7e65bd9
commit 6d1aba4e23
No known key found for this signature in database
GPG Key ID: 3CDE74631F13A748
1 changed files with 73 additions and 30 deletions

103
engine.go
View File

@ -537,6 +537,8 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
} }
dstDialect.Init(&destURI) dstDialect.Init(&destURI)
} }
cacherMgr := caches.NewManager()
dstTableCache := tags.NewParser("xorm", dstDialect, engine.GetTableMapper(), engine.GetColumnMapper(), cacherMgr)
_, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n", _, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n",
time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, dstDialect.URI().DBType)) time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, dstDialect.URI().DBType))
@ -545,9 +547,18 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
} }
for i, table := range tables { for i, table := range tables {
tableName := table.Name dstTable := table
if table.Type != nil {
dstTable, err = dstTableCache.Parse(reflect.New(table.Type))
if err != nil {
engine.logger.Errorf("Unable to infer table for %s in new dialect. Error: %v", table.Name)
dstTable = table
}
}
dstTableName := dstTable.Name
if dstDialect.URI().Schema != "" { if dstDialect.URI().Schema != "" {
tableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, table.Name) dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name)
} }
originalTableName := table.Name originalTableName := table.Name
if engine.dialect.URI().Schema != "" { if engine.dialect.URI().Schema != "" {
@ -559,27 +570,30 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return err return err
} }
} }
sqls, _ := dstDialect.CreateTableSQL(table, tableName)
sqls, _ := dstDialect.CreateTableSQL(dstTable, dstTableName)
for _, s := range sqls { for _, s := range sqls {
_, err = io.WriteString(w, s+";\n") _, err = io.WriteString(w, s+";\n")
if err != nil { if err != nil {
return err return err
} }
} }
if len(table.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL { if len(dstTable.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL {
fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", table.Name) fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", dstTable.Name)
} }
for _, index := range table.Indexes { for _, index := range dstTable.Indexes {
_, err = io.WriteString(w, dstDialect.CreateIndexSQL(table.Name, index)+";\n") _, err = io.WriteString(w, dstDialect.CreateIndexSQL(dstTable.Name, index)+";\n")
if err != nil { if err != nil {
return err return err
} }
} }
cols := table.ColumnsSeq() cols := table.ColumnsSeq()
dstCols := dstTable.ColumnsSeq()
colNames := engine.dialect.Quoter().Join(cols, ", ") colNames := engine.dialect.Quoter().Join(cols, ", ")
destColNames := dstDialect.Quoter().Join(cols, ", ") destColNames := dstDialect.Quoter().Join(dstCols, ", ")
rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(originalTableName)) rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(originalTableName))
if err != nil { if err != nil {
@ -587,35 +601,64 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
} }
defer rows.Close() defer rows.Close()
for rows.Next() { if table.Type != nil {
dest := make([]interface{}, len(cols)) val := reflect.New(table.Type)
err = rows.ScanSlice(&dest) for rows.Next() {
if err != nil { err = rows.ScanStructByName(val)
return err if err != nil {
} return err
}
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(tableName)+" ("+destColNames+") VALUES (")
if err != nil { _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
return err if err != nil {
} return err
}
var temp string
for i, d := range dest { var temp string
col := table.GetColumn(cols[i]) for _, d := range dstCols {
if col == nil { col := table.GetColumn(d)
return errors.New("unknow column error") if col == nil {
return errors.New("unknow column error")
}
temp += "," + formatColumnValue(dstDialect, val.FieldByName(col.FieldName), col)
}
_, err = io.WriteString(w, temp[1:]+");\n")
if err != nil {
return err
} }
temp += "," + formatColumnValue(dstDialect, d, col)
} }
_, err = io.WriteString(w, temp[1:]+");\n") } else {
if err != nil { for rows.Next() {
return err dest := make([]interface{}, len(cols))
err = rows.ScanSlice(&dest)
if err != nil {
return err
}
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
if err != nil {
return err
}
var temp string
for i, d := range dest {
col := table.GetColumn(cols[i])
if col == nil {
return errors.New("unknow column error")
}
temp += "," + formatColumnValue(dstDialect, d, col)
}
_, err = io.WriteString(w, temp[1:]+");\n")
if err != nil {
return err
}
} }
} }
// FIXME: Hack for postgres // FIXME: Hack for postgres
if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {
_, err = io.WriteString(w, "SELECT setval('"+tableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(tableName)+"), 1), false);\n") _, err = io.WriteString(w, "SELECT setval('"+dstTableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(dstTableName)+"), 1), false);\n")
if err != nil { if err != nil {
return err return err
} }