From f51d28304a2d89edfb5a57b42d91080a79cfa6c7 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 6 Mar 2020 06:43:49 +0000 Subject: [PATCH] Move some codes to statement sub package (#1574) revert change for delete refactor new engine fix tests Move some codes to statement sub package Reviewed-on: https://gitea.com/xorm/xorm/pulls/1574 --- .drone.yml | 4 +- dialects/dialect.go | 28 ++----- dialects/driver.go | 32 +++++++ dialects/mssql.go | 4 +- dialects/mysql.go | 4 +- dialects/oracle.go | 4 +- dialects/postgres.go | 13 ++- dialects/sqlite3.go | 4 +- engine.go | 29 ++++--- error.go | 4 - interface.go | 1 + internal/statements/query.go | 50 +++++------ internal/statements/statement_test.go | 116 ++++++++++++++------------ session.go | 2 +- session_convert.go | 3 +- session_delete.go | 8 ++ session_get_test.go | 2 +- session_tx.go | 4 +- xorm.go | 29 +------ 19 files changed, 169 insertions(+), 172 deletions(-) diff --git a/.drone.yml b/.drone.yml index dac49cdf..9a62c6bd 100644 --- a/.drone.yml +++ b/.drone.yml @@ -22,8 +22,10 @@ steps: commands: - make test-sqlite - TEST_CACHE_ENABLE=true make test-sqlite - - go test ./caches/... ./convert/... ./core/... ./dialects/... \ + - go test ./caches/... ./contexts/... ./convert/... ./core/... ./dialects/... \ + ./internal/json/... ./internal/statements/... ./internal/utils/... \ ./log/... ./migrate/... ./names/... ./schemas/... ./tags/... + when: event: - push diff --git a/dialects/dialect.go b/dialects/dialect.go index a0139d9f..c591cc7b 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -31,7 +31,7 @@ type URI struct { // Dialect represents a kind of database type Dialect interface { - Init(*core.DB, *URI, string, string) error + Init(*core.DB, *URI) error URI() *URI DB() *core.DB DBType() schemas.DBType @@ -39,9 +39,6 @@ type Dialect interface { FormatBytes(b []byte) string DefaultSchema() string - DriverName() string - DataSourceName() string - IsReserved(string) bool Quoter() schemas.Quoter @@ -77,17 +74,11 @@ type Dialect interface { SetParams(params map[string]string) } -func OpenDialect(dialect Dialect) (*core.DB, error) { - return core.Open(dialect.DriverName(), dialect.DataSourceName()) -} - // Base represents a basic dialect and all real dialects could embed this struct type Base struct { - db *core.DB - dialect Dialect - driverName string - dataSourceName string - uri *URI + db *core.DB + dialect Dialect + uri *URI } func (b *Base) DB() *core.DB { @@ -98,9 +89,8 @@ func (b *Base) DefaultSchema() string { return "" } -func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error { +func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI) error { b.db, b.dialect, b.uri = db, dialect, uri - b.driverName, b.dataSourceName = drivername, dataSourceName return nil } @@ -165,18 +155,10 @@ func (b *Base) FormatBytes(bs []byte) string { return fmt.Sprintf("0x%x", bs) } -func (b *Base) DriverName() string { - return b.driverName -} - func (b *Base) ShowCreateNull() bool { return true } -func (b *Base) DataSourceName() string { - return b.dataSourceName -} - func (db *Base) SupportDropIfExists() bool { return true } diff --git a/dialects/driver.go b/dialects/driver.go index 5343d594..89d21bfc 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -4,6 +4,12 @@ package dialects +import ( + "fmt" + + "xorm.io/xorm/core" +) + type Driver interface { Parse(string, string) (*URI, error) } @@ -29,3 +35,29 @@ func QueryDriver(driverName string) Driver { func RegisteredDriverSize() int { return len(drivers) } + +// OpenDialect opens a dialect via driver name and connection string +func OpenDialect(driverName, connstr string) (Dialect, error) { + driver := QueryDriver(driverName) + if driver == nil { + return nil, fmt.Errorf("Unsupported driver name: %v", driverName) + } + + uri, err := driver.Parse(driverName, connstr) + if err != nil { + return nil, err + } + + dialect := QueryDialect(uri.DBType) + if dialect == nil { + return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType) + } + + db, err := core.Open(driverName, connstr) + if err != nil { + return nil, err + } + dialect.Init(db, uri) + + return dialect, nil +} diff --git a/dialects/mssql.go b/dialects/mssql.go index 9963fc4f..558abdfc 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -210,8 +210,8 @@ type mssql struct { Base } -func (db *mssql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *mssql) Init(d *core.DB, uri *URI) error { + return db.Base.Init(d, db, uri) } func (db *mssql) SQLType(c *schemas.Column) string { diff --git a/dialects/mysql.go b/dialects/mysql.go index 5ed2d8f1..939a7cf1 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -177,8 +177,8 @@ type mysql struct { rowFormat string } -func (db *mysql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *mysql) Init(d *core.DB, uri *URI) error { + return db.Base.Init(d, db, uri) } func (db *mysql) SetParams(params map[string]string) { diff --git a/dialects/oracle.go b/dialects/oracle.go index e5c438bc..4a8162ac 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -504,8 +504,8 @@ type oracle struct { Base } -func (db *oracle) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *oracle) Init(d *core.DB, uri *URI) error { + return db.Base.Init(d, db, uri) } func (db *oracle) SQLType(c *schemas.Column) string { diff --git a/dialects/postgres.go b/dialects/postgres.go index 623b59ed..f92202cd 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -766,30 +766,27 @@ var ( "YES": true, "ZONE": true, } - - // DefaultPostgresSchema default postgres schema - DefaultPostgresSchema = "public" ) -const PostgresPublicSchema = "public" +const postgresPublicSchema = "public" type postgres struct { Base } -func (db *postgres) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - err := db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *postgres) Init(d *core.DB, uri *URI) error { + err := db.Base.Init(d, db, uri) if err != nil { return err } if db.uri.Schema == "" { - db.uri.Schema = DefaultPostgresSchema + db.uri.Schema = postgresPublicSchema } return nil } func (db *postgres) DefaultSchema() string { - return PostgresPublicSchema + return postgresPublicSchema } func (db *postgres) SQLType(c *schemas.Column) string { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 7dfa7fca..39138b13 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -149,8 +149,8 @@ type sqlite3 struct { Base } -func (db *sqlite3) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *sqlite3) Init(d *core.DB, uri *URI) error { + return db.Base.Init(d, db, uri) } func (db *sqlite3) SQLType(c *schemas.Column) string { diff --git a/engine.go b/engine.go index 221b7488..cc8a74a0 100644 --- a/engine.go +++ b/engine.go @@ -39,6 +39,9 @@ type Engine struct { logger log.ContextLogger tagParser *tags.Parser + driverName string + dataSourceName string + TZLocation *time.Location // The timezone of the application DatabaseTZ *time.Location // The timezone of the database } @@ -61,7 +64,7 @@ func (engine *Engine) BufferSize(size int) *Session { // ShowSQL show SQL statement or not on logger if log level is great than INFO func (engine *Engine) ShowSQL(show ...bool) { engine.logger.ShowSQL(show...) - engine.db.Logger = engine.logger + engine.DB().Logger = engine.logger } // Logger return the logger interface @@ -79,7 +82,7 @@ func (engine *Engine) SetLogger(logger interface{}) { realLogger = t } engine.logger = realLogger - engine.db.Logger = realLogger + engine.DB().Logger = realLogger } // SetLogLevel sets the logger level @@ -94,12 +97,12 @@ func (engine *Engine) SetDisableGlobalCache(disable bool) { // DriverName return the current sql driver's name func (engine *Engine) DriverName() string { - return engine.dialect.DriverName() + return engine.driverName } // DataSourceName return the current connection string func (engine *Engine) DataSourceName() string { - return engine.dialect.DataSourceName() + return engine.dataSourceName } // SetMapper set the name mapping rules @@ -164,17 +167,17 @@ func (engine *Engine) AutoIncrStr() string { // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. func (engine *Engine) SetConnMaxLifetime(d time.Duration) { - engine.db.SetConnMaxLifetime(d) + engine.DB().SetConnMaxLifetime(d) } // SetMaxOpenConns is only available for go 1.2+ func (engine *Engine) SetMaxOpenConns(conns int) { - engine.db.SetMaxOpenConns(conns) + engine.DB().SetMaxOpenConns(conns) } // SetMaxIdleConns set the max idle connections on pool, default is 2 func (engine *Engine) SetMaxIdleConns(conns int) { - engine.db.SetMaxIdleConns(conns) + engine.DB().SetMaxIdleConns(conns) } // SetDefaultCacher set the default cacher. Xorm's default not enable cacher. @@ -210,12 +213,12 @@ func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error { // NewDB provides an interface to operate database directly func (engine *Engine) NewDB() (*core.DB, error) { - return dialects.OpenDialect(engine.dialect) + return core.Open(engine.driverName, engine.dataSourceName) } // DB return the wrapper of sql.DB func (engine *Engine) DB() *core.DB { - return engine.db + return engine.dialect.DB() } // Dialect return database dialect @@ -232,7 +235,7 @@ func (engine *Engine) NewSession() *Session { // Close the engine func (engine *Engine) Close() error { - return engine.db.Close() + return engine.DB().Close() } // Ping tests if database is alive @@ -364,7 +367,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch if dialect == nil { return errors.New("Unsupported database type") } - dialect.Init(nil, engine.dialect.URI(), "", "") + dialect.Init(nil, engine.dialect.URI()) distDBName = string(tp[0]) } @@ -1211,10 +1214,6 @@ func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) { return dialects.FormatTime(engine.dialect, col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation) } -func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interface{}) { - return dialects.FormatColumnTime(engine.dialect, engine.DatabaseTZ, col, t) -} - // GetColumnMapper returns the column name mapper func (engine *Engine) GetColumnMapper() names.Mapper { return engine.tagParser.GetColumnMapper() diff --git a/error.go b/error.go index a19860e3..21a83f47 100644 --- a/error.go +++ b/error.go @@ -20,10 +20,6 @@ var ( ErrNotExist = errors.New("Record does not exist") // ErrCacheFailed cache failed error ErrCacheFailed = errors.New("Cache failed") - // ErrNeedDeletedCond delete needs less one condition error - ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") - // ErrNotImplemented not implemented - ErrNotImplemented = errors.New("Not implemented") // ErrConditionType condition type unsupported ErrConditionType = errors.New("Unsupported condition type") ) diff --git a/interface.go b/interface.go index 13f1e12a..8d2402f0 100644 --- a/interface.go +++ b/interface.go @@ -82,6 +82,7 @@ type EngineInterface interface { CreateTables(...interface{}) error DBMetas() ([]*schemas.Table, error) Dialect() dialects.Dialect + DriverName() string DropTables(...interface{}) error DumpAllToFile(fp string, tp ...schemas.DBType) error GetCacher(string) caches.Cacher diff --git a/internal/statements/query.go b/internal/statements/query.go index 1519cb08..a058f752 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -57,16 +57,12 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int return "", nil, err } - condSQL, condArgs, err := builder.ToSQL(statement.cond) + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) if err != nil { return "", nil, err } - args := append(statement.joinArgs, condArgs...) - sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true) - if err != nil { - return "", nil, err - } + // for mssql and use limit qs := strings.Count(sqlStr, "?") if len(args)*2 == qs { @@ -92,12 +88,11 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri } sumSelect := strings.Join(sumStrs, ", ") - condSQL, condArgs, err := statement.GenConds(bean) - if err != nil { + if err := statement.mergeConds(bean); err != nil { return "", nil, err } - sqlStr, err := statement.GenSelectSQL(sumSelect, condSQL, true, true) + sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true) if err != nil { return "", nil, err } @@ -147,12 +142,8 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, return "", nil, err } } - condSQL, condArgs, err := builder.ToSQL(statement.cond) - if err != nil { - return "", nil, err - } - sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true) + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) if err != nil { return "", nil, err } @@ -165,17 +156,13 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa return statement.RawSQL, statement.RawParams, nil } - var condSQL string var condArgs []interface{} var err error if len(beans) > 0 { statement.SetRefBean(beans[0]) - condSQL, condArgs, err = statement.GenConds(beans[0]) - } else { - condSQL, condArgs, err = builder.ToSQL(statement.cond) - } - if err != nil { - return "", nil, err + if err := statement.mergeConds(beans[0]); err != nil { + return "", nil, err + } } var selectSQL = statement.SelectStr @@ -186,7 +173,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa selectSQL = "count(*)" } } - sqlStr, err := statement.GenSelectSQL(selectSQL, condSQL, false, false) + sqlStr, condArgs, err := statement.genSelectSQL(selectSQL, false, false) if err != nil { return "", nil, err } @@ -194,7 +181,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { +func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { var ( distinct string dialect = statement.dialect @@ -205,6 +192,11 @@ func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, n if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { distinct = "DISTINCT " } + + condSQL, condArgs, err := builder.ToSQL(statement.cond) + if err != nil { + return "", nil, err + } if len(condSQL) > 0 { whereStr = " WHERE " + condSQL } @@ -313,10 +305,10 @@ func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, n } } if statement.IsForUpdate { - return dialect.ForUpdateSQL(buf.String()), nil + return dialect.ForUpdateSQL(buf.String()), condArgs, nil } - return buf.String(), nil + return buf.String(), condArgs, nil } func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) { @@ -428,16 +420,12 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa } statement.cond = statement.cond.And(autoCond) - condSQL, condArgs, err := builder.ToSQL(statement.cond) - if err != nil { - return "", nil, err - } - args = append(statement.joinArgs, condArgs...) - sqlStr, err = statement.GenSelectSQL(columnStr, condSQL, true, true) + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) if err != nil { return "", nil, err } + args = append(statement.joinArgs, condArgs...) // for mssql and use limit qs := strings.Count(sqlStr, "?") if len(args)*2 == qs { diff --git a/internal/statements/statement_test.go b/internal/statements/statement_test.go index 3b6e3ae2..15f446f4 100644 --- a/internal/statements/statement_test.go +++ b/internal/statements/statement_test.go @@ -8,10 +8,37 @@ import ( "reflect" "strings" "testing" + "time" + "github.com/stretchr/testify/assert" + "xorm.io/xorm/caches" + "xorm.io/xorm/dialects" + "xorm.io/xorm/names" "xorm.io/xorm/schemas" + "xorm.io/xorm/tags" + + _ "github.com/mattn/go-sqlite3" ) +var ( + dialect dialects.Dialect + tagParser *tags.Parser +) + +func TestMain(m *testing.M) { + var err error + dialect, err = dialects.OpenDialect("sqlite3", "./test.db") + if err != nil { + panic("unknow dialect") + } + + tagParser = tags.NewParser("xorm", dialect, names.SnakeMapper{}, names.SnakeMapper{}, caches.NewManager()) + if tagParser == nil { + panic("tags parser is nil") + } + m.Run() +} + var colStrTests = []struct { omitColumn string onlyToDBColumnNdx int @@ -26,14 +53,9 @@ var colStrTests = []struct { } func TestColumnsStringGeneration(t *testing.T) { - if dbType == "postgres" || dbType == "mssql" { - return - } - - var statement *Statement - for ndx, testCase := range colStrTests { - statement = createTestStatement() + statement, err := createTestStatement() + assert.NoError(t, err) if testCase.omitColumn != "" { statement.Omit(testCase.omitColumn) @@ -55,33 +77,6 @@ func TestColumnsStringGeneration(t *testing.T) { } } -func BenchmarkColumnsStringGeneration(b *testing.B) { - b.StopTimer() - - statement := createTestStatement() - - testCase := colStrTests[0] - - if testCase.omitColumn != "" { - statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped - } - - if testCase.onlyToDBColumnNdx >= 0 { - columns := statement.RefTable.Columns() - columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped - } - - b.StartTimer() - - for i := 0; i < b.N; i++ { - actual := statement.genColumnStr() - - if actual != testCase.expected { - b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) - } - } -} - func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { b.StopTimer() @@ -162,23 +157,40 @@ func (TestType) TableName() string { return "TestTable" } -func createTestStatement() *Statement { - if engine, ok := testEngine.(*Engine); ok { - statement := &Statement{} - statement.Reset() - statement.Engine = engine - statement.dialect = engine.dialect - statement.SetRefValue(reflect.ValueOf(TestType{})) - - return statement - } else if eg, ok := testEngine.(*EngineGroup); ok { - statement := &Statement{} - statement.Reset() - statement.Engine = eg.Engine - statement.dialect = eg.Engine.dialect - statement.SetRefValue(reflect.ValueOf(TestType{})) - - return statement +func createTestStatement() (*Statement, error) { + statement := NewStatement(dialect, tagParser, time.Local) + if err := statement.SetRefValue(reflect.ValueOf(TestType{})); err != nil { + return nil, err + } + return statement, nil +} + +func BenchmarkColumnsStringGeneration(b *testing.B) { + b.StopTimer() + + statement, err := createTestStatement() + if err != nil { + panic(err) + } + + testCase := colStrTests[0] + + if testCase.omitColumn != "" { + statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped + } + + if testCase.onlyToDBColumnNdx >= 0 { + columns := statement.RefTable.Columns() + columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + actual := statement.genColumnStr() + + if actual != testCase.expected { + b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) + } } - return nil } diff --git a/session.go b/session.go index db990684..07b99594 100644 --- a/session.go +++ b/session.go @@ -284,7 +284,7 @@ func (session *Session) Having(conditions string) *Session { // DB db return the wrapper of sql.DB func (session *Session) DB() *core.DB { if session.db == nil { - session.db = session.engine.db + session.db = session.engine.DB() session.stmtCache = make(map[uint32]*core.Stmt, 0) } return session.db diff --git a/session_convert.go b/session_convert.go index 1cd00627..0776bc45 100644 --- a/session_convert.go +++ b/session_convert.go @@ -15,6 +15,7 @@ import ( "time" "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" @@ -583,7 +584,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. case reflect.Struct: if fieldType.ConvertibleTo(schemas.TimeType) { t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) - tf := session.engine.formatColTime(col, t) + tf := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, t) return tf, nil } else if fieldType.ConvertibleTo(nullFloatType) { t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64) diff --git a/session_delete.go b/session_delete.go index 04200035..eb5e2aea 100644 --- a/session_delete.go +++ b/session_delete.go @@ -13,6 +13,14 @@ import ( "xorm.io/xorm/schemas" ) +var ( + // ErrNeedDeletedCond delete needs less one condition error + ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") + + // ErrNotImplemented not implemented + ErrNotImplemented = errors.New("Not implemented") +) + func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error { if table == nil || session.tx != nil { diff --git a/session_get_test.go b/session_get_test.go index 5bac9cd7..7e10bf54 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -179,7 +179,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", valuesString["money"]) // for mymysql driver, interface{} will be []byte, so ignore it currently - if testEngine.Dialect().DriverName() != "mymysql" { + if testEngine.DriverName() != "mymysql" { var valuesInter = make(map[string]interface{}) has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter) assert.NoError(t, err) diff --git a/session_tx.go b/session_tx.go index 489489f3..cd23cf89 100644 --- a/session_tx.go +++ b/session_tx.go @@ -34,7 +34,7 @@ func (session *Session) Rollback() error { session.isAutoCommit = true start := time.Now() - needSQL := session.engine.db.NeedLogSQL(session.ctx) + needSQL := session.DB().NeedLogSQL(session.ctx) if needSQL { session.engine.logger.BeforeSQL(log.LogContext{ Ctx: session.ctx, @@ -63,7 +63,7 @@ func (session *Session) Commit() error { session.isAutoCommit = true start := time.Now() - needSQL := session.engine.db.NeedLogSQL(session.ctx) + needSQL := session.DB().NeedLogSQL(session.ctx) if needSQL { session.engine.logger.BeforeSQL(log.LogContext{ Ctx: session.ctx, diff --git a/xorm.go b/xorm.go index 724a37cb..3618b718 100644 --- a/xorm.go +++ b/xorm.go @@ -8,13 +8,11 @@ package xorm import ( "context" - "fmt" "os" "runtime" "time" "xorm.io/xorm/caches" - "xorm.io/xorm/core" "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -34,27 +32,7 @@ func close(engine *Engine) { // NewEngine new a db manager according to the parameter. Currently support four // drivers func NewEngine(driverName string, dataSourceName string) (*Engine, error) { - driver := dialects.QueryDriver(driverName) - if driver == nil { - return nil, fmt.Errorf("Unsupported driver name: %v", driverName) - } - - uri, err := driver.Parse(driverName, dataSourceName) - if err != nil { - return nil, err - } - - dialect := dialects.QueryDialect(uri.DBType) - if dialect == nil { - return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType) - } - - db, err := core.Open(driverName, dataSourceName) - if err != nil { - return nil, err - } - - err = dialect.Init(db, uri, driverName, dataSourceName) + dialect, err := dialects.OpenDialect(driverName, dataSourceName) if err != nil { return nil, err } @@ -64,15 +42,16 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr) engine := &Engine{ - db: db, dialect: dialect, TZLocation: time.Local, defaultContext: context.Background(), cacherMgr: cacherMgr, tagParser: tagParser, + driverName: driverName, + dataSourceName: dataSourceName, } - if uri.DBType == schemas.SQLITE { + if dialect.URI().DBType == schemas.SQLITE { engine.DatabaseTZ = time.UTC } else { engine.DatabaseTZ = time.Local