improve scan values

This commit is contained in:
Lunny Xiao 2021-06-11 23:48:58 +08:00
parent 7b84aa150b
commit 1139445b2e
4 changed files with 105 additions and 67 deletions

View File

@ -10,29 +10,55 @@ import (
"strconv"
)
// ConvertAssignString converts an interface to string
func ConvertAssignString(v interface{}) (string, error) {
if v == nil {
return "", nil
}
switch vv := v.(type) {
case *int64:
return strconv.FormatInt(*vv, 10), nil
case *int8:
return strconv.FormatInt(int64(*vv), 10), nil
case *sql.NullString:
if vv.Valid {
return vv.String, nil
}
return "", nil
case *int64:
if vv != nil {
return strconv.FormatInt(*vv, 10), nil
}
return "", nil
case *int8:
if vv != nil {
return strconv.FormatInt(int64(*vv), 10), nil
}
return "", nil
case *sql.RawBytes:
if vv != nil && len([]byte(*vv)) > 0 {
if len([]byte(*vv)) > 0 {
return string(*vv), nil
}
return "", nil
case *sql.NullInt32:
if vv.Valid {
return fmt.Sprintf("%d", vv.Int32), nil
}
return "", nil
case *sql.NullInt64:
if vv.Valid {
return fmt.Sprintf("%d", vv.Int64), nil
}
return "", nil
case *sql.NullFloat64:
if vv.Valid {
return fmt.Sprintf("%g", vv.Float64), nil
}
return "", nil
case *sql.NullBool:
if vv.Valid {
if vv.Bool {
return "true", nil
}
return "false", nil
}
return "", nil
case *sql.NullTime:
if vv.Valid {
return vv.Time.Format("2006-01-02 15:04:05"), nil
}
return "", nil
default:
return "", fmt.Errorf("unsupported type: %#v", vv)
return "", fmt.Errorf("convert assign string unsupported type: %#v", vv)
}
}

View File

@ -446,6 +446,31 @@ func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d
return "NULL"
}
switch t := d.(type) {
case *sql.NullInt64:
if t.Valid {
return fmt.Sprintf("%d", t.Int64)
}
return "NULL"
case *sql.NullTime:
if t.Valid {
return fmt.Sprintf("'%s'", t.Time.Format("2006-1-02 15:04:05"))
}
return "NULL"
case *sql.NullString:
if t.Valid {
return "'" + strings.Replace(t.String, "'", "''", -1) + "'"
}
return "NULL"
case *sql.NullInt32:
if t.Valid {
return fmt.Sprintf("%d", t.Int32)
}
return "NULL"
}
fmt.Printf("%#v------%v\n", d, col.Name)
if dq, ok := d.(bool); ok && (dstDialect.URI().DBType == schemas.SQLITE ||
dstDialect.URI().DBType == schemas.MSSQL) {
if dq {
@ -604,6 +629,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
cols := table.ColumnsSeq()
dstCols := dstTable.ColumnsSeq()
dstColumns := dstTable.Columns()
colNames := engine.dialect.Quoter().Join(cols, ", ")
destColNames := dstDialect.Quoter().Join(dstCols, ", ")
@ -614,37 +640,29 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
}
defer rows.Close()
if table.Type != nil {
sess := engine.NewSession()
defer sess.Close()
for rows.Next() {
beanValue := reflect.New(table.Type)
bean := beanValue.Interface()
fields, err := rows.Columns()
if err != nil {
return err
}
scanResults, err := sess.row2Slice(rows, fields, bean)
if err != nil {
return err
}
coltypes, err := rows.ColumnTypes()
if err != nil {
return err
}
for rows.Next() {
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
if err != nil {
return err
}
dataStruct := utils.ReflectValue(bean)
_, err = sess.slice2Bean(scanResults, fields, bean, &dataStruct, table)
if err != nil {
row, err := engine.scanInterfaceResults(rows, coltypes, dstCols)
if err != nil {
return err
}
for i, cell := range row {
s := engine.formatColumnValue(dstDialect, cell, dstColumns[i])
if _, err = io.WriteString(w, s); err != nil {
return err
}
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
if err != nil {
return err
}
var temp string
for _, d := range dstCols {
col := table.GetColumn(d)
if col == nil {
return errors.New("unknown column error")
if i < len(row)-1 {
if _, err = io.WriteString(w, ","); err != nil {
return err
}
field := dataStruct.FieldByIndex(col.FieldIndex)
@ -655,25 +673,6 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return err
}
}
} else {
for rows.Next() {
dest := make([]interface{}, len(cols))
err = rows.ScanSlice(&dest)
if err != nil {
return err
}
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
if err != nil {
return err
}
var temp string
for i, d := range dest {
col := table.GetColumn(cols[i])
if col == nil {
return errors.New("unknow column error")
}
temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col)
}

View File

@ -5,6 +5,7 @@
package integrations
import (
"database/sql"
"fmt"
"strconv"
"testing"
@ -72,30 +73,36 @@ func TestQueryString2(t *testing.T) {
}
func toString(i interface{}) string {
switch i.(type) {
switch t := i.(type) {
case []byte:
return string(i.([]byte))
return string(t)
case string:
return i.(string)
return t
}
return fmt.Sprintf("%v", i)
}
func toInt64(i interface{}) int64 {
switch i.(type) {
switch t := i.(type) {
case []byte:
n, _ := strconv.ParseInt(string(i.([]byte)), 10, 64)
return n
case int:
return int64(i.(int))
return int64(t)
case int32:
return int64(t)
case int64:
return i.(int64)
return t
case *sql.NullInt64:
return t.Int64
case *sql.NullInt32:
return int64(t.Int32)
}
return 0
}
func toFloat64(i interface{}) float64 {
switch i.(type) {
switch t := i.(type) {
case []byte:
n, _ := strconv.ParseFloat(string(i.([]byte)), 64)
return n
@ -103,6 +110,12 @@ func toFloat64(i interface{}) float64 {
return i.(float64)
case float32:
return float64(i.(float32))
case *sql.NullInt32:
return float64(t.Int32)
case *sql.NullInt64:
return float64(t.Int64)
case *sql.NullFloat64:
return t.Float64
}
return 0
}

View File

@ -95,7 +95,7 @@ func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]stri
}
defer rows.Close()
return session.rows2Strings(rows)
return session.engine.rows2Strings(rows)
}
// QuerySliceString runs a raw sql and return records as [][]string