diff --git a/convert.go b/convert.go index 20a6e373..69277734 100644 --- a/convert.go +++ b/convert.go @@ -348,7 +348,6 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve *d = cloneBytes(s) return nil } - case time.Time: switch d := dest.(type) { case *string: diff --git a/convert/time.go b/convert/time.go index 8901279b..696b301c 100644 --- a/convert/time.go +++ b/convert/time.go @@ -19,7 +19,14 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t dt = dt.In(convertedLocation) return &dt, nil } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { - dt, err := time.ParseInLocation("2006-01-02T15:04:05Z", s, originalLocation) + dt, err := time.ParseInLocation(time.RFC3339, s, originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else if len(s) == 25 && s[10] == 'T' && s[19] == '+' && s[22] == ':' { + dt, err := time.Parse(time.RFC3339, s) if err != nil { return nil, err } diff --git a/convert/time_test.go b/convert/time_test.go new file mode 100644 index 00000000..ef01b362 --- /dev/null +++ b/convert/time_test.go @@ -0,0 +1,30 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestString2Time(t *testing.T) { + expectedLoc, err := time.LoadLocation("Asia/Shanghai") + assert.NoError(t, err) + + var kases = map[string]time.Time{ + "2021-06-06T22:58:20+08:00": time.Date(2021, 6, 6, 22, 58, 20, 0, expectedLoc), + "2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc), + "2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 04, 0, expectedLoc), + } + for layout, tm := range kases { + t.Run(layout, func(t *testing.T) { + target, err := String2Time(layout, time.UTC, expectedLoc) + assert.NoError(t, err) + assert.EqualValues(t, tm, *target) + }) + } +} diff --git a/engine.go b/engine.go index a45771a2..d3ee8a8c 100644 --- a/engine.go +++ b/engine.go @@ -13,7 +13,6 @@ import ( "os" "reflect" "runtime" - "strconv" "strings" "time" @@ -21,7 +20,6 @@ import ( "xorm.io/xorm/contexts" "xorm.io/xorm/core" "xorm.io/xorm/dialects" - "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/utils" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -446,93 +444,14 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return engine.dumpTables(tables, w, tp...) } -func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string { - if d == nil { - return "NULL" - } - - if dq, ok := d.(bool); ok && (dstDialect.URI().DBType == schemas.SQLITE || - dstDialect.URI().DBType == schemas.MSSQL) { - if dq { +func formatBool(s string, dstDialect dialects.Dialect) string { + if dstDialect.URI().DBType == schemas.MSSQL { + switch s { + case "true": return "1" + case "false": + return "0" } - return "0" - } - - if col.SQLType.IsText() { - var v string - switch reflect.TypeOf(d).Kind() { - case reflect.Struct, reflect.Array, reflect.Slice, reflect.Map: - bytes, err := json.DefaultJSONHandler.Marshal(d) - if err != nil { - v = fmt.Sprintf("%s", d) - } else { - v = string(bytes) - } - default: - v = fmt.Sprintf("%s", d) - } - - return "'" + strings.Replace(v, "'", "''", -1) + "'" - } else if col.SQLType.IsTime() { - if t, ok := d.(time.Time); ok { - return "'" + t.In(dbLocation).Format("2006-01-02 15:04:05") + "'" - } - var v = fmt.Sprintf("%s", d) - if strings.HasSuffix(v, " +0000 UTC") { - return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 UTC")]) - } else if strings.HasSuffix(v, " +0000 +0000") { - return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 +0000")]) - } - return "'" + strings.Replace(v, "'", "''", -1) + "'" - } else if col.SQLType.IsBlob() { - if reflect.TypeOf(d).Kind() == reflect.Slice { - return fmt.Sprintf("%s", dstDialect.FormatBytes(d.([]byte))) - } else if reflect.TypeOf(d).Kind() == reflect.String { - return fmt.Sprintf("'%s'", d.(string)) - } - } else if col.SQLType.IsNumeric() { - switch reflect.TypeOf(d).Kind() { - case reflect.Slice: - if col.SQLType.Name == schemas.Bool { - return fmt.Sprintf("%v", strconv.FormatBool(d.([]byte)[0] != byte('0'))) - } - return fmt.Sprintf("%s", string(d.([]byte))) - case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: - if col.SQLType.Name == schemas.Bool { - v := reflect.ValueOf(d).Int() > 0 - if dstDialect.URI().DBType == schemas.SQLITE { - if v { - return "1" - } - return "0" - } - return fmt.Sprintf("%v", strconv.FormatBool(v)) - } - return fmt.Sprintf("%d", d) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if col.SQLType.Name == schemas.Bool { - v := reflect.ValueOf(d).Uint() > 0 - if dstDialect.URI().DBType == schemas.SQLITE { - if v { - return "1" - } - return "0" - } - return fmt.Sprintf("%v", strconv.FormatBool(v)) - } - return fmt.Sprintf("%d", d) - default: - return fmt.Sprintf("%v", d) - } - } - - s := fmt.Sprintf("%v", d) - if strings.Contains(s, ":") || strings.Contains(s, "-") { - if strings.HasSuffix(s, " +0000 UTC") { - return fmt.Sprintf("'%s'", s[0:len(s)-len(" +0000 UTC")]) - } - return fmt.Sprintf("'%s'", s) } return s } @@ -545,7 +464,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } else { dstDialect = dialects.QueryDialect(tp[0]) if dstDialect == nil { - return errors.New("Unsupported database type") + return fmt.Errorf("unsupported database type %v", tp[0]) } uri := engine.dialect.URI() @@ -619,73 +538,68 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } defer rows.Close() - if table.Type != nil { - sess := engine.NewSession() - defer sess.Close() - for rows.Next() { - beanValue := reflect.New(table.Type) - bean := beanValue.Interface() - fields, err := rows.Columns() - if err != nil { - return err - } - scanResults, err := sess.row2Slice(rows, fields, bean) - if err != nil { - return err - } + types, err := rows.ColumnTypes() + if err != nil { + return err + } - dataStruct := utils.ReflectValue(bean) - _, err = sess.slice2Bean(scanResults, fields, bean, &dataStruct, table) - if err != nil { - return err - } + sess := engine.NewSession() + defer sess.Close() + for rows.Next() { + _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") + if err != nil { + return err + } - _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") - if err != nil { - return err - } - - var temp string - for _, d := range dstCols { - col := table.GetColumn(d) - if col == nil { - return errors.New("unknown column error") + scanResults, err := sess.engine.scanStringInterface(rows, types) + if err != nil { + return err + } + for i, scanResult := range scanResults { + stp := schemas.SQLType{Name: types[i].DatabaseTypeName()} + if stp.IsNumeric() { + s := scanResult.(*sql.NullString) + if s.Valid { + if _, err = io.WriteString(w, formatBool(s.String, dstDialect)); err != nil { + return err + } + } else { + if _, err = io.WriteString(w, "NULL"); err != nil { + return err + } + } + } else if stp.IsBool() { + s := scanResult.(*sql.NullString) + if s.Valid { + if _, err = io.WriteString(w, formatBool(s.String, dstDialect)); err != nil { + return err + } + } else { + if _, err = io.WriteString(w, "NULL"); err != nil { + return err + } + } + } else { + s := scanResult.(*sql.NullString) + if s.Valid { + if _, err = io.WriteString(w, "'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil { + return err + } + } else { + if _, err = io.WriteString(w, "NULL"); err != nil { + return err + } } - - field := dataStruct.FieldByIndex(col.FieldIndex) - temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, field.Interface(), col) } - _, err = io.WriteString(w, temp[1:]+");\n") - if err != nil { - return err + if i < len(scanResults)-1 { + if _, err = io.WriteString(w, ","); err != nil { + return err + } } } - } else { - for rows.Next() { - dest := make([]interface{}, len(cols)) - err = rows.ScanSlice(&dest) - if err != nil { - return err - } - - _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") - if err != nil { - return err - } - - var temp string - for i, d := range dest { - col := table.GetColumn(cols[i]) - if col == nil { - return errors.New("unknow column error") - } - - temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col) - } - _, err = io.WriteString(w, temp[1:]+");\n") - if err != nil { - return err - } + _, err = io.WriteString(w, ");\n") + if err != nil { + return err } } diff --git a/integrations/engine_test.go b/integrations/engine_test.go index a06d91aa..a594ee46 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -172,8 +172,21 @@ func TestDumpTables(t *testing.T) { name := fmt.Sprintf("dump_%v-table.sql", tp) t.Run(name, func(t *testing.T) { assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp)) + }) } + + assert.NoError(t, testEngine.DropTables(new(TestDumpTableStruct))) + + importPath := fmt.Sprintf("dump_%v-table.sql", testEngine.Dialect().URI().DBType) + t.Run("import_"+importPath, func(t *testing.T) { + sess := testEngine.NewSession() + defer sess.Close() + assert.NoError(t, sess.Begin()) + _, err = sess.ImportFile(importPath) + assert.NoError(t, err) + assert.NoError(t, sess.Commit()) + }) } func TestDumpTables2(t *testing.T) { diff --git a/schemas/type.go b/schemas/type.go index f49348be..62e66c2e 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -39,6 +39,7 @@ const ( TIME_TYPE NUMERIC_TYPE ARRAY_TYPE + BOOL_TYPE ) // IsType reutrns ture if the column type is the same as the parameter @@ -64,6 +65,10 @@ func (s *SQLType) IsTime() bool { return s.IsType(TIME_TYPE) } +func (s *SQLType) IsBool() bool { + return s.IsType(BOOL_TYPE) +} + // IsNumeric returns true if column is a numeric type func (s *SQLType) IsNumeric() bool { return s.IsType(NUMERIC_TYPE) @@ -209,7 +214,8 @@ var ( Bytea: BLOB_TYPE, UniqueIdentifier: BLOB_TYPE, - Bool: NUMERIC_TYPE, + Bool: BOOL_TYPE, + Boolean: BOOL_TYPE, Serial: NUMERIC_TYPE, BigSerial: NUMERIC_TYPE,