Merge branch 'master' into patch

This commit is contained in:
Lunny Xiao 2022-01-16 18:05:07 +08:00
commit d7bf314e72
6 changed files with 226 additions and 17 deletions

View File

@ -48,6 +48,16 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t
} }
dt = dt.In(convertedLocation) dt = dt.In(convertedLocation)
return &dt, nil return &dt, nil
} else if len(s) == 10 && s[4] == '-' {
if s == "0000-00-00" || s == "0001-01-01" {
return &time.Time{}, nil
}
dt, err := time.ParseInLocation("2006-01-02", s, originalLocation)
if err != nil {
return nil, err
}
dt = dt.In(convertedLocation)
return &dt, nil
} else { } else {
i, err := strconv.ParseInt(s, 10, 64) i, err := strconv.ParseInt(s, 10, 64)
if err == nil { if err == nil {

View File

@ -16,6 +16,7 @@ func TestString2Time(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
var kases = map[string]time.Time{ var kases = map[string]time.Time{
"2021-08-10": time.Date(2021, 8, 10, 8, 0, 0, 0, expectedLoc),
"2021-06-06T22:58:20+08:00": time.Date(2021, 6, 6, 22, 58, 20, 0, expectedLoc), "2021-06-06T22:58:20+08:00": time.Date(2021, 6, 6, 22, 58, 20, 0, expectedLoc),
"2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc), "2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc),
"2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 04, 0, expectedLoc), "2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 04, 0, expectedLoc),

213
engine.go
View File

@ -11,7 +11,9 @@ import (
"io" "io"
"os" "os"
"reflect" "reflect"
"regexp"
"runtime" "runtime"
"strconv"
"strings" "strings"
"time" "time"
@ -438,18 +440,18 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return engine.dumpTables(context.Background(), tables, w, tp...) return engine.dumpTables(context.Background(), tables, w, tp...)
} }
func formatBool(s string, dstDialect dialects.Dialect) string { func formatBool(s bool, dstDialect dialects.Dialect) string {
if dstDialect.URI().DBType == schemas.MSSQL { if dstDialect.URI().DBType != schemas.POSTGRES {
switch s { if s {
case "true":
return "1" return "1"
case "false":
return "0"
} }
return "0"
} }
return s return strconv.FormatBool(s)
} }
var controlCharactersRe = regexp.MustCompile(`[\x00-\x1f\x7f]+`)
// dumpTables dump database all table structs and data to w with specify db type // dumpTables dump database all table structs and data to w with specify db type
func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
var dstDialect dialects.Dialect var dstDialect dialects.Dialect
@ -465,7 +467,10 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
destURI := dialects.URI{ destURI := dialects.URI{
DBType: tp[0], DBType: tp[0],
DBName: uri.DBName, DBName: uri.DBName,
Schema: uri.Schema, // DO NOT SET SCHEMA HERE
}
if tp[0] == schemas.POSTGRES {
destURI.Schema = engine.dialect.URI().Schema
} }
if err := dstDialect.Init(&destURI); err != nil { if err := dstDialect.Init(&destURI); err != nil {
return err return err
@ -480,6 +485,13 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
return err return err
} }
if dstDialect.URI().DBType == schemas.MYSQL {
// For MySQL set NO_BACKLASH_ESCAPES so that strings work properly
if _, err := io.WriteString(w, "SET sql_mode='NO_BACKSLASH_ESCAPES';\n"); err != nil {
return err
}
}
for i, table := range tables { for i, table := range tables {
dstTable := table dstTable := table
if table.Type != nil { if table.Type != nil {
@ -581,8 +593,13 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
return err return err
} }
} else { } else {
if stp.IsBool() || (dstDialect.URI().DBType == schemas.MSSQL && strings.EqualFold(stp.Name, schemas.Bit)) { if table.Columns()[i].SQLType.IsBool() || stp.IsBool() || (dstDialect.URI().DBType == schemas.MSSQL && strings.EqualFold(stp.Name, schemas.Bit)) {
if _, err = io.WriteString(w, formatBool(s.String, dstDialect)); err != nil { val, err := strconv.ParseBool(s.String)
if err != nil {
return err
}
if _, err = io.WriteString(w, formatBool(val, dstDialect)); err != nil {
return err return err
} }
} else if stp.IsNumeric() { } else if stp.IsNumeric() {
@ -594,6 +611,182 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
if _, err = io.WriteString(w, "'"+r+"'"); err != nil { if _, err = io.WriteString(w, "'"+r+"'"); err != nil {
return err return err
} }
} else if len(s.String) == 0 {
if _, err := io.WriteString(w, "''"); err != nil {
return err
}
} else if dstDialect.URI().DBType == schemas.POSTGRES {
if dstTable.Columns()[i].SQLType.IsBlob() {
// Postgres has the escape format and we should use that for bytea data
if _, err := fmt.Fprintf(w, "'\\x%x'", s.String); err != nil {
return err
}
} else {
// Postgres concatentates strings using || (NOTE: a NUL byte in a text segment will fail)
toCheck := strings.ReplaceAll(s.String, "'", "''")
for len(toCheck) > 0 {
loc := controlCharactersRe.FindStringIndex(toCheck)
if loc == nil {
if _, err := io.WriteString(w, "'"+toCheck+"'"); err != nil {
return err
}
break
}
if loc[0] > 0 {
if _, err := io.WriteString(w, "'"+toCheck[:loc[0]]+"' || "); err != nil {
return err
}
}
if _, err := io.WriteString(w, "e'"); err != nil {
return err
}
for i := loc[0]; i < loc[1]; i++ {
if _, err := fmt.Fprintf(w, "\\x%02x", toCheck[i]); err != nil {
return err
}
}
toCheck = toCheck[loc[1]:]
if len(toCheck) > 0 {
if _, err := io.WriteString(w, "' || "); err != nil {
return err
}
} else {
if _, err := io.WriteString(w, "'"); err != nil {
return err
}
}
}
}
} else if dstDialect.URI().DBType == schemas.MYSQL {
loc := controlCharactersRe.FindStringIndex(s.String)
if loc == nil {
if _, err := io.WriteString(w, "'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil {
return err
}
} else {
if _, err := io.WriteString(w, "CONCAT("); err != nil {
return err
}
toCheck := strings.ReplaceAll(s.String, "'", "''")
for len(toCheck) > 0 {
loc := controlCharactersRe.FindStringIndex(toCheck)
if loc == nil {
if _, err := io.WriteString(w, "'"+toCheck+"')"); err != nil {
return err
}
break
}
if loc[0] > 0 {
if _, err := io.WriteString(w, "'"+toCheck[:loc[0]]+"', "); err != nil {
return err
}
}
for i := loc[0]; i < loc[1]-1; i++ {
if _, err := io.WriteString(w, "CHAR("+strconv.Itoa(int(toCheck[i]))+"), "); err != nil {
return err
}
}
char := toCheck[loc[1]-1]
toCheck = toCheck[loc[1]:]
if len(toCheck) > 0 {
if _, err := io.WriteString(w, "CHAR("+strconv.Itoa(int(char))+"), "); err != nil {
return err
}
} else {
if _, err = io.WriteString(w, "CHAR("+strconv.Itoa(int(char))+"))"); err != nil {
return err
}
}
}
}
} else if dstDialect.URI().DBType == schemas.SQLITE {
if dstTable.Columns()[i].SQLType.IsBlob() {
// SQLite has its escape format
if _, err := fmt.Fprintf(w, "X'%x'", s.String); err != nil {
return err
}
} else {
// SQLite concatentates strings using || (NOTE: a NUL byte in a text segment will fail)
toCheck := strings.ReplaceAll(s.String, "'", "''")
for len(toCheck) > 0 {
loc := controlCharactersRe.FindStringIndex(toCheck)
if loc == nil {
if _, err := io.WriteString(w, "'"+toCheck+"'"); err != nil {
return err
}
break
}
if loc[0] > 0 {
if _, err := io.WriteString(w, "'"+toCheck[:loc[0]]+"' || "); err != nil {
return err
}
}
if _, err := fmt.Fprintf(w, "X'%x'", toCheck[loc[0]:loc[1]]); err != nil {
return err
}
toCheck = toCheck[loc[1]:]
if len(toCheck) > 0 {
if _, err := io.WriteString(w, " || "); err != nil {
return err
}
}
}
}
} else if dstDialect.URI().DBType == schemas.DAMENG || dstDialect.URI().DBType == schemas.ORACLE {
if dstTable.Columns()[i].SQLType.IsBlob() {
// ORACLE/DAMENG uses HEXTORAW
if _, err := fmt.Fprintf(w, "HEXTORAW('%x')", s.String); err != nil {
return err
}
} else {
// ORACLE/DAMENG concatentates strings in multiple ways but uses CHAR and has CONCAT
// (NOTE: a NUL byte in a text segment will fail)
if _, err := io.WriteString(w, "CONCAT("); err != nil {
return err
}
toCheck := strings.ReplaceAll(s.String, "'", "''")
for len(toCheck) > 0 {
loc := controlCharactersRe.FindStringIndex(toCheck)
if loc == nil {
if _, err := io.WriteString(w, "'"+toCheck+"')"); err != nil {
return err
}
break
}
if loc[0] > 0 {
if _, err := io.WriteString(w, "'"+toCheck[:loc[0]]+"', "); err != nil {
return err
}
}
for i := loc[0]; i < loc[1]-1; i++ {
if _, err := io.WriteString(w, "CHAR("+strconv.Itoa(int(toCheck[i]))+"), "); err != nil {
return err
}
}
char := toCheck[loc[1]-1]
toCheck = toCheck[loc[1]:]
if len(toCheck) > 0 {
if _, err := io.WriteString(w, "CHAR("+strconv.Itoa(int(char))+"), "); err != nil {
return err
}
} else {
if _, err = io.WriteString(w, "CHAR("+strconv.Itoa(int(char))+"))"); err != nil {
return err
}
}
}
}
} else if dstDialect.URI().DBType == schemas.MSSQL {
if dstTable.Columns()[i].SQLType.IsBlob() {
// MSSQL uses CONVERT(VARBINARY(MAX), '0xDEADBEEF', 1)
if _, err := fmt.Fprintf(w, "CONVERT(VARBINARY(MAX), '0x%x', 1)", s.String); err != nil {
return err
}
} else {
if _, err = io.WriteString(w, "N'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil {
return err
}
}
} else { } else {
if _, err = io.WriteString(w, "'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil { if _, err = io.WriteString(w, "'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil {
return err return err

View File

@ -143,6 +143,7 @@ func TestDumpTables(t *testing.T) {
type TestDumpTableStruct struct { type TestDumpTableStruct struct {
Id int64 Id int64
Data []byte `xorm:"BLOB"`
Name string Name string
IsMan bool IsMan bool
Created time.Time `xorm:"created"` Created time.Time `xorm:"created"`
@ -152,10 +153,14 @@ func TestDumpTables(t *testing.T) {
_, err := testEngine.Insert([]TestDumpTableStruct{ _, err := testEngine.Insert([]TestDumpTableStruct{
{Name: "1", IsMan: true}, {Name: "1", IsMan: true},
{Name: "2\n"}, {Name: "2\n", Data: []byte{'\000', '\001', '\002'}},
{Name: "3;"}, {Name: "3;", Data: []byte("0x000102")},
{Name: "4\n;\n''"}, {Name: "4\n;\n''", Data: []byte("Help")},
{Name: "5'\n"}, {Name: "5'\n", Data: []byte("0x48656c70")},
{Name: "6\\n'\n", Data: []byte("48656c70")},
{Name: "7\\n'\r\n", Data: []byte("7\\n'\r\n")},
{Name: "x0809ee"},
{Name: "090a10"},
}) })
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -334,7 +334,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
fmt.Fprint(&buf, " LIMIT ", *pLimitN) fmt.Fprint(&buf, " LIMIT ", *pLimitN)
} }
} else if dialect.URI().DBType == schemas.ORACLE { } else if dialect.URI().DBType == schemas.ORACLE {
if statement.Start != 0 && pLimitN != nil { if pLimitN != nil {
oldString := buf.String() oldString := buf.String()
buf.Reset() buf.Reset()
rawColStr := columnStr rawColStr := columnStr

View File

@ -254,9 +254,9 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
switch elemType.Kind() { switch elemType.Kind() {
case reflect.Slice: case reflect.Slice:
err = rows.ScanSlice(bean) err = session.getSlice(rows, types, fields, bean)
case reflect.Map: case reflect.Map:
err = rows.ScanMap(bean) err = session.getMap(rows, types, fields, bean)
default: default:
err = rows.Scan(bean) err = rows.Scan(bean)
} }