Skip to content

Simplify transport_serial (modbus use) #1808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 35 additions & 47 deletions pymodbus/transport/transport_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,27 @@
class SerialTransport(asyncio.Transport):
"""An asyncio serial transport."""

force_poll: bool = False

def __init__(self, loop, protocol, *args, **kwargs):
"""Initialize."""
super().__init__()
self.async_loop = loop
self._protocol: asyncio.BaseProtocol = protocol
self.sync_serial = serial.serial_for_url(*args, **kwargs)
self._write_buffer = []
self._has_reader = False
self._has_writer = False
self.poll_task = None
self._poll_wait_time = 0.0005
self.sync_serial.timeout = 0
self.sync_serial.write_timeout = 0

def setup(self):
"""Prepare to read/write"""
self.async_loop.call_soon(self._protocol.connection_made, self)
if os.name == "nt":
self._has_reader = self.async_loop.call_later(
self._poll_wait_time, self._poll_read
)
if os.name == "nt" or self.force_poll:
self.poll_task = asyncio.create_task(self._polling_task())
else:
self.async_loop.add_reader(self.sync_serial.fileno(), self._read_ready)
self._has_reader = True
self.async_loop.call_soon(self._protocol.connection_made, self)

def close(self, exc=None):
"""Close the transport gracefully."""
Expand All @@ -43,13 +41,13 @@ def close(self, exc=None):
with contextlib.suppress(Exception):
self.sync_serial.flush()

if self._has_reader:
if os.name == "nt":
self._has_reader.cancel()
else:
self.async_loop.remove_reader(self.sync_serial.fileno())
self._has_reader = False
self.flush()
if self.poll_task:
self.poll_task.cancel()
_ = asyncio.ensure_future(self.poll_task)
self.poll_task = None
else:
self.async_loop.remove_reader(self.sync_serial.fileno())
self.sync_serial.close()
self.sync_serial = None
with contextlib.suppress(Exception):
Expand All @@ -58,21 +56,13 @@ def close(self, exc=None):
def write(self, data):
"""Write some data to the transport."""
self._write_buffer.append(data)
if not self._has_writer:
if os.name == "nt":
self._has_writer = self.async_loop.call_soon(self._poll_write)
else:
self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready)
self._has_writer = True
if not self.poll_task:
self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready)

def flush(self):
"""Clear output buffer and stops any more data being written"""
if self._has_writer:
if os.name == "nt":
self._has_writer.cancel()
else:
self.async_loop.remove_writer(self.sync_serial.fileno())
self._has_writer = False
if not self.poll_task:
self.async_loop.remove_writer(self.sync_serial.fileno())
self._write_buffer.clear()

# ------------------------------------------------
Expand Down Expand Up @@ -141,34 +131,32 @@ def _write_ready(self):
"""Asynchronously write buffered data."""
data = b"".join(self._write_buffer)
try:
if nlen := self.sync_serial.write(data) < len(data):
self._write_buffer = data[nlen:]
return True
if (nlen := self.sync_serial.write(data)) < len(data):
self._write_buffer = [data[nlen:]]
if not self.poll_task:
self.async_loop.add_writer(
self.sync_serial.fileno(), self._write_ready
)
return
self.flush()
except (BlockingIOError, InterruptedError):
return True
return
except serial.SerialException as exc:
self.close(exc=exc)
return False

def _poll_read(self):
if self._has_reader:
try:
self._has_reader = self.async_loop.call_later(
self._poll_wait_time, self._poll_read
)
async def _polling_task(self):
"""Poll and try to read/write."""
try:
while True:
await asyncio.sleep(self._poll_wait_time)
while self._write_buffer:
self._write_ready()
if self.sync_serial.in_waiting:
self._read_ready()
except serial.SerialException as exc:
self.close(exc=exc)

def _poll_write(self):
if not self._has_writer:
return
if self._write_ready():
self._has_writer = self.async_loop.call_later(
self._poll_wait_time, self._poll_write
)
except serial.SerialException as exc:
self.close(exc=exc)
except asyncio.CancelledError:
pass


async def create_serial_connection(loop, protocol_factory, *args, **kwargs):
Expand Down
6 changes: 5 additions & 1 deletion test/sub_transport/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ async def test_external_methods(self):
comm.close()
comm = SerialTransport(mock.MagicMock(), mock.Mock(), "dummy")
comm.abort()
assert await create_serial_connection(
transport, protocol = await create_serial_connection(
asyncio.get_running_loop(), mock.Mock, url="dummy"
)
await asyncio.sleep(0.1)
assert transport
assert protocol
transport.close()
56 changes: 56 additions & 0 deletions test/sub_transport/test_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
CommType,
ModbusProtocol,
)
from pymodbus.transport.transport_serial import SerialTransport


FACTOR = 1.2 if not pytest.IS_WINDOWS else 4.2
Expand Down Expand Up @@ -125,6 +126,61 @@ async def test_connected(self, client, server, use_comm_type):
assert not server.active_connections
server.transport_close()

def wrapped_write(self, data):
"""Wrap serial write, to split parameters."""
return self.serial_write(data[:2])

@pytest.mark.parametrize(
("use_comm_type", "use_host"),
[
(CommType.SERIAL, "socket://localhost:5020"),
],
)
async def test_split_serial_packet(self, client, server):
"""Test connection and data exchange."""
assert await server.transport_listen()
assert await client.transport_connect()
await asyncio.sleep(0.5)
assert len(server.active_connections) == 1
server_connected = list(server.active_connections.values())[0]
test_data = b"abcd"

self.serial_write = ( # pylint: disable=attribute-defined-outside-init
client.transport.sync_serial.write
)
with mock.patch.object(
client.transport.sync_serial, "write", wraps=self.wrapped_write
):
client.transport_send(test_data)
await asyncio.sleep(0.5)
assert server_connected.recv_buffer == test_data
assert not client.recv_buffer
client.transport_close()
server.transport_close()

@pytest.mark.parametrize(
("use_comm_type", "use_host"),
[
(CommType.SERIAL, "socket://localhost:5020"),
],
)
async def test_serial_poll(self, client, server):
"""Test connection and data exchange."""
assert await server.transport_listen()
SerialTransport.force_poll = True
assert await client.transport_connect()
await asyncio.sleep(0.5)
SerialTransport.force_poll = False
assert len(server.active_connections) == 1
server_connected = list(server.active_connections.values())[0]
test_data = b"abcd" * 1000
client.transport_send(test_data)
await asyncio.sleep(0.5)
assert server_connected.recv_buffer == test_data
assert not client.recv_buffer
client.transport_close()
server.transport_close()

@pytest.mark.parametrize(
("use_comm_type", "use_host"),
[
Expand Down