3535
3636
3737if not sys .implementation .name == "circuitpython" :
38- from typing import Optional , Tuple
38+ from typing import List , Optional , Tuple
3939
4040 from circuitpython_typing .socket import (
4141 CircuitPythonSocketType ,
@@ -68,15 +68,14 @@ def connect(self, address: Tuple[str, int]) -> None:
6868 try :
6969 return self ._socket .connect (address , self ._mode )
7070 except RuntimeError as error :
71- raise OSError (errno .ENOMEM ) from error
71+ raise OSError (errno .ENOMEM , str ( error ) ) from error
7272
7373
7474class _FakeSSLContext :
7575 def __init__ (self , iface : InterfaceType ) -> None :
7676 self ._iface = iface
7777
78- # pylint: disable=unused-argument
79- def wrap_socket (
78+ def wrap_socket ( # pylint: disable=unused-argument
8079 self , socket : CircuitPythonSocketType , server_hostname : Optional [str ] = None
8180 ) -> _FakeSSLSocket :
8281 """Return the same socket"""
@@ -106,7 +105,8 @@ def create_fake_ssl_context(
106105 return _FakeSSLContext (iface )
107106
108107
109- _global_socketpool = {}
108+ _global_connection_managers = {}
109+ _global_socketpools = {}
110110_global_ssl_contexts = {}
111111
112112
@@ -127,7 +127,7 @@ def get_radio_socketpool(radio):
127127 * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
128128 """
129129 key = _get_radio_hash_key (radio )
130- if key not in _global_socketpool :
130+ if key not in _global_socketpools :
131131 class_name = radio .__class__ .__name__
132132 if class_name == "Radio" :
133133 import ssl # pylint: disable=import-outside-toplevel
@@ -168,10 +168,10 @@ def get_radio_socketpool(radio):
168168 else :
169169 raise AttributeError (f"Unsupported radio class: { class_name } " )
170170
171- _global_socketpool [key ] = pool
171+ _global_socketpools [key ] = pool
172172 _global_ssl_contexts [key ] = ssl_context
173173
174- return _global_socketpool [key ]
174+ return _global_socketpools [key ]
175175
176176
177177def get_radio_ssl_context (radio ):
@@ -199,42 +199,75 @@ def __init__(
199199 ) -> None :
200200 self ._socket_pool = socket_pool
201201 # Hang onto open sockets so that we can reuse them.
202- self ._available_socket = {}
203- self ._open_sockets = {}
204-
205- def _free_sockets (self ) -> None :
206- available_sockets = []
207- for socket , free in self ._available_socket .items ():
208- if free :
209- available_sockets .append (socket )
202+ self ._available_sockets = set ()
203+ self ._key_by_managed_socket = {}
204+ self ._managed_socket_by_key = {}
210205
206+ def _free_sockets (self , force : bool = False ) -> None :
207+ # cloning lists since items are being removed
208+ available_sockets = list (self ._available_sockets )
211209 for socket in available_sockets :
212210 self .close_socket (socket )
211+ if force :
212+ open_sockets = list (self ._managed_socket_by_key .values ())
213+ for socket in open_sockets :
214+ self .close_socket (socket )
213215
214- def _get_key_for_socket (self , socket ):
216+ def _get_connected_socket ( # pylint: disable=too-many-arguments
217+ self ,
218+ addr_info : List [Tuple [int , int , int , str , Tuple [str , int ]]],
219+ host : str ,
220+ port : int ,
221+ timeout : float ,
222+ is_ssl : bool ,
223+ ssl_context : Optional [SSLContextType ] = None ,
224+ ):
215225 try :
216- return next (
217- key for key , value in self ._open_sockets .items () if value == socket
218- )
219- except StopIteration :
220- return None
226+ socket = self ._socket_pool .socket (addr_info [0 ], addr_info [1 ])
227+ except (OSError , RuntimeError ) as exc :
228+ return exc
229+
230+ if is_ssl :
231+ socket = ssl_context .wrap_socket (socket , server_hostname = host )
232+ connect_host = host
233+ else :
234+ connect_host = addr_info [- 1 ][0 ]
235+ socket .settimeout (timeout ) # socket read timeout
236+
237+ try :
238+ socket .connect ((connect_host , port ))
239+ except (MemoryError , OSError ) as exc :
240+ socket .close ()
241+ return exc
242+
243+ return socket
244+
245+ @property
246+ def available_socket_count (self ) -> int :
247+ """Get the count of freeable open sockets"""
248+ return len (self ._available_sockets )
249+
250+ @property
251+ def managed_socket_count (self ) -> int :
252+ """Get the count of open sockets"""
253+ return len (self ._managed_socket_by_key )
221254
222255 def close_socket (self , socket : SocketType ) -> None :
223256 """Close a previously opened socket."""
224- if socket not in self ._open_sockets .values ():
257+ if socket not in self ._managed_socket_by_key .values ():
225258 raise RuntimeError ("Socket not managed" )
226- key = self ._get_key_for_socket (socket )
227259 socket .close ()
228- del self ._available_socket [socket ]
229- del self ._open_sockets [key ]
260+ key = self ._key_by_managed_socket .pop (socket )
261+ del self ._managed_socket_by_key [key ]
262+ if socket in self ._available_sockets :
263+ self ._available_sockets .remove (socket )
230264
231265 def free_socket (self , socket : SocketType ) -> None :
232266 """Mark a previously opened socket as available so it can be reused if needed."""
233- if socket not in self ._open_sockets .values ():
267+ if socket not in self ._managed_socket_by_key .values ():
234268 raise RuntimeError ("Socket not managed" )
235- self ._available_socket [ socket ] = True
269+ self ._available_sockets . add ( socket )
236270
237- # pylint: disable=too-many-branches,too-many-locals,too-many-statements
238271 def get_socket (
239272 self ,
240273 host : str ,
@@ -250,10 +283,10 @@ def get_socket(
250283 if session_id :
251284 session_id = str (session_id )
252285 key = (host , port , proto , session_id )
253- if key in self ._open_sockets :
254- socket = self ._open_sockets [key ]
255- if self ._available_socket [ socket ] :
256- self ._available_socket [ socket ] = False
286+ if key in self ._managed_socket_by_key :
287+ socket = self ._managed_socket_by_key [key ]
288+ if socket in self ._available_sockets :
289+ self ._available_sockets . remove ( socket )
257290 return socket
258291
259292 raise RuntimeError (f"Socket already connected to { proto } //{ host } :{ port } " )
@@ -269,64 +302,68 @@ def get_socket(
269302 host , port , 0 , self ._socket_pool .SOCK_STREAM
270303 )[0 ]
271304
272- try_count = 0
273- socket = None
274- last_exc = None
275- while try_count < 2 and socket is None :
276- try_count += 1
277- if try_count > 1 :
278- if any (
279- socket
280- for socket , free in self ._available_socket .items ()
281- if free is True
282- ):
283- self ._free_sockets ()
284- else :
285- break
286-
287- try :
288- socket = self ._socket_pool .socket (addr_info [0 ], addr_info [1 ])
289- except OSError as exc :
290- last_exc = exc
291- continue
292- except RuntimeError as exc :
293- last_exc = exc
294- continue
295-
296- if is_ssl :
297- socket = ssl_context .wrap_socket (socket , server_hostname = host )
298- connect_host = host
299- else :
300- connect_host = addr_info [- 1 ][0 ]
301- socket .settimeout (timeout ) # socket read timeout
302-
303- try :
304- socket .connect ((connect_host , port ))
305- except MemoryError as exc :
306- last_exc = exc
307- socket .close ()
308- socket = None
309- except OSError as exc :
310- last_exc = exc
311- socket .close ()
312- socket = None
313-
314- if socket is None :
315- raise RuntimeError (f"Error connecting socket: { last_exc } " ) from last_exc
316-
317- self ._available_socket [socket ] = False
318- self ._open_sockets [key ] = socket
319- return socket
305+ first_exception = None
306+ result = self ._get_connected_socket (
307+ addr_info , host , port , timeout , is_ssl , ssl_context
308+ )
309+ if isinstance (result , Exception ):
310+ # Got an error, if there are any available sockets, free them and try again
311+ if self .available_socket_count :
312+ first_exception = result
313+ self ._free_sockets ()
314+ result = self ._get_connected_socket (
315+ addr_info , host , port , timeout , is_ssl , ssl_context
316+ )
317+ if isinstance (result , Exception ):
318+ last_result = f", first error: { first_exception } " if first_exception else ""
319+ raise RuntimeError (
320+ f"Error connecting socket: { result } { last_result } "
321+ ) from result
322+
323+ self ._key_by_managed_socket [result ] = key
324+ self ._managed_socket_by_key [key ] = result
325+ return result
320326
321327
322328# global helpers
323329
324330
325- _global_connection_manager = {}
331+ def connection_manager_close_all (
332+ socket_pool : Optional [SocketpoolModuleType ] = None , release_references : bool = False
333+ ) -> None :
334+ """Close all open sockets for pool"""
335+ if socket_pool :
336+ socket_pools = [socket_pool ]
337+ else :
338+ socket_pools = _global_connection_managers .keys ()
339+
340+ for pool in socket_pools :
341+ connection_manager = _global_connection_managers .get (pool , None )
342+ if connection_manager is None :
343+ raise RuntimeError ("SocketPool not managed" )
344+
345+ connection_manager ._free_sockets (force = True ) # pylint: disable=protected-access
346+
347+ if release_references :
348+ radio_key = None
349+ for radio_check , pool_check in _global_socketpools .items ():
350+ if pool == pool_check :
351+ radio_key = radio_check
352+ break
353+
354+ if radio_key :
355+ if radio_key in _global_socketpools :
356+ del _global_socketpools [radio_key ]
357+
358+ if radio_key in _global_ssl_contexts :
359+ del _global_ssl_contexts [radio_key ]
360+
361+ if pool in _global_connection_managers :
362+ del _global_connection_managers [pool ]
326363
327364
328365def get_connection_manager (socket_pool : SocketpoolModuleType ) -> ConnectionManager :
329366 """Get the ConnectionManager singleton for the given pool"""
330- if socket_pool not in _global_connection_manager :
331- _global_connection_manager [socket_pool ] = ConnectionManager (socket_pool )
332- return _global_connection_manager [socket_pool ]
367+ if socket_pool not in _global_connection_managers :
368+ _global_connection_managers [socket_pool ] = ConnectionManager (socket_pool )
369+ return _global_connection_managers [socket_pool ]
0 commit comments