Merge branch 'master' into master

This commit is contained in:
Lunny Xiao 2017-11-20 14:49:29 +08:00 committed by GitHub
commit 951e0ac2eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 53 deletions

View File

@ -32,13 +32,10 @@ proposed functionality.
We appreciate any bug reports, but especially ones with self-contained We appreciate any bug reports, but especially ones with self-contained
(doesn't depend on code outside of xorm), minimal (can't be simplified (doesn't depend on code outside of xorm), minimal (can't be simplified
further) test cases. It's especially helpful if you can submit a pull further) test cases. It's especially helpful if you can submit a pull
request with just the failing test case (you'll probably want to request with just the failing test case(you can find some example test file like [session_get_test.go](https://github.com/go-xorm/xorm/blob/master/session_get_test.go)).
pattern it after the tests in
[base.go](https://github.com/go-xorm/tests/blob/master/base.go) AND
[benchmark.go](https://github.com/go-xorm/tests/blob/master/benchmark.go).
If you implements a new database interface, you maybe need to add a <databasename>_test.go file. If you implements a new database interface, you maybe need to add a test_<databasename>.sh file.
For example, [mysql_test.go](https://github.com/go-xorm/tests/blob/master/mysql/mysql_test.go) For example, [mysql_test.go](https://github.com/go-xorm/xorm/blob/master/test_mysql.sh)
### New functionality ### New functionality

View File

@ -8,7 +8,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
"sort"
"strconv" "strconv"
"strings" "strings"
@ -1117,10 +1116,6 @@ func (vs values) Get(k string) (v string) {
return vs[k] return vs[k]
} }
func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
}
func parseURL(connstr string) (string, error) { func parseURL(connstr string) (string, error) {
u, err := url.Parse(connstr) u, err := url.Parse(connstr)
if err != nil { if err != nil {
@ -1131,46 +1126,18 @@ func parseURL(connstr string) (string, error) {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
} }
var kvs []string
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"="+escaper.Replace(v))
}
}
if u.User != nil {
v := u.User.Username()
accrue("user", v)
v, _ = u.User.Password()
accrue("password", v)
}
i := strings.Index(u.Host, ":")
if i < 0 {
accrue("host", u.Host)
} else {
accrue("host", u.Host[:i])
accrue("port", u.Host[i+1:])
}
if u.Path != "" { if u.Path != "" {
accrue("dbname", u.Path[1:]) return escaper.Replace(u.Path[1:]), nil
} }
q := u.Query() return "", nil
for k := range q {
accrue(k, q.Get(k))
}
sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
} }
func parseOpts(name string, o values) { func parseOpts(name string, o values) error {
if len(name) == 0 { if len(name) == 0 {
return return fmt.Errorf("invalid options: %s", name)
} }
name = strings.TrimSpace(name) name = strings.TrimSpace(name)
@ -1179,31 +1146,36 @@ func parseOpts(name string, o values) {
for _, p := range ps { for _, p := range ps {
kv := strings.Split(p, "=") kv := strings.Split(p, "=")
if len(kv) < 2 { if len(kv) < 2 {
errorf("invalid option: %q", p) return fmt.Errorf("invalid option: %q", p)
} }
o.Set(kv[0], kv[1]) o.Set(kv[0], kv[1])
} }
return nil
} }
func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.POSTGRES} db := &core.Uri{DbType: core.POSTGRES}
o := make(values)
var err error var err error
if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
dataSourceName, err = parseURL(dataSourceName) db.DbName, err = parseURL(dataSourceName)
if err != nil {
return nil, err
}
} else {
o := make(values)
err = parseOpts(dataSourceName, o)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
parseOpts(dataSourceName, o)
db.DbName = o.Get("dbname") db.DbName = o.Get("dbname")
}
if db.DbName == "" { if db.DbName == "" {
return nil, errors.New("dbname is empty") return nil, errors.New("dbname is empty")
} }
/*db.Schema = o.Get("schema")
if len(db.Schema) == 0 {
db.Schema = "public"
}*/
return db, nil return db, nil
} }

44
dialect_postgres_test.go Normal file
View File

@ -0,0 +1,44 @@
package xorm
import (
"reflect"
"testing"
"github.com/go-xorm/core"
)
func TestPostgresDialect(t *testing.T) {
TestParse(t)
}
func TestParse(t *testing.T) {
tests := []struct {
in string
expected string
valid bool
}{
{"postgres://auser:password@localhost:5432/db?sslmode=disable", "db", true},
{"postgresql://auser:password@localhost:5432/db?sslmode=disable", "db", true},
{"postg://auser:password@localhost:5432/db?sslmode=disable", "db", false},
{"postgres://auser:pass with space@localhost:5432/db?sslmode=disable", "db", true},
{"postgres:// auser : password@localhost:5432/db?sslmode=disable", "db", true},
{"postgres://%20auser%20:pass%20with%20space@localhost:5432/db?sslmode=disable", "db", true},
{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true},
{"dbname=db sslmode=disable", "db", true},
{"user=auser password=password dbname=db sslmode=disable", "db", true},
{"", "db", false},
{"dbname=db =disable", "db", false},
}
driver := core.QueryDriver("postgres")
for _, test := range tests {
uri, err := driver.Parse("postgres", test.in)
if err != nil && test.valid {
t.Errorf("%q got unexpected error: %s", test.in, err)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected)
}
}
}