Skip to content

Use AbstractMethod in transport. #2051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions examples/package_test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,21 @@ def __init__(
async def start_run(self):
"""Call need functions to start server/client."""
if self.is_server:
return await self.transport_listen()
return await self.transport_connect()
return await self.listen()
return await self.connect()

def callback_data(self, data: bytes, addr: tuple | None = None) -> int:
"""Handle received data."""
self.stub_handle_data(self, data)
return len(data)

def callback_connected(self) -> None:
"""Call when connection is succcesfull."""

def callback_disconnected(self, exc: Exception | None) -> None:
"""Call when connection is lost."""
Log.debug("callback_disconnected called: {}", exc)

def callback_new_connection(self) -> ModbusProtocol:
"""Call when listener receive new connection request."""
new_stub = TransportStub(self.comm_params, False, self.stub_handle_data)
Expand Down Expand Up @@ -166,7 +173,7 @@ async def run(self):
"""Execute test run."""
pymodbus_apply_logging_config()
Log.debug("--> Start testing.")
await self.server.transport_listen()
await self.server.listen()
await self.stub.start_run()
await server_calls(self.stub)
Log.debug("--> Shutting down.")
Expand All @@ -193,7 +200,7 @@ async def client_calls(client):
async def server_calls(transport: ModbusProtocol):
"""Test client API."""
Log.debug("--> Server calls starting.")
_resp = transport.transport_send(b'\x00\x02\x00\x00\x00\x06\x01\x03\x00\x00\x00\x01' +
_resp = transport.send(b'\x00\x02\x00\x00\x00\x06\x01\x03\x00\x00\x00\x01' +
b'\x07\x00\x03\x00\x00\x06\x01\x03\x00\x00\x00\x01')
await asyncio.sleep(1)
print("---> all done")
Expand All @@ -206,8 +213,8 @@ def handle_client_data(transport: ModbusProtocol, data: bytes):
# Multiple send is allowed, to test fragmentation
# for data in response:
# to_send = data.to_bytes()
# transport.transport_send(to_send)
transport.transport_send(response)
# transport.send(to_send)
transport.send(response)


def handle_server_data(_transport: ModbusProtocol, data: bytes):
Expand Down
27 changes: 18 additions & 9 deletions pymodbus/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def connected(self) -> bool:
"""Return state of connection."""
return self.is_active()

async def base_connect(self) -> bool:
"""Call transport connect."""
return await super().connect()


def register(self, custom_response_class: ModbusResponse) -> None:
"""Register a custom response class with the decoder (call **sync**).

Expand All @@ -119,12 +124,12 @@ def register(self, custom_response_class: ModbusResponse) -> None:
"""
self.framer.decoder.register(custom_response_class)

def close(self, reconnect: bool = False) -> None:
def close(self, reconnect: bool = False) -> None: # type: ignore[override] # pylint: disable=arguments-differ
"""Close connection."""
if reconnect:
self.connection_lost(asyncio.TimeoutError("Server not responding"))
else:
self.transport_close()
super().close()

def idle_time(self) -> float:
"""Time before initiating next transaction (call **sync**).
Expand Down Expand Up @@ -159,7 +164,7 @@ async def async_execute(self, request) -> ModbusResponse:
while count <= self.retries:
req = self.build_response(request.transaction_id)
if not count or not self.no_resend_on_retry:
self.transport_send(packet)
self.send(packet)
if self.broadcast_enable and not request.slave_id:
resp = None
break
Expand All @@ -178,6 +183,16 @@ async def async_execute(self, request) -> ModbusResponse:

return resp # type: ignore[return-value]

def callback_new_connection(self):
"""Call when listener receive new connection request."""

def callback_connected(self) -> None:
"""Call when connection is succcesfull."""

def callback_disconnected(self, exc: Exception | None) -> None:
"""Call when connection is lost."""
Log.debug("callback_disconnected called: {}", exc)

def callback_data(self, data: bytes, addr: tuple | None = None) -> int:
"""Handle received data.

Expand Down Expand Up @@ -216,12 +231,6 @@ def build_response(self, tid):
# ----------------------------------------------------------------------- #
# Internal methods
# ----------------------------------------------------------------------- #
def send(self, request) -> int: # type: ignore [empty-body]
"""Send request.

:meta private:
"""

def recv(self, size):
"""Receive data.

Expand Down
4 changes: 2 additions & 2 deletions pymodbus/client/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ async def connect(self) -> bool:
"""Connect Async client."""
self.reset_delay()
Log.debug("Connecting to {}.", self.comm_params.host)
return await self.transport_connect()
return await self.base_connect()

def close(self, reconnect: bool = False) -> None:
def close(self, reconnect: bool = False) -> None: # type: ignore[override]
"""Close connection."""
super().close(reconnect=reconnect)

Expand Down
4 changes: 2 additions & 2 deletions pymodbus/client/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ async def connect(self) -> bool:
self.comm_params.host,
self.comm_params.port,
)
return await self.transport_connect()
return await self.base_connect()

def close(self, reconnect: bool = False) -> None:
def close(self, reconnect: bool = False) -> None: # type: ignore[override]
"""Close connection."""
super().close(reconnect=reconnect)

Expand Down
2 changes: 1 addition & 1 deletion pymodbus/client/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def connect(self) -> bool:
self.comm_params.host,
self.comm_params.port,
)
return await self.transport_connect()
return await self.base_connect()


class ModbusTlsClient(ModbusTcpClient):
Expand Down
2 changes: 1 addition & 1 deletion pymodbus/client/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def connect(self) -> bool:
self.comm_params.host,
self.comm_params.port,
)
return await self.transport_connect()
return await self.base_connect()


class ModbusUdpClient(ModbusBaseSyncClient):
Expand Down
30 changes: 23 additions & 7 deletions pymodbus/server/async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def _log_exception(self):
"Handler for stream [{}] has been canceled", self.comm_params.comm_name
)

def callback_new_connection(self) -> ModbusProtocol:
"""Call when listener receive new connection request."""
Log.debug("callback_new_connection called")
return ModbusServerRequestHandler(self)

def callback_connected(self) -> None:
"""Call when connection is succcesfull."""
try:
Expand Down Expand Up @@ -160,7 +165,7 @@ async def handle(self):
exc,
self.comm_params.comm_name,
)
self.transport_close()
self.close()
self.callback_disconnected(exc)
else:
Log.error("Unknown error occurred {}", exc)
Expand Down Expand Up @@ -209,15 +214,15 @@ def execute(self, request, *addr):
skip_encoding = False
if self.server.response_manipulator:
response, skip_encoding = self.server.response_manipulator(response)
self.send(response, *addr, skip_encoding=skip_encoding)
self.server_send(response, *addr, skip_encoding=skip_encoding)

def send(self, message, addr, **kwargs):
def server_send(self, message, addr, **kwargs):
"""Send message."""
if kwargs.get("skip_encoding", False):
self.transport_send(message, addr=addr)
self.send(message, addr=addr)
elif message.should_respond:
pdu = self.framer.buildPacket(message)
self.transport_send(pdu, addr=addr)
self.send(pdu, addr=addr)
else:
Log.debug("Skipping sending response!!")

Expand Down Expand Up @@ -286,19 +291,30 @@ async def shutdown(self):
"""Close server."""
if not self.serving.done():
self.serving.set_result(True)
self.transport_close()
self.close()

async def serve_forever(self):
"""Start endless loop."""
if self.transport:
raise RuntimeError(
"Can't call serve_forever on an already running server object"
)
await self.transport_listen()
await self.listen()
Log.info("Server listening.")
await self.serving
Log.info("Server graceful shutdown.")

def callback_connected(self) -> None:
"""Call when connection is succcesfull."""

def callback_disconnected(self, exc: Exception | None) -> None:
"""Call when connection is lost."""
Log.debug("callback_disconnected called: {}", exc)

def callback_data(self, data: bytes, addr: tuple | None = None) -> int:
"""Handle received data."""
Log.debug("callback_data called: {} addr={}", data, ":hex", addr)
return 0

class ModbusTcpServer(ModbusBaseServer):
"""A modbus threaded tcp socket server.
Expand Down
29 changes: 14 additions & 15 deletions pymodbus/transport/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import asyncio
import dataclasses
import ssl
from abc import abstractmethod
from contextlib import suppress
from enum import Enum
from functools import partial
Expand Down Expand Up @@ -235,7 +236,7 @@ def init_setup_connect_listen(self, host: str, port: int) -> None:
ssl=self.comm_params.sslctx,
)

async def transport_connect(self) -> bool:
async def connect(self) -> bool:
"""Handle generic connect and call on to specific transport connect."""
Log.debug("Connecting {}", self.comm_params.comm_name)
self.is_closing = False
Expand All @@ -249,7 +250,7 @@ async def transport_connect(self) -> bool:
return False
return bool(self.transport)

async def transport_listen(self) -> bool:
async def listen(self) -> bool:
"""Handle generic listen and call on to specific transport listen."""
Log.debug("Awaiting connections {}", self.comm_params.comm_name)
self.is_closing = False
Expand All @@ -259,7 +260,7 @@ async def transport_listen(self) -> bool:
self.transport = self.transport[0]
except OSError as exc:
Log.warning("Failed to start server {}", exc)
# self.transport_close(intern=True)
# self.close(intern=True)
return False
return True

Expand All @@ -284,7 +285,7 @@ def connection_lost(self, reason: Exception | None) -> None:
if not self.transport or self.is_closing:
return
Log.debug("Connection lost {} due to {}", self.comm_params.comm_name, reason)
self.transport_close(intern=True)
self.close(intern=True)
if (
not self.is_server
and not self.listener
Expand Down Expand Up @@ -353,28 +354,26 @@ def error_received(self, exc):
# --------- #
# callbacks #
# --------- #
@abstractmethod
def callback_new_connection(self) -> ModbusProtocol:
"""Call when listener receive new connection request."""
Log.debug("callback_new_connection called")
return ModbusProtocol(self.comm_params, False)

@abstractmethod
def callback_connected(self) -> None:
"""Call when connection is succcesfull."""
Log.debug("callback_connected called")

@abstractmethod
def callback_disconnected(self, exc: Exception | None) -> None:
"""Call when connection is lost."""
Log.debug("callback_disconnected called: {}", exc)

@abstractmethod
def callback_data(self, data: bytes, addr: tuple | None = None) -> int:
"""Handle received data."""
Log.debug("callback_data called: {} addr={}", data, ":hex", addr)
return 0

# ----------------------------------- #
# Helper methods for external classes #
# ----------------------------------- #
def transport_send(self, data: bytes, addr: tuple | None = None) -> None:
def send(self, data: bytes, addr: tuple | None = None) -> None:
"""Send request.

:param data: non-empty bytes object with data to send.
Expand All @@ -391,7 +390,7 @@ def transport_send(self, data: bytes, addr: tuple | None = None) -> None:
else:
self.transport.write(data) # type: ignore[attr-defined]

def transport_close(self, intern: bool = False, reconnect: bool = False) -> None:
def close(self, intern: bool = False, reconnect: bool = False) -> None:
"""Close connection.

:param intern: (default false), True if called internally (temporary close)
Expand All @@ -409,7 +408,7 @@ def transport_close(self, intern: bool = False, reconnect: bool = False) -> None
for _key, value in self.active_connections.items():
value.listener = None
value.callback_disconnected(None)
value.transport_close()
value.close()
self.active_connections = {}
return
if not reconnect and self.reconnect_task:
Expand Down Expand Up @@ -466,7 +465,7 @@ async def do_reconnect(self) -> None:
await asyncio.sleep(self.reconnect_delay_current)
if self.comm_params.on_reconnect_callback:
self.comm_params.on_reconnect_callback()
if await self.transport_connect():
if await self.connect():
break
self.reconnect_delay_current = min(
2 * self.reconnect_delay_current,
Expand All @@ -485,7 +484,7 @@ async def __aenter__(self) -> ModbusProtocol:

async def __aexit__(self, _class, _value, _traceback) -> None:
"""Implement the client with async exit block."""
self.transport_close()
self.close()

def __str__(self) -> str:
"""Build a string representation of the connection."""
Expand Down
21 changes: 14 additions & 7 deletions test/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,21 @@ def __init__(
async def start_run(self):
"""Call need functions to start server/client."""
if self.is_server:
return await self.transport_listen()
return await self.transport_connect()
return await self.listen()
return await self.connect()


def callback_connected(self) -> None:
"""Call when connection is succcesfull."""

def callback_disconnected(self, exc: Exception | None) -> None:
"""Call when connection is lost."""
Log.debug("callback_disconnected called: {}", exc)

def callback_data(self, data: bytes, addr: tuple | None = None) -> int:
"""Handle received data."""
if (response := self.stub_handle_data(data)):
self.transport_send(response)
self.send(response)
return len(data)

def callback_new_connection(self) -> ModbusProtocol:
Expand Down Expand Up @@ -70,8 +77,8 @@ async def test_stub(self, use_port, use_cls):
assert await client.connect()
test_data = b"Data got echoed."
client.transport.write(test_data)
client.transport_close()
stub.transport_close()
client.close()
stub.close()

async def test_double_packet(self, use_port, use_cls):
"""Test double packet on network."""
Expand Down Expand Up @@ -122,5 +129,5 @@ async def local_call(addr: int) -> bool:

assert await client.connect()
await asyncio.gather(*[local_call(x) for x in range(1, 10)])
client.transport_close()
stub.transport_close()
client.close()
stub.close()
Loading