diff --git a/pq_driver.go b/pq_driver.go index a5bb6718..c8dd5aa0 100644 --- a/pq_driver.go +++ b/pq_driver.go @@ -3,6 +3,8 @@ package xorm import ( "errors" "fmt" + "net/url" + "sort" "strings" "github.com/go-xorm/core" @@ -29,6 +31,53 @@ func errorf(s string, args ...interface{}) { panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) } +func parseURL(connstr string) (string, error) { + u, err := url.Parse(connstr) + if err != nil { + return "", err + } + + if u.Scheme != "postgres" { + return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + var kvs []string + escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + accrue := func(k, v string) { + if v != "" { + kvs = append(kvs, k+"="+escaper.Replace(v)) + } + } + + if u.User != nil { + v := u.User.Username() + accrue("user", v) + + v, _ = u.User.Password() + accrue("password", v) + } + + i := strings.Index(u.Host, ":") + if i < 0 { + accrue("host", u.Host) + } else { + accrue("host", u.Host[:i]) + accrue("port", u.Host[i+1:]) + } + + if u.Path != "" { + accrue("dbname", u.Path[1:]) + } + + q := u.Query() + for k := range q { + accrue(k, q.Get(k)) + } + + sort.Strings(kvs) // Makes testing easier (not a performance concern) + return strings.Join(kvs, " "), nil +} + func parseOpts(name string, o values) { if len(name) == 0 { return @@ -49,6 +98,13 @@ func parseOpts(name string, o values) { func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { db := &core.Uri{DbType: core.POSTGRES} o := make(values) + var err error + if strings.HasPrefix(dataSourceName, "postgres://") { + dataSourceName, err = parseURL(dataSourceName) + if err != nil { + return nil, err + } + } parseOpts(dataSourceName, o) db.DbName = o.Get("dbname")