diff --git a/CREDITS.rst b/CREDITS.rst index a861fb2..29e115f 100644 --- a/CREDITS.rst +++ b/CREDITS.rst @@ -16,6 +16,7 @@ Contributors - Nicolas Di Pietro - jabdoa2 - Chris Seymour +- Yuval Mantin - ... not all names may be listed here, see also ``git log`` or online history_ diff --git a/serial_asyncio/__init__.py b/serial_asyncio/__init__.py index e173b56..be0b53a 100644 --- a/serial_asyncio/__init__.py +++ b/serial_asyncio/__init__.py @@ -21,13 +21,14 @@ import serial from functools import partial +import time try: import termios except ImportError: termios = None -__version__ = '0.6' +__version__ = '0.7' class SerialTransport(asyncio.Transport): @@ -527,6 +528,29 @@ async def open_serial_connection(*, return reader, writer +async def read_with_timeout( + reader: asyncio.StreamReader, + n: int, + timeout: float +) -> bytes: + """A wrapper for the StreamReader.read method that adds a timeout support. + It returns the bytes read during the given timeout or until n bytes reached. + + reader is a StreamReader item, n is the amount of bytes, timeout is the + max time that the reading can take. + + The idea is to read 1 byte with a timeout every time, until the timeout is + reached or it read n bytes. In any case it returns what it had read. + """ + start_t = time.time() + data = b'' + while time.time() - start_t < timeout and len(data) < n: + try: + data += await asyncio.wait_for(reader.read(1), timeout=timeout) + except TimeoutError: + break + return data + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # test if __name__ == '__main__': diff --git a/test/test_read_with_timeout.py b/test/test_read_with_timeout.py new file mode 100644 index 0000000..412bdac --- /dev/null +++ b/test/test_read_with_timeout.py @@ -0,0 +1,53 @@ +import unittest +import asyncio +import time + +from serial_asyncio import read_with_timeout + +class DummyStreamReader: + def __init__(self, data: bytes, delay: float = 0): + self._data = data + self._delay = delay + self._index = 0 + + async def read(self, n: int): + if self._index >= len(self._data): + await asyncio.sleep(self._delay) + return b'' + await asyncio.sleep(self._delay) + chunk = self._data[self._index:self._index + n] + self._index += n + return chunk + + +class TestReadWithTimeout(unittest.IsolatedAsyncioTestCase): + async def test_reads_all_bytes_before_timeout(self): + reader = DummyStreamReader(b'abcdef', delay=0) + result = await read_with_timeout(reader, 6, 0.1) + self.assertEqual(result, b'abcdef') + + async def test_reads_partial_bytes_due_to_timeout(self): + reader = DummyStreamReader(b'abcdef', delay=0.01) + start = time.time() + result = await read_with_timeout(reader, 6, 0.03) + elapsed = time.time() - start + self.assertTrue(len(result) < 6) + self.assertLessEqual(elapsed, 0.1) + + async def test_returns_empty_if_timeout_immediate(self): + reader = DummyStreamReader(b'abcdef', delay=0) + result = await read_with_timeout(reader, 6, 0) + self.assertEqual(result, b'') + + async def test_reads_until_n_bytes(self): + reader = DummyStreamReader(b'abcdef', delay=0) + result = await read_with_timeout(reader, 3, 0.1) + self.assertEqual(result, b'abc') + + async def test_handles_stream_end(self): + reader = DummyStreamReader(b'', delay=0) + result = await read_with_timeout(reader, 5, 0.1) + self.assertEqual(result, b'') + +if __name__ == "__main__": + unittest.main() \ No newline at end of file