From ffa2f561a38bcc99279b534ffa412ae582978699 Mon Sep 17 00:00:00 2001 From: mhh Date: Thu, 27 Jun 2024 17:50:04 +0200 Subject: [PATCH 1/2] feat: add aleph.sdk.security module Co-authored-by: Laurent Peuch Co-authored-by: Hugo Herter --- src/aleph/sdk/security.py | 64 +++++++++++++++++++++++++++++++++++++ tests/unit/test_security.py | 4 +++ 2 files changed, 68 insertions(+) create mode 100644 src/aleph/sdk/security.py create mode 100644 tests/unit/test_security.py diff --git a/src/aleph/sdk/security.py b/src/aleph/sdk/security.py new file mode 100644 index 00000000..63cceac2 --- /dev/null +++ b/src/aleph/sdk/security.py @@ -0,0 +1,64 @@ +from importlib import import_module +from typing import Callable, Dict, Optional, Union + +from aleph_message.models import AlephMessage, Chain + +from aleph.sdk.chains.common import get_verification_buffer +from aleph.sdk.query.responses import Post + + +def _try_import_verify_signature( + chain: str, +) -> Optional[ + Callable[[Union[bytes, str], Union[bytes, str], Union[bytes, str]], None] +]: + """Try to import a chain signature validator.""" + try: + return import_module(f"aleph.sdk.chains.{chain}").verify_signature + except (ImportError, AttributeError): + return None + + +# This is a dict containing all currently available signature validators, +# indexed by their Chain abbreviation. +# +# Ex.: validators["SOL"] -> aleph.sdk.chains.solana.verify_signature() +VALIDATORS: Dict[ + Chain, + Optional[Callable[[Union[bytes, str], Union[bytes, str], Union[bytes, str]], None]], +] = { + key: _try_import_verify_signature(value) + for key, value in { + # TODO: Add AVAX + Chain.ETH: "ethereum", + Chain.SOL: "sol", + Chain.CSDK: "cosmos", + Chain.DOT: "substrate", + Chain.NULS2: "nuls2", + Chain.TEZOS: "tezos", + }.items() +} + + +def verify_message_signature(message: Union[AlephMessage, Post]) -> None: + """Verify the signature of a message, raise an error if invalid or unsupported. + A BadSignatureError is raised when the signature is incorrect. + A ValueError is raised when the chain is not supported or required dependencies are missing. + """ + if message.chain not in VALIDATORS: + raise ValueError(f"Chain {message.chain} is not supported.") + + validator = VALIDATORS[message.chain] + if validator is None: + raise ValueError( + f"Chain {message.chain} is not installed. Install it with `aleph-sdk-python[{message.chain}]`." + ) + + signature = message.signature + public_key = message.sender + message = get_verification_buffer(message.dict()) + + # to please mypy + assert isinstance(signature, (str, bytes)) + + validator(signature, public_key, message) diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py new file mode 100644 index 00000000..ed3ca32c --- /dev/null +++ b/tests/unit/test_security.py @@ -0,0 +1,4 @@ +def test_validators_loaded(): + import aleph.sdk.security as security + + assert any([validator is not None for validator in security.validators.values()]) From 0533e4a582cff3262fa03bd736410ea00ce8b65b Mon Sep 17 00:00:00 2001 From: mhh Date: Thu, 27 Jun 2024 18:50:40 +0200 Subject: [PATCH 2/2] feat: add verify_signature parameter to fetch functions Co-authored-by: Laurent Peuch Co-authored-by: Hugo Herter --- src/aleph/sdk/client/abstract.py | 26 ++++++++++++++++++++++++++ src/aleph/sdk/client/http.py | 19 +++++++++++++++++-- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/aleph/sdk/client/abstract.py b/src/aleph/sdk/client/abstract.py index 9fce5469..301388ec 100644 --- a/src/aleph/sdk/client/abstract.py +++ b/src/aleph/sdk/client/abstract.py @@ -74,6 +74,7 @@ async def get_posts( post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, + verify_signatures: bool = False, ) -> PostsResponse: """ Fetch a list of posts from the network. @@ -83,18 +84,25 @@ async def get_posts( :param post_filter: Filter to apply to the posts (Default: None) :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + :param verify_signatures: Verify the signatures of the messages (Default: False) """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") async def get_posts_iterator( self, post_filter: Optional[PostFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + verify_signatures: bool = False, ) -> AsyncIterable[PostMessage]: """ Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates but will always return all posts. :param post_filter: Filter to apply to the posts (Default: None) + :param ignore_invalid_messages: Ignore invalid messages (Default: True) + :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + :param verify_signatures: Verify the signatures of the messages (Default: False) """ page = 1 resp = None @@ -102,6 +110,9 @@ async def get_posts_iterator( resp = await self.get_posts( page=page, post_filter=post_filter, + ignore_invalid_messages=ignore_invalid_messages, + invalid_messages_log_level=invalid_messages_log_level, + verify_signatures=verify_signatures, ) page += 1 for post in resp.posts: @@ -178,6 +189,7 @@ async def get_messages( message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, + verify_signatures: bool = False, ) -> MessagesResponse: """ Fetch a list of messages from the network. @@ -187,18 +199,25 @@ async def get_messages( :param message_filter: Filter to apply to the messages :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + :param verify_signatures: Verify the signatures of the messages (Default: False) """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") async def get_messages_iterator( self, message_filter: Optional[MessageFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + verify_signatures: bool = False, ) -> AsyncIterable[AlephMessage]: """ Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates but will always return all messages. :param message_filter: Filter to apply to the messages + :param ignore_invalid_messages: Ignore invalid messages (Default: True) + :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + :param verify_signatures: Whether to verify the signatures of the messages (Default: False) """ page = 1 resp = None @@ -206,6 +225,9 @@ async def get_messages_iterator( resp = await self.get_messages( page=page, message_filter=message_filter, + ignore_invalid_messages=ignore_invalid_messages, + invalid_messages_log_level=invalid_messages_log_level, + verify_signatures=verify_signatures, ) page += 1 for message in resp.messages: @@ -216,12 +238,14 @@ async def get_message( self, item_hash: str, message_type: Optional[Type[GenericMessage]] = None, + verify_signature: bool = False, ) -> GenericMessage: """ Get a single message from its `item_hash` and perform some basic validation. :param item_hash: Hash of the message to fetch :param message_type: Type of message to fetch + :param verify_signature: Whether to verify the signature of the message (Default: False) """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") @@ -229,11 +253,13 @@ async def get_message( def watch_messages( self, message_filter: Optional[MessageFilter] = None, + verify_signatures: bool = False, ) -> AsyncIterable[AlephMessage]: """ Iterate over current and future matching messages asynchronously. :param message_filter: Filter to apply to the messages + :param verify_signatures: Whether to verify the signatures of the messages (Default: False) """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index ae98b0d1..9ded2bcd 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -15,6 +15,7 @@ from ..exceptions import FileTooLarge, ForgottenMessageError, MessageNotFoundError from ..query.filters import MessageFilter, PostFilter from ..query.responses import MessagesResponse, Post, PostsResponse +from ..security import verify_message_signature from ..types import GenericMessage from ..utils import ( Writable, @@ -117,6 +118,7 @@ async def get_posts( post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, + verify_signatures: bool = False, ) -> PostsResponse: ignore_invalid_messages = ( True if ignore_invalid_messages is None else ignore_invalid_messages @@ -145,12 +147,15 @@ async def get_posts( posts: List[Post] = [] for post_raw in posts_raw: try: - posts.append(Post.parse_obj(post_raw)) + post = Post.parse_obj(post_raw) + posts.append(post) except ValidationError as e: if not ignore_invalid_messages: raise e if invalid_messages_log_level: logger.log(level=invalid_messages_log_level, msg=e) + if verify_signatures: + verify_message_signature(post) return PostsResponse( posts=posts, pagination_page=response_json["pagination_page"], @@ -266,6 +271,7 @@ async def get_messages( message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, + verify_signatures: bool = False, ) -> MessagesResponse: ignore_invalid_messages = ( True if ignore_invalid_messages is None else ignore_invalid_messages @@ -312,6 +318,8 @@ async def get_messages( raise e if invalid_messages_log_level: logger.log(level=invalid_messages_log_level, msg=e) + if verify_signatures: + verify_message_signature(message) return MessagesResponse( messages=messages, @@ -325,6 +333,7 @@ async def get_message( self, item_hash: str, message_type: Optional[Type[GenericMessage]] = None, + verify_signature: bool = False, ) -> GenericMessage: async with self.http_session.get(f"/api/v0/messages/{item_hash}") as resp: try: @@ -339,6 +348,8 @@ async def get_message( f"The requested message {message_raw['item_hash']} has been forgotten by {', '.join(message_raw['forgotten_by'])}" ) message = parse_message(message_raw["message"]) + if verify_signature: + verify_message_signature(message) if message_type: expected_type = get_message_type_value(message_type) if message.type != expected_type: @@ -374,6 +385,7 @@ async def get_message_error( async def watch_messages( self, message_filter: Optional[MessageFilter] = None, + verify_signatures: bool = False, ) -> AsyncIterable[AlephMessage]: message_filter = message_filter or MessageFilter() params = message_filter.as_http_params() @@ -389,6 +401,9 @@ async def watch_messages( break else: data = json.loads(msg.data) - yield parse_message(data) + message = parse_message(data) + if verify_signatures: + verify_message_signature(message) + yield message elif msg.type == aiohttp.WSMsgType.ERROR: break