Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 70 additions & 41 deletions packages/stompman/stompman/serde.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import struct
from collections import deque
from collections.abc import Iterator
from contextlib import suppress
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any, Final, cast

from stompman.frames import (
Expand Down Expand Up @@ -141,53 +140,83 @@ def make_frame_from_parts(*, command: bytes, headers: dict[str, str], body: byte
return frame_type(headers=headers_, body=body) if frame_type in FRAMES_WITH_BODY else frame_type(headers=headers_) # type: ignore[call-arg]


def parse_lines_into_frame(lines: deque[bytearray]) -> AnyClientFrame | AnyServerFrame:
command = bytes(lines.popleft())
headers = {}

while line := lines.popleft():
header = parse_header(line)
if header and header[0] not in headers:
headers[header[0]] = header[1]
body = bytes(lines.popleft()) if lines else b""
return make_frame_from_parts(command=command, headers=headers, body=body)


@dataclass(kw_only=True, slots=True)
@dataclass(kw_only=True, slots=True, init=False)
class FrameParser:
_lines: deque[bytearray] = field(default_factory=deque, init=False)
_current_line: bytearray = field(default_factory=bytearray, init=False)
_previous_byte: bytes = field(default=b"", init=False)
_headers_processed: bool = field(default=False, init=False)
_current_buf: bytearray
_previous_byte: bytes | None
_headers_processed: bool
_command: bytes | None
_headers: dict[str, str]
_content_length: int | None

def __init__(self) -> None:
self._previous_byte = None
self._reset()

def _reset(self) -> None:
self._current_buf = bytearray()
self._headers_processed = False
self._lines.clear()
self._current_line = bytearray()
self._command = None
self._headers = {}
self._content_length = None

def _handle_null_byte(self) -> Iterator[AnyClientFrame | AnyServerFrame]:
if not self._command or not self._headers_processed:
self._reset()
return
if self._content_length is not None and self._content_length != len(self._current_buf):
self._current_buf += NULL
return
yield make_frame_from_parts(command=self._command, headers=self._headers, body=bytes(self._current_buf))
self._reset()

def _handle_newline_byte(self) -> Iterator[HeartbeatFrame]:
if not self._current_buf and not self._command:
yield HeartbeatFrame()
return
if self._previous_byte == CARRIAGE:
self._current_buf.pop()
self._headers_processed = not self._current_buf # extra empty line after headers

if self._command:
self._process_header()
else:
self._process_command()

def _process_command(self) -> None:
current_buf_bytes = bytes(self._current_buf)
if current_buf_bytes not in COMMANDS_TO_FRAMES:
self._reset()
else:
self._command = current_buf_bytes
self._current_buf = bytearray()

def _process_header(self) -> None:
header = parse_header(self._current_buf)
if not header:
self._current_buf = bytearray()
return
header_key, header_value = header
if header_key not in self._headers:
self._headers[header_key] = header_value
if header_key.lower() == "content-length":
with suppress(ValueError):
self._content_length = int(header_value)
self._current_buf = bytearray()

def _handle_body_byte(self, byte: bytes) -> None:
if self._content_length is None or self._content_length != len(self._current_buf):
self._current_buf += byte

def parse_frames_from_chunk(self, chunk: bytes) -> Iterator[AnyClientFrame | AnyServerFrame]:
for byte in iter_bytes(chunk):
if byte == NULL:
if self._headers_processed:
self._lines.append(self._current_line)
yield parse_lines_into_frame(self._lines)
self._reset()

elif not self._headers_processed and byte == NEWLINE:
if self._current_line or self._lines:
if self._previous_byte == CARRIAGE:
self._current_line.pop()
self._headers_processed = not self._current_line # extra empty line after headers

if not self._lines and bytes(self._current_line) not in COMMANDS_TO_FRAMES:
self._reset()
else:
self._lines.append(self._current_line)
self._current_line = bytearray()
else:
yield HeartbeatFrame()

yield from self._handle_null_byte()
elif self._headers_processed:
self._handle_body_byte(byte)
elif byte == NEWLINE:
yield from self._handle_newline_byte()
else:
self._current_line += byte
self._current_buf += byte

self._previous_byte = byte
20 changes: 20 additions & 0 deletions packages/stompman/test_stompman/test_frame_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,26 @@ def test_dump_frame(frame: AnyClientFrame, dumped_frame: bytes) -> None:
ConnectedFrame(headers={"header": "1.2"}),
],
),
# Correct content-length with body containing NULL byte
(
b"MESSAGE\ncontent-length:5\n\nBod\x00y\x00",
[MessageFrame(headers={"content-length": "5"}, body=b"Bod\x00y")],
),
# Content-length shorter than actual body (should only read up to content-length)
(
b"MESSAGE\ncontent-length:4\n\nBody\x00 with extra\x00\n",
[MessageFrame(headers={"content-length": "4"}, body=b"Body"), HeartbeatFrame()],
),
# Content-length longer than actual body (should wait for more data)
(
b"MESSAGE\ncontent-length:10\n\nShort",
[],
),
# Content-length longer than actual body, then more data comes with NULL terminator
(
b"MESSAGE\ncontent-length:10\n\nShortMOREDATA\x00",
[MessageFrame(headers={"content-length": "10"}, body=b"ShortMORED")],
),
],
)
def test_load_frames(raw_frames: bytes, loaded_frames: list[AnyServerFrame]) -> None:
Expand Down