refactoring db interface
This commit is contained in:
parent
99c7031b50
commit
4c7e98d8d6
33
mymysql.go
33
mymysql.go
|
@ -8,23 +8,20 @@ import (
|
||||||
|
|
||||||
type mymysql struct {
|
type mymysql struct {
|
||||||
mysql
|
mysql
|
||||||
proto string
|
|
||||||
raddr string
|
|
||||||
laddr string
|
|
||||||
timeout time.Duration
|
|
||||||
db string
|
|
||||||
user string
|
|
||||||
passwd string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mymysql) Init(drivername, uri string) error {
|
type mymysqlParser struct {
|
||||||
db.mysql.base.init(drivername, uri)
|
}
|
||||||
pd := strings.SplitN(uri, "*", 2)
|
|
||||||
|
func (p *mymysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
|
||||||
|
db := &uri{dbType: MYSQL}
|
||||||
|
|
||||||
|
pd := strings.SplitN(dataSourceName, "*", 2)
|
||||||
if len(pd) == 2 {
|
if len(pd) == 2 {
|
||||||
// Parse protocol part of URI
|
// Parse protocol part of URI
|
||||||
p := strings.SplitN(pd[0], ":", 2)
|
p := strings.SplitN(pd[0], ":", 2)
|
||||||
if len(p) != 2 {
|
if len(p) != 2 {
|
||||||
return errors.New("Wrong protocol part of URI")
|
return nil, errors.New("Wrong protocol part of URI")
|
||||||
}
|
}
|
||||||
db.proto = p[0]
|
db.proto = p[0]
|
||||||
options := strings.Split(p[1], ",")
|
options := strings.Split(p[1], ",")
|
||||||
|
@ -43,11 +40,11 @@ func (db *mymysql) Init(drivername, uri string) error {
|
||||||
case "timeout":
|
case "timeout":
|
||||||
to, err := time.ParseDuration(v)
|
to, err := time.ParseDuration(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
db.timeout = to
|
db.timeout = to
|
||||||
default:
|
default:
|
||||||
return errors.New("Unknown option: " + k)
|
return nil, errors.New("Unknown option: " + k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Remove protocol part
|
// Remove protocol part
|
||||||
|
@ -56,11 +53,15 @@ func (db *mymysql) Init(drivername, uri string) error {
|
||||||
// Parse database part of URI
|
// Parse database part of URI
|
||||||
dup := strings.SplitN(pd[0], "/", 3)
|
dup := strings.SplitN(pd[0], "/", 3)
|
||||||
if len(dup) != 3 {
|
if len(dup) != 3 {
|
||||||
return errors.New("Wrong database part of URI")
|
return nil, errors.New("Wrong database part of URI")
|
||||||
}
|
}
|
||||||
db.dbname = dup[0]
|
db.dbName = dup[0]
|
||||||
db.user = dup[1]
|
db.user = dup[1]
|
||||||
db.passwd = dup[2]
|
db.passwd = dup[2]
|
||||||
|
|
||||||
return nil
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *mymysql) Init(drivername, uri string) error {
|
||||||
|
return db.mysql.base.init(&mymysqlParser{}, drivername, uri)
|
||||||
}
|
}
|
||||||
|
|
114
mysql.go
114
mysql.go
|
@ -11,22 +11,67 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type base struct {
|
type uri struct {
|
||||||
drivername string
|
dbType string
|
||||||
dataSourceName string
|
proto string
|
||||||
|
host string
|
||||||
|
port string
|
||||||
|
dbName string
|
||||||
|
user string
|
||||||
|
passwd string
|
||||||
|
charset string
|
||||||
|
laddr string
|
||||||
|
raddr string
|
||||||
|
timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *base) init(drivername, dataSourceName string) {
|
type parser interface {
|
||||||
b.drivername, b.dataSourceName = drivername, dataSourceName
|
parse(driverName, dataSourceName string) (*uri, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type mysqlParser struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
|
||||||
|
//cfg.params = make(map[string]string)
|
||||||
|
dsnPattern := regexp.MustCompile(
|
||||||
|
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
|
||||||
|
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
|
||||||
|
`\/(?P<dbname>.*?)` + // /dbname
|
||||||
|
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1¶mN=valueN]
|
||||||
|
matches := dsnPattern.FindStringSubmatch(dataSourceName)
|
||||||
|
//tlsConfigRegister := make(map[string]*tls.Config)
|
||||||
|
names := dsnPattern.SubexpNames()
|
||||||
|
|
||||||
|
uri := &uri{dbType: MYSQL}
|
||||||
|
|
||||||
|
for i, match := range matches {
|
||||||
|
switch names[i] {
|
||||||
|
case "dbname":
|
||||||
|
uri.dbName = match
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return uri, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type base struct {
|
||||||
|
parser parser
|
||||||
|
driverName string
|
||||||
|
dataSourceName string
|
||||||
|
*uri
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *base) init(parser parser, drivername, dataSourceName string) (err error) {
|
||||||
|
b.parser = parser
|
||||||
|
b.driverName, b.dataSourceName = drivername, dataSourceName
|
||||||
|
b.uri, err = b.parser.parse(b.driverName, b.dataSourceName)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type mysql struct {
|
type mysql struct {
|
||||||
base
|
base
|
||||||
user string
|
|
||||||
passwd string
|
|
||||||
net string
|
net string
|
||||||
addr string
|
addr string
|
||||||
dbname string
|
|
||||||
params map[string]string
|
params map[string]string
|
||||||
loc *time.Location
|
loc *time.Location
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
|
@ -36,41 +81,8 @@ type mysql struct {
|
||||||
clientFoundRows bool
|
clientFoundRows bool
|
||||||
}
|
}
|
||||||
|
|
||||||
/*func readBool(input string) (value bool, valid bool) {
|
|
||||||
switch input {
|
|
||||||
case "1", "true", "TRUE", "True":
|
|
||||||
return true, true
|
|
||||||
case "0", "false", "FALSE", "False":
|
|
||||||
return false, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Not a valid bool value
|
|
||||||
return
|
|
||||||
}*/
|
|
||||||
|
|
||||||
func (cfg *mysql) parseDSN(dsn string) (err error) {
|
|
||||||
//cfg.params = make(map[string]string)
|
|
||||||
dsnPattern := regexp.MustCompile(
|
|
||||||
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
|
|
||||||
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
|
|
||||||
`\/(?P<dbname>.*?)` + // /dbname
|
|
||||||
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1¶mN=valueN]
|
|
||||||
matches := dsnPattern.FindStringSubmatch(dsn)
|
|
||||||
//tlsConfigRegister := make(map[string]*tls.Config)
|
|
||||||
names := dsnPattern.SubexpNames()
|
|
||||||
|
|
||||||
for i, match := range matches {
|
|
||||||
switch names[i] {
|
|
||||||
case "dbname":
|
|
||||||
cfg.dbname = match
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *mysql) Init(drivername, uri string) error {
|
func (db *mysql) Init(drivername, uri string) error {
|
||||||
db.base.init(drivername, uri)
|
return db.base.init(&mysqlParser{}, drivername, uri)
|
||||||
return db.parseDSN(uri)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) SqlType(c *Column) string {
|
func (db *mysql) SqlType(c *Column) string {
|
||||||
|
@ -132,29 +144,29 @@ func (db *mysql) IndexOnTable() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
|
func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
|
||||||
args := []interface{}{db.dbname, tableName, idxName}
|
args := []interface{}{db.dbName, tableName, idxName}
|
||||||
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
|
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
|
||||||
sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?"
|
sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?"
|
||||||
return sql, args
|
return sql, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
|
func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
|
||||||
args := []interface{}{db.dbname, tableName, colName}
|
args := []interface{}{db.dbName, tableName, colName}
|
||||||
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
|
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
|
||||||
return sql, args
|
return sql, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
|
func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
|
||||||
args := []interface{}{db.dbname, tableName}
|
args := []interface{}{db.dbName, tableName}
|
||||||
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
|
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
|
||||||
return sql, args
|
return sql, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
|
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
|
||||||
args := []interface{}{db.dbname, tableName}
|
args := []interface{}{db.dbName, tableName}
|
||||||
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
|
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
|
||||||
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
|
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -232,9 +244,9 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) GetTables() ([]*Table, error) {
|
func (db *mysql) GetTables() ([]*Table, error) {
|
||||||
args := []interface{}{db.dbname}
|
args := []interface{}{db.dbName}
|
||||||
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
|
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -260,9 +272,9 @@ func (db *mysql) GetTables() ([]*Table, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
|
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
|
||||||
args := []interface{}{db.dbname, tableName}
|
args := []interface{}{db.dbName, tableName}
|
||||||
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
|
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
33
postgres.go
33
postgres.go
|
@ -10,7 +10,6 @@ import (
|
||||||
|
|
||||||
type postgres struct {
|
type postgres struct {
|
||||||
base
|
base
|
||||||
dbname string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type values map[string]string
|
type values map[string]string
|
||||||
|
@ -44,17 +43,23 @@ func parseOpts(name string, o values) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) Init(drivername, uri string) error {
|
type postgresParser struct {
|
||||||
db.base.init(drivername, uri)
|
|
||||||
|
|
||||||
o := make(values)
|
|
||||||
parseOpts(uri, o)
|
|
||||||
|
|
||||||
db.dbname = o.Get("dbname")
|
|
||||||
if db.dbname == "" {
|
|
||||||
return errors.New("dbname is empty")
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
func (p *postgresParser) parse(driverName, dataSourceName string) (*uri, error) {
|
||||||
|
db := &uri{dbType: POSTGRES}
|
||||||
|
o := make(values)
|
||||||
|
parseOpts(dataSourceName, o)
|
||||||
|
|
||||||
|
db.dbName = o.Get("dbname")
|
||||||
|
if db.dbName == "" {
|
||||||
|
return nil, errors.New("dbname is empty")
|
||||||
|
}
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *postgres) Init(drivername, uri string) error {
|
||||||
|
return db.base.init(&postgresParser{}, drivername, uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) SqlType(c *Column) string {
|
func (db *postgres) SqlType(c *Column) string {
|
||||||
|
@ -145,7 +150,7 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column,
|
||||||
s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
|
s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
|
||||||
", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
|
", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
|
||||||
|
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -221,7 +226,7 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column,
|
||||||
func (db *postgres) GetTables() ([]*Table, error) {
|
func (db *postgres) GetTables() ([]*Table, error) {
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
s := "SELECT tablename FROM pg_tables where schemaname = 'public'"
|
s := "SELECT tablename FROM pg_tables where schemaname = 'public'"
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -249,7 +254,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1"
|
s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1"
|
||||||
|
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
16
sqlite3.go
16
sqlite3.go
|
@ -9,9 +9,15 @@ type sqlite3 struct {
|
||||||
base
|
base
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type sqlite3Parser struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sqlite3Parser) parse(driverName, dataSourceName string) (*uri, error) {
|
||||||
|
return &uri{dbType: SQLITE, dbName: dataSourceName}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (db *sqlite3) Init(drivername, dataSourceName string) error {
|
func (db *sqlite3) Init(drivername, dataSourceName string) error {
|
||||||
db.base.init(drivername, dataSourceName)
|
return db.base.init(&sqlite3Parser{}, drivername, dataSourceName)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) SqlType(c *Column) string {
|
func (db *sqlite3) SqlType(c *Column) string {
|
||||||
|
@ -83,7 +89,7 @@ func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interfac
|
||||||
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) {
|
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
|
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -142,7 +148,7 @@ func (db *sqlite3) GetTables() ([]*Table, error) {
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
s := "SELECT name FROM sqlite_master WHERE type='table'"
|
s := "SELECT name FROM sqlite_master WHERE type='table'"
|
||||||
|
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -172,7 +178,7 @@ func (db *sqlite3) GetTables() ([]*Table, error) {
|
||||||
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
|
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
|
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue