diff --git a/convert/string.go b/convert/string.go index a9d9ee98..193052d5 100644 --- a/convert/string.go +++ b/convert/string.go @@ -10,29 +10,55 @@ import ( "strconv" ) +// ConvertAssignString converts an interface to string func ConvertAssignString(v interface{}) (string, error) { + if v == nil { + return "", nil + } switch vv := v.(type) { + case *int64: + return strconv.FormatInt(*vv, 10), nil + case *int8: + return strconv.FormatInt(int64(*vv), 10), nil case *sql.NullString: if vv.Valid { return vv.String, nil } return "", nil - case *int64: - if vv != nil { - return strconv.FormatInt(*vv, 10), nil - } - return "", nil - case *int8: - if vv != nil { - return strconv.FormatInt(int64(*vv), 10), nil - } - return "", nil case *sql.RawBytes: - if vv != nil && len([]byte(*vv)) > 0 { + if len([]byte(*vv)) > 0 { return string(*vv), nil } return "", nil + case *sql.NullInt32: + if vv.Valid { + return fmt.Sprintf("%d", vv.Int32), nil + } + return "", nil + case *sql.NullInt64: + if vv.Valid { + return fmt.Sprintf("%d", vv.Int64), nil + } + return "", nil + case *sql.NullFloat64: + if vv.Valid { + return fmt.Sprintf("%g", vv.Float64), nil + } + return "", nil + case *sql.NullBool: + if vv.Valid { + if vv.Bool { + return "true", nil + } + return "false", nil + } + return "", nil + case *sql.NullTime: + if vv.Valid { + return vv.Time.Format("2006-01-02 15:04:05"), nil + } + return "", nil default: - return "", fmt.Errorf("unsupported type: %#v", vv) + return "", fmt.Errorf("convert assign string unsupported type: %#v", vv) } } diff --git a/engine.go b/engine.go index c40eb042..fa6f0da5 100644 --- a/engine.go +++ b/engine.go @@ -446,6 +446,31 @@ func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d return "NULL" } + switch t := d.(type) { + case *sql.NullInt64: + if t.Valid { + return fmt.Sprintf("%d", t.Int64) + } + return "NULL" + case *sql.NullTime: + if t.Valid { + return fmt.Sprintf("'%s'", t.Time.Format("2006-1-02 15:04:05")) + } + return "NULL" + case *sql.NullString: + if t.Valid { + return "'" + strings.Replace(t.String, "'", "''", -1) + "'" + } + return "NULL" + case *sql.NullInt32: + if t.Valid { + return fmt.Sprintf("%d", t.Int32) + } + return "NULL" + } + + fmt.Printf("%#v------%v\n", d, col.Name) + if dq, ok := d.(bool); ok && (dstDialect.URI().DBType == schemas.SQLITE || dstDialect.URI().DBType == schemas.MSSQL) { if dq { @@ -604,6 +629,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch cols := table.ColumnsSeq() dstCols := dstTable.ColumnsSeq() + dstColumns := dstTable.Columns() colNames := engine.dialect.Quoter().Join(cols, ", ") destColNames := dstDialect.Quoter().Join(dstCols, ", ") @@ -614,37 +640,29 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } defer rows.Close() - if table.Type != nil { - sess := engine.NewSession() - defer sess.Close() - for rows.Next() { - beanValue := reflect.New(table.Type) - bean := beanValue.Interface() - fields, err := rows.Columns() - if err != nil { - return err - } - scanResults, err := sess.row2Slice(rows, fields, bean) - if err != nil { - return err - } + coltypes, err := rows.ColumnTypes() + if err != nil { + return err + } + for rows.Next() { + _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") + if err != nil { + return err + } - dataStruct := utils.ReflectValue(bean) - _, err = sess.slice2Bean(scanResults, fields, bean, &dataStruct, table) - if err != nil { + row, err := engine.scanInterfaceResults(rows, coltypes, dstCols) + if err != nil { + return err + } + + for i, cell := range row { + s := engine.formatColumnValue(dstDialect, cell, dstColumns[i]) + if _, err = io.WriteString(w, s); err != nil { return err } - - _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") - if err != nil { - return err - } - - var temp string - for _, d := range dstCols { - col := table.GetColumn(d) - if col == nil { - return errors.New("unknown column error") + if i < len(row)-1 { + if _, err = io.WriteString(w, ","); err != nil { + return err } field := dataStruct.FieldByIndex(col.FieldIndex) @@ -655,25 +673,6 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } } - } else { - for rows.Next() { - dest := make([]interface{}, len(cols)) - err = rows.ScanSlice(&dest) - if err != nil { - return err - } - - _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") - if err != nil { - return err - } - - var temp string - for i, d := range dest { - col := table.GetColumn(cols[i]) - if col == nil { - return errors.New("unknow column error") - } temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col) } diff --git a/integrations/session_query_test.go b/integrations/session_query_test.go index 1e0b0dc6..831091d5 100644 --- a/integrations/session_query_test.go +++ b/integrations/session_query_test.go @@ -5,6 +5,7 @@ package integrations import ( + "database/sql" "fmt" "strconv" "testing" @@ -72,30 +73,36 @@ func TestQueryString2(t *testing.T) { } func toString(i interface{}) string { - switch i.(type) { + switch t := i.(type) { case []byte: - return string(i.([]byte)) + return string(t) case string: - return i.(string) + return t } return fmt.Sprintf("%v", i) } func toInt64(i interface{}) int64 { - switch i.(type) { + switch t := i.(type) { case []byte: n, _ := strconv.ParseInt(string(i.([]byte)), 10, 64) return n case int: - return int64(i.(int)) + return int64(t) + case int32: + return int64(t) case int64: - return i.(int64) + return t + case *sql.NullInt64: + return t.Int64 + case *sql.NullInt32: + return int64(t.Int32) } return 0 } func toFloat64(i interface{}) float64 { - switch i.(type) { + switch t := i.(type) { case []byte: n, _ := strconv.ParseFloat(string(i.([]byte)), 64) return n @@ -103,6 +110,12 @@ func toFloat64(i interface{}) float64 { return i.(float64) case float32: return float64(i.(float32)) + case *sql.NullInt32: + return float64(t.Int32) + case *sql.NullInt64: + return float64(t.Int64) + case *sql.NullFloat64: + return t.Float64 } return 0 } diff --git a/session_query.go b/session_query.go index 9594da25..c24b4e64 100644 --- a/session_query.go +++ b/session_query.go @@ -95,7 +95,7 @@ func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]stri } defer rows.Close() - return session.rows2Strings(rows) + return session.engine.rows2Strings(rows) } // QuerySliceString runs a raw sql and return records as [][]string