diff --git a/dialects/postgres.go b/dialects/postgres.go index 2111f195..6fd9e64a 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1071,6 +1071,8 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att col.DefaultIsEmpty = false if strings.HasPrefix(col.Default, "nextval(") { col.IsAutoIncrement = true + col.Default = "" + col.DefaultIsEmpty = true } } else { col.DefaultIsEmpty = true diff --git a/engine.go b/engine.go index c6fd5c7e..99412c4f 100644 --- a/engine.go +++ b/engine.go @@ -5,8 +5,6 @@ package xorm import ( - "bufio" - "bytes" "context" "database/sql" "errors" @@ -389,6 +387,10 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } } + if len(table.PKColumns()) > 0 && engine.dialect.URI().DBType == schemas.MSSQL { + fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", table.Name) + } + for _, index := range table.Indexes { _, err = io.WriteString(w, dialect.CreateIndexSQL(table.Name, index)+";\n") if err != nil { @@ -1160,49 +1162,16 @@ func (engine *Engine) SumsInt(bean interface{}, colNames ...string) ([]int64, er // ImportFile SQL DDL file func (engine *Engine) ImportFile(ddlPath string) ([]sql.Result, error) { - file, err := os.Open(ddlPath) - if err != nil { - return nil, err - } - defer file.Close() - return engine.Import(file) + session := engine.NewSession() + defer session.Close() + return session.ImportFile(ddlPath) } // Import SQL DDL from io.Reader func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) { - var results []sql.Result - var lastError error - scanner := bufio.NewScanner(r) - - semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, ';'); i >= 0 { - return i + 1, data[0:i], nil - } - // If we're at EOF, we have a final, non-terminated line. Return it. - if atEOF { - return len(data), data, nil - } - // Request more data. - return 0, nil, nil - } - - scanner.Split(semiColSpliter) - - for scanner.Scan() { - query := strings.Trim(scanner.Text(), " \t\n\r") - if len(query) > 0 { - result, err := engine.DB().ExecContext(engine.defaultContext, query) - results = append(results, result) - if err != nil { - return nil, err - } - } - } - - return results, lastError + session := engine.NewSession() + defer session.Close() + return session.Import(r) } // nowTime return current time diff --git a/engine_test.go b/engine_test.go index b82ee96a..459d63c4 100644 --- a/engine_test.go +++ b/engine_test.go @@ -7,6 +7,7 @@ package xorm import ( "context" "fmt" + "os" "testing" "time" @@ -64,3 +65,35 @@ func TestAutoTransaction(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, false, has) } + +func TestDump(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type TestDumpStruct struct { + Id int64 + Name string + } + + assertSync(t, new(TestDumpStruct)) + + testEngine.Insert([]TestDumpStruct{ + {Name: "1"}, + {Name: "2\n"}, + {Name: "3;"}, + {Name: "4\n;\n''"}, + {Name: "5'\n"}, + }) + + fp := testEngine.Dialect().URI().DBName + ".sql" + os.Remove(fp) + assert.NoError(t, testEngine.DumpAllToFile(fp)) + + assert.NoError(t, prepareEngine()) + + sess := testEngine.NewSession() + defer sess.Close() + assert.NoError(t, sess.Begin()) + _, err := sess.ImportFile(fp) + assert.NoError(t, err) + assert.NoError(t, sess.Commit()) +} diff --git a/interface.go b/interface.go index be4da707..262a2cfe 100644 --- a/interface.go +++ b/interface.go @@ -92,6 +92,7 @@ type EngineInterface interface { GetTableMapper() names.Mapper GetTZDatabase() *time.Location GetTZLocation() *time.Location + ImportFile(fp string) ([]sql.Result, error) MapCacher(interface{}, caches.Cacher) error NewSession() *Session NoAutoTime() *Session diff --git a/session_schema.go b/session_schema.go index ca4e2d75..84eb586e 100644 --- a/session_schema.go +++ b/session_schema.go @@ -5,8 +5,11 @@ package xorm import ( + "bufio" "database/sql" "fmt" + "io" + "os" "strings" "xorm.io/xorm/internal/utils" @@ -432,3 +435,56 @@ func (session *Session) Sync2(beans ...interface{}) error { return nil } + +// ImportFile SQL DDL file +func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) { + file, err := os.Open(ddlPath) + if err != nil { + return nil, err + } + defer file.Close() + return session.Import(file) +} + +// Import SQL DDL from io.Reader +func (session *Session) Import(r io.Reader) ([]sql.Result, error) { + var results []sql.Result + var lastError error + scanner := bufio.NewScanner(r) + + var inSingleQuote bool + semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + for i, b := range data { + if b == '\'' { + inSingleQuote = !inSingleQuote + } + if !inSingleQuote && b == ';' { + return i + 1, data[0:i], nil + } + } + // If we're at EOF, we have a final, non-terminated line. Return it. + if atEOF { + return len(data), data, nil + } + // Request more data. + return 0, nil, nil + } + + scanner.Split(semiColSpliter) + + for scanner.Scan() { + query := strings.Trim(scanner.Text(), " \t\n\r") + if len(query) > 0 { + result, err := session.Exec(query) + results = append(results, result) + if err != nil { + return nil, err + } + } + } + + return results, lastError +} diff --git a/session_schema_test.go b/session_schema_test.go index a20a1f97..37a1246b 100644 --- a/session_schema_test.go +++ b/session_schema_test.go @@ -6,7 +6,6 @@ package xorm import ( "fmt" - "os" "testing" "time" @@ -210,14 +209,6 @@ func TestCustomTableName(t *testing.T) { assert.NoError(t, testEngine.CreateTables(c)) } -func TestDump(t *testing.T) { - assert.NoError(t, prepareEngine()) - - fp := testEngine.Dialect().URI().DBName + ".sql" - os.Remove(fp) - assert.NoError(t, testEngine.DumpAllToFile(fp)) -} - type IndexOrUnique struct { Id int64 Index int `xorm:"index"`