From 77a4ff63c53660f36262be4d7bd9359ec1bec595 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 24 Jun 2020 21:20:55 +0800 Subject: [PATCH] Add DBVersion --- dialects/dialect.go | 1 + dialects/mssql.go | 18 ++++++++++++++++++ dialects/mysql.go | 18 ++++++++++++++++++ dialects/oracle.go | 18 ++++++++++++++++++ dialects/postgres.go | 18 ++++++++++++++++++ dialects/sqlite3.go | 18 ++++++++++++++++++ engine.go | 12 +++--------- integrations/engine_test.go | 9 +++++++++ interface.go | 1 + 9 files changed, 104 insertions(+), 9 deletions(-) diff --git a/dialects/dialect.go b/dialects/dialect.go index 52655e6b..e75ffbfa 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -44,6 +44,7 @@ type Dialect interface { URI() *URI SQLType(*schemas.Column) string FormatBytes(b []byte) string + Version(ctx context.Context, queryer core.Queryer) (string, error) IsReserved(string) bool Quoter() schemas.Quoter diff --git a/dialects/mssql.go b/dialects/mssql.go index 32e7ac50..ae77910f 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -253,6 +253,24 @@ func (db *mssql) SetParams(params map[string]string) { } } +func (db *mssql) Version(ctx context.Context, queryer core.Queryer) (string, error) { + rows, err := queryer.QueryContext(ctx, "SELECT @@VERSION") + if err != nil { + return "", err + } + defer rows.Close() + + var version string + if !rows.Next() { + return "", errors.New("Unknow version") + } + + if err := rows.Scan(&version); err != nil { + return "", err + } + return version, nil +} + func (db *mssql) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { diff --git a/dialects/mysql.go b/dialects/mysql.go index 2b530daf..3e786de7 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -188,6 +188,24 @@ func (db *mysql) Init(uri *URI) error { return db.Base.Init(db, uri) } +func (db *mysql) Version(ctx context.Context, queryer core.Queryer) (string, error) { + rows, err := queryer.QueryContext(ctx, "SELECT VERSION()") + if err != nil { + return "", err + } + defer rows.Close() + + var version string + if !rows.Next() { + return "", errors.New("Unknow version") + } + + if err := rows.Scan(&version); err != nil { + return "", err + } + return version, nil +} + func (db *mysql) SetParams(params map[string]string) { rowFormat, ok := params["rowFormat"] if ok { diff --git a/dialects/oracle.go b/dialects/oracle.go index 72bbe54d..d857aa81 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -515,6 +515,24 @@ func (db *oracle) Init(uri *URI) error { return db.Base.Init(db, uri) } +func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (string, error) { + rows, err := queryer.QueryContext(ctx, "select * from v$version where banner like 'Oracle%'") + if err != nil { + return "", err + } + defer rows.Close() + + var version string + if !rows.Next() { + return "", errors.New("Unknow version") + } + + if err := rows.Scan(&version); err != nil { + return "", err + } + return version, nil +} + func (db *oracle) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { diff --git a/dialects/postgres.go b/dialects/postgres.go index 544c98e9..83e49c19 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -788,6 +788,24 @@ func (db *postgres) Init(uri *URI) error { return db.Base.Init(db, uri) } +func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (string, error) { + rows, err := queryer.QueryContext(ctx, "SELECT version()") + if err != nil { + return "", err + } + defer rows.Close() + + var version string + if !rows.Next() { + return "", errors.New("Unknow version") + } + + if err := rows.Scan(&version); err != nil { + return "", err + } + return version, nil +} + func (db *postgres) getSchema() string { if db.uri.Schema != "" { return db.uri.Schema diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 82683606..a666fe72 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -160,6 +160,24 @@ func (db *sqlite3) Init(uri *URI) error { return db.Base.Init(db, uri) } +func (db *sqlite3) Version(ctx context.Context, queryer core.Queryer) (string, error) { + rows, err := queryer.QueryContext(ctx, "SELECT sqlite_version()") + if err != nil { + return "", err + } + defer rows.Close() + + var version string + if !rows.Next() { + return "", errors.New("Unknow version") + } + + if err := rows.Scan(&version); err != nil { + return "", err + } + return version, nil +} + func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: diff --git a/engine.go b/engine.go index d49eea9a..14726638 100644 --- a/engine.go +++ b/engine.go @@ -925,15 +925,9 @@ func (engine *Engine) Having(conditions string) *Session { return session.Having(conditions) } -// Table table struct -type Table struct { - *schemas.Table - Name string -} - -// IsValid if table is valid -func (t *Table) IsValid() bool { - return t.Table != nil && len(t.Name) > 0 +// DBVersion returns the database version +func (engine *Engine) DBVersion() (string, error) { + return engine.dialect.Version(engine.defaultContext, engine.db) } // TableInfo get table info according to bean's content diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 3b843f16..d01dc7e8 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -200,3 +200,12 @@ func TestImport(t *testing.T) { assert.NoError(t, err) assert.NoError(t, sess.Commit()) } + +func TestDBVersion(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + version, err := testEngine.DBVersion() + assert.NoError(t, err) + + fmt.Println(testEngine.Dialect().URI().DBType, "version", version) +} diff --git a/interface.go b/interface.go index 55162c8c..45d54905 100644 --- a/interface.go +++ b/interface.go @@ -79,6 +79,7 @@ type EngineInterface interface { Before(func(interface{})) *Session Charset(charset string) *Session + DBVersion() (string, error) ClearCache(...interface{}) error Context(context.Context) *Session CreateTables(...interface{}) error