From 51d6afa3300f978c9071952227aa3134ac77ecde Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 4 Mar 2020 16:49:41 +0800 Subject: [PATCH] fix tests --- .drone.yml | 4 +- dialects/dialect.go | 28 ++----- dialects/driver.go | 32 +++++++ dialects/mssql.go | 4 +- dialects/mysql.go | 4 +- dialects/oracle.go | 4 +- dialects/postgres.go | 13 ++- dialects/sqlite3.go | 4 +- engine.go | 11 ++- interface.go | 1 + internal/statements/delete.go | 98 ++++++++++------------ internal/statements/statement_test.go | 116 ++++++++++++++------------ session_delete.go | 7 +- session_get_test.go | 2 +- xorm.go | 4 +- 15 files changed, 178 insertions(+), 154 deletions(-) diff --git a/.drone.yml b/.drone.yml index dac49cdf..9a62c6bd 100644 --- a/.drone.yml +++ b/.drone.yml @@ -22,8 +22,10 @@ steps: commands: - make test-sqlite - TEST_CACHE_ENABLE=true make test-sqlite - - go test ./caches/... ./convert/... ./core/... ./dialects/... \ + - go test ./caches/... ./contexts/... ./convert/... ./core/... ./dialects/... \ + ./internal/json/... ./internal/statements/... ./internal/utils/... \ ./log/... ./migrate/... ./names/... ./schemas/... ./tags/... + when: event: - push diff --git a/dialects/dialect.go b/dialects/dialect.go index a0139d9f..7d816bda 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -31,7 +31,7 @@ type URI struct { // Dialect represents a kind of database type Dialect interface { - Init(*core.DB, *URI, string, string) error + Init(*core.DB, *URI /*, string, string*/) error URI() *URI DB() *core.DB DBType() schemas.DBType @@ -39,9 +39,6 @@ type Dialect interface { FormatBytes(b []byte) string DefaultSchema() string - DriverName() string - DataSourceName() string - IsReserved(string) bool Quoter() schemas.Quoter @@ -77,17 +74,11 @@ type Dialect interface { SetParams(params map[string]string) } -func OpenDialect(dialect Dialect) (*core.DB, error) { - return core.Open(dialect.DriverName(), dialect.DataSourceName()) -} - // Base represents a basic dialect and all real dialects could embed this struct type Base struct { - db *core.DB - dialect Dialect - driverName string - dataSourceName string - uri *URI + db *core.DB + dialect Dialect + uri *URI } func (b *Base) DB() *core.DB { @@ -98,9 +89,8 @@ func (b *Base) DefaultSchema() string { return "" } -func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error { +func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI) error { b.db, b.dialect, b.uri = db, dialect, uri - b.driverName, b.dataSourceName = drivername, dataSourceName return nil } @@ -165,18 +155,10 @@ func (b *Base) FormatBytes(bs []byte) string { return fmt.Sprintf("0x%x", bs) } -func (b *Base) DriverName() string { - return b.driverName -} - func (b *Base) ShowCreateNull() bool { return true } -func (b *Base) DataSourceName() string { - return b.dataSourceName -} - func (db *Base) SupportDropIfExists() bool { return true } diff --git a/dialects/driver.go b/dialects/driver.go index 5343d594..89d21bfc 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -4,6 +4,12 @@ package dialects +import ( + "fmt" + + "xorm.io/xorm/core" +) + type Driver interface { Parse(string, string) (*URI, error) } @@ -29,3 +35,29 @@ func QueryDriver(driverName string) Driver { func RegisteredDriverSize() int { return len(drivers) } + +// OpenDialect opens a dialect via driver name and connection string +func OpenDialect(driverName, connstr string) (Dialect, error) { + driver := QueryDriver(driverName) + if driver == nil { + return nil, fmt.Errorf("Unsupported driver name: %v", driverName) + } + + uri, err := driver.Parse(driverName, connstr) + if err != nil { + return nil, err + } + + dialect := QueryDialect(uri.DBType) + if dialect == nil { + return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType) + } + + db, err := core.Open(driverName, connstr) + if err != nil { + return nil, err + } + dialect.Init(db, uri) + + return dialect, nil +} diff --git a/dialects/mssql.go b/dialects/mssql.go index 9963fc4f..3c95dd20 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -210,8 +210,8 @@ type mssql struct { Base } -func (db *mssql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *mssql) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { + return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) } func (db *mssql) SQLType(c *schemas.Column) string { diff --git a/dialects/mysql.go b/dialects/mysql.go index 5ed2d8f1..7c41ecf6 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -177,8 +177,8 @@ type mysql struct { rowFormat string } -func (db *mysql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *mysql) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { + return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) } func (db *mysql) SetParams(params map[string]string) { diff --git a/dialects/oracle.go b/dialects/oracle.go index e5c438bc..49c65837 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -504,8 +504,8 @@ type oracle struct { Base } -func (db *oracle) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *oracle) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { + return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) } func (db *oracle) SQLType(c *schemas.Column) string { diff --git a/dialects/postgres.go b/dialects/postgres.go index 623b59ed..ad3c8c68 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -766,30 +766,27 @@ var ( "YES": true, "ZONE": true, } - - // DefaultPostgresSchema default postgres schema - DefaultPostgresSchema = "public" ) -const PostgresPublicSchema = "public" +const postgresPublicSchema = "public" type postgres struct { Base } -func (db *postgres) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - err := db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *postgres) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { + err := db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) if err != nil { return err } if db.uri.Schema == "" { - db.uri.Schema = DefaultPostgresSchema + db.uri.Schema = postgresPublicSchema } return nil } func (db *postgres) DefaultSchema() string { - return PostgresPublicSchema + return postgresPublicSchema } func (db *postgres) SQLType(c *schemas.Column) string { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 7dfa7fca..3b9cb97c 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -149,8 +149,8 @@ type sqlite3 struct { Base } -func (db *sqlite3) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *sqlite3) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { + return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) } func (db *sqlite3) SQLType(c *schemas.Column) string { diff --git a/engine.go b/engine.go index 8b4f3931..31d891dc 100644 --- a/engine.go +++ b/engine.go @@ -39,6 +39,9 @@ type Engine struct { logger log.ContextLogger tagParser *tags.Parser + driverName string + dataSourceName string + TZLocation *time.Location // The timezone of the application DatabaseTZ *time.Location // The timezone of the database } @@ -94,12 +97,12 @@ func (engine *Engine) SetDisableGlobalCache(disable bool) { // DriverName return the current sql driver's name func (engine *Engine) DriverName() string { - return engine.dialect.DriverName() + return engine.driverName } // DataSourceName return the current connection string func (engine *Engine) DataSourceName() string { - return engine.dialect.DataSourceName() + return engine.dataSourceName } // SetMapper set the name mapping rules @@ -210,7 +213,7 @@ func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error { // NewDB provides an interface to operate database directly func (engine *Engine) NewDB() (*core.DB, error) { - return dialects.OpenDialect(engine.dialect) + return core.Open(engine.driverName, engine.dataSourceName) } // DB return the wrapper of sql.DB @@ -364,7 +367,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch if dialect == nil { return errors.New("Unsupported database type") } - dialect.Init(nil, engine.dialect.URI(), "", "") + dialect.Init(nil, engine.dialect.URI()) distDBName = string(tp[0]) } diff --git a/interface.go b/interface.go index 13f1e12a..8d2402f0 100644 --- a/interface.go +++ b/interface.go @@ -82,6 +82,7 @@ type EngineInterface interface { CreateTables(...interface{}) error DBMetas() ([]*schemas.Table, error) Dialect() dialects.Dialect + DriverName() string DropTables(...interface{}) error DumpAllToFile(fp string, tp ...schemas.DBType) error GetCacher(string) caches.Cacher diff --git a/internal/statements/delete.go b/internal/statements/delete.go index de4f9f0f..2cb91f2a 100644 --- a/internal/statements/delete.go +++ b/internal/statements/delete.go @@ -22,14 +22,14 @@ var ( ) // GenDeleteSQL generated delete SQL according conditions -func (statement *Statement) GenDeleteSQL(bean interface{}) (string, string, []interface{}, error) { +func (statement *Statement) GenDeleteSQL(bean interface{}) (string, string, []interface{}, *time.Time, error) { condSQL, condArgs, err := statement.GenConds(bean) if err != nil { - return "", "", nil, err + return "", "", nil, nil, err } pLimitN := statement.LimitN if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { - return "", "", nil, ErrNeedDeletedCond + return "", "", nil, nil, ErrNeedDeletedCond } var tableNameNoQuote = statement.TableName() @@ -69,63 +69,57 @@ func (statement *Statement) GenDeleteSQL(bean interface{}) (string, string, []in } // TODO: how to handle delete limit on mssql? case schemas.MSSQL: - return "", "", nil, ErrNotImplemented + return "", "", nil, nil, ErrNotImplemented default: deleteSQL += orderSQL } } var realSQL string - argsForCache := make([]interface{}, 0, len(condArgs)*2) if statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled - realSQL = deleteSQL - copy(argsForCache, condArgs) - argsForCache = append(condArgs, argsForCache...) - } else { - // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches. - copy(argsForCache, condArgs) - argsForCache = append(condArgs, argsForCache...) - - deletedColumn := table.DeletedColumn() - realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", - statement.quote(statement.TableName()), - statement.quote(deletedColumn.Name), - condSQL) - - if len(orderSQL) > 0 { - switch statement.dialect.DBType() { - case schemas.POSTGRES: - inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - realSQL += " AND " + inSQL - } else { - realSQL += " WHERE " + inSQL - } - case schemas.SQLITE: - inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - realSQL += " AND " + inSQL - } else { - realSQL += " WHERE " + inSQL - } - // TODO: how to handle delete limit on mssql? - case schemas.MSSQL: - return "", "", nil, ErrNotImplemented - default: - realSQL += orderSQL - } - } - - // !oinume! Insert nowTime to the head of statement.Params - condArgs = append(condArgs, "") - paramsLen := len(condArgs) - copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) - - now := ColumnNow(deletedColumn, statement.defaultTimeZone) - val := dialects.FormatTime(statement.dialect, deletedColumn.SQLType.Name, now) - condArgs[0] = val + return deleteSQL, deleteSQL, condArgs, nil, nil } - return realSQL, deleteSQL, condArgs, nil + + deletedColumn := table.DeletedColumn() + realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", + statement.quote(statement.TableName()), + statement.quote(deletedColumn.Name), + condSQL) + + if len(orderSQL) > 0 { + switch statement.dialect.DBType() { + case schemas.POSTGRES: + inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + realSQL += " AND " + inSQL + } else { + realSQL += " WHERE " + inSQL + } + case schemas.SQLITE: + inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + realSQL += " AND " + inSQL + } else { + realSQL += " WHERE " + inSQL + } + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: + return "", "", nil, nil, ErrNotImplemented + default: + realSQL += orderSQL + } + } + + // !oinume! Insert nowTime to the head of statement.Params + condArgs = append(condArgs, "") + paramsLen := len(condArgs) + copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) + + now := ColumnNow(deletedColumn, statement.defaultTimeZone) + val := dialects.FormatTime(statement.dialect, deletedColumn.SQLType.Name, now) + condArgs[0] = val + + return realSQL, deleteSQL, condArgs, &now, nil } // ColumnNow returns the current time for a column diff --git a/internal/statements/statement_test.go b/internal/statements/statement_test.go index 3b6e3ae2..15f446f4 100644 --- a/internal/statements/statement_test.go +++ b/internal/statements/statement_test.go @@ -8,10 +8,37 @@ import ( "reflect" "strings" "testing" + "time" + "github.com/stretchr/testify/assert" + "xorm.io/xorm/caches" + "xorm.io/xorm/dialects" + "xorm.io/xorm/names" "xorm.io/xorm/schemas" + "xorm.io/xorm/tags" + + _ "github.com/mattn/go-sqlite3" ) +var ( + dialect dialects.Dialect + tagParser *tags.Parser +) + +func TestMain(m *testing.M) { + var err error + dialect, err = dialects.OpenDialect("sqlite3", "./test.db") + if err != nil { + panic("unknow dialect") + } + + tagParser = tags.NewParser("xorm", dialect, names.SnakeMapper{}, names.SnakeMapper{}, caches.NewManager()) + if tagParser == nil { + panic("tags parser is nil") + } + m.Run() +} + var colStrTests = []struct { omitColumn string onlyToDBColumnNdx int @@ -26,14 +53,9 @@ var colStrTests = []struct { } func TestColumnsStringGeneration(t *testing.T) { - if dbType == "postgres" || dbType == "mssql" { - return - } - - var statement *Statement - for ndx, testCase := range colStrTests { - statement = createTestStatement() + statement, err := createTestStatement() + assert.NoError(t, err) if testCase.omitColumn != "" { statement.Omit(testCase.omitColumn) @@ -55,33 +77,6 @@ func TestColumnsStringGeneration(t *testing.T) { } } -func BenchmarkColumnsStringGeneration(b *testing.B) { - b.StopTimer() - - statement := createTestStatement() - - testCase := colStrTests[0] - - if testCase.omitColumn != "" { - statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped - } - - if testCase.onlyToDBColumnNdx >= 0 { - columns := statement.RefTable.Columns() - columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped - } - - b.StartTimer() - - for i := 0; i < b.N; i++ { - actual := statement.genColumnStr() - - if actual != testCase.expected { - b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) - } - } -} - func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { b.StopTimer() @@ -162,23 +157,40 @@ func (TestType) TableName() string { return "TestTable" } -func createTestStatement() *Statement { - if engine, ok := testEngine.(*Engine); ok { - statement := &Statement{} - statement.Reset() - statement.Engine = engine - statement.dialect = engine.dialect - statement.SetRefValue(reflect.ValueOf(TestType{})) - - return statement - } else if eg, ok := testEngine.(*EngineGroup); ok { - statement := &Statement{} - statement.Reset() - statement.Engine = eg.Engine - statement.dialect = eg.Engine.dialect - statement.SetRefValue(reflect.ValueOf(TestType{})) - - return statement +func createTestStatement() (*Statement, error) { + statement := NewStatement(dialect, tagParser, time.Local) + if err := statement.SetRefValue(reflect.ValueOf(TestType{})); err != nil { + return nil, err + } + return statement, nil +} + +func BenchmarkColumnsStringGeneration(b *testing.B) { + b.StopTimer() + + statement, err := createTestStatement() + if err != nil { + panic(err) + } + + testCase := colStrTests[0] + + if testCase.omitColumn != "" { + statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped + } + + if testCase.onlyToDBColumnNdx >= 0 { + columns := statement.RefTable.Columns() + columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + actual := statement.genColumnStr() + + if actual != testCase.expected { + b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) + } } - return nil } diff --git a/session_delete.go b/session_delete.go index 3373d89e..16434bac 100644 --- a/session_delete.go +++ b/session_delete.go @@ -98,7 +98,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { processor.BeforeDelete() } - realSQL, deleteSQL, condArgs, err := session.statement.GenDeleteSQL(bean) + realSQL, deleteSQL, condArgs, now, err := session.statement.GenDeleteSQL(bean) if err != nil { return 0, err } @@ -110,12 +110,11 @@ func (session *Session) Delete(bean interface{}) (int64, error) { if !session.statement.GetUnscoped() && session.statement.RefTable.DeletedColumn() != nil { deletedColumn := session.statement.RefTable.DeletedColumn() - session.afterClosures = append(session.afterClosures, func(col *schemas.Column, tz *time.Location) func(interface{}) { + session.afterClosures = append(session.afterClosures, func(col *schemas.Column, t time.Time) func(interface{}) { return func(bean interface{}) { - t := time.Now().In(tz) setColumnTime(bean, col, t) } - }(deletedColumn, session.engine.TZLocation)) + }(deletedColumn, now.In(session.engine.TZLocation))) } var tableNameNoQuote = session.statement.TableName() diff --git a/session_get_test.go b/session_get_test.go index 5bac9cd7..7e10bf54 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -179,7 +179,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", valuesString["money"]) // for mymysql driver, interface{} will be []byte, so ignore it currently - if testEngine.Dialect().DriverName() != "mymysql" { + if testEngine.DriverName() != "mymysql" { var valuesInter = make(map[string]interface{}) has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter) assert.NoError(t, err) diff --git a/xorm.go b/xorm.go index 724a37cb..51915940 100644 --- a/xorm.go +++ b/xorm.go @@ -54,7 +54,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { return nil, err } - err = dialect.Init(db, uri, driverName, dataSourceName) + err = dialect.Init(db, uri /*, driverName, dataSourceName*/) if err != nil { return nil, err } @@ -70,6 +70,8 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { defaultContext: context.Background(), cacherMgr: cacherMgr, tagParser: tagParser, + driverName: driverName, + dataSourceName: dataSourceName, } if uri.DBType == schemas.SQLITE {