Fix bug on dumptable

This commit is contained in:
Lunny Xiao 2021-07-10 17:05:12 +08:00
parent 6f46e68425
commit 56029ace6f
2 changed files with 75 additions and 68 deletions

104
engine.go
View File

@ -451,13 +451,46 @@ func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d
return "NULL" return "NULL"
} }
if dq, ok := d.(bool); ok && (dstDialect.URI().DBType == schemas.SQLITE || switch t := d.(type) {
dstDialect.URI().DBType == schemas.MSSQL) { case *sql.NullInt64:
if dq { if t.Valid {
return fmt.Sprintf("%d", t.Int64)
}
return "NULL"
case *sql.NullTime:
if t.Valid {
return fmt.Sprintf("'%s'", t.Time.Format("2006-1-02 15:04:05"))
}
return "NULL"
case *sql.NullString:
if t.Valid {
return "'" + strings.Replace(t.String, "'", "''", -1) + "'"
}
return "NULL"
case *sql.NullInt32:
if t.Valid {
return fmt.Sprintf("%d", t.Int32)
}
return "NULL"
case *sql.NullFloat64:
if t.Valid {
return fmt.Sprintf("%f", t.Float64)
}
return "NULL"
case *sql.NullBool:
if t.Valid {
return strconv.FormatBool(t.Bool)
}
return "NULL"
case bool:
if dstDialect.URI().DBType == schemas.SQLITE ||
dstDialect.URI().DBType == schemas.MSSQL {
if t {
return "1" return "1"
} }
return "0" return "0"
} }
}
if col.SQLType.IsText() { if col.SQLType.IsText() {
var v string var v string
@ -487,7 +520,7 @@ func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d
return "'" + strings.Replace(v, "'", "''", -1) + "'" return "'" + strings.Replace(v, "'", "''", -1) + "'"
} else if col.SQLType.IsBlob() { } else if col.SQLType.IsBlob() {
if reflect.TypeOf(d).Kind() == reflect.Slice { if reflect.TypeOf(d).Kind() == reflect.Slice {
return fmt.Sprintf("%s", dstDialect.FormatBytes(d.([]byte))) return dstDialect.FormatBytes(d.([]byte))
} else if reflect.TypeOf(d).Kind() == reflect.String { } else if reflect.TypeOf(d).Kind() == reflect.String {
return fmt.Sprintf("'%s'", d.(string)) return fmt.Sprintf("'%s'", d.(string))
} }
@ -497,7 +530,7 @@ func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d
if col.SQLType.Name == schemas.Bool { if col.SQLType.Name == schemas.Bool {
return fmt.Sprintf("%v", strconv.FormatBool(d.([]byte)[0] != byte('0'))) return fmt.Sprintf("%v", strconv.FormatBool(d.([]byte)[0] != byte('0')))
} }
return fmt.Sprintf("%s", string(d.([]byte))) return string(d.([]byte))
case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int:
if col.SQLType.Name == schemas.Bool { if col.SQLType.Name == schemas.Bool {
v := reflect.ValueOf(d).Int() > 0 v := reflect.ValueOf(d).Int() > 0
@ -609,6 +642,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
cols := table.ColumnsSeq() cols := table.ColumnsSeq()
dstCols := dstTable.ColumnsSeq() dstCols := dstTable.ColumnsSeq()
dstColumns := dstTable.Columns()
colNames := engine.dialect.Quoter().Join(cols, ", ") colNames := engine.dialect.Quoter().Join(cols, ", ")
destColNames := dstDialect.Quoter().Join(dstCols, ", ") destColNames := dstDialect.Quoter().Join(dstCols, ", ")
@ -619,75 +653,39 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
} }
defer rows.Close() defer rows.Close()
if table.Type != nil { types, err := rows.ColumnTypes()
if err != nil {
return err
}
sess := engine.NewSession() sess := engine.NewSession()
defer sess.Close() defer sess.Close()
for rows.Next() { for rows.Next() {
beanValue := reflect.New(table.Type)
bean := beanValue.Interface()
fields, err := rows.Columns()
if err != nil {
return err
}
scanResults, err := sess.row2Slice(rows, fields, bean)
if err != nil {
return err
}
dataStruct := utils.ReflectValue(bean)
_, err = sess.slice2Bean(scanResults, fields, bean, &dataStruct, table)
if err != nil {
return err
}
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
if err != nil { if err != nil {
return err return err
} }
var temp string scanResults, err := sess.engine.scanInterfaces(rows, types)
for _, d := range dstCols {
col := table.GetColumn(d)
if col == nil {
return errors.New("unknown column error")
}
field := dataStruct.FieldByIndex(col.FieldIndex)
temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, field.Interface(), col)
}
_, err = io.WriteString(w, temp[1:]+");\n")
if err != nil { if err != nil {
return err return err
} }
} for i, scanResult := range scanResults {
} else { s := formatColumnValue(engine.DatabaseTZ, dstDialect, scanResult, dstColumns[i])
for rows.Next() { if _, err = io.WriteString(w, s); err != nil {
dest := make([]interface{}, len(cols))
err = rows.ScanSlice(&dest)
if err != nil {
return err return err
} }
if i < len(scanResults)-1 {
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") if _, err = io.WriteString(w, ","); err != nil {
if err != nil {
return err return err
} }
var temp string
for i, d := range dest {
col := table.GetColumn(cols[i])
if col == nil {
return errors.New("unknow column error")
} }
temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col)
} }
_, err = io.WriteString(w, temp[1:]+");\n") _, err = io.WriteString(w, ");\n")
if err != nil { if err != nil {
return err return err
} }
} }
}
// FIXME: Hack for postgres // FIXME: Hack for postgres
if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {

View File

@ -172,8 +172,17 @@ func TestDumpTables(t *testing.T) {
name := fmt.Sprintf("dump_%v-table.sql", tp) name := fmt.Sprintf("dump_%v-table.sql", tp)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp)) assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp))
}) })
} }
assert.NoError(t, testEngine.DropTables(new(TestDumpTableStruct)))
importPath := fmt.Sprintf("dump_%v-table.sql", testEngine.Dialect().URI().DBType)
t.Run("import_"+importPath, func(t *testing.T) {
_, err = testEngine.ImportFile(importPath)
assert.NoError(t, err)
})
} }
func TestDumpTables2(t *testing.T) { func TestDumpTables2(t *testing.T) {