Skip to content

Commit dbf2048

Browse files
authored
Allow multiple message types (#444)
Problem: `msgType` is the only parameter when fetching messages, which does not take a list of items as a parameter. Solution: add `msgTypes` as a parameter to the API and deprecate `msgType`.
1 parent 71cf823 commit dbf2048

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

src/aleph/db/accessors/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def make_matching_messages_query(
5252
refs: Optional[Sequence[str]] = None,
5353
chains: Optional[Sequence[Chain]] = None,
5454
message_type: Optional[MessageType] = None,
55+
message_types: Optional[Sequence[MessageType]] = None,
5556
start_date: Optional[Union[float, dt.datetime]] = None,
5657
end_date: Optional[Union[float, dt.datetime]] = None,
5758
content_hashes: Optional[Sequence[ItemHash]] = None,
@@ -87,6 +88,8 @@ def make_matching_messages_query(
8788
select_stmt = select_stmt.where(MessageDb.sender.in_(addresses))
8889
if chains:
8990
select_stmt = select_stmt.where(MessageDb.chain.in_(chains))
91+
if message_types:
92+
select_stmt = select_stmt.where(MessageDb.type.in_(message_types))
9093
if message_type:
9194
select_stmt = select_stmt.where(MessageDb.type == message_type)
9295
if start_datetime:

src/aleph/web/controllers/messages.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ class BaseMessageQueryParams(BaseModel):
6767
"-1 means most recent messages first, 1 means older messages first.",
6868
)
6969
message_type: Optional[MessageType] = Field(
70-
default=None, alias="msgType", description="Message type."
70+
default=None, alias="msgType", description="Message type. Deprecated: use msgTypes instead"
71+
)
72+
message_types: Optional[List[MessageType]] = Field(
73+
default=None, alias="msgTypes", description="Accepted message types."
7174
)
7275
addresses: Optional[List[str]] = Field(
7376
default=None, description="Accepted values for the 'sender' field."
@@ -120,6 +123,7 @@ def validate_field_dependencies(cls, values):
120123
"content_types",
121124
"chains",
122125
"channels",
126+
"message_types",
123127
"tags",
124128
pre=True,
125129
)
@@ -356,7 +360,6 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse:
356360
except ValidationError as e:
357361
raise web.HTTPUnprocessableEntity(body=e.json(indent=4))
358362

359-
message_filters = query_params.dict(exclude_none=True)
360363
history = query_params.history
361364

362365
if history:

tests/api/test_list_messages.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime as dt
22
import itertools
3+
from collections import defaultdict
34
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union, Tuple
45

56
import aiohttp
@@ -221,6 +222,24 @@ async def test_get_messages_filter_by_tags(
221222
assert messages[0]["item_hash"] == amend_message_db.item_hash
222223

223224

225+
@pytest.mark.parametrize("type_field", ("msgType", "msgTypes"))
226+
@pytest.mark.asyncio
227+
async def test_get_by_message_type(fixture_messages, ccn_api_client, type_field: str):
228+
messages_by_type = defaultdict(list)
229+
for message in fixture_messages:
230+
messages_by_type[message["type"]].append(message)
231+
232+
for message_type, expected_messages in messages_by_type.items():
233+
response = await ccn_api_client.get(
234+
MESSAGES_URI, params={type_field: message_type}
235+
)
236+
assert response.status == 200, await response.text()
237+
messages = (await response.json())["messages"]
238+
assert set(msg["item_hash"] for msg in messages) == set(
239+
msg["item_hash"] for msg in expected_messages
240+
)
241+
242+
224243
@pytest.mark.asyncio
225244
async def test_get_messages_filter_by_tags_no_match(fixture_messages, ccn_api_client):
226245
"""

0 commit comments

Comments
 (0)