@@ -226,6 +226,7 @@ def __init__(
226226 """
227227 if family != AF_INET :
228228 raise RuntimeError ("Only AF_INET family supported by W5K modules." )
229+ self ._socket_closed = False
229230 self ._sock_type = type
230231 self ._buffer = b""
231232 self ._timeout = _default_socket_timeout
@@ -255,6 +256,17 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
255256 if time .monotonic () - stamp > 1000 :
256257 raise RuntimeError ("Failed to close socket" )
257258
259+ # This works around problems with using a class method as a decorator.
260+ def _check_socket_closed (func ): # pylint: disable=no-self-argument
261+ """Decorator to check whether the socket object has been closed."""
262+
263+ def wrapper (self , * args , ** kwargs ):
264+ if self ._socket_closed : # pylint: disable=protected-access
265+ raise RuntimeError ("The socket has been closed." )
266+ return func (self , * args , ** kwargs ) # pylint: disable=not-callable
267+
268+ return wrapper
269+
258270 @property
259271 def _status (self ) -> int :
260272 """
@@ -292,6 +304,7 @@ def _connected(self) -> bool:
292304 self .close ()
293305 return result
294306
307+ @_check_socket_closed
295308 def getpeername (self ) -> Tuple [str , int ]:
296309 """
297310 Return the remote address to which the socket is connected.
@@ -302,6 +315,7 @@ def getpeername(self) -> Tuple[str, int]:
302315 self ._socknum
303316 )
304317
318+ @_check_socket_closed
305319 def bind (self , address : Tuple [Optional [str ], int ]) -> None :
306320 """
307321 Bind the socket to address. The socket must not already be bound.
@@ -347,6 +361,7 @@ def _bind(self, address: Tuple[Optional[str], int]) -> None:
347361 )
348362 self ._buffer = b""
349363
364+ @_check_socket_closed
350365 def listen (self , backlog : int = 0 ) -> None :
351366 """
352367 Enable a server to accept connections.
@@ -358,6 +373,7 @@ def listen(self, backlog: int = 0) -> None:
358373 _the_interface .socket_listen (self ._socknum , self ._listen_port )
359374 self ._buffer = b""
360375
376+ @_check_socket_closed
361377 def accept (
362378 self ,
363379 ) -> Tuple [socket , Tuple [str , int ]]:
@@ -392,6 +408,7 @@ def accept(
392408 raise RuntimeError ("Failed to open new listening socket" )
393409 return client_sock , addr
394410
411+ @_check_socket_closed
395412 def connect (self , address : Tuple [str , int ]) -> None :
396413 """
397414 Connect to a remote socket at address.
@@ -411,6 +428,7 @@ def connect(self, address: Tuple[str, int]) -> None:
411428 raise RuntimeError ("Failed to connect to host " , address [0 ])
412429 self ._buffer = b""
413430
431+ @_check_socket_closed
414432 def send (self , data : Union [bytes , bytearray ]) -> int :
415433 """
416434 Send data to the socket. The socket must be connected to a remote socket.
@@ -426,6 +444,7 @@ def send(self, data: Union[bytes, bytearray]) -> int:
426444 gc .collect ()
427445 return bytes_sent
428446
447+ @_check_socket_closed
429448 def sendto (self , data : bytearray , * flags_and_or_address : any ) -> int :
430449 """
431450 Send data to the socket. The socket should not be connected to a remote socket, since the
@@ -449,6 +468,7 @@ def sendto(self, data: bytearray, *flags_and_or_address: any) -> int:
449468 self .connect (address )
450469 return self .send (data )
451470
471+ @_check_socket_closed
452472 def recv (
453473 # pylint: disable=too-many-branches
454474 self ,
@@ -504,6 +524,7 @@ def _embed_recv(
504524 gc .collect ()
505525 return ret
506526
527+ @_check_socket_closed
507528 def recvfrom (self , bufsize : int , flags : int = 0 ) -> Tuple [bytes , Tuple [str , int ]]:
508529 """
509530 Receive data from the socket. The return value is a pair (bytes, address) where bytes is
@@ -524,6 +545,7 @@ def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int]
524545 ),
525546 )
526547
548+ @_check_socket_closed
527549 def recv_into (self , buffer : bytearray , nbytes : int = 0 , flags : int = 0 ) -> int :
528550 """
529551 Receive up to nbytes bytes from the socket, storing the data into a buffer
@@ -542,6 +564,7 @@ def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int:
542564 buffer [:nbytes ] = bytes_received
543565 return nbytes
544566
567+ @_check_socket_closed
545568 def recvfrom_into (
546569 self , buffer : bytearray , nbytes : int = 0 , flags : int = 0
547570 ) -> Tuple [int , Tuple [str , int ]]:
@@ -600,13 +623,15 @@ def _disconnect(self) -> None:
600623 raise RuntimeError ("Socket must be a TCP socket." )
601624 _the_interface .socket_disconnect (self ._socknum )
602625
626+ @_check_socket_closed
603627 def close (self ) -> None :
604628 """
605629 Mark the socket closed. Once that happens, all future operations on the socket object
606630 will fail. The remote end will receive no more data.
607631 """
608632 _the_interface .release_socket (self ._socknum )
609633 _the_interface .socket_close (self ._socknum )
634+ self ._socket_closed = True
610635
611636 def _available (self ) -> int :
612637 """
@@ -616,6 +641,7 @@ def _available(self) -> int:
616641 """
617642 return _the_interface .socket_available (self ._socknum , self ._sock_type )
618643
644+ @_check_socket_closed
619645 def settimeout (self , value : Optional [float ]) -> None :
620646 """
621647 Set a timeout on blocking socket operations. The value argument can be a
@@ -632,6 +658,7 @@ def settimeout(self, value: Optional[float]) -> None:
632658 else :
633659 raise ValueError ("Timeout must be None, 0.0 or a positive numeric value." )
634660
661+ @_check_socket_closed
635662 def gettimeout (self ) -> Optional [float ]:
636663 """
637664 Return the timeout in seconds (float) associated with socket operations, or None if no
@@ -641,6 +668,7 @@ def gettimeout(self) -> Optional[float]:
641668 """
642669 return self ._timeout
643670
671+ @_check_socket_closed
644672 def setblocking (self , flag : bool ) -> None :
645673 """
646674 Set blocking or non-blocking mode of the socket: if flag is false, the socket is set
@@ -663,6 +691,7 @@ def setblocking(self, flag: bool) -> None:
663691 else :
664692 raise TypeError ("Flag must be a boolean." )
665693
694+ @_check_socket_closed
666695 def getblocking (self ) -> bool :
667696 """
668697 Return True if socket is in blocking mode, False if in non-blocking.
@@ -674,16 +703,19 @@ def getblocking(self) -> bool:
674703 return self .gettimeout () == 0
675704
676705 @property
706+ @_check_socket_closed
677707 def family (self ) -> int :
678708 """Socket family (always 0x03 in this implementation)."""
679709 return 3
680710
681711 @property
712+ @_check_socket_closed
682713 def type (self ):
683714 """Socket type."""
684715 return self ._sock_type
685716
686717 @property
718+ @_check_socket_closed
687719 def proto (self ):
688720 """Socket protocol (always 0x00 in this implementation)."""
689721 return 0
0 commit comments