Fix postgres driver datasource name parse (#2012)

Fix #2010

Reviewed-on: https://gitea.com/xorm/xorm/pulls/2012
Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
Co-committed-by: Lunny Xiao <xiaolunwen@gmail.com>
This commit is contained in:
Lunny Xiao 2021-07-28 10:03:09 +08:00
parent d973423802
commit c02a1bf00c
2 changed files with 82 additions and 28 deletions

View File

@ -1341,14 +1341,6 @@ type pqDriver struct {
type values map[string]string type values map[string]string
func (vs values) Set(k, v string) {
vs[k] = v
}
func (vs values) Get(k string) (v string) {
return vs[k]
}
func parseURL(connstr string) (string, error) { func parseURL(connstr string) (string, error) {
u, err := url.Parse(connstr) u, err := url.Parse(connstr)
if err != nil { if err != nil {
@ -1368,20 +1360,74 @@ func parseURL(connstr string) (string, error) {
return "", nil return "", nil
} }
func parseOpts(name string, o values) error { func parseOpts(urlStr string, o values) error {
if len(name) == 0 { if len(urlStr) == 0 {
return fmt.Errorf("invalid options: %s", name) return fmt.Errorf("invalid options: %s", urlStr)
} }
name = strings.TrimSpace(name) urlStr = strings.TrimSpace(urlStr)
ps := strings.Split(name, " ") var (
for _, p := range ps { inQuote bool
kv := strings.Split(p, "=") state int // 0 key, 1 space, 2 value, 3 equal
if len(kv) < 2 { start int
return fmt.Errorf("invalid option: %q", p) key string
)
for i, c := range urlStr {
switch c {
case ' ':
if !inQuote {
if state == 2 {
state = 1
v := urlStr[start:i]
if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") {
v = v[1 : len(v)-1]
} else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") {
return fmt.Errorf("wrong single quote in %d of %s", i, urlStr)
}
o[key] = v
} else if state != 1 {
return fmt.Errorf("wrong format: %v", urlStr)
}
}
case '\'':
if state == 3 {
state = 2
start = i
} else if state != 2 {
return fmt.Errorf("wrong format: %v", urlStr)
}
inQuote = !inQuote
case '=':
if !inQuote {
if state != 0 {
return fmt.Errorf("wrong format: %v", urlStr)
}
key = urlStr[start:i]
state = 3
}
default:
if state == 3 {
state = 2
start = i
} else if state == 1 {
state = 0
start = i
}
}
if i == len(urlStr)-1 {
if state != 2 {
return errors.New("no value matched key")
}
v := urlStr[start : i+1]
if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") {
v = v[1 : len(v)-1]
} else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") {
return fmt.Errorf("wrong single quote in %d of %s", i, urlStr)
}
o[key] = v
} }
o.Set(kv[0], kv[1])
} }
return nil return nil
@ -1395,9 +1441,13 @@ func (p *pqDriver) Features() *DriverFeatures {
func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.POSTGRES} db := &URI{DBType: schemas.POSTGRES}
var err error
if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { var err error
if strings.Contains(dataSourceName, "://") {
if !strings.HasPrefix(dataSourceName, "postgresql://") && !strings.HasPrefix(dataSourceName, "postgres://") {
return nil, fmt.Errorf("unsupported protocol %v", dataSourceName)
}
db.DBName, err = parseURL(dataSourceName) db.DBName, err = parseURL(dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1409,7 +1459,7 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
return nil, err return nil, err
} }
db.DBName = o.Get("dbname") db.DBName = o["dbname"]
} }
if db.DBName == "" { if db.DBName == "" {

View File

@ -22,13 +22,16 @@ func TestParsePostgres(t *testing.T) {
//{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true}, //{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true},
{"dbname=db sslmode=disable", "db", true}, {"dbname=db sslmode=disable", "db", true},
{"user=auser password=password dbname=db sslmode=disable", "db", true}, {"user=auser password=password dbname=db sslmode=disable", "db", true},
{"user=auser password='pass word' dbname=db sslmode=disable", "db", true},
{"user=auser password='pass word' sslmode=disable dbname='db'", "db", true},
{"user=auser password='pass word' sslmode='disable dbname=db'", "db", false},
{"", "db", false}, {"", "db", false},
{"dbname=db =disable", "db", false}, {"dbname=db =disable", "db", false},
} }
driver := QueryDriver("postgres") driver := QueryDriver("postgres")
for _, test := range tests { for _, test := range tests {
t.Run(test.in, func(t *testing.T) {
uri, err := driver.Parse("postgres", test.in) uri, err := driver.Parse("postgres", test.in)
if err != nil && test.valid { if err != nil && test.valid {
@ -36,6 +39,7 @@ func TestParsePostgres(t *testing.T) {
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected)
} }
})
} }
} }