This commit is contained in:
Lunny Xiao 2014-07-08 09:34:56 +08:00
parent a404099f25
commit 12a68f628a
1 changed files with 56 additions and 0 deletions

View File

@ -3,6 +3,8 @@ package xorm
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/url"
"sort"
"strings" "strings"
"github.com/go-xorm/core" "github.com/go-xorm/core"
@ -29,6 +31,53 @@ func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
} }
func parseURL(connstr string) (string, error) {
u, err := url.Parse(connstr)
if err != nil {
return "", err
}
if u.Scheme != "postgres" {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
var kvs []string
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"="+escaper.Replace(v))
}
}
if u.User != nil {
v := u.User.Username()
accrue("user", v)
v, _ = u.User.Password()
accrue("password", v)
}
i := strings.Index(u.Host, ":")
if i < 0 {
accrue("host", u.Host)
} else {
accrue("host", u.Host[:i])
accrue("port", u.Host[i+1:])
}
if u.Path != "" {
accrue("dbname", u.Path[1:])
}
q := u.Query()
for k := range q {
accrue(k, q.Get(k))
}
sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
}
func parseOpts(name string, o values) { func parseOpts(name string, o values) {
if len(name) == 0 { if len(name) == 0 {
return return
@ -49,6 +98,13 @@ func parseOpts(name string, o values) {
func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.POSTGRES} db := &core.Uri{DbType: core.POSTGRES}
o := make(values) o := make(values)
var err error
if strings.HasPrefix(dataSourceName, "postgres://") {
dataSourceName, err = parseURL(dataSourceName)
if err != nil {
return nil, err
}
}
parseOpts(dataSourceName, o) parseOpts(dataSourceName, o)
db.DbName = o.Get("dbname") db.DbName = o.Get("dbname")