@@ -226,6 +226,7 @@ def __init__(
226226 """
227227 if family != AF_INET :
228228 raise RuntimeError ("Only AF_INET family supported by W5K modules." )
229+ self ._socket_closed = False
229230 self ._sock_type = type
230231 self ._buffer = b""
231232 self ._timeout = _default_socket_timeout
@@ -251,6 +252,17 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
251252 if time .monotonic () - stamp > 1000 :
252253 raise RuntimeError ("Failed to close socket" )
253254
255+ # This works around problems with using a class method as a decorator.
256+ def _check_socket_closed (func ): # pylint: disable=no-self-argument
257+ """Decorator to check whether the socket object has been closed."""
258+
259+ def wrapper (self , * args , ** kwargs ):
260+ if self ._socket_closed : # pylint: disable=protected-access
261+ raise RuntimeError ("The socket has been closed." )
262+ return func (self , * args , ** kwargs ) # pylint: disable=not-callable
263+
264+ return wrapper
265+
254266 @property
255267 def _status (self ) -> int :
256268 """
@@ -288,6 +300,7 @@ def _connected(self) -> bool:
288300 self .close ()
289301 return result
290302
303+ @_check_socket_closed
291304 def getpeername (self ) -> Tuple [str , int ]:
292305 """
293306 Return the remote address to which the socket is connected.
@@ -298,6 +311,7 @@ def getpeername(self) -> Tuple[str, int]:
298311 self ._socknum
299312 )
300313
314+ @_check_socket_closed
301315 def bind (self , address : Tuple [Optional [str ], int ]) -> None :
302316 """
303317 Bind the socket to address. The socket must not already be bound.
@@ -343,6 +357,7 @@ def _bind(self, address: Tuple[Optional[str], int]) -> None:
343357 )
344358 self ._buffer = b""
345359
360+ @_check_socket_closed
346361 def listen (self , backlog : int = 0 ) -> None :
347362 """
348363 Enable a server to accept connections.
@@ -354,6 +369,7 @@ def listen(self, backlog: int = 0) -> None:
354369 _the_interface .socket_listen (self ._socknum , self ._listen_port )
355370 self ._buffer = b""
356371
372+ @_check_socket_closed
357373 def accept (
358374 self ,
359375 ) -> Tuple [socket , Tuple [str , int ]]:
@@ -388,6 +404,7 @@ def accept(
388404 raise RuntimeError ("Failed to open new listening socket" )
389405 return client_sock , addr
390406
407+ @_check_socket_closed
391408 def connect (self , address : Tuple [str , int ]) -> None :
392409 """
393410 Connect to a remote socket at address.
@@ -407,6 +424,7 @@ def connect(self, address: Tuple[str, int]) -> None:
407424 raise RuntimeError ("Failed to connect to host " , address [0 ])
408425 self ._buffer = b""
409426
427+ @_check_socket_closed
410428 def send (self , data : Union [bytes , bytearray ]) -> int :
411429 """
412430 Send data to the socket. The socket must be connected to a remote socket.
@@ -422,6 +440,7 @@ def send(self, data: Union[bytes, bytearray]) -> int:
422440 gc .collect ()
423441 return bytes_sent
424442
443+ @_check_socket_closed
425444 def sendto (self , data : bytearray , * flags_and_or_address : any ) -> int :
426445 """
427446 Send data to the socket. The socket should not be connected to a remote socket, since the
@@ -445,6 +464,7 @@ def sendto(self, data: bytearray, *flags_and_or_address: any) -> int:
445464 self .connect (address )
446465 return self .send (data )
447466
467+ @_check_socket_closed
448468 def recv (
449469 # pylint: disable=too-many-branches
450470 self ,
@@ -500,6 +520,7 @@ def _embed_recv(
500520 gc .collect ()
501521 return ret
502522
523+ @_check_socket_closed
503524 def recvfrom (self , bufsize : int , flags : int = 0 ) -> Tuple [bytes , Tuple [str , int ]]:
504525 """
505526 Receive data from the socket. The return value is a pair (bytes, address) where bytes is
@@ -520,6 +541,7 @@ def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int]
520541 ),
521542 )
522543
544+ @_check_socket_closed
523545 def recv_into (self , buffer : bytearray , nbytes : int = 0 , flags : int = 0 ) -> int :
524546 """
525547 Receive up to nbytes bytes from the socket, storing the data into a buffer
@@ -538,6 +560,7 @@ def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int:
538560 buffer [:nbytes ] = bytes_received
539561 return nbytes
540562
563+ @_check_socket_closed
541564 def recvfrom_into (
542565 self , buffer : bytearray , nbytes : int = 0 , flags : int = 0
543566 ) -> Tuple [int , Tuple [str , int ]]:
@@ -596,11 +619,13 @@ def _disconnect(self) -> None:
596619 raise RuntimeError ("Socket must be a TCP socket." )
597620 _the_interface .socket_disconnect (self ._socknum )
598621
622+ @_check_socket_closed
599623 def close (self ) -> None :
600624 """
601625 Mark the socket closed. Once that happens, all future operations on the socket object
602626 will fail. The remote end will receive no more data.
603627 """
628+ self ._socket_closed = True
604629 _the_interface .socket_close (self ._socknum )
605630
606631 def _available (self ) -> int :
@@ -611,6 +636,7 @@ def _available(self) -> int:
611636 """
612637 return _the_interface .socket_available (self ._socknum , self ._sock_type )
613638
639+ @_check_socket_closed
614640 def settimeout (self , value : Optional [float ]) -> None :
615641 """
616642 Set a timeout on blocking socket operations. The value argument can be a
@@ -627,6 +653,7 @@ def settimeout(self, value: Optional[float]) -> None:
627653 else :
628654 raise ValueError ("Timeout must be None, 0.0 or a positive numeric value." )
629655
656+ @_check_socket_closed
630657 def gettimeout (self ) -> Optional [float ]:
631658 """
632659 Return the timeout in seconds (float) associated with socket operations, or None if no
@@ -636,6 +663,7 @@ def gettimeout(self) -> Optional[float]:
636663 """
637664 return self ._timeout
638665
666+ @_check_socket_closed
639667 def setblocking (self , flag : bool ) -> None :
640668 """
641669 Set blocking or non-blocking mode of the socket: if flag is false, the socket is set
@@ -658,6 +686,7 @@ def setblocking(self, flag: bool) -> None:
658686 else :
659687 raise TypeError ("Flag must be a boolean." )
660688
689+ @_check_socket_closed
661690 def getblocking (self ) -> bool :
662691 """
663692 Return True if socket is in blocking mode, False if in non-blocking.
@@ -669,16 +698,19 @@ def getblocking(self) -> bool:
669698 return self .gettimeout () == 0
670699
671700 @property
701+ @_check_socket_closed
672702 def family (self ) -> int :
673703 """Socket family (always 0x03 in this implementation)."""
674704 return 3
675705
676706 @property
707+ @_check_socket_closed
677708 def type (self ):
678709 """Socket type."""
679710 return self ._sock_type
680711
681712 @property
713+ @_check_socket_closed
682714 def proto (self ):
683715 """Socket protocol (always 0x00 in this implementation)."""
684716 return 0
0 commit comments