Fix postgres driver datasource name parse (#2012)
Fix #2010 Reviewed-on: https://gitea.com/xorm/xorm/pulls/2012 Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com> Co-committed-by: Lunny Xiao <xiaolunwen@gmail.com>
This commit is contained in:
parent
d973423802
commit
c02a1bf00c
|
@ -1341,14 +1341,6 @@ type pqDriver struct {
|
|||
|
||||
type values map[string]string
|
||||
|
||||
func (vs values) Set(k, v string) {
|
||||
vs[k] = v
|
||||
}
|
||||
|
||||
func (vs values) Get(k string) (v string) {
|
||||
return vs[k]
|
||||
}
|
||||
|
||||
func parseURL(connstr string) (string, error) {
|
||||
u, err := url.Parse(connstr)
|
||||
if err != nil {
|
||||
|
@ -1368,20 +1360,74 @@ func parseURL(connstr string) (string, error) {
|
|||
return "", nil
|
||||
}
|
||||
|
||||
func parseOpts(name string, o values) error {
|
||||
if len(name) == 0 {
|
||||
return fmt.Errorf("invalid options: %s", name)
|
||||
func parseOpts(urlStr string, o values) error {
|
||||
if len(urlStr) == 0 {
|
||||
return fmt.Errorf("invalid options: %s", urlStr)
|
||||
}
|
||||
|
||||
name = strings.TrimSpace(name)
|
||||
urlStr = strings.TrimSpace(urlStr)
|
||||
|
||||
ps := strings.Split(name, " ")
|
||||
for _, p := range ps {
|
||||
kv := strings.Split(p, "=")
|
||||
if len(kv) < 2 {
|
||||
return fmt.Errorf("invalid option: %q", p)
|
||||
var (
|
||||
inQuote bool
|
||||
state int // 0 key, 1 space, 2 value, 3 equal
|
||||
start int
|
||||
key string
|
||||
)
|
||||
for i, c := range urlStr {
|
||||
switch c {
|
||||
case ' ':
|
||||
if !inQuote {
|
||||
if state == 2 {
|
||||
state = 1
|
||||
v := urlStr[start:i]
|
||||
if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") {
|
||||
v = v[1 : len(v)-1]
|
||||
} else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") {
|
||||
return fmt.Errorf("wrong single quote in %d of %s", i, urlStr)
|
||||
}
|
||||
o[key] = v
|
||||
} else if state != 1 {
|
||||
return fmt.Errorf("wrong format: %v", urlStr)
|
||||
}
|
||||
}
|
||||
case '\'':
|
||||
if state == 3 {
|
||||
state = 2
|
||||
start = i
|
||||
} else if state != 2 {
|
||||
return fmt.Errorf("wrong format: %v", urlStr)
|
||||
}
|
||||
inQuote = !inQuote
|
||||
case '=':
|
||||
if !inQuote {
|
||||
if state != 0 {
|
||||
return fmt.Errorf("wrong format: %v", urlStr)
|
||||
}
|
||||
key = urlStr[start:i]
|
||||
state = 3
|
||||
}
|
||||
default:
|
||||
if state == 3 {
|
||||
state = 2
|
||||
start = i
|
||||
} else if state == 1 {
|
||||
state = 0
|
||||
start = i
|
||||
}
|
||||
}
|
||||
|
||||
if i == len(urlStr)-1 {
|
||||
if state != 2 {
|
||||
return errors.New("no value matched key")
|
||||
}
|
||||
v := urlStr[start : i+1]
|
||||
if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") {
|
||||
v = v[1 : len(v)-1]
|
||||
} else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") {
|
||||
return fmt.Errorf("wrong single quote in %d of %s", i, urlStr)
|
||||
}
|
||||
o[key] = v
|
||||
}
|
||||
o.Set(kv[0], kv[1])
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -1395,9 +1441,13 @@ func (p *pqDriver) Features() *DriverFeatures {
|
|||
|
||||
func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||
db := &URI{DBType: schemas.POSTGRES}
|
||||
var err error
|
||||
|
||||
if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
|
||||
var err error
|
||||
if strings.Contains(dataSourceName, "://") {
|
||||
if !strings.HasPrefix(dataSourceName, "postgresql://") && !strings.HasPrefix(dataSourceName, "postgres://") {
|
||||
return nil, fmt.Errorf("unsupported protocol %v", dataSourceName)
|
||||
}
|
||||
|
||||
db.DBName, err = parseURL(dataSourceName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -1409,7 +1459,7 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
db.DBName = o.Get("dbname")
|
||||
db.DBName = o["dbname"]
|
||||
}
|
||||
|
||||
if db.DBName == "" {
|
||||
|
|
|
@ -22,20 +22,24 @@ func TestParsePostgres(t *testing.T) {
|
|||
//{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true},
|
||||
{"dbname=db sslmode=disable", "db", true},
|
||||
{"user=auser password=password dbname=db sslmode=disable", "db", true},
|
||||
{"user=auser password='pass word' dbname=db sslmode=disable", "db", true},
|
||||
{"user=auser password='pass word' sslmode=disable dbname='db'", "db", true},
|
||||
{"user=auser password='pass word' sslmode='disable dbname=db'", "db", false},
|
||||
{"", "db", false},
|
||||
{"dbname=db =disable", "db", false},
|
||||
}
|
||||
|
||||
driver := QueryDriver("postgres")
|
||||
|
||||
for _, test := range tests {
|
||||
uri, err := driver.Parse("postgres", test.in)
|
||||
t.Run(test.in, func(t *testing.T) {
|
||||
uri, err := driver.Parse("postgres", test.in)
|
||||
|
||||
if err != nil && test.valid {
|
||||
t.Errorf("%q got unexpected error: %s", test.in, err)
|
||||
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) {
|
||||
t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected)
|
||||
}
|
||||
if err != nil && test.valid {
|
||||
t.Errorf("%q got unexpected error: %s", test.in, err)
|
||||
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) {
|
||||
t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue