diff --git a/engine.go b/engine.go index a45771a2..126b3e37 100644 --- a/engine.go +++ b/engine.go @@ -451,12 +451,45 @@ func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d return "NULL" } - if dq, ok := d.(bool); ok && (dstDialect.URI().DBType == schemas.SQLITE || - dstDialect.URI().DBType == schemas.MSSQL) { - if dq { - return "1" + switch t := d.(type) { + case *sql.NullInt64: + if t.Valid { + return fmt.Sprintf("%d", t.Int64) + } + return "NULL" + case *sql.NullTime: + if t.Valid { + return fmt.Sprintf("'%s'", t.Time.Format("2006-1-02 15:04:05")) + } + return "NULL" + case *sql.NullString: + if t.Valid { + return "'" + strings.Replace(t.String, "'", "''", -1) + "'" + } + return "NULL" + case *sql.NullInt32: + if t.Valid { + return fmt.Sprintf("%d", t.Int32) + } + return "NULL" + case *sql.NullFloat64: + if t.Valid { + return fmt.Sprintf("%f", t.Float64) + } + return "NULL" + case *sql.NullBool: + if t.Valid { + return strconv.FormatBool(t.Bool) + } + return "NULL" + case bool: + if dstDialect.URI().DBType == schemas.SQLITE || + dstDialect.URI().DBType == schemas.MSSQL { + if t { + return "1" + } + return "0" } - return "0" } if col.SQLType.IsText() { @@ -487,7 +520,7 @@ func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d return "'" + strings.Replace(v, "'", "''", -1) + "'" } else if col.SQLType.IsBlob() { if reflect.TypeOf(d).Kind() == reflect.Slice { - return fmt.Sprintf("%s", dstDialect.FormatBytes(d.([]byte))) + return dstDialect.FormatBytes(d.([]byte)) } else if reflect.TypeOf(d).Kind() == reflect.String { return fmt.Sprintf("'%s'", d.(string)) } @@ -497,7 +530,7 @@ func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d if col.SQLType.Name == schemas.Bool { return fmt.Sprintf("%v", strconv.FormatBool(d.([]byte)[0] != byte('0'))) } - return fmt.Sprintf("%s", string(d.([]byte))) + return string(d.([]byte)) case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: if col.SQLType.Name == schemas.Bool { v := reflect.ValueOf(d).Int() > 0 @@ -609,6 +642,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch cols := table.ColumnsSeq() dstCols := dstTable.ColumnsSeq() + dstColumns := dstTable.Columns() colNames := engine.dialect.Quoter().Join(cols, ", ") destColNames := dstDialect.Quoter().Join(dstCols, ", ") @@ -619,73 +653,37 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } defer rows.Close() - if table.Type != nil { - sess := engine.NewSession() - defer sess.Close() - for rows.Next() { - beanValue := reflect.New(table.Type) - bean := beanValue.Interface() - fields, err := rows.Columns() - if err != nil { - return err - } - scanResults, err := sess.row2Slice(rows, fields, bean) - if err != nil { - return err - } + types, err := rows.ColumnTypes() + if err != nil { + return err + } - dataStruct := utils.ReflectValue(bean) - _, err = sess.slice2Bean(scanResults, fields, bean, &dataStruct, table) - if err != nil { + sess := engine.NewSession() + defer sess.Close() + for rows.Next() { + _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") + if err != nil { + return err + } + + scanResults, err := sess.engine.scanInterfaces(rows, types) + if err != nil { + return err + } + for i, scanResult := range scanResults { + s := formatColumnValue(engine.DatabaseTZ, dstDialect, scanResult, dstColumns[i]) + if _, err = io.WriteString(w, s); err != nil { return err } - - _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") - if err != nil { - return err - } - - var temp string - for _, d := range dstCols { - col := table.GetColumn(d) - if col == nil { - return errors.New("unknown column error") + if i < len(scanResults)-1 { + if _, err = io.WriteString(w, ","); err != nil { + return err } - - field := dataStruct.FieldByIndex(col.FieldIndex) - temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, field.Interface(), col) - } - _, err = io.WriteString(w, temp[1:]+");\n") - if err != nil { - return err } } - } else { - for rows.Next() { - dest := make([]interface{}, len(cols)) - err = rows.ScanSlice(&dest) - if err != nil { - return err - } - - _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") - if err != nil { - return err - } - - var temp string - for i, d := range dest { - col := table.GetColumn(cols[i]) - if col == nil { - return errors.New("unknow column error") - } - - temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col) - } - _, err = io.WriteString(w, temp[1:]+");\n") - if err != nil { - return err - } + _, err = io.WriteString(w, ");\n") + if err != nil { + return err } } diff --git a/integrations/engine_test.go b/integrations/engine_test.go index a06d91aa..48082ce5 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -172,8 +172,17 @@ func TestDumpTables(t *testing.T) { name := fmt.Sprintf("dump_%v-table.sql", tp) t.Run(name, func(t *testing.T) { assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp)) + }) } + + assert.NoError(t, testEngine.DropTables(new(TestDumpTableStruct))) + + importPath := fmt.Sprintf("dump_%v-table.sql", testEngine.Dialect().URI().DBType) + t.Run("import_"+importPath, func(t *testing.T) { + _, err = testEngine.ImportFile(importPath) + assert.NoError(t, err) + }) } func TestDumpTables2(t *testing.T) {