improve scan values

This commit is contained in:
Lunny Xiao 2021-06-11 23:48:58 +08:00
parent 7b84aa150b
commit 1139445b2e
4 changed files with 105 additions and 67 deletions

View File

@ -10,29 +10,55 @@ import (
"strconv" "strconv"
) )
// ConvertAssignString converts an interface to string
func ConvertAssignString(v interface{}) (string, error) { func ConvertAssignString(v interface{}) (string, error) {
if v == nil {
return "", nil
}
switch vv := v.(type) { switch vv := v.(type) {
case *int64:
return strconv.FormatInt(*vv, 10), nil
case *int8:
return strconv.FormatInt(int64(*vv), 10), nil
case *sql.NullString: case *sql.NullString:
if vv.Valid { if vv.Valid {
return vv.String, nil return vv.String, nil
} }
return "", nil return "", nil
case *int64:
if vv != nil {
return strconv.FormatInt(*vv, 10), nil
}
return "", nil
case *int8:
if vv != nil {
return strconv.FormatInt(int64(*vv), 10), nil
}
return "", nil
case *sql.RawBytes: case *sql.RawBytes:
if vv != nil && len([]byte(*vv)) > 0 { if len([]byte(*vv)) > 0 {
return string(*vv), nil return string(*vv), nil
} }
return "", nil return "", nil
case *sql.NullInt32:
if vv.Valid {
return fmt.Sprintf("%d", vv.Int32), nil
}
return "", nil
case *sql.NullInt64:
if vv.Valid {
return fmt.Sprintf("%d", vv.Int64), nil
}
return "", nil
case *sql.NullFloat64:
if vv.Valid {
return fmt.Sprintf("%g", vv.Float64), nil
}
return "", nil
case *sql.NullBool:
if vv.Valid {
if vv.Bool {
return "true", nil
}
return "false", nil
}
return "", nil
case *sql.NullTime:
if vv.Valid {
return vv.Time.Format("2006-01-02 15:04:05"), nil
}
return "", nil
default: default:
return "", fmt.Errorf("unsupported type: %#v", vv) return "", fmt.Errorf("convert assign string unsupported type: %#v", vv)
} }
} }

View File

@ -446,6 +446,31 @@ func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d
return "NULL" return "NULL"
} }
switch t := d.(type) {
case *sql.NullInt64:
if t.Valid {
return fmt.Sprintf("%d", t.Int64)
}
return "NULL"
case *sql.NullTime:
if t.Valid {
return fmt.Sprintf("'%s'", t.Time.Format("2006-1-02 15:04:05"))
}
return "NULL"
case *sql.NullString:
if t.Valid {
return "'" + strings.Replace(t.String, "'", "''", -1) + "'"
}
return "NULL"
case *sql.NullInt32:
if t.Valid {
return fmt.Sprintf("%d", t.Int32)
}
return "NULL"
}
fmt.Printf("%#v------%v\n", d, col.Name)
if dq, ok := d.(bool); ok && (dstDialect.URI().DBType == schemas.SQLITE || if dq, ok := d.(bool); ok && (dstDialect.URI().DBType == schemas.SQLITE ||
dstDialect.URI().DBType == schemas.MSSQL) { dstDialect.URI().DBType == schemas.MSSQL) {
if dq { if dq {
@ -604,6 +629,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
cols := table.ColumnsSeq() cols := table.ColumnsSeq()
dstCols := dstTable.ColumnsSeq() dstCols := dstTable.ColumnsSeq()
dstColumns := dstTable.Columns()
colNames := engine.dialect.Quoter().Join(cols, ", ") colNames := engine.dialect.Quoter().Join(cols, ", ")
destColNames := dstDialect.Quoter().Join(dstCols, ", ") destColNames := dstDialect.Quoter().Join(dstCols, ", ")
@ -614,37 +640,29 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
} }
defer rows.Close() defer rows.Close()
if table.Type != nil { coltypes, err := rows.ColumnTypes()
sess := engine.NewSession() if err != nil {
defer sess.Close() return err
for rows.Next() { }
beanValue := reflect.New(table.Type) for rows.Next() {
bean := beanValue.Interface() _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
fields, err := rows.Columns() if err != nil {
if err != nil { return err
return err }
}
scanResults, err := sess.row2Slice(rows, fields, bean)
if err != nil {
return err
}
dataStruct := utils.ReflectValue(bean) row, err := engine.scanInterfaceResults(rows, coltypes, dstCols)
_, err = sess.slice2Bean(scanResults, fields, bean, &dataStruct, table) if err != nil {
if err != nil { return err
}
for i, cell := range row {
s := engine.formatColumnValue(dstDialect, cell, dstColumns[i])
if _, err = io.WriteString(w, s); err != nil {
return err return err
} }
if i < len(row)-1 {
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") if _, err = io.WriteString(w, ","); err != nil {
if err != nil { return err
return err
}
var temp string
for _, d := range dstCols {
col := table.GetColumn(d)
if col == nil {
return errors.New("unknown column error")
} }
field := dataStruct.FieldByIndex(col.FieldIndex) field := dataStruct.FieldByIndex(col.FieldIndex)
@ -655,25 +673,6 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return err return err
} }
} }
} else {
for rows.Next() {
dest := make([]interface{}, len(cols))
err = rows.ScanSlice(&dest)
if err != nil {
return err
}
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
if err != nil {
return err
}
var temp string
for i, d := range dest {
col := table.GetColumn(cols[i])
if col == nil {
return errors.New("unknow column error")
}
temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col) temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col)
} }

View File

@ -5,6 +5,7 @@
package integrations package integrations
import ( import (
"database/sql"
"fmt" "fmt"
"strconv" "strconv"
"testing" "testing"
@ -72,30 +73,36 @@ func TestQueryString2(t *testing.T) {
} }
func toString(i interface{}) string { func toString(i interface{}) string {
switch i.(type) { switch t := i.(type) {
case []byte: case []byte:
return string(i.([]byte)) return string(t)
case string: case string:
return i.(string) return t
} }
return fmt.Sprintf("%v", i) return fmt.Sprintf("%v", i)
} }
func toInt64(i interface{}) int64 { func toInt64(i interface{}) int64 {
switch i.(type) { switch t := i.(type) {
case []byte: case []byte:
n, _ := strconv.ParseInt(string(i.([]byte)), 10, 64) n, _ := strconv.ParseInt(string(i.([]byte)), 10, 64)
return n return n
case int: case int:
return int64(i.(int)) return int64(t)
case int32:
return int64(t)
case int64: case int64:
return i.(int64) return t
case *sql.NullInt64:
return t.Int64
case *sql.NullInt32:
return int64(t.Int32)
} }
return 0 return 0
} }
func toFloat64(i interface{}) float64 { func toFloat64(i interface{}) float64 {
switch i.(type) { switch t := i.(type) {
case []byte: case []byte:
n, _ := strconv.ParseFloat(string(i.([]byte)), 64) n, _ := strconv.ParseFloat(string(i.([]byte)), 64)
return n return n
@ -103,6 +110,12 @@ func toFloat64(i interface{}) float64 {
return i.(float64) return i.(float64)
case float32: case float32:
return float64(i.(float32)) return float64(i.(float32))
case *sql.NullInt32:
return float64(t.Int32)
case *sql.NullInt64:
return float64(t.Int64)
case *sql.NullFloat64:
return t.Float64
} }
return 0 return 0
} }

View File

@ -95,7 +95,7 @@ func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]stri
} }
defer rows.Close() defer rows.Close()
return session.rows2Strings(rows) return session.engine.rows2Strings(rows)
} }
// QuerySliceString runs a raw sql and return records as [][]string // QuerySliceString runs a raw sql and return records as [][]string