diff --git a/adafruit_wiznet5k/adafruit_wiznet5k_socket.py b/adafruit_wiznet5k/adafruit_wiznet5k_socket.py index 0378d03..2059fa9 100644 --- a/adafruit_wiznet5k/adafruit_wiznet5k_socket.py +++ b/adafruit_wiznet5k/adafruit_wiznet5k_socket.py @@ -226,6 +226,7 @@ def __init__( """ if family != AF_INET: raise RuntimeError("Only AF_INET family supported by W5K modules.") + self._socket_closed = False self._sock_type = type self._buffer = b"" self._timeout = _default_socket_timeout @@ -251,6 +252,17 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: if time.monotonic() - stamp > 1000: raise RuntimeError("Failed to close socket") + # This works around problems with using a class method as a decorator. + def _check_socket_closed(func): # pylint: disable=no-self-argument + """Decorator to check whether the socket object has been closed.""" + + def wrapper(self, *args, **kwargs): + if self._socket_closed: # pylint: disable=protected-access + raise RuntimeError("The socket has been closed.") + return func(self, *args, **kwargs) # pylint: disable=not-callable + + return wrapper + @property def _status(self) -> int: """ @@ -288,6 +300,7 @@ def _connected(self) -> bool: self.close() return result + @_check_socket_closed def getpeername(self) -> Tuple[str, int]: """ Return the remote address to which the socket is connected. @@ -298,6 +311,7 @@ def getpeername(self) -> Tuple[str, int]: self._socknum ) + @_check_socket_closed def bind(self, address: Tuple[Optional[str], int]) -> None: """ Bind the socket to address. The socket must not already be bound. @@ -343,6 +357,7 @@ def _bind(self, address: Tuple[Optional[str], int]) -> None: ) self._buffer = b"" + @_check_socket_closed def listen(self, backlog: int = 0) -> None: """ Enable a server to accept connections. @@ -354,6 +369,7 @@ def listen(self, backlog: int = 0) -> None: _the_interface.socket_listen(self._socknum, self._listen_port) self._buffer = b"" + @_check_socket_closed def accept( self, ) -> Tuple[socket, Tuple[str, int]]: @@ -388,6 +404,7 @@ def accept( raise RuntimeError("Failed to open new listening socket") return client_sock, addr + @_check_socket_closed def connect(self, address: Tuple[str, int]) -> None: """ Connect to a remote socket at address. @@ -407,6 +424,7 @@ def connect(self, address: Tuple[str, int]) -> None: raise RuntimeError("Failed to connect to host ", address[0]) self._buffer = b"" + @_check_socket_closed def send(self, data: Union[bytes, bytearray]) -> int: """ Send data to the socket. The socket must be connected to a remote socket. @@ -422,6 +440,7 @@ def send(self, data: Union[bytes, bytearray]) -> int: gc.collect() return bytes_sent + @_check_socket_closed def sendto(self, data: bytearray, *flags_and_or_address: any) -> int: """ Send data to the socket. The socket should not be connected to a remote socket, since the @@ -445,6 +464,7 @@ def sendto(self, data: bytearray, *flags_and_or_address: any) -> int: self.connect(address) return self.send(data) + @_check_socket_closed def recv( # pylint: disable=too-many-branches self, @@ -500,6 +520,7 @@ def _embed_recv( gc.collect() return ret + @_check_socket_closed def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int]]: """ Receive data from the socket. The return value is a pair (bytes, address) where bytes is @@ -520,6 +541,7 @@ def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int] ), ) + @_check_socket_closed def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int: """ Receive up to nbytes bytes from the socket, storing the data into a buffer @@ -538,6 +560,7 @@ def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int: buffer[:nbytes] = bytes_received return nbytes + @_check_socket_closed def recvfrom_into( self, buffer: bytearray, nbytes: int = 0, flags: int = 0 ) -> Tuple[int, Tuple[str, int]]: @@ -596,11 +619,13 @@ def _disconnect(self) -> None: raise RuntimeError("Socket must be a TCP socket.") _the_interface.socket_disconnect(self._socknum) + @_check_socket_closed def close(self) -> None: """ Mark the socket closed. Once that happens, all future operations on the socket object will fail. The remote end will receive no more data. """ + self._socket_closed = True _the_interface.socket_close(self._socknum) def _available(self) -> int: @@ -611,6 +636,7 @@ def _available(self) -> int: """ return _the_interface.socket_available(self._socknum, self._sock_type) + @_check_socket_closed def settimeout(self, value: Optional[float]) -> None: """ Set a timeout on blocking socket operations. The value argument can be a @@ -627,6 +653,7 @@ def settimeout(self, value: Optional[float]) -> None: else: raise ValueError("Timeout must be None, 0.0 or a positive numeric value.") + @_check_socket_closed def gettimeout(self) -> Optional[float]: """ Return the timeout in seconds (float) associated with socket operations, or None if no @@ -636,6 +663,7 @@ def gettimeout(self) -> Optional[float]: """ return self._timeout + @_check_socket_closed def setblocking(self, flag: bool) -> None: """ Set blocking or non-blocking mode of the socket: if flag is false, the socket is set @@ -658,6 +686,7 @@ def setblocking(self, flag: bool) -> None: else: raise TypeError("Flag must be a boolean.") + @_check_socket_closed def getblocking(self) -> bool: """ Return True if socket is in blocking mode, False if in non-blocking. @@ -669,16 +698,19 @@ def getblocking(self) -> bool: return self.gettimeout() == 0 @property + @_check_socket_closed def family(self) -> int: """Socket family (always 0x03 in this implementation).""" return 3 @property + @_check_socket_closed def type(self): """Socket type.""" return self._sock_type @property + @_check_socket_closed def proto(self): """Socket protocol (always 0x00 in this implementation).""" return 0