fix tests

This commit is contained in:
Lunny Xiao 2020-03-04 16:49:41 +08:00
parent 9a7b4e7af5
commit 51d6afa330
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
15 changed files with 178 additions and 154 deletions

View File

@ -22,8 +22,10 @@ steps:
commands: commands:
- make test-sqlite - make test-sqlite
- TEST_CACHE_ENABLE=true make test-sqlite - TEST_CACHE_ENABLE=true make test-sqlite
- go test ./caches/... ./convert/... ./core/... ./dialects/... \ - go test ./caches/... ./contexts/... ./convert/... ./core/... ./dialects/... \
./internal/json/... ./internal/statements/... ./internal/utils/... \
./log/... ./migrate/... ./names/... ./schemas/... ./tags/... ./log/... ./migrate/... ./names/... ./schemas/... ./tags/...
when: when:
event: event:
- push - push

View File

@ -31,7 +31,7 @@ type URI struct {
// Dialect represents a kind of database // Dialect represents a kind of database
type Dialect interface { type Dialect interface {
Init(*core.DB, *URI, string, string) error Init(*core.DB, *URI /*, string, string*/) error
URI() *URI URI() *URI
DB() *core.DB DB() *core.DB
DBType() schemas.DBType DBType() schemas.DBType
@ -39,9 +39,6 @@ type Dialect interface {
FormatBytes(b []byte) string FormatBytes(b []byte) string
DefaultSchema() string DefaultSchema() string
DriverName() string
DataSourceName() string
IsReserved(string) bool IsReserved(string) bool
Quoter() schemas.Quoter Quoter() schemas.Quoter
@ -77,17 +74,11 @@ type Dialect interface {
SetParams(params map[string]string) SetParams(params map[string]string)
} }
func OpenDialect(dialect Dialect) (*core.DB, error) {
return core.Open(dialect.DriverName(), dialect.DataSourceName())
}
// Base represents a basic dialect and all real dialects could embed this struct // Base represents a basic dialect and all real dialects could embed this struct
type Base struct { type Base struct {
db *core.DB db *core.DB
dialect Dialect dialect Dialect
driverName string uri *URI
dataSourceName string
uri *URI
} }
func (b *Base) DB() *core.DB { func (b *Base) DB() *core.DB {
@ -98,9 +89,8 @@ func (b *Base) DefaultSchema() string {
return "" return ""
} }
func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error { func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI) error {
b.db, b.dialect, b.uri = db, dialect, uri b.db, b.dialect, b.uri = db, dialect, uri
b.driverName, b.dataSourceName = drivername, dataSourceName
return nil return nil
} }
@ -165,18 +155,10 @@ func (b *Base) FormatBytes(bs []byte) string {
return fmt.Sprintf("0x%x", bs) return fmt.Sprintf("0x%x", bs)
} }
func (b *Base) DriverName() string {
return b.driverName
}
func (b *Base) ShowCreateNull() bool { func (b *Base) ShowCreateNull() bool {
return true return true
} }
func (b *Base) DataSourceName() string {
return b.dataSourceName
}
func (db *Base) SupportDropIfExists() bool { func (db *Base) SupportDropIfExists() bool {
return true return true
} }

View File

@ -4,6 +4,12 @@
package dialects package dialects
import (
"fmt"
"xorm.io/xorm/core"
)
type Driver interface { type Driver interface {
Parse(string, string) (*URI, error) Parse(string, string) (*URI, error)
} }
@ -29,3 +35,29 @@ func QueryDriver(driverName string) Driver {
func RegisteredDriverSize() int { func RegisteredDriverSize() int {
return len(drivers) return len(drivers)
} }
// OpenDialect opens a dialect via driver name and connection string
func OpenDialect(driverName, connstr string) (Dialect, error) {
driver := QueryDriver(driverName)
if driver == nil {
return nil, fmt.Errorf("Unsupported driver name: %v", driverName)
}
uri, err := driver.Parse(driverName, connstr)
if err != nil {
return nil, err
}
dialect := QueryDialect(uri.DBType)
if dialect == nil {
return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType)
}
db, err := core.Open(driverName, connstr)
if err != nil {
return nil, err
}
dialect.Init(db, uri)
return dialect, nil
}

View File

@ -210,8 +210,8 @@ type mssql struct {
Base Base
} }
func (db *mssql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { func (db *mssql) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName) return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/)
} }
func (db *mssql) SQLType(c *schemas.Column) string { func (db *mssql) SQLType(c *schemas.Column) string {

View File

@ -177,8 +177,8 @@ type mysql struct {
rowFormat string rowFormat string
} }
func (db *mysql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { func (db *mysql) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName) return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/)
} }
func (db *mysql) SetParams(params map[string]string) { func (db *mysql) SetParams(params map[string]string) {

View File

@ -504,8 +504,8 @@ type oracle struct {
Base Base
} }
func (db *oracle) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { func (db *oracle) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName) return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/)
} }
func (db *oracle) SQLType(c *schemas.Column) string { func (db *oracle) SQLType(c *schemas.Column) string {

View File

@ -766,30 +766,27 @@ var (
"YES": true, "YES": true,
"ZONE": true, "ZONE": true,
} }
// DefaultPostgresSchema default postgres schema
DefaultPostgresSchema = "public"
) )
const PostgresPublicSchema = "public" const postgresPublicSchema = "public"
type postgres struct { type postgres struct {
Base Base
} }
func (db *postgres) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { func (db *postgres) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error {
err := db.Base.Init(d, db, uri, drivername, dataSourceName) err := db.Base.Init(d, db, uri /*, drivername, dataSourceName*/)
if err != nil { if err != nil {
return err return err
} }
if db.uri.Schema == "" { if db.uri.Schema == "" {
db.uri.Schema = DefaultPostgresSchema db.uri.Schema = postgresPublicSchema
} }
return nil return nil
} }
func (db *postgres) DefaultSchema() string { func (db *postgres) DefaultSchema() string {
return PostgresPublicSchema return postgresPublicSchema
} }
func (db *postgres) SQLType(c *schemas.Column) string { func (db *postgres) SQLType(c *schemas.Column) string {

View File

@ -149,8 +149,8 @@ type sqlite3 struct {
Base Base
} }
func (db *sqlite3) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { func (db *sqlite3) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName) return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/)
} }
func (db *sqlite3) SQLType(c *schemas.Column) string { func (db *sqlite3) SQLType(c *schemas.Column) string {

View File

@ -39,6 +39,9 @@ type Engine struct {
logger log.ContextLogger logger log.ContextLogger
tagParser *tags.Parser tagParser *tags.Parser
driverName string
dataSourceName string
TZLocation *time.Location // The timezone of the application TZLocation *time.Location // The timezone of the application
DatabaseTZ *time.Location // The timezone of the database DatabaseTZ *time.Location // The timezone of the database
} }
@ -94,12 +97,12 @@ func (engine *Engine) SetDisableGlobalCache(disable bool) {
// DriverName return the current sql driver's name // DriverName return the current sql driver's name
func (engine *Engine) DriverName() string { func (engine *Engine) DriverName() string {
return engine.dialect.DriverName() return engine.driverName
} }
// DataSourceName return the current connection string // DataSourceName return the current connection string
func (engine *Engine) DataSourceName() string { func (engine *Engine) DataSourceName() string {
return engine.dialect.DataSourceName() return engine.dataSourceName
} }
// SetMapper set the name mapping rules // SetMapper set the name mapping rules
@ -210,7 +213,7 @@ func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error {
// NewDB provides an interface to operate database directly // NewDB provides an interface to operate database directly
func (engine *Engine) NewDB() (*core.DB, error) { func (engine *Engine) NewDB() (*core.DB, error) {
return dialects.OpenDialect(engine.dialect) return core.Open(engine.driverName, engine.dataSourceName)
} }
// DB return the wrapper of sql.DB // DB return the wrapper of sql.DB
@ -364,7 +367,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
if dialect == nil { if dialect == nil {
return errors.New("Unsupported database type") return errors.New("Unsupported database type")
} }
dialect.Init(nil, engine.dialect.URI(), "", "") dialect.Init(nil, engine.dialect.URI())
distDBName = string(tp[0]) distDBName = string(tp[0])
} }

View File

@ -82,6 +82,7 @@ type EngineInterface interface {
CreateTables(...interface{}) error CreateTables(...interface{}) error
DBMetas() ([]*schemas.Table, error) DBMetas() ([]*schemas.Table, error)
Dialect() dialects.Dialect Dialect() dialects.Dialect
DriverName() string
DropTables(...interface{}) error DropTables(...interface{}) error
DumpAllToFile(fp string, tp ...schemas.DBType) error DumpAllToFile(fp string, tp ...schemas.DBType) error
GetCacher(string) caches.Cacher GetCacher(string) caches.Cacher

View File

@ -22,14 +22,14 @@ var (
) )
// GenDeleteSQL generated delete SQL according conditions // GenDeleteSQL generated delete SQL according conditions
func (statement *Statement) GenDeleteSQL(bean interface{}) (string, string, []interface{}, error) { func (statement *Statement) GenDeleteSQL(bean interface{}) (string, string, []interface{}, *time.Time, error) {
condSQL, condArgs, err := statement.GenConds(bean) condSQL, condArgs, err := statement.GenConds(bean)
if err != nil { if err != nil {
return "", "", nil, err return "", "", nil, nil, err
} }
pLimitN := statement.LimitN pLimitN := statement.LimitN
if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) {
return "", "", nil, ErrNeedDeletedCond return "", "", nil, nil, ErrNeedDeletedCond
} }
var tableNameNoQuote = statement.TableName() var tableNameNoQuote = statement.TableName()
@ -69,63 +69,57 @@ func (statement *Statement) GenDeleteSQL(bean interface{}) (string, string, []in
} }
// TODO: how to handle delete limit on mssql? // TODO: how to handle delete limit on mssql?
case schemas.MSSQL: case schemas.MSSQL:
return "", "", nil, ErrNotImplemented return "", "", nil, nil, ErrNotImplemented
default: default:
deleteSQL += orderSQL deleteSQL += orderSQL
} }
} }
var realSQL string var realSQL string
argsForCache := make([]interface{}, 0, len(condArgs)*2)
if statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled if statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled
realSQL = deleteSQL return deleteSQL, deleteSQL, condArgs, nil, nil
copy(argsForCache, condArgs)
argsForCache = append(condArgs, argsForCache...)
} else {
// !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches.
copy(argsForCache, condArgs)
argsForCache = append(condArgs, argsForCache...)
deletedColumn := table.DeletedColumn()
realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v",
statement.quote(statement.TableName()),
statement.quote(deletedColumn.Name),
condSQL)
if len(orderSQL) > 0 {
switch statement.dialect.DBType() {
case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
realSQL += " AND " + inSQL
} else {
realSQL += " WHERE " + inSQL
}
case schemas.SQLITE:
inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
realSQL += " AND " + inSQL
} else {
realSQL += " WHERE " + inSQL
}
// TODO: how to handle delete limit on mssql?
case schemas.MSSQL:
return "", "", nil, ErrNotImplemented
default:
realSQL += orderSQL
}
}
// !oinume! Insert nowTime to the head of statement.Params
condArgs = append(condArgs, "")
paramsLen := len(condArgs)
copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1])
now := ColumnNow(deletedColumn, statement.defaultTimeZone)
val := dialects.FormatTime(statement.dialect, deletedColumn.SQLType.Name, now)
condArgs[0] = val
} }
return realSQL, deleteSQL, condArgs, nil
deletedColumn := table.DeletedColumn()
realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v",
statement.quote(statement.TableName()),
statement.quote(deletedColumn.Name),
condSQL)
if len(orderSQL) > 0 {
switch statement.dialect.DBType() {
case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
realSQL += " AND " + inSQL
} else {
realSQL += " WHERE " + inSQL
}
case schemas.SQLITE:
inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
realSQL += " AND " + inSQL
} else {
realSQL += " WHERE " + inSQL
}
// TODO: how to handle delete limit on mssql?
case schemas.MSSQL:
return "", "", nil, nil, ErrNotImplemented
default:
realSQL += orderSQL
}
}
// !oinume! Insert nowTime to the head of statement.Params
condArgs = append(condArgs, "")
paramsLen := len(condArgs)
copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1])
now := ColumnNow(deletedColumn, statement.defaultTimeZone)
val := dialects.FormatTime(statement.dialect, deletedColumn.SQLType.Name, now)
condArgs[0] = val
return realSQL, deleteSQL, condArgs, &now, nil
} }
// ColumnNow returns the current time for a column // ColumnNow returns the current time for a column

View File

@ -8,10 +8,37 @@ import (
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert"
"xorm.io/xorm/caches"
"xorm.io/xorm/dialects"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
"xorm.io/xorm/tags"
_ "github.com/mattn/go-sqlite3"
) )
var (
dialect dialects.Dialect
tagParser *tags.Parser
)
func TestMain(m *testing.M) {
var err error
dialect, err = dialects.OpenDialect("sqlite3", "./test.db")
if err != nil {
panic("unknow dialect")
}
tagParser = tags.NewParser("xorm", dialect, names.SnakeMapper{}, names.SnakeMapper{}, caches.NewManager())
if tagParser == nil {
panic("tags parser is nil")
}
m.Run()
}
var colStrTests = []struct { var colStrTests = []struct {
omitColumn string omitColumn string
onlyToDBColumnNdx int onlyToDBColumnNdx int
@ -26,14 +53,9 @@ var colStrTests = []struct {
} }
func TestColumnsStringGeneration(t *testing.T) { func TestColumnsStringGeneration(t *testing.T) {
if dbType == "postgres" || dbType == "mssql" {
return
}
var statement *Statement
for ndx, testCase := range colStrTests { for ndx, testCase := range colStrTests {
statement = createTestStatement() statement, err := createTestStatement()
assert.NoError(t, err)
if testCase.omitColumn != "" { if testCase.omitColumn != "" {
statement.Omit(testCase.omitColumn) statement.Omit(testCase.omitColumn)
@ -55,33 +77,6 @@ func TestColumnsStringGeneration(t *testing.T) {
} }
} }
func BenchmarkColumnsStringGeneration(b *testing.B) {
b.StopTimer()
statement := createTestStatement()
testCase := colStrTests[0]
if testCase.omitColumn != "" {
statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped
}
if testCase.onlyToDBColumnNdx >= 0 {
columns := statement.RefTable.Columns()
columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped
}
b.StartTimer()
for i := 0; i < b.N; i++ {
actual := statement.genColumnStr()
if actual != testCase.expected {
b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual)
}
}
}
func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
b.StopTimer() b.StopTimer()
@ -162,23 +157,40 @@ func (TestType) TableName() string {
return "TestTable" return "TestTable"
} }
func createTestStatement() *Statement { func createTestStatement() (*Statement, error) {
if engine, ok := testEngine.(*Engine); ok { statement := NewStatement(dialect, tagParser, time.Local)
statement := &Statement{} if err := statement.SetRefValue(reflect.ValueOf(TestType{})); err != nil {
statement.Reset() return nil, err
statement.Engine = engine }
statement.dialect = engine.dialect return statement, nil
statement.SetRefValue(reflect.ValueOf(TestType{})) }
return statement func BenchmarkColumnsStringGeneration(b *testing.B) {
} else if eg, ok := testEngine.(*EngineGroup); ok { b.StopTimer()
statement := &Statement{}
statement.Reset() statement, err := createTestStatement()
statement.Engine = eg.Engine if err != nil {
statement.dialect = eg.Engine.dialect panic(err)
statement.SetRefValue(reflect.ValueOf(TestType{})) }
return statement testCase := colStrTests[0]
if testCase.omitColumn != "" {
statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped
}
if testCase.onlyToDBColumnNdx >= 0 {
columns := statement.RefTable.Columns()
columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped
}
b.StartTimer()
for i := 0; i < b.N; i++ {
actual := statement.genColumnStr()
if actual != testCase.expected {
b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual)
}
} }
return nil
} }

View File

@ -98,7 +98,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
processor.BeforeDelete() processor.BeforeDelete()
} }
realSQL, deleteSQL, condArgs, err := session.statement.GenDeleteSQL(bean) realSQL, deleteSQL, condArgs, now, err := session.statement.GenDeleteSQL(bean)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -110,12 +110,11 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
if !session.statement.GetUnscoped() && session.statement.RefTable.DeletedColumn() != nil { if !session.statement.GetUnscoped() && session.statement.RefTable.DeletedColumn() != nil {
deletedColumn := session.statement.RefTable.DeletedColumn() deletedColumn := session.statement.RefTable.DeletedColumn()
session.afterClosures = append(session.afterClosures, func(col *schemas.Column, tz *time.Location) func(interface{}) { session.afterClosures = append(session.afterClosures, func(col *schemas.Column, t time.Time) func(interface{}) {
return func(bean interface{}) { return func(bean interface{}) {
t := time.Now().In(tz)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)
} }
}(deletedColumn, session.engine.TZLocation)) }(deletedColumn, now.In(session.engine.TZLocation)))
} }
var tableNameNoQuote = session.statement.TableName() var tableNameNoQuote = session.statement.TableName()

View File

@ -179,7 +179,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "1.5", valuesString["money"]) assert.Equal(t, "1.5", valuesString["money"])
// for mymysql driver, interface{} will be []byte, so ignore it currently // for mymysql driver, interface{} will be []byte, so ignore it currently
if testEngine.Dialect().DriverName() != "mymysql" { if testEngine.DriverName() != "mymysql" {
var valuesInter = make(map[string]interface{}) var valuesInter = make(map[string]interface{})
has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter) has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -54,7 +54,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
return nil, err return nil, err
} }
err = dialect.Init(db, uri, driverName, dataSourceName) err = dialect.Init(db, uri /*, driverName, dataSourceName*/)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -70,6 +70,8 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
defaultContext: context.Background(), defaultContext: context.Background(),
cacherMgr: cacherMgr, cacherMgr: cacherMgr,
tagParser: tagParser, tagParser: tagParser,
driverName: driverName,
dataSourceName: dataSourceName,
} }
if uri.DBType == schemas.SQLITE { if uri.DBType == schemas.SQLITE {