This commit is contained in:
brookechen 2023-05-19 16:45:00 +08:00
parent d5a0855ace
commit 58c7a413a0
1 changed files with 405 additions and 405 deletions

View File

@ -5,534 +5,534 @@
package xorm package xorm
import ( import (
"bufio" "bufio"
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"io" "io"
"os" "os"
"strings" "strings"
"xorm.io/xorm/dialects" "xorm.io/xorm/dialects"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
// Ping test if database is ok // Ping test if database is ok
func (session *Session) Ping() error { func (session *Session) Ping() error {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
return session.DB().PingContext(session.ctx) return session.DB().PingContext(session.ctx)
} }
// CreateTable create a table according a bean // CreateTable create a table according a bean
func (session *Session) CreateTable(bean interface{}) error { func (session *Session) CreateTable(bean interface{}) error {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
return session.createTable(bean) return session.createTable(bean)
} }
func (session *Session) createTable(bean interface{}) error { func (session *Session) createTable(bean interface{}) error {
if err := session.statement.SetRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return err return err
} }
session.statement.RefTable.StoreEngine = session.statement.StoreEngine session.statement.RefTable.StoreEngine = session.statement.StoreEngine
session.statement.RefTable.Charset = session.statement.Charset session.statement.RefTable.Charset = session.statement.Charset
tableName := session.statement.TableName() tableName := session.statement.TableName()
refTable := session.statement.RefTable refTable := session.statement.RefTable
if refTable.AutoIncrement != "" && session.engine.dialect.Features().AutoincrMode == dialects.SequenceAutoincrMode { if refTable.AutoIncrement != "" && session.engine.dialect.Features().AutoincrMode == dialects.SequenceAutoincrMode {
sqlStr, err := session.engine.dialect.CreateSequenceSQL(context.Background(), session.engine.db, utils.SeqName(tableName)) sqlStr, err := session.engine.dialect.CreateSequenceSQL(context.Background(), session.engine.db, utils.SeqName(tableName))
if err != nil { if err != nil {
return err return err
} }
if _, err := session.exec(sqlStr); err != nil { if _, err := session.exec(sqlStr); err != nil {
return err return err
} }
} }
sqlStr, _, err := session.engine.dialect.CreateTableSQL(context.Background(), session.engine.db, refTable, tableName) sqlStr, _, err := session.engine.dialect.CreateTableSQL(context.Background(), session.engine.db, refTable, tableName)
if err != nil { if err != nil {
return err return err
} }
if _, err := session.exec(sqlStr); err != nil { if _, err := session.exec(sqlStr); err != nil {
return err return err
} }
return nil return nil
} }
// CreateIndexes create indexes // CreateIndexes create indexes
func (session *Session) CreateIndexes(bean interface{}) error { func (session *Session) CreateIndexes(bean interface{}) error {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
return session.createIndexes(bean) return session.createIndexes(bean)
} }
func (session *Session) createIndexes(bean interface{}) error { func (session *Session) createIndexes(bean interface{}) error {
if err := session.statement.SetRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return err return err
} }
sqls := session.statement.GenIndexSQL() sqls := session.statement.GenIndexSQL()
for _, sqlStr := range sqls { for _, sqlStr := range sqls {
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
if err != nil { if err != nil {
return err return err
} }
} }
return nil return nil
} }
// CreateUniques create uniques // CreateUniques create uniques
func (session *Session) CreateUniques(bean interface{}) error { func (session *Session) CreateUniques(bean interface{}) error {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
return session.createUniques(bean) return session.createUniques(bean)
} }
func (session *Session) createUniques(bean interface{}) error { func (session *Session) createUniques(bean interface{}) error {
if err := session.statement.SetRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return err return err
} }
sqls := session.statement.GenUniqueSQL() sqls := session.statement.GenUniqueSQL()
for _, sqlStr := range sqls { for _, sqlStr := range sqls {
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
if err != nil { if err != nil {
return err return err
} }
} }
return nil return nil
} }
// DropIndexes drop indexes // DropIndexes drop indexes
func (session *Session) DropIndexes(bean interface{}) error { func (session *Session) DropIndexes(bean interface{}) error {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
return session.dropIndexes(bean) return session.dropIndexes(bean)
} }
func (session *Session) dropIndexes(bean interface{}) error { func (session *Session) dropIndexes(bean interface{}) error {
if err := session.statement.SetRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return err return err
} }
sqls := session.statement.GenDelIndexSQL() sqls := session.statement.GenDelIndexSQL()
for _, sqlStr := range sqls { for _, sqlStr := range sqls {
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
if err != nil { if err != nil {
return err return err
} }
} }
return nil return nil
} }
// DropTable drop table will drop table if exist, if drop failed, it will return error // DropTable drop table will drop table if exist, if drop failed, it will return error
func (session *Session) DropTable(beanOrTableName interface{}) error { func (session *Session) DropTable(beanOrTableName interface{}) error {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
return session.dropTable(beanOrTableName) return session.dropTable(beanOrTableName)
} }
func (session *Session) dropTable(beanOrTableName interface{}) error { func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName := session.engine.TableName(beanOrTableName) tableName := session.engine.TableName(beanOrTableName)
sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true)) sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true))
if !checkIfExist { if !checkIfExist {
exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
if err != nil { if err != nil {
return err return err
} }
checkIfExist = exist checkIfExist = exist
} }
if !checkIfExist { if !checkIfExist {
return nil return nil
} }
if _, err := session.exec(sqlStr); err != nil { if _, err := session.exec(sqlStr); err != nil {
return err return err
} }
if session.engine.dialect.Features().AutoincrMode == dialects.IncrAutoincrMode { if session.engine.dialect.Features().AutoincrMode == dialects.IncrAutoincrMode {
return nil return nil
} }
var seqName = utils.SeqName(tableName) var seqName = utils.SeqName(tableName)
exist, err := session.engine.dialect.IsSequenceExist(session.ctx, session.getQueryer(), seqName) exist, err := session.engine.dialect.IsSequenceExist(session.ctx, session.getQueryer(), seqName)
if err != nil { if err != nil {
return err return err
} }
if !exist { if !exist {
return nil return nil
} }
sqlStr, err = session.engine.dialect.DropSequenceSQL(seqName) sqlStr, err = session.engine.dialect.DropSequenceSQL(seqName)
if err != nil { if err != nil {
return err return err
} }
_, err = session.exec(sqlStr) _, err = session.exec(sqlStr)
return err return err
} }
// IsTableExist if a table is exist // IsTableExist if a table is exist
func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) { func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
tableName := session.engine.TableName(beanOrTableName) tableName := session.engine.TableName(beanOrTableName)
return session.isTableExist(tableName) return session.isTableExist(tableName)
} }
func (session *Session) isTableExist(tableName string) (bool, error) { func (session *Session) isTableExist(tableName string) (bool, error) {
return session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) return session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
} }
// IsTableEmpty if table have any records // IsTableEmpty if table have any records
func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
return session.isTableEmpty(session.engine.TableName(bean)) return session.isTableEmpty(session.engine.TableName(bean))
} }
func (session *Session) isTableEmpty(tableName string) (bool, error) { func (session *Session) isTableEmpty(tableName string) (bool, error) {
var total int64 var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true))) sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true)))
err := session.queryRow(sqlStr).Scan(&total) err := session.queryRow(sqlStr).Scan(&total)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = nil err = nil
} }
return true, err return true, err
} }
return total == 0, nil return total == 0, nil
} }
func (session *Session) addColumn(colName string) error { func (session *Session) addColumn(colName string) error {
col := session.statement.RefTable.GetColumn(colName) col := session.statement.RefTable.GetColumn(colName)
sql := session.engine.dialect.AddColumnSQL(session.statement.TableName(), col) sql := session.engine.dialect.AddColumnSQL(session.statement.TableName(), col)
_, err := session.exec(sql) _, err := session.exec(sql)
return err return err
} }
func (session *Session) addIndex(tableName, idxName string) error { func (session *Session) addIndex(tableName, idxName string) error {
index := session.statement.RefTable.Indexes[idxName] index := session.statement.RefTable.Indexes[idxName]
sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index) sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index)
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
return err return err
} }
func (session *Session) addUnique(tableName, uqeName string) error { func (session *Session) addUnique(tableName, uqeName string) error {
index := session.statement.RefTable.Indexes[uqeName] index := session.statement.RefTable.Indexes[uqeName]
sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index) sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index)
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
return err return err
} }
// Sync2 synchronize structs to database tables // Sync2 synchronize structs to database tables
// Depricated // Depricated
func (session *Session) Sync2(beans ...interface{}) error { func (session *Session) Sync2(beans ...interface{}) error {
return session.Sync(beans...) return session.Sync(beans...)
} }
// Sync synchronize structs to database tables // Sync synchronize structs to database tables
func (session *Session) Sync(beans ...interface{}) error { func (session *Session) Sync(beans ...interface{}) error {
engine := session.engine engine := session.engine
if session.isAutoClose { if session.isAutoClose {
session.isAutoClose = false session.isAutoClose = false
defer session.Close() defer session.Close()
} }
tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx) tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx)
if err != nil { if err != nil {
return err return err
} }
session.autoResetStatement = false session.autoResetStatement = false
defer func() { defer func() {
session.autoResetStatement = true session.autoResetStatement = true
session.resetStatement() session.resetStatement()
}() }()
for _, bean := range beans { for _, bean := range beans {
v := utils.ReflectValue(bean) v := utils.ReflectValue(bean)
table, err := engine.tagParser.ParseWithCache(v) table, err := engine.tagParser.ParseWithCache(v)
if err != nil { if err != nil {
return err return err
} }
var tbName string var tbName string
if len(session.statement.AltTableName) > 0 { if len(session.statement.AltTableName) > 0 {
tbName = session.statement.AltTableName tbName = session.statement.AltTableName
} else { } else {
tbName = engine.TableName(bean) tbName = engine.TableName(bean)
} }
tbNameWithSchema := engine.tbNameWithSchema(tbName) tbNameWithSchema := engine.tbNameWithSchema(tbName)
var oriTable *schemas.Table var oriTable *schemas.Table
for _, tb := range tables { for _, tb := range tables {
if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) { if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
oriTable = tb oriTable = tb
break break
} }
} }
// this is a new table // this is a new table
if oriTable == nil { if oriTable == nil {
err = session.StoreEngine(session.statement.StoreEngine).createTable(bean) err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
if err != nil { if err != nil {
return err return err
} }
err = session.createUniques(bean) err = session.createUniques(bean)
if err != nil { if err != nil {
return err return err
} }
err = session.createIndexes(bean) err = session.createIndexes(bean)
if err != nil { if err != nil {
return err return err
} }
continue continue
} }
// this will modify an old table // this will modify an old table
if err = engine.loadTableInfo(oriTable); err != nil { if err = engine.loadTableInfo(oriTable); err != nil {
return err return err
} }
// check columns // check columns
for _, col := range table.Columns() { for _, col := range table.Columns() {
var oriCol *schemas.Column var oriCol *schemas.Column
for _, col2 := range oriTable.Columns() { for _, col2 := range oriTable.Columns() {
if strings.EqualFold(col.Name, col2.Name) { if strings.EqualFold(col.Name, col2.Name) {
oriCol = col2 oriCol = col2
break break
} }
} }
// column is not exist on table // column is not exist on table
if oriCol == nil { if oriCol == nil {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema) session.statement.SetTableName(tbNameWithSchema)
if err = session.addColumn(col.Name); err != nil { if err = session.addColumn(col.Name); err != nil {
return err return err
} }
continue continue
} }
err = nil err = nil
expectedType := engine.dialect.SQLType(col) expectedType := engine.dialect.SQLType(col)
curType := engine.dialect.SQLType(oriCol) curType := engine.dialect.SQLType(oriCol)
if expectedType != curType { if expectedType != curType {
if expectedType == schemas.Text && if expectedType == schemas.Text &&
strings.HasPrefix(curType, schemas.Varchar) { strings.HasPrefix(curType, schemas.Varchar) {
// currently only support mysql & postgres // currently only support mysql & postgres
if engine.dialect.URI().DBType == schemas.MYSQL || if engine.dialect.URI().DBType == schemas.MYSQL ||
engine.dialect.URI().DBType == schemas.POSTGRES { engine.dialect.URI().DBType == schemas.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n", engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbNameWithSchema, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} else { } else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbNameWithSchema, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
} }
} else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) { } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
if engine.dialect.URI().DBType == schemas.MYSQL { if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length) tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} }
} }
} else { } else {
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) { if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
tbNameWithSchema, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
} }
} }
} }
} else if expectedType == schemas.Varchar { } else if expectedType == schemas.Varchar {
if engine.dialect.URI().DBType == schemas.MYSQL { if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length) tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} }
} }
} else if col.Comment != oriCol.Comment { } else if col.Comment != oriCol.Comment {
if engine.dialect.URI().DBType == schemas.POSTGRES { if engine.dialect.URI().DBType == schemas.POSTGRES {
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} }
} }
if col.Default != oriCol.Default { if col.Default != oriCol.Default {
switch { switch {
case col.IsAutoIncrement: // For autoincrement column, don't check default case col.IsAutoIncrement: // For autoincrement column, don't check default
case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) && case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) &&
((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") || ((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") ||
(strings.EqualFold(col.Default, "false") && oriCol.Default == "0")): (strings.EqualFold(col.Default, "false") && oriCol.Default == "0")):
default: default:
engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s", engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s",
tbName, col.Name, oriCol.Default, col.Default) tbName, col.Name, oriCol.Default, col.Default)
} }
} }
if col.Nullable != oriCol.Nullable { if col.Nullable != oriCol.Nullable {
engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v", engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v",
tbName, col.Name, oriCol.Nullable, col.Nullable) tbName, col.Name, oriCol.Nullable, col.Nullable)
} }
if err != nil { if err != nil {
return err return err
} }
} }
var foundIndexNames = make(map[string]bool) var foundIndexNames = make(map[string]bool)
var addedNames = make(map[string]*schemas.Index) var addedNames = make(map[string]*schemas.Index)
for name, index := range table.Indexes { for name, index := range table.Indexes {
var oriIndex *schemas.Index var oriIndex *schemas.Index
for name2, index2 := range oriTable.Indexes { for name2, index2 := range oriTable.Indexes {
if index.Equal(index2) { if index.Equal(index2) {
oriIndex = index2 oriIndex = index2
foundIndexNames[name2] = true foundIndexNames[name2] = true
break break
} }
} }
if oriIndex != nil { if oriIndex != nil {
if oriIndex.Type != index.Type { if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex) sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex)
_, err = session.exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
return err return err
} }
oriIndex = nil oriIndex = nil
} }
} }
if oriIndex == nil { if oriIndex == nil {
addedNames[name] = index addedNames[name] = index
} }
} }
for name2, index2 := range oriTable.Indexes { for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok { if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2) sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2)
_, err = session.exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
return err return err
} }
} }
} }
for name, index := range addedNames { for name, index := range addedNames {
if index.Type == schemas.UniqueType { if index.Type == schemas.UniqueType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema) session.statement.SetTableName(tbNameWithSchema)
err = session.addUnique(tbNameWithSchema, name) err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == schemas.IndexType { } else if index.Type == schemas.IndexType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema) session.statement.SetTableName(tbNameWithSchema)
err = session.addIndex(tbNameWithSchema, name) err = session.addIndex(tbNameWithSchema, name)
} }
if err != nil { if err != nil {
return err return err
} }
} }
// check all the columns which removed from struct fields but left on database tables. // check all the columns which removed from struct fields but left on database tables.
for _, colName := range oriTable.ColumnsSeq() { for _, colName := range oriTable.ColumnsSeq() {
if table.GetColumn(colName) == nil { if table.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName) engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName)
} }
} }
} }
return nil return nil
} }
// ImportFile SQL DDL file // ImportFile SQL DDL file
func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) { func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) {
file, err := os.Open(ddlPath) file, err := os.Open(ddlPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer file.Close() defer file.Close()
return session.Import(file) return session.Import(file)
} }
// Import SQL DDL from io.Reader // Import SQL DDL from io.Reader
func (session *Session) Import(r io.Reader) ([]sql.Result, error) { func (session *Session) Import(r io.Reader) ([]sql.Result, error) {
var ( var (
results []sql.Result results []sql.Result
lastError error lastError error
inSingleQuote bool inSingleQuote bool
startComment bool startComment bool
) )
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
return 0, nil, nil return 0, nil, nil
} }
var oriInSingleQuote = inSingleQuote var oriInSingleQuote = inSingleQuote
for i, b := range data { for i, b := range data {
if startComment { if startComment {
if b == '\n' { if b == '\n' {
startComment = false startComment = false
} }
} else { } else {
if !inSingleQuote && i > 0 && data[i-1] == '-' && data[i] == '-' { if !inSingleQuote && i > 0 && data[i-1] == '-' && data[i] == '-' {
startComment = true startComment = true
continue continue
} }
if b == '\'' { if b == '\'' {
inSingleQuote = !inSingleQuote inSingleQuote = !inSingleQuote
} }
if !inSingleQuote && b == ';' { if !inSingleQuote && b == ';' {
return i + 1, data[0:i], nil return i + 1, data[0:i], nil
} }
} }
} }
// If we're at EOF, we have a final, non-terminated line. Return it. // If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF { if atEOF {
return len(data), data, nil return len(data), data, nil
} }
inSingleQuote = oriInSingleQuote inSingleQuote = oriInSingleQuote
// Request more data. // Request more data.
return 0, nil, nil return 0, nil, nil
} }
scanner.Split(semiColSpliter) scanner.Split(semiColSpliter)
for scanner.Scan() { for scanner.Scan() {
query := strings.Trim(scanner.Text(), " \t\n\r") query := strings.Trim(scanner.Text(), " \t\n\r")
if len(query) > 0 { if len(query) > 0 {
result, err := session.Exec(query) result, err := session.Exec(query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
results = append(results, result) results = append(results, result)
} }
} }
return results, lastError return results, lastError
} }