Skip to content

Commit 2a3ca96

Browse files
MHHukiewitzPsycojokerhoh
committed
feat: add verify_signature parameter to fetch functions
Co-authored-by: Laurent Peuch <[email protected]> Co-authored-by: Hugo Herter <[email protected]>
1 parent 1013145 commit 2a3ca96

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

src/aleph/sdk/client/abstract.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ async def get_posts(
7474
post_filter: Optional[PostFilter] = None,
7575
ignore_invalid_messages: Optional[bool] = True,
7676
invalid_messages_log_level: Optional[int] = logging.NOTSET,
77+
verify_signatures: bool = False,
7778
) -> PostsResponse:
7879
"""
7980
Fetch a list of posts from the network.
@@ -83,25 +84,35 @@ async def get_posts(
8384
:param post_filter: Filter to apply to the posts (Default: None)
8485
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
8586
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
87+
:param verify_signatures: Verify the signatures of the messages (Default: False)
8688
"""
8789
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")
8890

8991
async def get_posts_iterator(
9092
self,
9193
post_filter: Optional[PostFilter] = None,
94+
ignore_invalid_messages: Optional[bool] = True,
95+
invalid_messages_log_level: Optional[int] = logging.NOTSET,
96+
verify_signatures: bool = False,
9297
) -> AsyncIterable[PostMessage]:
9398
"""
9499
Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates
95100
but will always return all posts.
96101
97102
:param post_filter: Filter to apply to the posts (Default: None)
103+
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
104+
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
105+
:param verify_signatures: Verify the signatures of the messages (Default: False)
98106
"""
99107
page = 1
100108
resp = None
101109
while resp is None or len(resp.posts) > 0:
102110
resp = await self.get_posts(
103111
page=page,
104112
post_filter=post_filter,
113+
ignore_invalid_messages=ignore_invalid_messages,
114+
invalid_messages_log_level=invalid_messages_log_level,
115+
verify_signatures=verify_signatures,
105116
)
106117
page += 1
107118
for post in resp.posts:
@@ -178,6 +189,7 @@ async def get_messages(
178189
message_filter: Optional[MessageFilter] = None,
179190
ignore_invalid_messages: Optional[bool] = True,
180191
invalid_messages_log_level: Optional[int] = logging.NOTSET,
192+
verify_signatures: bool = False,
181193
) -> MessagesResponse:
182194
"""
183195
Fetch a list of messages from the network.
@@ -187,25 +199,35 @@ async def get_messages(
187199
:param message_filter: Filter to apply to the messages
188200
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
189201
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
202+
:param verify_signatures: Verify the signatures of the messages (Default: False)
190203
"""
191204
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")
192205

193206
async def get_messages_iterator(
194207
self,
195208
message_filter: Optional[MessageFilter] = None,
209+
ignore_invalid_messages: Optional[bool] = True,
210+
invalid_messages_log_level: Optional[int] = logging.NOTSET,
211+
verify_signatures: bool = False,
196212
) -> AsyncIterable[AlephMessage]:
197213
"""
198214
Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates
199215
but will always return all messages.
200216
201217
:param message_filter: Filter to apply to the messages
218+
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
219+
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
220+
:param verify_signatures: Whether to verify the signatures of the messages (Default: False)
202221
"""
203222
page = 1
204223
resp = None
205224
while resp is None or len(resp.messages) > 0:
206225
resp = await self.get_messages(
207226
page=page,
208227
message_filter=message_filter,
228+
ignore_invalid_messages=ignore_invalid_messages,
229+
invalid_messages_log_level=invalid_messages_log_level,
230+
verify_signatures=verify_signatures,
209231
)
210232
page += 1
211233
for message in resp.messages:
@@ -216,24 +238,28 @@ async def get_message(
216238
self,
217239
item_hash: str,
218240
message_type: Optional[Type[GenericMessage]] = None,
241+
verify_signature: bool = False,
219242
) -> GenericMessage:
220243
"""
221244
Get a single message from its `item_hash` and perform some basic validation.
222245
223246
:param item_hash: Hash of the message to fetch
224247
:param message_type: Type of message to fetch
248+
:param verify_signature: Whether to verify the signature of the message (Default: False)
225249
"""
226250
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")
227251

228252
@abstractmethod
229253
def watch_messages(
230254
self,
231255
message_filter: Optional[MessageFilter] = None,
256+
verify_signatures: bool = False,
232257
) -> AsyncIterable[AlephMessage]:
233258
"""
234259
Iterate over current and future matching messages asynchronously.
235260
236261
:param message_filter: Filter to apply to the messages
262+
:param verify_signatures: Whether to verify the signatures of the messages (Default: False)
237263
"""
238264
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")
239265

src/aleph/sdk/client/http.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..exceptions import FileTooLarge, ForgottenMessageError, MessageNotFoundError
1616
from ..query.filters import MessageFilter, PostFilter
1717
from ..query.responses import MessagesResponse, Post, PostsResponse
18+
from ..security import verify_message_signature
1819
from ..types import GenericMessage
1920
from ..utils import (
2021
Writable,
@@ -117,6 +118,7 @@ async def get_posts(
117118
post_filter: Optional[PostFilter] = None,
118119
ignore_invalid_messages: Optional[bool] = True,
119120
invalid_messages_log_level: Optional[int] = logging.NOTSET,
121+
verify_signatures: bool = False,
120122
) -> PostsResponse:
121123
ignore_invalid_messages = (
122124
True if ignore_invalid_messages is None else ignore_invalid_messages
@@ -145,12 +147,15 @@ async def get_posts(
145147
posts: List[Post] = []
146148
for post_raw in posts_raw:
147149
try:
148-
posts.append(Post.parse_obj(post_raw))
150+
post = Post.parse_obj(post_raw)
151+
posts.append(post)
149152
except ValidationError as e:
150153
if not ignore_invalid_messages:
151154
raise e
152155
if invalid_messages_log_level:
153156
logger.log(level=invalid_messages_log_level, msg=e)
157+
if verify_signatures:
158+
verify_message_signature(post)
154159
return PostsResponse(
155160
posts=posts,
156161
pagination_page=response_json["pagination_page"],
@@ -266,6 +271,7 @@ async def get_messages(
266271
message_filter: Optional[MessageFilter] = None,
267272
ignore_invalid_messages: Optional[bool] = True,
268273
invalid_messages_log_level: Optional[int] = logging.NOTSET,
274+
verify_signatures: bool = False,
269275
) -> MessagesResponse:
270276
ignore_invalid_messages = (
271277
True if ignore_invalid_messages is None else ignore_invalid_messages
@@ -312,6 +318,8 @@ async def get_messages(
312318
raise e
313319
if invalid_messages_log_level:
314320
logger.log(level=invalid_messages_log_level, msg=e)
321+
if verify_signatures:
322+
verify_message_signature(message)
315323

316324
return MessagesResponse(
317325
messages=messages,
@@ -325,6 +333,7 @@ async def get_message(
325333
self,
326334
item_hash: str,
327335
message_type: Optional[Type[GenericMessage]] = None,
336+
verify_signature: bool = False,
328337
) -> GenericMessage:
329338
async with self.http_session.get(f"/api/v0/messages/{item_hash}") as resp:
330339
try:
@@ -339,6 +348,8 @@ async def get_message(
339348
f"The requested message {message_raw['item_hash']} has been forgotten by {', '.join(message_raw['forgotten_by'])}"
340349
)
341350
message = parse_message(message_raw["message"])
351+
if verify_signature:
352+
verify_message_signature(message)
342353
if message_type:
343354
expected_type = get_message_type_value(message_type)
344355
if message.type != expected_type:
@@ -374,6 +385,7 @@ async def get_message_error(
374385
async def watch_messages(
375386
self,
376387
message_filter: Optional[MessageFilter] = None,
388+
verify_signatures: bool = False,
377389
) -> AsyncIterable[AlephMessage]:
378390
message_filter = message_filter or MessageFilter()
379391
params = message_filter.as_http_params()
@@ -389,6 +401,9 @@ async def watch_messages(
389401
break
390402
else:
391403
data = json.loads(msg.data)
392-
yield parse_message(data)
404+
message = parse_message(data)
405+
if verify_signatures:
406+
verify_message_signature(message)
407+
yield message
393408
elif msg.type == aiohttp.WSMsgType.ERROR:
394409
break

0 commit comments

Comments
 (0)