diff --git a/src/realtime/src/realtime/types.py b/src/realtime/src/realtime/types.py index 75f8d4a6..feaca122 100644 --- a/src/realtime/src/realtime/types.py +++ b/src/realtime/src/realtime/types.py @@ -107,9 +107,15 @@ class PostgresChangesPayload(TypedDict): ids: List[int] +class BroadcastMeta(TypedDict, total=False): + replayed: bool + id: str + + class BroadcastPayload(TypedDict): event: str payload: dict[str, Any] + meta: NotRequired[BroadcastMeta] @dataclass(frozen=True) @@ -172,9 +178,15 @@ def __init__(self, events: PresenceEvents): # TypedDicts -class RealtimeChannelBroadcastConfig(TypedDict): +class ReplayOption(TypedDict, total=False): + since: int + limit: int + + +class RealtimeChannelBroadcastConfig(TypedDict, total=False): ack: bool self: bool + replay: ReplayOption class RealtimeChannelPresenceConfig(TypedDict): diff --git a/src/realtime/tests/test_connection.py b/src/realtime/tests/test_connection.py index ed9ab733..0f5aeb2a 100644 --- a/src/realtime/tests/test_connection.py +++ b/src/realtime/tests/test_connection.py @@ -6,6 +6,7 @@ import pytest from dotenv import load_dotenv from pydantic import BaseModel +from websockets import broadcast from realtime import ( AsyncRealtimeChannel, @@ -297,7 +298,7 @@ def insert_callback(payload): assert insert["data"]["record"]["id"] == created_todo_id assert insert["data"]["record"]["description"] == "Test todo" - assert insert["data"]["record"]["is_completed"] == False + assert insert["data"]["record"]["is_completed"] is False assert received_events["insert"] == [insert, message_insert] @@ -488,3 +489,195 @@ async def test_send_message_reconnection(socket: AsyncRealtimeClient): await socket.send(message) await socket.close() + + +@pytest.mark.asyncio +async def test_subscribe_to_private_channel_with_broadcast_replay( + socket: AsyncRealtimeClient, +): + """Test that channel subscription sends correct payload with broadcast replay configuration.""" + import json + from unittest.mock import AsyncMock, patch + + # Mock the websocket connection + mock_ws = AsyncMock() + socket._ws_connection = mock_ws + + # Connect the socket (this will use our mock) + await socket.connect() + + # Calculate replay timestamp + ten_mins_ago = datetime.datetime.now() - datetime.timedelta(minutes=10) + ten_mins_ago_ms = int(ten_mins_ago.timestamp() * 1000) + + # Create channel with broadcast replay configuration + channel: AsyncRealtimeChannel = socket.channel( + "test-private-channel", + params={ + "config": { + "private": True, + "broadcast": {"replay": {"since": ten_mins_ago_ms, "limit": 100}}, + "presence": {"enabled": True, "key": ""}, + } + }, + ) + + # Mock the subscription callback to be called immediately + callback_called = False + + def mock_callback(state, error): + nonlocal callback_called + callback_called = True + + # Subscribe to the channel + await channel.subscribe(mock_callback) + + # Verify that send was called with the correct payload + assert mock_ws.send.called, "WebSocket send should have been called" + + # Get the sent message + sent_message = mock_ws.send.call_args[0][0] + message_data = json.loads(sent_message) + + # Verify the message structure + assert message_data["topic"] == "realtime:test-private-channel" + assert message_data["event"] == "phx_join" + assert "ref" in message_data + assert "payload" in message_data + + # Verify the payload contains the correct broadcast replay configuration + payload = message_data["payload"] + assert "config" in payload + + config = payload["config"] + assert config["private"] is True + assert "broadcast" in config + + broadcast_config = config["broadcast"] + assert "replay" in broadcast_config + + replay_config = broadcast_config["replay"] + assert replay_config["since"] == ten_mins_ago_ms + assert replay_config["limit"] == 100 + + # Verify postgres_changes array is present (even if empty) + assert "postgres_changes" in config + assert isinstance(config["postgres_changes"], list) + + await socket.close() + + +@pytest.mark.asyncio +async def test_subscribe_to_channel_with_empty_replay_config( + socket: AsyncRealtimeClient, +): + """Test that channel subscription handles empty replay configuration correctly.""" + import json + from unittest.mock import AsyncMock, patch + + # Mock the websocket connection + mock_ws = AsyncMock() + socket._ws_connection = mock_ws + + # Connect the socket + await socket.connect() + + # Create channel with empty replay configuration + channel: AsyncRealtimeChannel = socket.channel( + "test-empty-replay", + params={ + "config": { + "private": False, + "broadcast": {"ack": True, "self": False, "replay": {}}, + "presence": {"enabled": True, "key": ""}, + } + }, + ) + + # Mock the subscription callback + callback_called = False + + def mock_callback(state, error): + nonlocal callback_called + callback_called = True + + # Subscribe to the channel + await channel.subscribe(mock_callback) + + # Verify that send was called + assert mock_ws.send.called, "WebSocket send should have been called" + + # Get the sent message + sent_message = mock_ws.send.call_args[0][0] + message_data = json.loads(sent_message) + + # Verify the payload structure + payload = message_data["payload"] + config = payload["config"] + + assert config["private"] is False + assert "broadcast" in config + + broadcast_config = config["broadcast"] + assert broadcast_config["ack"] is True + assert broadcast_config["self"] is False + assert broadcast_config["replay"] == {} + + await socket.close() + + +@pytest.mark.asyncio +async def test_subscribe_to_channel_without_replay_config(socket: AsyncRealtimeClient): + """Test that channel subscription works without replay configuration.""" + import json + from unittest.mock import AsyncMock, patch + + # Mock the websocket connection + mock_ws = AsyncMock() + socket._ws_connection = mock_ws + + # Connect the socket + await socket.connect() + + # Create channel without replay configuration + channel: AsyncRealtimeChannel = socket.channel( + "test-no-replay", + params={ + "config": { + "private": False, + "broadcast": {"ack": True, "self": True}, + "presence": {"enabled": True, "key": ""}, + } + }, + ) + + # Mock the subscription callback + callback_called = False + + def mock_callback(state, error): + nonlocal callback_called + callback_called = True + + # Subscribe to the channel + await channel.subscribe(mock_callback) + + # Verify that send was called + assert mock_ws.send.called, "WebSocket send should have been called" + + # Get the sent message + sent_message = mock_ws.send.call_args[0][0] + message_data = json.loads(sent_message) + + # Verify the payload structure + payload = message_data["payload"] + config = payload["config"] + + assert config["private"] is False + assert "broadcast" in config + + broadcast_config = config["broadcast"] + assert broadcast_config["ack"] is True + assert broadcast_config["self"] is True + assert "replay" not in broadcast_config + + await socket.close()