@@ -11,8 +11,12 @@ package topology
1111
1212import (
1313 "context"
14+ "crypto/tls"
15+ "crypto/x509"
1416 "errors"
17+ "io/ioutil"
1518 "net"
19+ "os"
1620 "runtime"
1721 "sync"
1822 "sync/atomic"
@@ -49,6 +53,144 @@ func (cncd *channelNetConnDialer) DialContext(_ context.Context, _, _ string) (n
4953 return cnc , nil
5054}
5155
56+ type errorQueue struct {
57+ errors []error
58+ mutex sync.Mutex
59+ }
60+
61+ func (eq * errorQueue ) head () error {
62+ eq .mutex .Lock ()
63+ defer eq .mutex .Unlock ()
64+ if len (eq .errors ) > 0 {
65+ return eq .errors [0 ]
66+ }
67+ return nil
68+ }
69+
70+ func (eq * errorQueue ) dequeue () bool {
71+ eq .mutex .Lock ()
72+ defer eq .mutex .Unlock ()
73+ if len (eq .errors ) > 0 {
74+ eq .errors = eq .errors [1 :]
75+ return true
76+ }
77+ return false
78+ }
79+
80+ type timeoutConn struct {
81+ net.Conn
82+ errors * errorQueue
83+ }
84+
85+ func (c * timeoutConn ) Read (b []byte ) (int , error ) {
86+ n , err := 0 , c .errors .head ()
87+ if err == nil {
88+ n , err = c .Conn .Read (b )
89+ }
90+ return n , err
91+ }
92+
93+ func (c * timeoutConn ) Write (b []byte ) (int , error ) {
94+ n , err := 0 , c .errors .head ()
95+ if err == nil {
96+ n , err = c .Conn .Write (b )
97+ }
98+ return n , err
99+ }
100+
101+ type timeoutDialer struct {
102+ Dialer
103+ errors * errorQueue
104+ }
105+
106+ func (d * timeoutDialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
107+ c , e := d .Dialer .DialContext (ctx , network , address )
108+
109+ if caFile := os .Getenv ("MONGO_GO_DRIVER_CA_FILE" ); len (caFile ) > 0 {
110+ pem , err := ioutil .ReadFile (caFile )
111+ if err != nil {
112+ return nil , err
113+ }
114+
115+ ca := x509 .NewCertPool ()
116+ if ! ca .AppendCertsFromPEM (pem ) {
117+ return nil , errors .New ("unable to load CA file" )
118+ }
119+
120+ config := & tls.Config {
121+ InsecureSkipVerify : true ,
122+ RootCAs : ca ,
123+ }
124+ c = tls .Client (c , config )
125+ }
126+ return & timeoutConn {c , d .errors }, e
127+ }
128+
129+ // TestServerHeartbeatTimeout tests timeout retry for GODRIVER-2577.
130+ func TestServerHeartbeatTimeout (t * testing.T ) {
131+ networkTimeoutError := & net.DNSError {
132+ IsTimeout : true ,
133+ }
134+
135+ testCases := []struct {
136+ desc string
137+ ioErrors []error
138+ expectPoolCleared bool
139+ }{
140+ {
141+ desc : "one single timeout should not clear the pool" ,
142+ ioErrors : []error {nil , networkTimeoutError , nil , networkTimeoutError , nil },
143+ expectPoolCleared : false ,
144+ },
145+ {
146+ desc : "continuous timeouts should clear the pool" ,
147+ ioErrors : []error {nil , networkTimeoutError , networkTimeoutError , nil },
148+ expectPoolCleared : true ,
149+ },
150+ }
151+ for _ , tc := range testCases {
152+ tc := tc
153+ t .Run (tc .desc , func (t * testing.T ) {
154+ t .Parallel ()
155+
156+ var wg sync.WaitGroup
157+ wg .Add (1 )
158+
159+ errors := & errorQueue {errors : tc .ioErrors }
160+ tpm := monitor .NewTestPoolMonitor ()
161+ server := NewServer (
162+ address .Address ("localhost:27017" ),
163+ primitive .NewObjectID (),
164+ WithConnectionPoolMonitor (func (* event.PoolMonitor ) * event.PoolMonitor {
165+ return tpm .PoolMonitor
166+ }),
167+ WithConnectionOptions (func (opts ... ConnectionOption ) []ConnectionOption {
168+ return append (opts ,
169+ WithDialer (func (d Dialer ) Dialer {
170+ var dialer net.Dialer
171+ return & timeoutDialer {& dialer , errors }
172+ }))
173+ }),
174+ WithServerMonitor (func (* event.ServerMonitor ) * event.ServerMonitor {
175+ return & event.ServerMonitor {
176+ ServerHeartbeatStarted : func (e * event.ServerHeartbeatStartedEvent ) {
177+ if ! errors .dequeue () {
178+ wg .Done ()
179+ }
180+ },
181+ }
182+ }),
183+ WithHeartbeatInterval (func (time.Duration ) time.Duration {
184+ return 200 * time .Millisecond
185+ }),
186+ )
187+ require .NoError (t , server .Connect (nil ))
188+ wg .Wait ()
189+ assert .Equal (t , tc .expectPoolCleared , tpm .IsPoolCleared (), "expected pool cleared to be %v but was %v" , tc .expectPoolCleared , tpm .IsPoolCleared ())
190+ })
191+ }
192+ }
193+
52194// TestServerConnectionTimeout tests how different timeout errors are handled during connection
53195// creation and server handshake.
54196func TestServerConnectionTimeout (t * testing.T ) {
0 commit comments