From 30dcce510d75def4fa156af1a8cf845c2d32af19 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 10 Apr 2017 23:10:59 +0800 Subject: [PATCH] go test add mysql and postgres drivers --- session.go | 1 - session_get_test.go | 67 +++++++++++++++++++++++++++++++++++-- statement_test.go | 4 +++ types_test.go | 2 +- xorm_test.go | 81 +++++++++++++++++++++++---------------------- 5 files changed, 110 insertions(+), 45 deletions(-) diff --git a/session.go b/session.go index 60ad9f0d..d7694dc7 100644 --- a/session.go +++ b/session.go @@ -43,7 +43,6 @@ type Session struct { prepareStmt bool stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) - cascadeDeep int // !evalphobia! stored the last executed query on this session //beforeSQLExec func(string, ...interface{}) diff --git a/session_get_test.go b/session_get_test.go index 4c25dbd4..a41fbccb 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -6,6 +6,7 @@ package xorm import ( "fmt" + "strconv" "testing" "time" @@ -84,8 +85,68 @@ func TestGetVar(t *testing.T) { has, err = testEngine.Table("get_var").Get(&valuesSliceInter) assert.NoError(t, err) assert.Equal(t, true, has) - assert.EqualValues(t, 1, valuesSliceInter[0]) + + v1, err := convertInt(valuesSliceInter[0]) + assert.NoError(t, err) + assert.EqualValues(t, 1, v1) + assert.Equal(t, "hi", fmt.Sprintf("%s", valuesSliceInter[1])) - assert.EqualValues(t, 28, valuesSliceInter[2]) - assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesSliceInter[3])) + + v3, err := convertInt(valuesSliceInter[2]) + assert.NoError(t, err) + assert.EqualValues(t, 28, v3) + + v4, err := convertFloat(valuesSliceInter[3]) + assert.NoError(t, err) + assert.Equal(t, "1.5", fmt.Sprintf("%v", v4)) +} + +func convertFloat(v interface{}) (float64, error) { + switch v.(type) { + case float32: + return float64(v.(float32)), nil + case float64: + return v.(float64), nil + case string: + i, err := strconv.ParseFloat(v.(string), 64) + if err != nil { + return 0, err + } + return i, nil + case []byte: + i, err := strconv.ParseFloat(string(v.([]byte)), 64) + if err != nil { + return 0, err + } + return i, nil + } + return 0, fmt.Errorf("unsupported type: %v", v) +} + +func convertInt(v interface{}) (int64, error) { + switch v.(type) { + case int: + return int64(v.(int)), nil + case int8: + return int64(v.(int8)), nil + case int16: + return int64(v.(int16)), nil + case int32: + return int64(v.(int32)), nil + case int64: + return v.(int64), nil + case []byte: + i, err := strconv.ParseInt(string(v.([]byte)), 10, 64) + if err != nil { + return 0, err + } + return i, nil + case string: + i, err := strconv.ParseInt(v.(string), 10, 64) + if err != nil { + return 0, err + } + return i, nil + } + return 0, fmt.Errorf("unsupported type: %v", v) } diff --git a/statement_test.go b/statement_test.go index cb3730ef..01a09afc 100644 --- a/statement_test.go +++ b/statement_test.go @@ -26,6 +26,10 @@ var colStrTests = []struct { } func TestColumnsStringGeneration(t *testing.T) { + if *db == "postgres" { + return + } + var statement *Statement for ndx, testCase := range colStrTests { diff --git a/types_test.go b/types_test.go index 47c59de7..fae2f0c0 100644 --- a/types_test.go +++ b/types_test.go @@ -15,7 +15,7 @@ func TestArrayField(t *testing.T) { type ArrayStruct struct { Id int64 - Name [20]byte `xorm:"char(20)"` + Name [20]byte `xorm:"char(80)"` } assert.NoError(t, testEngine.Sync2(new(ArrayStruct))) diff --git a/xorm_test.go b/xorm_test.go index b124deec..98b42b66 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -1,72 +1,73 @@ package xorm import ( - "errors" "flag" + "fmt" "os" "testing" + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) var ( testEngine *Engine - dbType string - connStr string + connString string + + db = flag.String("db", "sqlite3", "the tested database") + showSQL = flag.Bool("show_sql", true, "show generated SQLs") + ptrConnStr = flag.String("conn_str", "", "test database connection string") + mapType = flag.String("map_type", "snake", "indicate the name mapping") + cache = flag.Bool("cache", false, "if enable cache") ) -func prepareSqlite3Engine() error { - //if testEngine == nil { - os.Remove("./test.db") - var err error - testEngine, err = NewEngine("sqlite3", "./test.db") +func createEngine(dbType, connStr string) error { + if testEngine == nil { + var err error + testEngine, err = NewEngine(dbType, connStr) + if err != nil { + return err + } + + testEngine.ShowSQL(*showSQL) + } + + tables, err := testEngine.DBMetas() if err != nil { return err } - testEngine.ShowSQL(*showSQL) - //} - return nil -} - -func prepareMysqlEngine() error { - if testEngine == nil { - var err error - testEngine, err = NewEngine("mysql", connStr) - if err != nil { - return err - } - testEngine.ShowSQL(*showSQL) - _, err = testEngine.Exec("DROP DATABASE") - if err != nil { - return err - } + var tableNames = make([]interface{}, 0, len(tables)) + for _, table := range tables { + tableNames = append(tableNames, table.Name) } - return nil + return testEngine.DropTables(tableNames...) } func prepareEngine() error { - if dbType == "sqlite" { - return prepareSqlite3Engine() - } else if dbType == "mysql" { - return prepareMysqlEngine() - } - return errors.New("Unknown test database driver") + return createEngine(*db, connString) } -var ( - db = flag.String("db", "sqlite", "the tested database") - showSQL = flag.Bool("show_sql", true, "show generated SQLs") -) - func TestMain(m *testing.M) { flag.Parse() - if db != nil { - dbType = *db + if *db == "sqlite3" { + if ptrConnStr == nil { + connString = "./test.db" + } else { + connString = *ptrConnStr + } + } else { + if ptrConnStr == nil { + fmt.Println("you should indicate conn string") + return + } + connString = *ptrConnStr } if err := prepareEngine(); err != nil { - panic(err) + fmt.Println(err) + return } os.Exit(m.Run()) }