From b922c1c01550d723c47f1c50d89dfcd8119a7ba4 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 29 Oct 2024 16:19:16 -0400 Subject: [PATCH 1/3] GODRIVER-3284 Allow valid SRV hostnames with less than 3 parts. --- ...itial_dns_seedlist_discovery_prose_test.go | 116 ++++++++++++++++++ x/mongo/driver/dns/dns.go | 11 +- 2 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go diff --git a/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go b/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go new file mode 100644 index 0000000000..a2ddf14a88 --- /dev/null +++ b/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go @@ -0,0 +1,116 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package connstring + +import ( + "net" + "testing" + + "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/dns" +) + +func TestInitialDNSSeedlistDiscoveryProse(t *testing.T) { + newTestParser := func(record string) *parser { + return &parser{&dns.Resolver{ + LookupSRV: func(_, _, _ string) (string, []*net.SRV, error) { + return "", []*net.SRV{ + { + Target: record, + Port: 27017, + }, + }, nil + }, + LookupTXT: func(string) ([]string, error) { + return nil, nil + }, + }} + } + + t.Run("1. Allow SRVs with fewer than 3 . separated parts", func(t *testing.T) { + t.Parallel() + + cases := []struct { + record string + uri string + }{ + {"test_1.localhost", "mongodb+srv://localhost"}, + {"test_1.mongo.local", "mongodb+srv://mongo.local"}, + } + for _, c := range cases { + c := c + t.Run(c.uri, func(t *testing.T) { + t.Parallel() + + _, err := newTestParser(c.record).parse(c.uri) + assert.NoError(t, err, "expected no URI parsing error, got %v", err) + }) + } + }) + t.Run("2. Throw when return address does not end with SRV domain", func(t *testing.T) { + t.Parallel() + + cases := []struct { + record string + uri string + }{ + {"localhost.mongodb", "mongodb+srv://localhost"}, + {"test_1.evil.local", "mongodb+srv://mongo.local"}, + {"blogs.evil.com", "mongodb+srv://blogs.mongodb.com"}, + } + for _, c := range cases { + c := c + t.Run(c.uri, func(t *testing.T) { + t.Parallel() + + _, err := newTestParser(c.record).parse(c.uri) + assert.ErrorContains(t, err, "Domain suffix from SRV record not matched input domain") + }) + } + }) + t.Run("3. Throw when return address is identical to SRV hostname", func(t *testing.T) { + t.Parallel() + + cases := []struct { + record string + uri string + }{ + {"localhost", "mongodb+srv://localhost"}, + {"mongo.local", "mongodb+srv://mongo.local"}, + } + for _, c := range cases { + c := c + t.Run(c.uri, func(t *testing.T) { + t.Parallel() + + _, err := newTestParser(c.record).parse(c.uri) + assert.ErrorContains(t, err, "DNS name must contain at least") + }) + } + }) + t.Run("4. Throw when return address does not contain . separating shared part of domain", func(t *testing.T) { + t.Parallel() + + cases := []struct { + record string + uri string + }{ + {"test_1.cluster_1localhost", "mongodb+srv://localhost"}, + {"test_1.my_hostmongo.local", "mongodb+srv://mongo.local"}, + {"cluster.testmongodb.com", "mongodb+srv://blogs.mongodb.com"}, + } + for _, c := range cases { + c := c + t.Run(c.uri, func(t *testing.T) { + t.Parallel() + + _, err := newTestParser(c.record).parse(c.uri) + assert.ErrorContains(t, err, "Domain suffix from SRV record not matched input domain") + }) + } + }) +} diff --git a/x/mongo/driver/dns/dns.go b/x/mongo/driver/dns/dns.go index 9334d493ed..4524af2794 100644 --- a/x/mongo/driver/dns/dns.go +++ b/x/mongo/driver/dns/dns.go @@ -113,15 +113,18 @@ func (r *Resolver) fetchSeedlistFromSRV(host string, srvName string, stopOnErr b func validateSRVResult(recordFromSRV, inputHostName string) error { separatedInputDomain := strings.Split(strings.ToLower(inputHostName), ".") separatedRecord := strings.Split(strings.ToLower(recordFromSRV), ".") - if len(separatedRecord) < 2 { - return errors.New("DNS name must contain at least 2 labels") + if l := len(separatedInputDomain); l < 3 && len(separatedRecord) <= l { + return fmt.Errorf("DNS name must contain at least %d labels", l+1) } if len(separatedRecord) < len(separatedInputDomain) { return errors.New("Domain suffix from SRV record not matched input domain") } - inputDomainSuffix := separatedInputDomain[1:] - domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1) + inputDomainSuffix := separatedInputDomain + if len(inputDomainSuffix) > 2 { + inputDomainSuffix = inputDomainSuffix[1:] + } + domainSuffixOffset := len(separatedRecord) - len(inputDomainSuffix) recordDomainSuffix := separatedRecord[domainSuffixOffset:] for ix, label := range inputDomainSuffix { From 99d82efc204b1cbb29704f6444669b43a0d99214 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Thu, 5 Dec 2024 17:31:29 -0500 Subject: [PATCH 2/3] improve tests --- .../initial_dns_seedlist_discovery_prose_test.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go b/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go index a2ddf14a88..89744d91bb 100644 --- a/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go +++ b/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go @@ -7,6 +7,7 @@ package connstring import ( + "fmt" "net" "testing" @@ -78,9 +79,10 @@ func TestInitialDNSSeedlistDiscoveryProse(t *testing.T) { cases := []struct { record string uri string + labels int }{ - {"localhost", "mongodb+srv://localhost"}, - {"mongo.local", "mongodb+srv://mongo.local"}, + {"localhost", "mongodb+srv://localhost", 2}, + {"mongo.local", "mongodb+srv://mongo.local", 3}, } for _, c := range cases { c := c @@ -88,7 +90,7 @@ func TestInitialDNSSeedlistDiscoveryProse(t *testing.T) { t.Parallel() _, err := newTestParser(c.record).parse(c.uri) - assert.ErrorContains(t, err, "DNS name must contain at least") + assert.ErrorContains(t, err, fmt.Sprintf("DNS name must contain at least %d labels", c.labels)) }) } }) From deae65b5e0a02f6e116fa24389eb3d56587cd0f4 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 10 Jan 2025 18:54:05 -0500 Subject: [PATCH 3/3] update error message --- .../initial_dns_seedlist_discovery_prose_test.go | 10 +++++++--- x/mongo/driver/dns/dns.go | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go b/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go index 89744d91bb..ecd9eccc17 100644 --- a/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go +++ b/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go @@ -81,8 +81,8 @@ func TestInitialDNSSeedlistDiscoveryProse(t *testing.T) { uri string labels int }{ - {"localhost", "mongodb+srv://localhost", 2}, - {"mongo.local", "mongodb+srv://mongo.local", 3}, + {"localhost", "mongodb+srv://localhost", 1}, + {"mongo.local", "mongodb+srv://mongo.local", 2}, } for _, c := range cases { c := c @@ -90,7 +90,11 @@ func TestInitialDNSSeedlistDiscoveryProse(t *testing.T) { t.Parallel() _, err := newTestParser(c.record).parse(c.uri) - assert.ErrorContains(t, err, fmt.Sprintf("DNS name must contain at least %d labels", c.labels)) + expected := fmt.Sprintf( + "Server record (%d levels) should have more domain levels than parent URI (%d levels)", + c.labels, c.labels, + ) + assert.ErrorContains(t, err, expected) }) } }) diff --git a/x/mongo/driver/dns/dns.go b/x/mongo/driver/dns/dns.go index 4524af2794..2d599db6de 100644 --- a/x/mongo/driver/dns/dns.go +++ b/x/mongo/driver/dns/dns.go @@ -114,7 +114,7 @@ func validateSRVResult(recordFromSRV, inputHostName string) error { separatedInputDomain := strings.Split(strings.ToLower(inputHostName), ".") separatedRecord := strings.Split(strings.ToLower(recordFromSRV), ".") if l := len(separatedInputDomain); l < 3 && len(separatedRecord) <= l { - return fmt.Errorf("DNS name must contain at least %d labels", l+1) + return fmt.Errorf("Server record (%d levels) should have more domain levels than parent URI (%d levels)", l, len(separatedRecord)) } if len(separatedRecord) < len(separatedInputDomain) { return errors.New("Domain suffix from SRV record not matched input domain")