From c02a1bf00c44574eeff0e86c16138fb94a47baae Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 28 Jul 2021 10:03:09 +0800 Subject: [PATCH] Fix postgres driver datasource name parse (#2012) Fix #2010 Reviewed-on: https://gitea.com/xorm/xorm/pulls/2012 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- dialects/postgres.go | 92 ++++++++++++++++++++++++++++++--------- dialects/postgres_test.go | 18 +++++--- 2 files changed, 82 insertions(+), 28 deletions(-) diff --git a/dialects/postgres.go b/dialects/postgres.go index 8a0dd7a8..cf760e18 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1341,14 +1341,6 @@ type pqDriver struct { type values map[string]string -func (vs values) Set(k, v string) { - vs[k] = v -} - -func (vs values) Get(k string) (v string) { - return vs[k] -} - func parseURL(connstr string) (string, error) { u, err := url.Parse(connstr) if err != nil { @@ -1368,20 +1360,74 @@ func parseURL(connstr string) (string, error) { return "", nil } -func parseOpts(name string, o values) error { - if len(name) == 0 { - return fmt.Errorf("invalid options: %s", name) +func parseOpts(urlStr string, o values) error { + if len(urlStr) == 0 { + return fmt.Errorf("invalid options: %s", urlStr) } - name = strings.TrimSpace(name) + urlStr = strings.TrimSpace(urlStr) - ps := strings.Split(name, " ") - for _, p := range ps { - kv := strings.Split(p, "=") - if len(kv) < 2 { - return fmt.Errorf("invalid option: %q", p) + var ( + inQuote bool + state int // 0 key, 1 space, 2 value, 3 equal + start int + key string + ) + for i, c := range urlStr { + switch c { + case ' ': + if !inQuote { + if state == 2 { + state = 1 + v := urlStr[start:i] + if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") { + v = v[1 : len(v)-1] + } else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") { + return fmt.Errorf("wrong single quote in %d of %s", i, urlStr) + } + o[key] = v + } else if state != 1 { + return fmt.Errorf("wrong format: %v", urlStr) + } + } + case '\'': + if state == 3 { + state = 2 + start = i + } else if state != 2 { + return fmt.Errorf("wrong format: %v", urlStr) + } + inQuote = !inQuote + case '=': + if !inQuote { + if state != 0 { + return fmt.Errorf("wrong format: %v", urlStr) + } + key = urlStr[start:i] + state = 3 + } + default: + if state == 3 { + state = 2 + start = i + } else if state == 1 { + state = 0 + start = i + } + } + + if i == len(urlStr)-1 { + if state != 2 { + return errors.New("no value matched key") + } + v := urlStr[start : i+1] + if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") { + v = v[1 : len(v)-1] + } else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") { + return fmt.Errorf("wrong single quote in %d of %s", i, urlStr) + } + o[key] = v } - o.Set(kv[0], kv[1]) } return nil @@ -1395,9 +1441,13 @@ func (p *pqDriver) Features() *DriverFeatures { func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { db := &URI{DBType: schemas.POSTGRES} - var err error - if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { + var err error + if strings.Contains(dataSourceName, "://") { + if !strings.HasPrefix(dataSourceName, "postgresql://") && !strings.HasPrefix(dataSourceName, "postgres://") { + return nil, fmt.Errorf("unsupported protocol %v", dataSourceName) + } + db.DBName, err = parseURL(dataSourceName) if err != nil { return nil, err @@ -1409,7 +1459,7 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { return nil, err } - db.DBName = o.Get("dbname") + db.DBName = o["dbname"] } if db.DBName == "" { diff --git a/dialects/postgres_test.go b/dialects/postgres_test.go index e0c36f92..bed8f307 100644 --- a/dialects/postgres_test.go +++ b/dialects/postgres_test.go @@ -22,20 +22,24 @@ func TestParsePostgres(t *testing.T) { //{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true}, {"dbname=db sslmode=disable", "db", true}, {"user=auser password=password dbname=db sslmode=disable", "db", true}, + {"user=auser password='pass word' dbname=db sslmode=disable", "db", true}, + {"user=auser password='pass word' sslmode=disable dbname='db'", "db", true}, + {"user=auser password='pass word' sslmode='disable dbname=db'", "db", false}, {"", "db", false}, {"dbname=db =disable", "db", false}, } driver := QueryDriver("postgres") - for _, test := range tests { - uri, err := driver.Parse("postgres", test.in) + t.Run(test.in, func(t *testing.T) { + uri, err := driver.Parse("postgres", test.in) - if err != nil && test.valid { - t.Errorf("%q got unexpected error: %s", test.in, err) - } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { - t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) - } + if err != nil && test.valid { + t.Errorf("%q got unexpected error: %s", test.in, err) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) + } + }) } }