From d5a638544e05484867c7e37fe9ccb3ec23b4c32d Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 14 Mar 2020 14:59:09 +0800 Subject: [PATCH] Fix bug --- engine.go | 52 +++++-------------------------------------- engine_test.go | 6 ++++- session_schema.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 47 deletions(-) diff --git a/engine.go b/engine.go index b59494d3..99412c4f 100644 --- a/engine.go +++ b/engine.go @@ -5,7 +5,6 @@ package xorm import ( - "bufio" "context" "database/sql" "errors" @@ -1163,55 +1162,16 @@ func (engine *Engine) SumsInt(bean interface{}, colNames ...string) ([]int64, er // ImportFile SQL DDL file func (engine *Engine) ImportFile(ddlPath string) ([]sql.Result, error) { - file, err := os.Open(ddlPath) - if err != nil { - return nil, err - } - defer file.Close() - return engine.Import(file) + session := engine.NewSession() + defer session.Close() + return session.ImportFile(ddlPath) } // Import SQL DDL from io.Reader func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) { - var results []sql.Result - var lastError error - scanner := bufio.NewScanner(r) - - var inSingleQuote bool - semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - for i, b := range data { - if b == '\'' { - inSingleQuote = !inSingleQuote - } - if !inSingleQuote && b == ';' { - return i + 1, data[0:i], nil - } - } - // If we're at EOF, we have a final, non-terminated line. Return it. - if atEOF { - return len(data), data, nil - } - // Request more data. - return 0, nil, nil - } - - scanner.Split(semiColSpliter) - - for scanner.Scan() { - query := strings.Trim(scanner.Text(), " \t\n\r") - if len(query) > 0 { - result, err := engine.DB().ExecContext(engine.defaultContext, query) - results = append(results, result) - if err != nil { - return nil, err - } - } - } - - return results, lastError + session := engine.NewSession() + defer session.Close() + return session.Import(r) } // nowTime return current time diff --git a/engine_test.go b/engine_test.go index e61ff924..459d63c4 100644 --- a/engine_test.go +++ b/engine_test.go @@ -90,6 +90,10 @@ func TestDump(t *testing.T) { assert.NoError(t, prepareEngine()) - _, err := testEngine.ImportFile(fp) + sess := testEngine.NewSession() + defer sess.Close() + assert.NoError(t, sess.Begin()) + _, err := sess.ImportFile(fp) assert.NoError(t, err) + assert.NoError(t, sess.Commit()) } diff --git a/session_schema.go b/session_schema.go index ca4e2d75..84eb586e 100644 --- a/session_schema.go +++ b/session_schema.go @@ -5,8 +5,11 @@ package xorm import ( + "bufio" "database/sql" "fmt" + "io" + "os" "strings" "xorm.io/xorm/internal/utils" @@ -432,3 +435,56 @@ func (session *Session) Sync2(beans ...interface{}) error { return nil } + +// ImportFile SQL DDL file +func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) { + file, err := os.Open(ddlPath) + if err != nil { + return nil, err + } + defer file.Close() + return session.Import(file) +} + +// Import SQL DDL from io.Reader +func (session *Session) Import(r io.Reader) ([]sql.Result, error) { + var results []sql.Result + var lastError error + scanner := bufio.NewScanner(r) + + var inSingleQuote bool + semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + for i, b := range data { + if b == '\'' { + inSingleQuote = !inSingleQuote + } + if !inSingleQuote && b == ';' { + return i + 1, data[0:i], nil + } + } + // If we're at EOF, we have a final, non-terminated line. Return it. + if atEOF { + return len(data), data, nil + } + // Request more data. + return 0, nil, nil + } + + scanner.Split(semiColSpliter) + + for scanner.Scan() { + query := strings.Trim(scanner.Text(), " \t\n\r") + if len(query) > 0 { + result, err := session.Exec(query) + results = append(results, result) + if err != nil { + return nil, err + } + } + } + + return results, lastError +}