This commit is contained in:
Lunny Xiao 2020-03-14 14:59:09 +08:00
parent 9c55057ef2
commit d5a638544e
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
3 changed files with 67 additions and 47 deletions

View File

@ -5,7 +5,6 @@
package xorm package xorm
import ( import (
"bufio"
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
@ -1163,55 +1162,16 @@ func (engine *Engine) SumsInt(bean interface{}, colNames ...string) ([]int64, er
// ImportFile SQL DDL file // ImportFile SQL DDL file
func (engine *Engine) ImportFile(ddlPath string) ([]sql.Result, error) { func (engine *Engine) ImportFile(ddlPath string) ([]sql.Result, error) {
file, err := os.Open(ddlPath) session := engine.NewSession()
if err != nil { defer session.Close()
return nil, err return session.ImportFile(ddlPath)
}
defer file.Close()
return engine.Import(file)
} }
// Import SQL DDL from io.Reader // Import SQL DDL from io.Reader
func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) { func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) {
var results []sql.Result session := engine.NewSession()
var lastError error defer session.Close()
scanner := bufio.NewScanner(r) return session.Import(r)
var inSingleQuote bool
semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
for i, b := range data {
if b == '\'' {
inSingleQuote = !inSingleQuote
}
if !inSingleQuote && b == ';' {
return i + 1, data[0:i], nil
}
}
// If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
scanner.Split(semiColSpliter)
for scanner.Scan() {
query := strings.Trim(scanner.Text(), " \t\n\r")
if len(query) > 0 {
result, err := engine.DB().ExecContext(engine.defaultContext, query)
results = append(results, result)
if err != nil {
return nil, err
}
}
}
return results, lastError
} }
// nowTime return current time // nowTime return current time

View File

@ -90,6 +90,10 @@ func TestDump(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
_, err := testEngine.ImportFile(fp) sess := testEngine.NewSession()
defer sess.Close()
assert.NoError(t, sess.Begin())
_, err := sess.ImportFile(fp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NoError(t, sess.Commit())
} }

View File

@ -5,8 +5,11 @@
package xorm package xorm
import ( import (
"bufio"
"database/sql" "database/sql"
"fmt" "fmt"
"io"
"os"
"strings" "strings"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
@ -432,3 +435,56 @@ func (session *Session) Sync2(beans ...interface{}) error {
return nil return nil
} }
// ImportFile SQL DDL file
func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) {
file, err := os.Open(ddlPath)
if err != nil {
return nil, err
}
defer file.Close()
return session.Import(file)
}
// Import SQL DDL from io.Reader
func (session *Session) Import(r io.Reader) ([]sql.Result, error) {
var results []sql.Result
var lastError error
scanner := bufio.NewScanner(r)
var inSingleQuote bool
semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
for i, b := range data {
if b == '\'' {
inSingleQuote = !inSingleQuote
}
if !inSingleQuote && b == ';' {
return i + 1, data[0:i], nil
}
}
// If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
scanner.Split(semiColSpliter)
for scanner.Scan() {
query := strings.Trim(scanner.Text(), " \t\n\r")
if len(query) > 0 {
result, err := session.Exec(query)
results = append(results, result)
if err != nil {
return nil, err
}
}
}
return results, lastError
}