@@ -123,89 +123,7 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
123123_IS_SYNC = False
124124
125125
126- class AsyncBaseConnection :
127- """A base connection object for server and kms connections."""
128-
129- def __init__ (self , conn : AsyncNetworkingInterface , opts : PoolOptions ):
130- self .conn = conn
131- self .socket_checker : SocketChecker = SocketChecker ()
132- self .cancel_context : _CancellationContext = _CancellationContext ()
133- self .is_sdam = False
134- self .closed = False
135- self .last_timeout : float | None = None
136- self .more_to_come = False
137- self .opts = opts
138- self .max_wire_version = - 1
139-
140- def set_conn_timeout (self , timeout : Optional [float ]) -> None :
141- """Cache last timeout to avoid duplicate calls to conn.settimeout."""
142- if timeout == self .last_timeout :
143- return
144- self .last_timeout = timeout
145- self .conn .get_conn .settimeout (timeout )
146-
147- def apply_timeout (
148- self , client : AsyncMongoClient [Any ], cmd : Optional [MutableMapping [str , Any ]]
149- ) -> Optional [float ]:
150- # CSOT: use remaining timeout when set.
151- timeout = _csot .remaining ()
152- if timeout is None :
153- # Reset the socket timeout unless we're performing a streaming monitor check.
154- if not self .more_to_come :
155- self .set_conn_timeout (self .opts .socket_timeout )
156- return None
157- # RTT validation.
158- rtt = _csot .get_rtt ()
159- if rtt is None :
160- rtt = self .connect_rtt
161- max_time_ms = timeout - rtt
162- if max_time_ms < 0 :
163- timeout_details = _get_timeout_details (self .opts )
164- formatted = format_timeout_details (timeout_details )
165- # CSOT: raise an error without running the command since we know it will time out.
166- errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
167- if self .max_wire_version != - 1 :
168- raise ExecutionTimeout (
169- errmsg ,
170- 50 ,
171- {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
172- self .max_wire_version ,
173- )
174- else :
175- raise TimeoutError (errmsg )
176- if cmd is not None :
177- cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
178- self .set_conn_timeout (timeout )
179- return timeout
180-
181- async def close_conn (self , reason : Optional [str ]) -> None :
182- """Close this connection with a reason."""
183- if self .closed :
184- return
185- await self ._close_conn ()
186-
187- async def _close_conn (self ) -> None :
188- """Close this connection."""
189- if self .closed :
190- return
191- self .closed = True
192- self .cancel_context .cancel ()
193- # Note: We catch exceptions to avoid spurious errors on interpreter
194- # shutdown.
195- try :
196- await self .conn .close ()
197- except Exception : # noqa: S110
198- pass
199-
200- def conn_closed (self ) -> bool :
201- """Return True if we know socket has been closed, False otherwise."""
202- if _IS_SYNC :
203- return self .socket_checker .socket_closed (self .conn .get_conn )
204- else :
205- return self .conn .is_closing ()
206-
207-
208- class AsyncConnection (AsyncBaseConnection ):
126+ class AsyncConnection :
209127 """Store a connection with some metadata.
210128
211129 :param conn: a raw connection object
@@ -223,27 +141,29 @@ def __init__(
223141 id : int ,
224142 is_sdam : bool ,
225143 ):
226- super ().__init__ (conn , pool .opts )
227144 self .pool_ref = weakref .ref (pool )
228- self .address : tuple [str , int ] = address
229- self .id : int = id
145+ self .conn = conn
146+ self .address = address
147+ self .id = id
230148 self .is_sdam = is_sdam
149+ self .closed = False
231150 self .last_checkin_time = time .monotonic ()
232151 self .performed_handshake = False
233152 self .is_writable : bool = False
234153 self .max_wire_version = MAX_WIRE_VERSION
235- self .max_bson_size : int = MAX_BSON_SIZE
236- self .max_message_size : int = MAX_MESSAGE_SIZE
237- self .max_write_batch_size : int = MAX_WRITE_BATCH_SIZE
154+ self .max_bson_size = MAX_BSON_SIZE
155+ self .max_message_size = MAX_MESSAGE_SIZE
156+ self .max_write_batch_size = MAX_WRITE_BATCH_SIZE
238157 self .supports_sessions = False
239158 self .hello_ok : bool = False
240- self .is_mongos : bool = False
159+ self .is_mongos = False
241160 self .op_msg_enabled = False
242161 self .listeners = pool .opts ._event_listeners
243162 self .enabled_for_cmap = pool .enabled_for_cmap
244163 self .enabled_for_logging = pool .enabled_for_logging
245164 self .compression_settings = pool .opts ._compression_settings
246165 self .compression_context : Union [SnappyContext , ZlibContext , ZstdContext , None ] = None
166+ self .socket_checker : SocketChecker = SocketChecker ()
247167 self .oidc_token_gen_id : Optional [int ] = None
248168 # Support for mechanism negotiation on the initial handshake.
249169 self .negotiated_mechs : Optional [list [str ]] = None
@@ -254,6 +174,9 @@ def __init__(
254174 self .pool_gen = pool .gen
255175 self .generation = self .pool_gen .get_overall ()
256176 self .ready = False
177+ self .cancel_context : _CancellationContext = _CancellationContext ()
178+ self .opts = pool .opts
179+ self .more_to_come : bool = False
257180 # For load balancer support.
258181 self .service_id : Optional [ObjectId ] = None
259182 self .server_connection_id : Optional [int ] = None
@@ -269,6 +192,44 @@ def __init__(
269192 # For gossiping $clusterTime from the connection handshake to the client.
270193 self ._cluster_time = None
271194
195+ def set_conn_timeout (self , timeout : Optional [float ]) -> None :
196+ """Cache last timeout to avoid duplicate calls to conn.settimeout."""
197+ if timeout == self .last_timeout :
198+ return
199+ self .last_timeout = timeout
200+ self .conn .get_conn .settimeout (timeout )
201+
202+ def apply_timeout (
203+ self , client : AsyncMongoClient [Any ], cmd : Optional [MutableMapping [str , Any ]]
204+ ) -> Optional [float ]:
205+ # CSOT: use remaining timeout when set.
206+ timeout = _csot .remaining ()
207+ if timeout is None :
208+ # Reset the socket timeout unless we're performing a streaming monitor check.
209+ if not self .more_to_come :
210+ self .set_conn_timeout (self .opts .socket_timeout )
211+ return None
212+ # RTT validation.
213+ rtt = _csot .get_rtt ()
214+ if rtt is None :
215+ rtt = self .connect_rtt
216+ max_time_ms = timeout - rtt
217+ if max_time_ms < 0 :
218+ timeout_details = _get_timeout_details (self .opts )
219+ formatted = format_timeout_details (timeout_details )
220+ # CSOT: raise an error without running the command since we know it will time out.
221+ errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
222+ raise ExecutionTimeout (
223+ errmsg ,
224+ 50 ,
225+ {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
226+ self .max_wire_version ,
227+ )
228+ if cmd is not None :
229+ cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
230+ self .set_conn_timeout (timeout )
231+ return timeout
232+
272233 def pin_txn (self ) -> None :
273234 self .pinned_txn = True
274235 assert not self .pinned_cursor
@@ -612,6 +573,26 @@ async def close_conn(self, reason: Optional[str]) -> None:
612573 error = reason ,
613574 )
614575
576+ async def _close_conn (self ) -> None :
577+ """Close this connection."""
578+ if self .closed :
579+ return
580+ self .closed = True
581+ self .cancel_context .cancel ()
582+ # Note: We catch exceptions to avoid spurious errors on interpreter
583+ # shutdown.
584+ try :
585+ await self .conn .close ()
586+ except Exception : # noqa: S110
587+ pass
588+
589+ def conn_closed (self ) -> bool :
590+ """Return True if we know socket has been closed, False otherwise."""
591+ if _IS_SYNC :
592+ return self .socket_checker .socket_closed (self .conn .get_conn )
593+ else :
594+ return self .conn .is_closing ()
595+
615596 def send_cluster_time (
616597 self ,
617598 command : MutableMapping [str , Any ],
0 commit comments