From 86cf0e8c3cff89a7f5044a3a995526577f3b66ec Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 23 Mar 2017 14:05:32 +0800 Subject: [PATCH] add test framework --- .gitignore | 3 +- goracle_driver.go | 42 --------------- mssql_dialect.go | 23 +++++++++ mymysql_driver.go | 65 ----------------------- mysql_dialect.go | 92 +++++++++++++++++++++++++++++++++ mysql_driver.go | 50 ------------------ oci8_driver.go | 37 -------------- odbc_driver.go | 34 ------------- oracle_dialect.go | 53 +++++++++++++++++++ postgres_dialect.go | 107 ++++++++++++++++++++++++++++++++++++++ pq_driver.go | 119 ------------------------------------------- session_cols_test.go | 16 ++++++ sqlite3_dialect.go | 7 +++ sqlite3_driver.go | 20 -------- statement_test.go | 54 +++----------------- xorm_test.go | 59 +++++++++++++++++++++ 16 files changed, 367 insertions(+), 414 deletions(-) delete mode 100644 goracle_driver.go delete mode 100644 mymysql_driver.go delete mode 100644 mysql_driver.go delete mode 100644 oci8_driver.go delete mode 100644 odbc_driver.go delete mode 100644 pq_driver.go create mode 100644 session_cols_test.go delete mode 100644 sqlite3_driver.go create mode 100644 xorm_test.go diff --git a/.gitignore b/.gitignore index c7f7dc2b..fa31dd78 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,5 @@ vendor *.log .vendor temp_test.go -.vscode \ No newline at end of file +.vscode +xorm.test \ No newline at end of file diff --git a/goracle_driver.go b/goracle_driver.go deleted file mode 100644 index 9fcde48f..00000000 --- a/goracle_driver.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "regexp" - - "github.com/go-xorm/core" -) - -// func init() { -// core.RegisterDriver("goracle", &goracleDriver{}) -// } - -type goracleDriver struct { -} - -func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - db := &core.Uri{DbType: core.ORACLE} - dsnPattern := regexp.MustCompile( - `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] - `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] - `\/(?P.*?)` + // /dbname - `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] - matches := dsnPattern.FindStringSubmatch(dataSourceName) - //tlsConfigRegister := make(map[string]*tls.Config) - names := dsnPattern.SubexpNames() - - for i, match := range matches { - switch names[i] { - case "dbname": - db.DbName = match - } - } - if db.DbName == "" { - return nil, errors.New("dbname is empty") - } - return db, nil -} diff --git a/mssql_dialect.go b/mssql_dialect.go index e9bda1fd..70fcaf6e 100644 --- a/mssql_dialect.go +++ b/mssql_dialect.go @@ -5,6 +5,7 @@ package xorm import ( + "errors" "fmt" "strconv" "strings" @@ -526,3 +527,25 @@ func (db *mssql) ForUpdateSql(query string) string { func (db *mssql) Filters() []core.Filter { return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}} } + +type odbcDriver struct { +} + +func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { + kv := strings.Split(dataSourceName, ";") + var dbName string + + for _, c := range kv { + vv := strings.Split(strings.TrimSpace(c), "=") + if len(vv) == 2 { + switch strings.ToLower(vv[0]) { + case "database": + dbName = vv[1] + } + } + } + if dbName == "" { + return nil, errors.New("no db name provided") + } + return &core.Uri{DbName: dbName, DbType: core.MSSQL}, nil +} diff --git a/mymysql_driver.go b/mymysql_driver.go deleted file mode 100644 index ef3086a4..00000000 --- a/mymysql_driver.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "strings" - "time" - - "github.com/go-xorm/core" -) - -type mymysqlDriver struct { -} - -func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - db := &core.Uri{DbType: core.MYSQL} - - pd := strings.SplitN(dataSourceName, "*", 2) - if len(pd) == 2 { - // Parse protocol part of URI - p := strings.SplitN(pd[0], ":", 2) - if len(p) != 2 { - return nil, errors.New("Wrong protocol part of URI") - } - db.Proto = p[0] - options := strings.Split(p[1], ",") - db.Raddr = options[0] - for _, o := range options[1:] { - kv := strings.SplitN(o, "=", 2) - var k, v string - if len(kv) == 2 { - k, v = kv[0], kv[1] - } else { - k, v = o, "true" - } - switch k { - case "laddr": - db.Laddr = v - case "timeout": - to, err := time.ParseDuration(v) - if err != nil { - return nil, err - } - db.Timeout = to - default: - return nil, errors.New("Unknown option: " + k) - } - } - // Remove protocol part - pd = pd[1:] - } - // Parse database part of URI - dup := strings.SplitN(pd[0], "/", 3) - if len(dup) != 3 { - return nil, errors.New("Wrong database part of URI") - } - db.DbName = dup[0] - db.User = dup[1] - db.Passwd = dup[2] - - return db, nil -} diff --git a/mysql_dialect.go b/mysql_dialect.go index ab756f35..55cfdd76 100644 --- a/mysql_dialect.go +++ b/mysql_dialect.go @@ -6,7 +6,9 @@ package xorm import ( "crypto/tls" + "errors" "fmt" + "regexp" "strconv" "strings" "time" @@ -486,3 +488,93 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *mysql) Filters() []core.Filter { return []core.Filter{&core.IdFilter{}} } + +type mymysqlDriver struct { +} + +func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { + db := &core.Uri{DbType: core.MYSQL} + + pd := strings.SplitN(dataSourceName, "*", 2) + if len(pd) == 2 { + // Parse protocol part of URI + p := strings.SplitN(pd[0], ":", 2) + if len(p) != 2 { + return nil, errors.New("Wrong protocol part of URI") + } + db.Proto = p[0] + options := strings.Split(p[1], ",") + db.Raddr = options[0] + for _, o := range options[1:] { + kv := strings.SplitN(o, "=", 2) + var k, v string + if len(kv) == 2 { + k, v = kv[0], kv[1] + } else { + k, v = o, "true" + } + switch k { + case "laddr": + db.Laddr = v + case "timeout": + to, err := time.ParseDuration(v) + if err != nil { + return nil, err + } + db.Timeout = to + default: + return nil, errors.New("Unknown option: " + k) + } + } + // Remove protocol part + pd = pd[1:] + } + // Parse database part of URI + dup := strings.SplitN(pd[0], "/", 3) + if len(dup) != 3 { + return nil, errors.New("Wrong database part of URI") + } + db.DbName = dup[0] + db.User = dup[1] + db.Passwd = dup[2] + + return db, nil +} + +type mysqlDriver struct { +} + +func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { + dsnPattern := regexp.MustCompile( + `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] + `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] + `\/(?P.*?)` + // /dbname + `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] + matches := dsnPattern.FindStringSubmatch(dataSourceName) + //tlsConfigRegister := make(map[string]*tls.Config) + names := dsnPattern.SubexpNames() + + uri := &core.Uri{DbType: core.MYSQL} + + for i, match := range matches { + switch names[i] { + case "dbname": + uri.DbName = match + case "params": + if len(match) > 0 { + kvs := strings.Split(match, "&") + for _, kv := range kvs { + splits := strings.Split(kv, "=") + if len(splits) == 2 { + switch splits[0] { + case "charset": + uri.Charset = splits[1] + } + } + } + } + + } + } + return uri, nil +} diff --git a/mysql_driver.go b/mysql_driver.go deleted file mode 100644 index 6ceeed58..00000000 --- a/mysql_driver.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "regexp" - "strings" - - "github.com/go-xorm/core" -) - -type mysqlDriver struct { -} - -func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - dsnPattern := regexp.MustCompile( - `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] - `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] - `\/(?P.*?)` + // /dbname - `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] - matches := dsnPattern.FindStringSubmatch(dataSourceName) - //tlsConfigRegister := make(map[string]*tls.Config) - names := dsnPattern.SubexpNames() - - uri := &core.Uri{DbType: core.MYSQL} - - for i, match := range matches { - switch names[i] { - case "dbname": - uri.DbName = match - case "params": - if len(match) > 0 { - kvs := strings.Split(match, "&") - for _, kv := range kvs { - splits := strings.Split(kv, "=") - if len(splits) == 2 { - switch splits[0] { - case "charset": - uri.Charset = splits[1] - } - } - } - } - - } - } - return uri, nil -} diff --git a/oci8_driver.go b/oci8_driver.go deleted file mode 100644 index ec5f2022..00000000 --- a/oci8_driver.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "regexp" - - "github.com/go-xorm/core" -) - -type oci8Driver struct { -} - -//dataSourceName=user/password@ipv4:port/dbname -//dataSourceName=user/password@[ipv6]:port/dbname -func (p *oci8Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - db := &core.Uri{DbType: core.ORACLE} - dsnPattern := regexp.MustCompile( - `^(?P.*)\/(?P.*)@` + // user:password@ - `(?P.*)` + // ip:port - `\/(?P.*)`) // dbname - matches := dsnPattern.FindStringSubmatch(dataSourceName) - names := dsnPattern.SubexpNames() - for i, match := range matches { - switch names[i] { - case "dbname": - db.DbName = match - } - } - if db.DbName == "" { - return nil, errors.New("dbname is empty") - } - return db, nil -} diff --git a/odbc_driver.go b/odbc_driver.go deleted file mode 100644 index 6770de60..00000000 --- a/odbc_driver.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "strings" - - "github.com/go-xorm/core" -) - -type odbcDriver struct { -} - -func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - kv := strings.Split(dataSourceName, ";") - var dbName string - - for _, c := range kv { - vv := strings.Split(strings.TrimSpace(c), "=") - if len(vv) == 2 { - switch strings.ToLower(vv[0]) { - case "database": - dbName = vv[1] - } - } - } - if dbName == "" { - return nil, errors.New("no db name provided") - } - return &core.Uri{DbName: dbName, DbType: core.MSSQL}, nil -} diff --git a/oracle_dialect.go b/oracle_dialect.go index b19ea38b..8c43aa4c 100644 --- a/oracle_dialect.go +++ b/oracle_dialect.go @@ -5,7 +5,9 @@ package xorm import ( + "errors" "fmt" + "regexp" "strconv" "strings" @@ -844,3 +846,54 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *oracle) Filters() []core.Filter { return []core.Filter{&core.QuoteFilter{}, &core.SeqFilter{Prefix: ":", Start: 1}, &core.IdFilter{}} } + +type goracleDriver struct { +} + +func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { + db := &core.Uri{DbType: core.ORACLE} + dsnPattern := regexp.MustCompile( + `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] + `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] + `\/(?P.*?)` + // /dbname + `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] + matches := dsnPattern.FindStringSubmatch(dataSourceName) + //tlsConfigRegister := make(map[string]*tls.Config) + names := dsnPattern.SubexpNames() + + for i, match := range matches { + switch names[i] { + case "dbname": + db.DbName = match + } + } + if db.DbName == "" { + return nil, errors.New("dbname is empty") + } + return db, nil +} + +type oci8Driver struct { +} + +//dataSourceName=user/password@ipv4:port/dbname +//dataSourceName=user/password@[ipv6]:port/dbname +func (p *oci8Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { + db := &core.Uri{DbType: core.ORACLE} + dsnPattern := regexp.MustCompile( + `^(?P.*)\/(?P.*)@` + // user:password@ + `(?P.*)` + // ip:port + `\/(?P.*)`) // dbname + matches := dsnPattern.FindStringSubmatch(dataSourceName) + names := dsnPattern.SubexpNames() + for i, match := range matches { + switch names[i] { + case "dbname": + db.DbName = match + } + } + if db.DbName == "" { + return nil, errors.New("dbname is empty") + } + return db, nil +} diff --git a/postgres_dialect.go b/postgres_dialect.go index c23ab6f3..05fc1235 100644 --- a/postgres_dialect.go +++ b/postgres_dialect.go @@ -5,7 +5,10 @@ package xorm import ( + "errors" "fmt" + "net/url" + "sort" "strconv" "strings" @@ -1095,3 +1098,107 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) func (db *postgres) Filters() []core.Filter { return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}, &core.SeqFilter{Prefix: "$", Start: 1}} } + +type pqDriver struct { +} + +type values map[string]string + +func (vs values) Set(k, v string) { + vs[k] = v +} + +func (vs values) Get(k string) (v string) { + return vs[k] +} + +func errorf(s string, args ...interface{}) { + panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) +} + +func parseURL(connstr string) (string, error) { + u, err := url.Parse(connstr) + if err != nil { + return "", err + } + + if u.Scheme != "postgresql" && u.Scheme != "postgres" { + return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + var kvs []string + escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + accrue := func(k, v string) { + if v != "" { + kvs = append(kvs, k+"="+escaper.Replace(v)) + } + } + + if u.User != nil { + v := u.User.Username() + accrue("user", v) + + v, _ = u.User.Password() + accrue("password", v) + } + + i := strings.Index(u.Host, ":") + if i < 0 { + accrue("host", u.Host) + } else { + accrue("host", u.Host[:i]) + accrue("port", u.Host[i+1:]) + } + + if u.Path != "" { + accrue("dbname", u.Path[1:]) + } + + q := u.Query() + for k := range q { + accrue(k, q.Get(k)) + } + + sort.Strings(kvs) // Makes testing easier (not a performance concern) + return strings.Join(kvs, " "), nil +} + +func parseOpts(name string, o values) { + if len(name) == 0 { + return + } + + name = strings.TrimSpace(name) + + ps := strings.Split(name, " ") + for _, p := range ps { + kv := strings.Split(p, "=") + if len(kv) < 2 { + errorf("invalid option: %q", p) + } + o.Set(kv[0], kv[1]) + } +} + +func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { + db := &core.Uri{DbType: core.POSTGRES} + o := make(values) + var err error + if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { + dataSourceName, err = parseURL(dataSourceName) + if err != nil { + return nil, err + } + } + parseOpts(dataSourceName, o) + + db.DbName = o.Get("dbname") + if db.DbName == "" { + return nil, errors.New("dbname is empty") + } + /*db.Schema = o.Get("schema") + if len(db.Schema) == 0 { + db.Schema = "public" + }*/ + return db, nil +} diff --git a/pq_driver.go b/pq_driver.go deleted file mode 100644 index 5d608f25..00000000 --- a/pq_driver.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "fmt" - "net/url" - "sort" - "strings" - - "github.com/go-xorm/core" -) - -type pqDriver struct { -} - -type values map[string]string - -func (vs values) Set(k, v string) { - vs[k] = v -} - -func (vs values) Get(k string) (v string) { - return vs[k] -} - -func errorf(s string, args ...interface{}) { - panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) -} - -func parseURL(connstr string) (string, error) { - u, err := url.Parse(connstr) - if err != nil { - return "", err - } - - if u.Scheme != "postgresql" && u.Scheme != "postgres" { - return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) - } - - var kvs []string - escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) - accrue := func(k, v string) { - if v != "" { - kvs = append(kvs, k+"="+escaper.Replace(v)) - } - } - - if u.User != nil { - v := u.User.Username() - accrue("user", v) - - v, _ = u.User.Password() - accrue("password", v) - } - - i := strings.Index(u.Host, ":") - if i < 0 { - accrue("host", u.Host) - } else { - accrue("host", u.Host[:i]) - accrue("port", u.Host[i+1:]) - } - - if u.Path != "" { - accrue("dbname", u.Path[1:]) - } - - q := u.Query() - for k := range q { - accrue(k, q.Get(k)) - } - - sort.Strings(kvs) // Makes testing easier (not a performance concern) - return strings.Join(kvs, " "), nil -} - -func parseOpts(name string, o values) { - if len(name) == 0 { - return - } - - name = strings.TrimSpace(name) - - ps := strings.Split(name, " ") - for _, p := range ps { - kv := strings.Split(p, "=") - if len(kv) < 2 { - errorf("invalid option: %q", p) - } - o.Set(kv[0], kv[1]) - } -} - -func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - db := &core.Uri{DbType: core.POSTGRES} - o := make(values) - var err error - if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { - dataSourceName, err = parseURL(dataSourceName) - if err != nil { - return nil, err - } - } - parseOpts(dataSourceName, o) - - db.DbName = o.Get("dbname") - if db.DbName == "" { - return nil, errors.New("dbname is empty") - } - /*db.Schema = o.Get("schema") - if len(db.Schema) == 0 { - db.Schema = "public" - }*/ - return db, nil -} diff --git a/session_cols_test.go b/session_cols_test.go new file mode 100644 index 00000000..8bef8bd7 --- /dev/null +++ b/session_cols_test.go @@ -0,0 +1,16 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import "testing" + +func TestSetExpr(t *testing.T) { + type User struct { + Id int64 + Show bool + } + + testEngine.SetExpr("show", "NOT show").Id(1).Update(new(User)) +} diff --git a/sqlite3_dialect.go b/sqlite3_dialect.go index b72459ae..c190c4d9 100644 --- a/sqlite3_dialect.go +++ b/sqlite3_dialect.go @@ -434,3 +434,10 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) func (db *sqlite3) Filters() []core.Filter { return []core.Filter{&core.IdFilter{}} } + +type sqlite3Driver struct { +} + +func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { + return &core.Uri{DbType: core.SQLITE, DbName: dataSourceName}, nil +} diff --git a/sqlite3_driver.go b/sqlite3_driver.go deleted file mode 100644 index 6ae19569..00000000 --- a/sqlite3_driver.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "github.com/go-xorm/core" -) - -// func init() { -// core.RegisterDriver("sqlite3", &sqlite3Driver{}) -// } - -type sqlite3Driver struct { -} - -func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - return &core.Uri{DbType: core.SQLITE, DbName: dataSourceName}, nil -} diff --git a/statement_test.go b/statement_test.go index 7e8d6c0c..8fc092d2 100644 --- a/statement_test.go +++ b/statement_test.go @@ -2,11 +2,8 @@ package xorm import ( "reflect" - "sync" - "testing" - "time" - "strings" + "testing" "github.com/go-xorm/core" ) @@ -24,14 +21,6 @@ var colStrTests = []struct { {"", 8, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"}, } -// !nemec784! Only for Statement object creation -const driverName = "mysql" -const dataSourceName = "Server=TestServer;Database=TestDB;Uid=testUser;Pwd=testPassword;" - -func init() { - core.RegisterDriver(driverName, &mysqlDriver{}) -} - func TestColumnsStringGeneration(t *testing.T) { var statement *Statement @@ -41,12 +30,12 @@ func TestColumnsStringGeneration(t *testing.T) { statement = createTestStatement() if testCase.omitColumn != "" { - statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped + statement.Omit(testCase.omitColumn) } + columns := statement.RefTable.Columns() if testCase.onlyToDBColumnNdx >= 0 { - columns := statement.RefTable.Columns() - columns[testCase.onlyToDBColumnNdx].MapType = core.ONLYTODB // !nemec784! Column must be skipped + columns[testCase.onlyToDBColumnNdx].MapType = core.ONLYTODB } actual := statement.genColumnStr() @@ -54,6 +43,9 @@ func TestColumnsStringGeneration(t *testing.T) { if actual != testCase.expected { t.Errorf("[test #%d] Unexpected columns string:\nwant:\t%s\nhave:\t%s", ndx, testCase.expected, actual) } + if testCase.onlyToDBColumnNdx >= 0 { + columns[testCase.onlyToDBColumnNdx].MapType = core.TWOSIDES + } } } @@ -166,40 +158,10 @@ func (TestType) TableName() string { } func createTestStatement() *Statement { - - engine := createTestEngine() - statement := &Statement{} statement.Init() - statement.Engine = engine + statement.Engine = testEngine statement.setRefValue(reflect.ValueOf(TestType{})) return statement } - -func createTestEngine() *Engine { - driver := core.QueryDriver(driverName) - uri, err := driver.Parse(driverName, dataSourceName) - - if err != nil { - panic(err) - } - - dialect := &mysql{} - err = dialect.Init(nil, uri, driverName, dataSourceName) - - if err != nil { - panic(err) - } - - engine := &Engine{ - dialect: dialect, - Tables: make(map[reflect.Type]*core.Table), - mutex: &sync.RWMutex{}, - TagIdentifier: "xorm", - TZLocation: time.Local, - } - engine.SetMapper(core.NewCacheMapper(new(core.SnakeMapper))) - - return engine -} diff --git a/xorm_test.go b/xorm_test.go new file mode 100644 index 00000000..b9adaee4 --- /dev/null +++ b/xorm_test.go @@ -0,0 +1,59 @@ +package xorm + +import ( + "errors" + "flag" + "os" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +var ( + testEngine *Engine + dbType string +) + +func prepareSqlite3Engine() error { + if testEngine == nil { + os.Remove("./test.db") + var err error + testEngine, err = NewEngine("sqlite3", "./test.db") + if err != nil { + return err + } + testEngine.ShowSQL(*showSQL) + } + return nil +} + +func prepareEngine() error { + if dbType == "sqlite" { + return prepareSqlite3Engine() + } + return errors.New("Unknown test database driver") +} + +var ( + db = flag.String("db", "sqlite", "the tested database") + showSQL = flag.Bool("show_sql", true, "show generated SQLs") +) + +func TestMain(m *testing.M) { + flag.Parse() + + if db != nil { + dbType = *db + } + + if err := prepareEngine(); err != nil { + panic(err) + } + os.Exit(m.Run()) +} + +func TestPing(t *testing.T) { + if err := testEngine.Ping(); err != nil { + t.Fatal(err) + } +}