diff --git a/.coveragerc b/.coveragerc index 89ec6fe0..bf5bfda2 100644 --- a/.coveragerc +++ b/.coveragerc @@ -26,3 +26,6 @@ exclude_lines = # Don't complain if non-runnable code isn't run: if 0: if __name__ == .__main__.: + + # Don't complain about ineffective code: + pass diff --git a/setup.cfg b/setup.cfg index 48eb7f9b..1b5fa494 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,6 +79,8 @@ testing = black isort flake8 + aiodns + peewee mqtt = aiomqtt<=0.1.3 certifi @@ -103,6 +105,8 @@ ledger = ledgereth==0.9.0 docs = sphinxcontrib-plantuml +cache = + peewee [options.entry_points] # Add here console scripts like: diff --git a/src/aleph/sdk/base.py b/src/aleph/sdk/base.py new file mode 100644 index 00000000..62bd5725 --- /dev/null +++ b/src/aleph/sdk/base.py @@ -0,0 +1,476 @@ +# An interface for all clients to implement. + +import logging +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path +from typing import ( + Any, + AsyncIterable, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Type, + Union, +) + +from aleph_message.models import ( + AlephMessage, + MessagesResponse, + MessageType, + PostMessage, +) +from aleph_message.models.execution.program import Encoding +from aleph_message.status import MessageStatus + +from aleph.sdk.models import PostsResponse +from aleph.sdk.types import GenericMessage, StorageEnum + +DEFAULT_PAGE_SIZE = 200 + + +class AlephClientBase(ABC): + @abstractmethod + async def fetch_aggregate( + self, + address: str, + key: str, + limit: int = 100, + ) -> Dict[str, Dict]: + """ + Fetch a value from the aggregate store by owner address and item key. + + :param address: Address of the owner of the aggregate + :param key: Key of the aggregate + :param limit: Maximum number of items to fetch (Default: 100) + """ + pass + + @abstractmethod + async def fetch_aggregates( + self, + address: str, + keys: Optional[Iterable[str]] = None, + limit: int = 100, + ) -> Dict[str, Dict]: + """ + Fetch key-value pairs from the aggregate store by owner address. + + :param address: Address of the owner of the aggregate + :param keys: Keys of the aggregates to fetch (Default: all items) + :param limit: Maximum number of items to fetch (Default: 100) + """ + pass + + @abstractmethod + async def get_posts( + self, + pagination: int = DEFAULT_PAGE_SIZE, + page: int = 1, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> PostsResponse: + """ + Fetch a list of posts from the network. + + :param pagination: Number of items to fetch (Default: 200) + :param page: Page to fetch, begins at 1 (Default: 1) + :param types: Types of posts to fetch (Default: all types) + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Chains of the posts to fetch (Default: all chains) + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + :param ignore_invalid_messages: Ignore invalid messages (Default: True) + :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + """ + pass + + async def get_posts_iterator( + self, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> AsyncIterable[PostMessage]: + """ + Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates + but will always return all posts. + + :param types: Types of posts to fetch (Default: all types) + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Chains of the posts to fetch (Default: all chains) + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + """ + total_items = None + per_page = DEFAULT_PAGE_SIZE + page = 1 + while total_items is None or page * per_page < total_items: + resp = await self.get_posts( + page=page, + types=types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + total_items = resp.pagination_total + page += 1 + for post in resp.posts: + yield post + + @abstractmethod + async def download_file( + self, + file_hash: str, + ) -> bytes: + """ + Get a file from the storage engine as raw bytes. + + Warning: Downloading large files can be slow and memory intensive. + + :param file_hash: The hash of the file to retrieve. + """ + pass + + @abstractmethod + async def get_messages( + self, + pagination: int = DEFAULT_PAGE_SIZE, + page: int = 1, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> MessagesResponse: + """ + Fetch a list of messages from the network. + + :param pagination: Number of items to fetch (Default: 200) + :param page: Page to fetch, begins at 1 (Default: 1) + :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" + :param content_types: Filter by content type + :param content_keys: Filter by aggregate key + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Filter by sender address chain + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + :param ignore_invalid_messages: Ignore invalid messages (Default: True) + :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + """ + pass + + async def get_messages_iterator( + self, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> AsyncIterable[AlephMessage]: + """ + Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates + but will always return all messages. + + :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" + :param content_types: Filter by content type + :param content_keys: Filter by content key + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Filter by sender address chain + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + """ + total_items = None + per_page = DEFAULT_PAGE_SIZE + page = 1 + while total_items is None or page * per_page < total_items: + resp = await self.get_messages( + page=page, + message_type=message_type, + content_types=content_types, + content_keys=content_keys, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + total_items = resp.pagination_total + page += 1 + for message in resp.messages: + yield message + + @abstractmethod + async def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + """ + Get a single message from its `item_hash` and perform some basic validation. + + :param item_hash: Hash of the message to fetch + :param message_type: Type of message to fetch + :param channel: Channel of the message to fetch + """ + pass + + @abstractmethod + def watch_messages( + self, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> AsyncIterable[AlephMessage]: + """ + Iterate over current and future matching messages asynchronously. + + :param message_type: Type of message to watch + :param content_types: Content types to watch + :param content_keys: Filter by aggregate key + :param refs: References to watch + :param addresses: Addresses to watch + :param tags: Tags to watch + :param hashes: Hashes to watch + :param channels: Channels to watch + :param chains: Chains to watch + :param start_date: Start date from when to watch + :param end_date: End date until when to watch + """ + pass + + +class AuthenticatedAlephClientBase(AlephClientBase): + @abstractmethod + async def create_post( + self, + post_content: Any, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + """ + Create a POST message on the Aleph network. It is associated with a channel and owned by an account. + + :param post_content: The content of the message + :param post_type: An arbitrary content type that helps to describe the post_content + :param ref: A reference to a previous message that it replaces + :param address: The address that will be displayed as the author of the message + :param channel: The channel that the message will be posted on + :param inline: An optional flag to indicate if the content should be inlined in the message or not + :param storage_engine: An optional storage engine to use for the message, if not inlined (Default: "storage") + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + pass + + @abstractmethod + async def create_aggregate( + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + """ + Create an AGGREGATE message. It is meant to be used as a quick access storage associated with an account. + + :param key: Key to use to store the content + :param content: Content to store + :param address: Address to use to sign the message + :param channel: Channel to use (Default: "TEST") + :param inline: Whether to write content inside the message (Default: True) + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + pass + + @abstractmethod + async def create_store( + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + """ + Create a STORE message to store a file on the Aleph network. + + Can be passed either a file path, an IPFS hash or the file's content as raw bytes. + + :param address: Address to display as the author of the message (Default: account.get_address()) + :param file_content: Byte stream of the file to store (Default: None) + :param file_path: Path to the file to store (Default: None) + :param file_hash: Hash of the file to store (Default: None) + :param guess_mime_type: Guess the MIME type of the file (Default: False) + :param ref: Reference to a previous message (Default: None) + :param storage_engine: Storage engine to use (Default: "storage") + :param extra_fields: Extra fields to add to the STORE message (Default: None) + :param channel: Channel to post the message to (Default: "TEST") + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + pass + + @abstractmethod + async def create_program( + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + metadata: Optional[Mapping[str, Any]] = None, + ) -> Tuple[AlephMessage, MessageStatus]: + """ + Post a (create) PROGRAM message. + + :param program_ref: Reference to the program to run + :param entrypoint: Entrypoint to run + :param runtime: Runtime to use + :param environment_variables: Environment variables to pass to the program + :param storage_engine: Storage engine to use (Default: "storage") + :param channel: Channel to use (Default: "TEST") + :param address: Address to use (Default: account.get_address()) + :param sync: If true, waits for the message to be processed by the API server + :param memory: Memory in MB for the VM to be allocated (Default: 128) + :param vcpus: Number of vCPUs to allocate (Default: 1) + :param timeout_seconds: Timeout in seconds (Default: 30.0) + :param persistent: Whether the program should be persistent or not (Default: False) + :param encoding: Encoding to use (Default: Encoding.zip) + :param volumes: Volumes to mount + :param subscriptions: Patterns of Aleph messages to forward to the program's event receiver + :param metadata: Metadata to attach to the message + """ + pass + + @abstractmethod + async def forget( + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + """ + Post a FORGET message to remove previous messages from the network. + + Targeted messages need to be signed by the same account that is attempting to forget them, + if the creating address did not delegate the access rights to the forgetting account. + + :param hashes: Hashes of the messages to forget + :param reason: Reason for forgetting the messages + :param storage_engine: Storage engine to use (Default: "storage") + :param channel: Channel to use (Default: "TEST") + :param address: Address to use (Default: account.get_address()) + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + pass + + @abstractmethod + async def submit( + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + """ + Submit a message to the network. This is a generic method that can be used to submit any type of message. + Prefer using the more specific methods to submit messages. + + :param content: Content of the message + :param message_type: Type of the message + :param channel: Channel to use (Default: "TEST") + :param storage_engine: Storage engine to use (Default: "storage") + :param allow_inlining: Whether to allow inlining the content of the message (Default: True) + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + pass diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index d48fded9..7a79959f 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -50,6 +50,7 @@ from aleph.sdk.types import Account, GenericMessage, StorageEnum from aleph.sdk.utils import Writable, copy_async_readable_to_buffer +from .base import AlephClientBase, AuthenticatedAlephClientBase from .conf import settings from .exceptions import ( BroadcastError, @@ -58,7 +59,7 @@ MessageNotFoundError, MultipleMessagesError, ) -from .models import MessagesResponse +from .models import MessagesResponse, PostsResponse from .utils import check_unix_socket_valid, get_message_type_value logger = logging.getLogger(__name__) @@ -214,7 +215,7 @@ def get_posts( chains: Optional[Iterable[str]] = None, start_date: Optional[Union[datetime, float]] = None, end_date: Optional[Union[datetime, float]] = None, - ) -> Dict[str, Dict]: + ) -> PostsResponse: return self._wrap( self.async_session.get_posts, pagination=pagination, @@ -468,7 +469,7 @@ def submit( ) -class AlephClient: +class AlephClient(AlephClientBase): api_server: str http_session: aiohttp.ClientSession @@ -529,14 +530,6 @@ async def fetch_aggregate( key: str, limit: int = 100, ) -> Dict[str, Dict]: - """ - Fetch a value from the aggregate store by owner address and item key. - - :param address: Address of the owner of the aggregate - :param key: Key of the aggregate - :param limit: Maximum number of items to fetch (Default: 100) - """ - params: Dict[str, Any] = {"keys": key} if limit: params["limit"] = limit @@ -554,14 +547,6 @@ async def fetch_aggregates( keys: Optional[Iterable[str]] = None, limit: int = 100, ) -> Dict[str, Dict]: - """ - Fetch key-value pairs from the aggregate store by owner address. - - :param address: Address of the owner of the aggregate - :param keys: Keys of the aggregates to fetch (Default: all items) - :param limit: Maximum number of items to fetch (Default: 100) - """ - keys_str = ",".join(keys) if keys else "" params: Dict[str, Any] = {} if keys_str: @@ -590,22 +575,17 @@ async def get_posts( chains: Optional[Iterable[str]] = None, start_date: Optional[Union[datetime, float]] = None, end_date: Optional[Union[datetime, float]] = None, - ) -> Dict[str, Dict]: - """ - Fetch a list of posts from the network. - - :param pagination: Number of items to fetch (Default: 200) - :param page: Page to fetch, begins at 1 (Default: 1) - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from - """ + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> PostsResponse: + ignore_invalid_messages = ( + True if ignore_invalid_messages is None else ignore_invalid_messages + ) + invalid_messages_log_level = ( + logging.NOTSET + if invalid_messages_log_level is None + else invalid_messages_log_level + ) params: Dict[str, Any] = dict(pagination=pagination, page=page) @@ -635,7 +615,36 @@ async def get_posts( async with self.http_session.get("/api/v0/posts.json", params=params) as resp: resp.raise_for_status() - return await resp.json() + response_json = await resp.json() + posts_raw = response_json["posts"] + + # All posts may not be valid according to the latest specification in + # aleph-message. This allows the user to specify how errors should be handled. + posts: List[AlephMessage] = [] + for post_raw in posts_raw: + try: + message = parse_message(post_raw) + posts.append(message) + except KeyError as e: + if not ignore_invalid_messages: + raise e + logger.log( + level=invalid_messages_log_level, + msg=f"KeyError: Field '{e.args[0]}' not found", + ) + except ValidationError as e: + if not ignore_invalid_messages: + raise e + if invalid_messages_log_level: + logger.log(level=invalid_messages_log_level, msg=e) + + return PostsResponse( + posts=posts, + pagination_page=response_json["pagination_page"], + pagination_total=response_json["pagination_total"], + pagination_per_page=response_json["pagination_per_page"], + pagination_item=response_json["pagination_item"], + ) async def download_file_to_buffer( self, @@ -734,25 +743,6 @@ async def get_messages( ignore_invalid_messages: bool = True, invalid_messages_log_level: int = logging.NOTSET, ) -> MessagesResponse: - """ - Fetch a list of messages from the network. - - :param pagination: Number of items to fetch (Default: 200) - :param page: Page to fetch, begins at 1 (Default: 1) - :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by content key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from - :param ignore_invalid_messages: Ignore invalid messages (Default: False) - :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) - """ ignore_invalid_messages = ( True if ignore_invalid_messages is None else ignore_invalid_messages ) @@ -833,13 +823,6 @@ async def get_message( message_type: Optional[Type[GenericMessage]] = None, channel: Optional[str] = None, ) -> GenericMessage: - """ - Get a single message from its `item_hash` and perform some basic validation. - - :param item_hash: Hash of the message to fetch - :param message_type: Type of message to fetch - :param channel: Channel of the message to fetch - """ messages_response = await self.get_messages( hashes=[item_hash], channels=[channel] if channel else None, @@ -864,6 +847,7 @@ async def watch_messages( self, message_type: Optional[MessageType] = None, content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, refs: Optional[Iterable[str]] = None, addresses: Optional[Iterable[str]] = None, tags: Optional[Iterable[str]] = None, @@ -873,26 +857,14 @@ async def watch_messages( start_date: Optional[Union[datetime, float]] = None, end_date: Optional[Union[datetime, float]] = None, ) -> AsyncIterable[AlephMessage]: - """ - Iterate over current and future matching messages asynchronously. - - :param message_type: Type of message to watch - :param content_types: Content types to watch - :param refs: References to watch - :param addresses: Addresses to watch - :param tags: Tags to watch - :param hashes: Hashes to watch - :param channels: Channels to watch - :param chains: Chains to watch - :param start_date: Start date from when to watch - :param end_date: End date until when to watch - """ params: Dict[str, Any] = dict() if message_type is not None: params["msgType"] = message_type.value if content_types is not None: params["contentTypes"] = ",".join(content_types) + if content_keys is not None: + params["contentKeys"] = ",".join(content_keys) if refs is not None: params["refs"] = ",".join(refs) if addresses is not None: @@ -931,7 +903,7 @@ async def watch_messages( break -class AuthenticatedAlephClient(AlephClient): +class AuthenticatedAlephClient(AlephClient, AuthenticatedAlephClientBase): account: Account BROADCAST_MESSAGE_FIELDS = { @@ -969,8 +941,11 @@ async def __aenter__(self) -> "AuthenticatedAlephClient": return self async def ipfs_push(self, content: Mapping) -> str: - """Push arbitrary content as JSON to the IPFS service.""" + """ + Push arbitrary content as JSON to the IPFS service. + :param content: The dict-like content to upload + """ url = "/api/v0/ipfs/add_json" logger.debug(f"Pushing to IPFS on {url}") @@ -979,8 +954,11 @@ async def ipfs_push(self, content: Mapping) -> str: return (await resp.json()).get("hash") async def storage_push(self, content: Mapping) -> str: - """Push arbitrary content as JSON to the storage service.""" + """ + Push arbitrary content as JSON to the storage service. + :param content: The dict-like content to upload + """ url = "/api/v0/storage/add_json" logger.debug(f"Pushing to storage on {url}") @@ -989,7 +967,11 @@ async def storage_push(self, content: Mapping) -> str: return (await resp.json()).get("hash") async def ipfs_push_file(self, file_content: Union[str, bytes]) -> str: - """Push a file to the IPFS service.""" + """ + Push a file to the IPFS service. + + :param file_content: The file content to upload + """ data = aiohttp.FormData() data.add_field("file", file_content) @@ -1001,7 +983,9 @@ async def ipfs_push_file(self, file_content: Union[str, bytes]) -> str: return (await resp.json()).get("hash") async def storage_push_file(self, file_content) -> str: - """Push a file to the storage service.""" + """ + Push a file to the storage service. + """ data = aiohttp.FormData() data.add_field("file", file_content) @@ -1144,18 +1128,6 @@ async def create_post( storage_engine: StorageEnum = StorageEnum.storage, sync: bool = False, ) -> Tuple[PostMessage, MessageStatus]: - """ - Create a POST message on the Aleph network. It is associated with a channel and owned by an account. - - :param post_content: The content of the message - :param post_type: An arbitrary content type that helps to describe the post_content - :param ref: A reference to a previous message that it replaces - :param address: The address that will be displayed as the author of the message - :param channel: The channel that the message will be posted on - :param inline: An optional flag to indicate if the content should be inlined in the message or not - :param storage_engine: An optional storage engine to use for the message, if not inlined (Default: "storage") - :param sync: If true, waits for the message to be processed by the API server (Default: False) - """ address = address or settings.ADDRESS_TO_USE or self.account.get_address() content = PostContent( @@ -1184,16 +1156,6 @@ async def create_aggregate( inline: bool = True, sync: bool = False, ) -> Tuple[AggregateMessage, MessageStatus]: - """ - Create an AGGREGATE message. It is meant to be used as a quick access storage associated with an account. - - :param key: Key to use to store the content - :param content: Content to store - :param address: Address to use to sign the message - :param channel: Channel to use (Default: "TEST") - :param inline: Whether to write content inside the message (Default: True) - :param sync: If true, waits for the message to be processed by the API server (Default: False) - """ address = address or settings.ADDRESS_TO_USE or self.account.get_address() content_ = AggregateContent( @@ -1224,22 +1186,6 @@ async def create_store( channel: Optional[str] = None, sync: bool = False, ) -> Tuple[StoreMessage, MessageStatus]: - """ - Create a STORE message to store a file on the Aleph network. - - Can be passed either a file path, an IPFS hash or the file's content as raw bytes. - - :param address: Address to display as the author of the message (Default: account.get_address()) - :param file_content: Byte stream of the file to store (Default: None) - :param file_path: Path to the file to store (Default: None) - :param file_hash: Hash of the file to store (Default: None) - :param guess_mime_type: Guess the MIME type of the file (Default: False) - :param ref: Reference to a previous message (Default: None) - :param storage_engine: Storage engine to use (Default: "storage") - :param extra_fields: Extra fields to add to the STORE message (Default: None) - :param channel: Channel to post the message to (Default: "TEST") - :param sync: If true, waits for the message to be processed by the API server (Default: False) - """ address = address or settings.ADDRESS_TO_USE or self.account.get_address() extra_fields = extra_fields or {} @@ -1260,7 +1206,7 @@ async def create_store( else: raise ValueError(f"Unknown storage engine: '{storage_engine}'") - assert file_hash, "File hash should be empty" + assert file_hash, "File hash should not be empty" if magic is None: pass @@ -1308,25 +1254,6 @@ async def create_program( subscriptions: Optional[List[Mapping]] = None, metadata: Optional[Mapping[str, Any]] = None, ) -> Tuple[ProgramMessage, MessageStatus]: - """ - Post a (create) PROGRAM message. - - :param program_ref: Reference to the program to run - :param entrypoint: Entrypoint to run - :param runtime: Runtime to use - :param environment_variables: Environment variables to pass to the program - :param storage_engine: Storage engine to use (Default: "storage") - :param channel: Channel to use (Default: "TEST") - :param address: Address to use (Default: account.get_address()) - :param sync: If true, waits for the message to be processed by the API server - :param memory: Memory in MB for the VM to be allocated (Default: 128) - :param vcpus: Number of vCPUs to allocate (Default: 1) - :param timeout_seconds: Timeout in seconds (Default: 30.0) - :param persistent: Whether the program should be persistent or not (Default: False) - :param encoding: Encoding to use (Default: Encoding.zip) - :param volumes: Volumes to mount - :param subscriptions: Patterns of Aleph messages to forward to the program's event receiver - """ address = address or settings.ADDRESS_TO_USE or self.account.get_address() volumes = volumes if volumes is not None else [] @@ -1404,19 +1331,6 @@ async def forget( address: Optional[str] = None, sync: bool = False, ) -> Tuple[ForgetMessage, MessageStatus]: - """ - Post a FORGET message to remove previous messages from the network. - - Targeted messages need to be signed by the same account that is attempting to forget them, - if the creating address did not delegate the access rights to the forgetting account. - - :param hashes: Hashes of the messages to forget - :param reason: Reason for forgetting the messages - :param storage_engine: Storage engine to use (Default: "storage") - :param channel: Channel to use (Default: "TEST") - :param address: Address to use (Default: account.get_address()) - :param sync: If true, waits for the message to be processed by the API server (Default: False) - """ address = address or settings.ADDRESS_TO_USE or self.account.get_address() content = ForgetContent( diff --git a/src/aleph/sdk/conf.py b/src/aleph/sdk/conf.py index 264c8c9f..64cbeef5 100644 --- a/src/aleph/sdk/conf.py +++ b/src/aleph/sdk/conf.py @@ -33,6 +33,22 @@ class Settings(BaseSettings): CODE_USES_SQUASHFS: bool = which("mksquashfs") is not None # True if command exists + # Dns resolver + DNS_IPFS_DOMAIN = "ipfs.public.aleph.sh" + DNS_PROGRAM_DOMAIN = "program.public.aleph.sh" + DNS_INSTANCE_DOMAIN = "instance.public.aleph.sh" + DNS_ROOT_DOMAIN = "static.public.aleph.sh" + DNS_RESOLVERS = ["1.1.1.1", "1.0.0.1"] + + CACHE_DATABASE_PATH: Path = Field( + default=Path(":memory:"), # can also be :memory: for in-memory caching + description="Path to the cache database", + ) + CACHE_FILES_PATH: Path = Field( + default=Path("cache", "files"), + description="Path to the cache files", + ) + class Config: env_prefix = "ALEPH_" case_sensitive = False diff --git a/src/aleph/sdk/domain.py b/src/aleph/sdk/domain.py new file mode 100644 index 00000000..93862023 --- /dev/null +++ b/src/aleph/sdk/domain.py @@ -0,0 +1,141 @@ +import re +from typing import Optional + +import aiodns + +from aleph.sdk.exceptions import DomainConfigurationError + +from .conf import settings + + +class AlephDNS: + def __init__(self): + self.resolver = aiodns.DNSResolver(servers=settings.DNS_RESOLVERS) + self.fqdn_matcher = re.compile(r"https?://?") + + async def query(self, name: str, query_type: str): + try: + return await self.resolver.query(name, query_type) + except Exception as e: + print(e) + return None + + def url_to_domain(self, url): + return self.fqdn_matcher.sub("", url).strip().strip("/") + + async def get_ipv6_address(self, url: str): + domain = self.url_to_domain(url) + ipv6 = [] + query = await self.query(domain, "AAAA") + if query: + for entry in query: + ipv6.append(entry.host) + return ipv6 + + async def get_dnslink(self, url: str): + domain = self.url_to_domain(url) + query = await self.query(f"_dnslink.{domain}", "TXT") + if query is not None and len(query) > 0: + return query[0].text + + async def check_domain_configured(self, domain, target, owner): + try: + print("Check...", target) + return await self.check_domain(domain, target, owner) + except Exception as error: + raise DomainConfigurationError(error) + + async def check_domain(self, url: str, target: str, owner: Optional[str] = None): + status = {"cname": False, "owner_proof": False} + + target = target.lower() + + dns_rules = self.get_required_dns_rules(url, target, owner) + + for dns_rule in dns_rules: + status[dns_rule["rule_name"]] = False + + record_name = dns_rule["dns"]["name"] + record_type = dns_rule["dns"]["type"] + record_value = dns_rule["dns"]["value"] + + res = await self.query(record_name, record_type.upper()) + + if record_type == "txt": + found = False + + for _res in res: + if hasattr(_res, "text") and _res.text == record_value: + found = True + + if not found: + raise DomainConfigurationError( + (dns_rule["info"], dns_rule["on_error"], status) + ) + + elif ( + res is None + or not hasattr(res, record_type) + or getattr(res, record_type) != record_value + ): + raise DomainConfigurationError( + (dns_rule["info"], dns_rule["on_error"], status) + ) + + status[dns_rule["rule_name"]] = True + + return status + + def get_required_dns_rules(self, url, target, owner: Optional[str] = None): + domain = self.url_to_domain(url) + target = target.lower() + dns_rules = [] + + if target == "ipfs": + cname_value = settings.DNS_IPFS_DOMAIN + elif target == "program": + cname_value = settings.DNS_PROGRAM_DOMAIN + elif target == "instance": + cname_value = f"{domain}.{settings.DNS_INSTANCE_DOMAIN}" + + # cname rule + dns_rules.append( + { + "rule_name": "cname", + "dns": {"type": "cname", "name": domain, "value": cname_value}, + "info": f"Create a CNAME record for {domain} with value {cname_value}", + "on_error": f"CNAME record not found: {domain}", + } + ) + + if target == "ipfs": + # ipfs rule + dns_rules.append( + { + "rule_name": "delegation", + "dns": { + "type": "cname", + "name": f"_dnslink.{domain}", + "value": f"_dnslink.{domain}.{settings.DNS_ROOT_DOMAIN}", + }, + "info": f"Create a CNAME record for _dnslink.{domain} with value _dnslink.{domain}.{settings.DNS_ROOT_DOMAIN}", + "on_error": f"CNAME record not found: _dnslink.{domain}", + } + ) + + if owner: + # ownership rule + dns_rules.append( + { + "rule_name": "owner_proof", + "dns": { + "type": "txt", + "name": f"_control.{domain}", + "value": owner, + }, + "info": f"Create a TXT record for _control.{domain} with value = owner address", + "on_error": "Owner address mismatch", + } + ) + + return dns_rules diff --git a/src/aleph/sdk/exceptions.py b/src/aleph/sdk/exceptions.py index 51762925..b885da50 100644 --- a/src/aleph/sdk/exceptions.py +++ b/src/aleph/sdk/exceptions.py @@ -50,3 +50,8 @@ class FileTooLarge(Exception): """ pass + + +class DomainConfigurationError(Exception): + "Raised when the domain checks are not satisfied" + pass diff --git a/src/aleph/sdk/models.py b/src/aleph/sdk/models.py index f7dfcd6e..4e8b9990 100644 --- a/src/aleph/sdk/models.py +++ b/src/aleph/sdk/models.py @@ -1,14 +1,25 @@ from typing import List -from aleph_message.models import AlephMessage +from aleph_message.models import AlephMessage, PostMessage from pydantic import BaseModel -class MessagesResponse(BaseModel): - """Response from an Aleph node API on the path /api/v0/messages.json""" - - messages: List[AlephMessage] +class PaginationResponse(BaseModel): pagination_page: int pagination_total: int pagination_per_page: int pagination_item: str + + +class MessagesResponse(PaginationResponse): + """Response from an Aleph node API on the path /api/v0/messages.json""" + + messages: List[AlephMessage] + pagination_item = "messages" + + +class PostsResponse(PaginationResponse): + """Response from an Aleph node API on the path /api/v0/posts.json""" + + posts: List[PostMessage] + pagination_item = "posts" diff --git a/src/aleph/sdk/node.py b/src/aleph/sdk/node.py new file mode 100644 index 00000000..bb1c0d89 --- /dev/null +++ b/src/aleph/sdk/node.py @@ -0,0 +1,749 @@ +import asyncio +import json +import logging +import typing +from datetime import datetime +from functools import partial +from pathlib import Path +from typing import ( + Any, + AsyncIterable, + Coroutine, + Dict, + Generic, + Iterable, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from aleph_message import MessagesResponse, parse_message +from aleph_message.models import ( + AlephMessage, + Chain, + ItemHash, + MessageConfirmation, + MessageType, +) +from aleph_message.models.execution.base import Encoding +from aleph_message.status import MessageStatus +from peewee import ( + BooleanField, + CharField, + FloatField, + IntegerField, + Model, + SqliteDatabase, +) +from playhouse.shortcuts import model_to_dict +from playhouse.sqlite_ext import JSONField +from pydantic import BaseModel + +from aleph.sdk import AuthenticatedAlephClient +from aleph.sdk.base import AlephClientBase, AuthenticatedAlephClientBase +from aleph.sdk.conf import settings +from aleph.sdk.exceptions import MessageNotFoundError +from aleph.sdk.models import PostsResponse +from aleph.sdk.types import GenericMessage, StorageEnum + +db = SqliteDatabase(settings.CACHE_DATABASE_PATH) +T = TypeVar("T", bound=BaseModel) + + +class JSONDictEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, BaseModel): + return obj.dict() + return json.JSONEncoder.default(self, obj) + + +pydantic_json_dumps = partial(json.dumps, cls=JSONDictEncoder) + + +class PydanticField(JSONField, Generic[T]): + """ + A field for storing pydantic model types as JSON in a database. Uses json for serialization. + """ + + type: T + + def __init__(self, *args, **kwargs): + self.type = kwargs.pop("type") + super().__init__(*args, **kwargs) + + def db_value(self, value: Optional[T]) -> Optional[str]: + if value is None: + return None + return value.json() + + def python_value(self, value: Optional[str]) -> Optional[T]: + if value is None: + return None + return self.type.parse_raw(value) + + +class MessageModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + item_hash = CharField(primary_key=True) + chain = CharField(5) + type = CharField(9) + sender = CharField() + channel = CharField(null=True) + confirmations: PydanticField[MessageConfirmation] = PydanticField( + type=MessageConfirmation, null=True + ) + confirmed = BooleanField(null=True) + signature = CharField(null=True) + size = IntegerField(null=True) + time = FloatField() + item_type = CharField(7) + item_content = CharField(null=True) + hash_type = CharField(6, null=True) + content = JSONField(json_dumps=pydantic_json_dumps) + forgotten_by = CharField(null=True) + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + key = CharField(null=True) + ref = CharField(null=True) + content_type = CharField(null=True) + + class Meta: + database = db + + +def message_to_model(message: AlephMessage) -> Dict: + return { + "item_hash": str(message.item_hash), + "chain": message.chain, + "type": message.type, + "sender": message.sender, + "channel": message.channel, + "confirmations": message.confirmations[0] if message.confirmations else None, + "confirmed": message.confirmed, + "signature": message.signature, + "size": message.size, + "time": message.time, + "item_type": message.item_type, + "item_content": message.item_content, + "hash_type": message.hash_type, + "content": message.content, + "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, + "tags": message.content.content.get("tags", None) + if hasattr(message.content, "content") + else None, + "key": message.content.key if hasattr(message.content, "key") else None, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "content_type": message.content.type + if hasattr(message.content, "type") + else None, + } + + +def model_to_message(item: Any) -> AlephMessage: + item.confirmations = [item.confirmations] if item.confirmations else [] + item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None + + to_exclude = [ + MessageModel.tags, + MessageModel.ref, + MessageModel.key, + MessageModel.content_type, + ] + + item_dict = model_to_dict(item, exclude=to_exclude) + return parse_message(item_dict) + + +def query_field(field_name, field_values: Iterable[str]): + field = getattr(MessageModel, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) + + +def get_message_query( + message_type: Optional[MessageType] = None, + content_keys: Optional[Iterable[str]] = None, + content_types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, +): + query = MessageModel.select().order_by(MessageModel.time.desc()) + conditions = [] + if message_type: + conditions.append(query_field("type", [message_type.value])) + if content_keys: + conditions.append(query_field("key", content_keys)) + if content_types: + conditions.append(query_field("content_type", content_types)) + if refs: + conditions.append(query_field("ref", refs)) + if addresses: + conditions.append(query_field("sender", addresses)) + if tags: + for tag in tags: + conditions.append(MessageModel.tags.contains(tag)) + if hashes: + conditions.append(query_field("item_hash", hashes)) + if channels: + conditions.append(query_field("channel", channels)) + if chains: + conditions.append(query_field("chain", chains)) + if start_date: + conditions.append(MessageModel.time >= start_date) + if end_date: + conditions.append(MessageModel.time <= end_date) + + if conditions: + query = query.where(*conditions) + return query + + +class MessageCache(AlephClientBase): + """ + A wrapper around a sqlite3 database for caching AlephMessage objects. + + It can be used independently of a DomainNode to implement any kind of caching strategy. + """ + + _instance_count = 0 # Class-level counter for active instances + + def __init__(self): + if db.is_closed(): + db.connect() + if not MessageModel.table_exists(): + db.create_tables([MessageModel]) + + MessageCache._instance_count += 1 + + def __del__(self): + MessageCache._instance_count -= 1 + + if MessageCache._instance_count == 0: + db.close() + + def __getitem__(self, item_hash: Union[ItemHash, str]) -> Optional[AlephMessage]: + try: + item = MessageModel.get(MessageModel.item_hash == str(item_hash)) + except MessageModel.DoesNotExist: + return None + return model_to_message(item) + + def __delitem__(self, item_hash: Union[ItemHash, str]): + MessageModel.delete().where(MessageModel.item_hash == str(item_hash)).execute() + + def __contains__(self, item_hash: Union[ItemHash, str]) -> bool: + return ( + MessageModel.select() + .where(MessageModel.item_hash == str(item_hash)) + .exists() + ) + + def __len__(self): + return MessageModel.select().count() + + def __iter__(self) -> Iterator[AlephMessage]: + """ + Iterate over all messages in the cache, the latest first. + """ + for item in iter(MessageModel.select().order_by(-MessageModel.time)): + yield model_to_message(item) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return repr(self) + + @staticmethod + def add(messages: Union[AlephMessage, Iterable[AlephMessage]]): + if isinstance(messages, typing.get_args(AlephMessage)): + messages = [messages] + + data_source = (message_to_model(message) for message in messages) + MessageModel.insert_many(data_source).on_conflict_replace().execute() + + @staticmethod + def get( + item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]] + ) -> List[AlephMessage]: + """ + Get many messages from the cache by their item hash. + """ + if not isinstance(item_hashes, list): + item_hashes = [item_hashes] + item_hashes = [str(item_hash) for item_hash in item_hashes] + items = ( + MessageModel.select() + .where(MessageModel.item_hash.in_(item_hashes)) + .execute() + ) + return [model_to_message(item) for item in items] + + def listen_to(self, message_stream: AsyncIterable[AlephMessage]) -> Coroutine: + """ + Listen to a stream of messages and add them to the cache. + """ + + async def _listen(): + async for message in message_stream: + self.add(message) + print(f"Added message {message.item_hash} to cache") + + return _listen() + + async def fetch_aggregate( + self, address: str, key: str, limit: int = 100 + ) -> Dict[str, Dict]: + item = ( + MessageModel.select() + .where(MessageModel.type == MessageType.aggregate.value) + .where(MessageModel.sender == address) + .where(MessageModel.key == key) + .order_by(MessageModel.time.desc()) + .first() + ) + return item.content["content"] + + async def fetch_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None, limit: int = 100 + ) -> Dict[str, Dict]: + query = ( + MessageModel.select() + .where(MessageModel.type == MessageType.aggregate.value) + .where(MessageModel.sender == address) + .order_by(MessageModel.time.desc()) + ) + if keys: + query = query.where(MessageModel.key.in_(keys)) + query = query.limit(limit) + return {item.key: item.content["content"] for item in list(query)} + + async def get_posts( + self, + pagination: int = 200, + page: int = 1, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> PostsResponse: + query = get_message_query( + message_type=MessageType.post, + content_types=types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + + query = query.paginate(page, pagination) + + posts = [model_to_message(item) for item in list(query)] + + return PostsResponse( + posts=posts, + pagination_page=page, + pagination_per_page=pagination, + pagination_total=query.count(), + pagination_item="posts", + ) + + async def download_file(self, file_hash: str) -> bytes: + raise NotImplementedError + + async def get_messages( + self, + pagination: int = 200, + page: int = 1, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> MessagesResponse: + """ + Get many messages from the cache. + """ + query = get_message_query( + message_type=message_type, + content_keys=content_keys, + content_types=content_types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + + query = query.paginate(page, pagination) + + messages = [model_to_message(item) for item in list(query)] + + return MessagesResponse( + messages=messages, + pagination_page=page, + pagination_per_page=pagination, + pagination_total=query.count(), + pagination_item="messages", + ) + + async def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + """ + Get a single message from the cache. + """ + query = MessageModel.select().where(MessageModel.item_hash == item_hash) + + if message_type: + query = query.where(MessageModel.type == message_type.value) + if channel: + query = query.where(MessageModel.channel == channel) + + item = query.first() + + if item: + return model_to_message(item) + + raise MessageNotFoundError(f"No such hash {item_hash}") + + async def watch_messages( + self, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> AsyncIterable[AlephMessage]: + """ + Watch messages from the cache. + """ + query = get_message_query( + message_type=message_type, + content_keys=content_keys, + content_types=content_types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + + async for item in query: + yield model_to_message(item) + + +class DomainNode(MessageCache, AuthenticatedAlephClientBase): + """ + A Domain Node is a queryable proxy for Aleph Messages that are stored in a database cache and/or in the Aleph + network. + + It synchronizes with the network on a subset of the messages by listening to the network and storing the + messages in the cache. The user may define the subset by specifying a channels, tags, senders, chains, + message types, and/or a time window. + """ + + def __init__( + self, + session: AuthenticatedAlephClient, + channels: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + chains: Optional[Iterable[Chain]] = None, + message_type: Optional[MessageType] = None, + ): + super().__init__() + self.session = session + self.channels = channels + self.tags = tags + self.addresses = addresses + self.chains = chains + self.message_type = message_type + + # start listening to the network and storing messages in the cache + asyncio.get_event_loop().create_task( + self.listen_to( + self.session.watch_messages( + channels=self.channels, + tags=self.tags, + addresses=self.addresses, + chains=self.chains, + message_type=self.message_type, + ) + ) + ) + + # synchronize with past messages + asyncio.get_event_loop().run_until_complete( + self.synchronize( + channels=self.channels, + tags=self.tags, + addresses=self.addresses, + chains=self.chains, + message_type=self.message_type, + ) + ) + + async def __aenter__(self) -> "DomainNode": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + async def synchronize( + self, + channels: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + chains: Optional[Iterable[Chain]] = None, + message_type: Optional[MessageType] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + """ + Synchronize with past messages. + """ + chunk_size = 200 + messages = [] + async for message in self.session.get_messages_iterator( + channels=channels, + tags=tags, + addresses=addresses, + chains=chains, + message_type=message_type, + start_date=start_date, + end_date=end_date, + ): + messages.append(message) + if len(messages) >= chunk_size: + self.add(messages) + messages = [] + if messages: + self.add(messages) + + async def download_file(self, file_hash: str) -> bytes: + """ + Opens a file that has been locally stored by its hash. + """ + try: + with open(self._file_path(file_hash), "rb") as f: + return f.read() + except FileNotFoundError: + file = await self.session.download_file(file_hash) + self._file_path(file_hash).parent.mkdir(parents=True, exist_ok=True) + with open(self._file_path(file_hash), "wb") as f: + f.write(file) + return file + + @staticmethod + def _file_path(file_hash: str) -> Path: + return settings.CACHE_FILES_PATH / Path(file_hash) + + async def create_post( + self, + post_content: Any, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.create_post( + post_content=post_content, + post_type=post_type, + ref=ref, + address=address, + channel=channel, + inline=inline, + storage_engine=storage_engine, + sync=sync, + ) + # TODO: This can cause inconsistencies, if the message is rejected by the aleph node + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def create_aggregate( + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.create_aggregate( + key=key, + content=content, + address=address, + channel=channel, + inline=inline, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def create_store( + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.create_store( + address=address, + file_content=file_content, + file_path=file_path, + file_hash=file_hash, + guess_mime_type=guess_mime_type, + ref=ref, + storage_engine=storage_engine, + extra_fields=extra_fields, + channel=channel, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def create_program( + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + metadata: Optional[Mapping[str, Any]] = None, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.create_program( + program_ref=program_ref, + entrypoint=entrypoint, + runtime=runtime, + environment_variables=environment_variables, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, + persistent=persistent, + encoding=encoding, + volumes=volumes, + subscriptions=subscriptions, + metadata=metadata, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def forget( + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.forget( + hashes=hashes, + reason=reason, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + ) + del self[resp.item_hash] + return resp, status + + async def submit( + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.submit( + content=content, + message_type=message_type, + channel=channel, + storage_engine=storage_engine, + allow_inlining=allow_inlining, + sync=sync, + ) + # TODO: this can cause inconsistencies if the message is dropped + if status in [MessageStatus.PROCESSED, MessageStatus.PENDING]: + self.add(resp) + return resp, status diff --git a/src/aleph/sdk/wallets/ledger/ethereum.py b/src/aleph/sdk/wallets/ledger/ethereum.py index a8ba6899..ea1ba8bc 100644 --- a/src/aleph/sdk/wallets/ledger/ethereum.py +++ b/src/aleph/sdk/wallets/ledger/ethereum.py @@ -39,6 +39,8 @@ def from_address( """ device = device or init_dongle() account = find_account(address=address, dongle=device, count=5) + if not account: + return None return LedgerETHAccount( account=account, device=device, diff --git a/tests/integration/itest_forget.py b/tests/integration/itest_forget.py index 8339e316..29b6c6d9 100644 --- a/tests/integration/itest_forget.py +++ b/tests/integration/itest_forget.py @@ -107,10 +107,10 @@ async def test_forget_a_forget_message(fixture_account): account=fixture_account, api_server=TARGET_NODE ) as session: get_post_response = await session.get_posts(hashes=[post_hash]) - assert len(get_post_response["posts"]) == 1 - post = get_post_response["posts"][0] + assert len(get_post_response.posts) == 1 + post = get_post_response.posts[0] - forget_message_hash = post["forgotten_by"][0] + forget_message_hash = post.forgotten_by[0] forget_message, forget_status = await session.forget( hashes=[forget_message_hash], reason="I want to remember this post. Maybe I can forget I forgot it?", diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9952f847..311d32f3 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,7 +1,9 @@ from pathlib import Path from tempfile import NamedTemporaryFile +from typing import List import pytest as pytest +from aleph_message.models import AggregateMessage, AlephMessage, PostMessage import aleph.sdk.chains.ethereum as ethereum import aleph.sdk.chains.sol as solana @@ -34,3 +36,71 @@ def tezos_account() -> tezos.TezosAccount: with NamedTemporaryFile(delete=False) as private_key_file: private_key_file.close() yield tezos.get_fallback_account(path=Path(private_key_file.name)) + + +@pytest.fixture +def messages() -> List[AlephMessage]: + return [ + AggregateMessage.parse_obj( + { + "item_hash": "5b26d949fe05e38f535ef990a89da0473f9d700077cced228f2d36e73fca1fd6", + "type": "AGGREGATE", + "chain": "ETH", + "sender": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "signature": "0xca5825b6b93390482b436cb7f28b4628f8c9f56dc6af08260c869b79dd6017c94248839bd9fd0ffa1230dc3b1f4f7572a8d1f6fed6c6e1fb4d70ccda0ab5d4f21b", + "item_type": "inline", + "item_content": '{"address":"0x51A58800b26AA1451aaA803d1746687cB88E0501","key":"0xce844d79e5c0c325490c530aa41e8f602f0b5999binance","content":{"1692026263168":{"version":"x25519-xsalsa20-poly1305","nonce":"RT4Lbqs7Xzk+op2XC+VpXgwOgg21BotN","ephemPublicKey":"CVW8ECE3m8BepytHMTLan6/jgIfCxGdnKmX47YirF08=","ciphertext":"VuGJ9vMkJSbaYZCCv6Zemx4ixeb+9IW8H1vFB9vLtz1a8d87R4BfYUisLoCQxRkeUXqfW0/KIGQ5idVjr8Yj7QnKglW5AJ8UX7wEWMhiRFLatpWP8P9FI2n8Z7Rblu7Oz/OeKnuljKL3KsalcUQSsFa/1qACsIoycPZ6Wq6t1mXxVxxJWzClLyKRihv1pokZGT9UWxh7+tpoMGlRdYainyAt0/RygFw+r8iCMOilHnyv4ndLkKQJXyttb0tdNr/gr57+9761+trioGSysLQKZQWW6Ih6aE8V9t3BenfzYwiCnfFw3YAAKBPMdm9QdIETyrOi7YhD/w==","sha256":"bbeb499f681aed2bc18b6f3b6a30d25254bd30fbfde43444e9085f3bcd075c3c"}},"time":1692026263.662}', + "content": { + "key": "0xce844d79e5c0c325490c530aa41e8f602f0b5999binance", + "time": 1692026263.662, + "address": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "content": { + "hello": "world", + }, + }, + "time": 1692026263.662, + "channel": "UNSLASHED", + "size": 734, + "confirmations": [], + "confirmed": False, + } + ), + PostMessage.parse_obj( + { + "item_hash": "70f3798fdc68ce0ee03715a5547ee24e2c3e259bf02e3f5d1e4bf5a6f6a5e99f", + "type": "POST", + "chain": "SOL", + "sender": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "signature": "0x91616ee45cfba55742954ff87ebf86db4988bcc5e3334b49a4caa6436e28e28d4ab38667cbd4bfb8903abf8d71f70d9ceb2c0a8d0a15c04fc1af5657f0050c101b", + "item_type": "storage", + "item_content": None, + "content": { + "time": 1692026021.1257718, + "type": "aleph-network-metrics", + "address": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "ref": "0123456789abcdef", + "content": { + "tags": ["mainnet"], + "hello": "world", + "version": "1.0", + }, + }, + "time": 1692026021.132849, + "channel": "aleph-scoring", + "size": 122537, + "confirmations": [], + "confirmed": False, + } + ), + ] + + +@pytest.fixture +def raw_messages_response(messages): + return { + "messages": [message.dict() for message in messages], + "pagination_item": "messages", + "pagination_page": 1, + "pagination_per_page": 20, + "pagination_total": 2, + } diff --git a/tests/unit/test_domains.py b/tests/unit/test_domains.py new file mode 100644 index 00000000..afad05cb --- /dev/null +++ b/tests/unit/test_domains.py @@ -0,0 +1,48 @@ +import pytest + +from aleph.sdk.domain import AlephDNS +from aleph.sdk.exceptions import DomainConfigurationError + + +@pytest.mark.asyncio +async def test_url_to_domain(): + alephdns = AlephDNS() + domain = alephdns.url_to_domain("https://aleph.im") + query = await alephdns.query(domain, "A") + assert query is not None + assert len(query) > 0 + assert hasattr(query[0], "host") + + +@pytest.mark.asyncio +async def test_get_ipv6_address(): + alephdns = AlephDNS() + url = "https://aleph.im" + ipv6_address = await alephdns.get_ipv6_address(url) + assert ipv6_address is not None + assert len(ipv6_address) > 0 + assert ":" in ipv6_address[0] + + +@pytest.mark.asyncio +async def test_dnslink(): + alephdns = AlephDNS() + url = "https://aleph.im" + dnslink = await alephdns.get_dnslink(url) + assert dnslink is not None + + +@pytest.mark.asyncio +async def test_configured_domain(): + alephdns = AlephDNS() + url = "https://custom-domain-unit-test.aleph.sh" + status = await alephdns.check_domain(url, "ipfs", "0xfakeaddress") + assert type(status) is dict + + +@pytest.mark.asyncio +async def test_not_configured_domain(): + alephdns = AlephDNS() + url = "https://not-configured-domain.aleph.sh" + with pytest.raises(DomainConfigurationError): + await alephdns.check_domain(url, "ipfs", "0xfakeaddress") diff --git a/tests/unit/test_node.py b/tests/unit/test_node.py new file mode 100644 index 00000000..0b844e50 --- /dev/null +++ b/tests/unit/test_node.py @@ -0,0 +1,255 @@ +import json +import os +from pathlib import Path +from typing import Any, Dict +from unittest.mock import AsyncMock, MagicMock + +import pytest as pytest +from aleph_message.models import ( + AggregateMessage, + ForgetMessage, + MessageType, + PostMessage, + ProgramMessage, + StoreMessage, +) +from aleph_message.status import MessageStatus + +from aleph.sdk import AuthenticatedAlephClient +from aleph.sdk.conf import settings +from aleph.sdk.node import DomainNode +from aleph.sdk.types import Account, StorageEnum + + +class MockPostResponse: + def __init__(self, response_message: Any, sync: bool): + self.response_message = response_message + self.sync = sync + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 if self.sync else 202 + + def raise_for_status(self): + if self.status not in [200, 202]: + raise Exception("Bad status code") + + async def json(self): + message_status = "processed" if self.sync else "pending" + return { + "message_status": message_status, + "publication_status": {"status": "success", "failed": []}, + "hash": "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", + "message": self.response_message, + } + + async def text(self): + return json.dumps(await self.json()) + + +class MockGetResponse: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 + + def raise_for_status(self): + if self.status != 200: + raise Exception("Bad status code") + + async def json(self): + return self.response + + +@pytest.fixture +def mock_session_with_two_messages( + ethereum_account: Account, raw_messages_response: Dict[str, Any] +) -> AuthenticatedAlephClient: + http_session = AsyncMock() + http_session.post = MagicMock() + http_session.post.side_effect = lambda *args, **kwargs: MockPostResponse( + response_message={ + "type": "post", + "channel": "TEST", + "content": {"Hello": "World"}, + "key": "QmBlahBlahBlah", + "item_hash": "QmBlahBlahBlah", + }, + sync=kwargs.get("sync", False), + ) + http_session.get = MagicMock() + http_session.get.return_value = MockGetResponse(raw_messages_response) + + client = AuthenticatedAlephClient( + account=ethereum_account, api_server="http://localhost" + ) + client.http_session = http_session + + return client + + +@pytest.mark.asyncio +def test_node_init(mock_session_with_two_messages): + node = DomainNode(session=mock_session_with_two_messages) + assert node.session == mock_session_with_two_messages + assert len(node) >= 2 + + +@pytest.fixture +def mock_node_with_post_success(mock_session_with_two_messages) -> DomainNode: + node = DomainNode(session=mock_session_with_two_messages) + return node + + +@pytest.mark.asyncio +async def test_create_post(mock_node_with_post_success): + async with mock_node_with_post_success as session: + content = {"Hello": "World"} + + post_message, message_status = await session.create_post( + post_content=content, + post_type="TEST", + channel="TEST", + sync=False, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(post_message, PostMessage) + assert message_status == MessageStatus.PENDING + + +@pytest.mark.asyncio +async def test_create_aggregate(mock_node_with_post_success): + async with mock_node_with_post_success as session: + aggregate_message, message_status = await session.create_aggregate( + key="hello", + content={"Hello": "world"}, + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(aggregate_message, AggregateMessage) + + +@pytest.mark.asyncio +async def test_create_store(mock_node_with_post_success): + mock_ipfs_push_file = AsyncMock() + mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + + mock_node_with_post_success.ipfs_push_file = mock_ipfs_push_file + + async with mock_node_with_post_success as node: + _ = await node.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + _ = await node.create_store( + file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + mock_storage_push_file = AsyncMock() + mock_storage_push_file.return_value = ( + "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + ) + mock_node_with_post_success.storage_push_file = mock_storage_push_file + async with mock_node_with_post_success as node: + store_message, message_status = await node.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.storage, + ) + + assert mock_node_with_post_success.session.http_session.post.called + assert isinstance(store_message, StoreMessage) + + +@pytest.mark.asyncio +async def test_create_program(mock_node_with_post_success): + async with mock_node_with_post_success as node: + program_message, message_status = await node.create_program( + program_ref="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", + entrypoint="main:app", + runtime="facefacefacefacefacefacefacefacefacefacefacefacefacefacefaceface", + channel="TEST", + metadata={"tags": ["test"]}, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(program_message, ProgramMessage) + + +@pytest.mark.asyncio +async def test_forget(mock_node_with_post_success): + async with mock_node_with_post_success as node: + forget_message, message_status = await node.forget( + hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"], + reason="GDPR", + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(forget_message, ForgetMessage) + + +@pytest.mark.asyncio +async def test_download_file(mock_node_with_post_success): + mock_node_with_post_success.session.download_file = AsyncMock() + mock_node_with_post_success.session.download_file.return_value = b"HELLO" + + # remove file locally + if os.path.exists(settings.CACHE_FILES_PATH / Path("QmAndSoOn")): + os.remove(settings.CACHE_FILES_PATH / Path("QmAndSoOn")) + + # fetch from mocked response + async with mock_node_with_post_success as node: + file_content = await node.download_file( + file_hash="QmAndSoOn", + ) + + assert mock_node_with_post_success.session.http_session.get.called_once + assert file_content == b"HELLO" + + # fetch cached + async with mock_node_with_post_success as node: + file_content = await node.download_file( + file_hash="QmAndSoOn", + ) + + assert file_content == b"HELLO" + + +@pytest.mark.asyncio +async def test_submit_message(mock_node_with_post_success): + content = {"Hello": "World"} + async with mock_node_with_post_success as node: + message, status = await node.submit( + content={ + "address": "0x1234567890123456789012345678901234567890", + "time": 1234567890, + "type": "TEST", + "content": content, + }, + message_type=MessageType.post, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert message.content.content == content + assert status == MessageStatus.PENDING diff --git a/tests/unit/test_node_get.py b/tests/unit/test_node_get.py new file mode 100644 index 00000000..48bff3b8 --- /dev/null +++ b/tests/unit/test_node_get.py @@ -0,0 +1,231 @@ +import json +from hashlib import sha256 +from typing import List + +import pytest +from aleph_message.models import ( + AlephMessage, + Chain, + MessageType, + PostContent, + PostMessage, +) + +from aleph.sdk.chains.ethereum import get_fallback_account +from aleph.sdk.exceptions import MessageNotFoundError +from aleph.sdk.node import MessageCache + + +@pytest.mark.asyncio +async def test_base(messages): + # test add_many + cache = MessageCache() + cache.add(messages) + assert len(cache) == len(messages) + + item_hashes = [message.item_hash for message in messages] + cached_messages = cache.get(item_hashes) + assert len(cached_messages) == len(messages) + + for message in messages: + assert cache[message.item_hash] == message + + for message in messages: + assert message.item_hash in cache + + for message in cache: + del cache[message.item_hash] + assert message.item_hash not in cache + + assert len(cache) == 0 + del cache + + +class TestMessageQueries: + messages: List[AlephMessage] + cache: MessageCache + + @pytest.fixture(autouse=True) + def class_setup(self, messages): + self.messages = messages + self.cache = MessageCache() + self.cache.add(self.messages) + + def class_teardown(self): + del self.cache + + @pytest.mark.asyncio + async def test_iterate(self): + assert len(self.cache) == len(self.messages) + for message in self.cache: + assert message in self.messages + + @pytest.mark.asyncio + async def test_addresses(self): + items = ( + await self.cache.get_messages(addresses=[self.messages[0].sender]) + ).messages + assert items[0] == self.messages[0] + + @pytest.mark.asyncio + async def test_tags(self): + assert ( + len((await self.cache.get_messages(tags=["thistagdoesnotexist"])).messages) + == 0 + ) + + @pytest.mark.asyncio + async def test_message_type(self): + assert (await self.cache.get_messages(message_type=MessageType.post)).messages[ + 0 + ] == self.messages[1] + + @pytest.mark.asyncio + async def test_refs(self): + assert ( + await self.cache.get_messages(refs=[self.messages[1].content.ref]) + ).messages[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_hashes(self): + assert ( + await self.cache.get_messages(hashes=[self.messages[0].item_hash]) + ).messages[0] == self.messages[0] + + @pytest.mark.asyncio + async def test_pagination(self): + assert len((await self.cache.get_messages(pagination=1)).messages) == 1 + + @pytest.mark.asyncio + async def test_content_types(self): + assert ( + await self.cache.get_messages(content_types=[self.messages[1].content.type]) + ).messages[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_channels(self): + assert ( + await self.cache.get_messages(channels=[self.messages[1].channel]) + ).messages[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_chains(self): + assert ( + await self.cache.get_messages(chains=[self.messages[1].chain]) + ).messages[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_content_keys(self): + assert ( + await self.cache.get_messages(content_keys=[self.messages[0].content.key]) + ).messages[0] == self.messages[0] + + +class TestPostQueries: + messages: List[AlephMessage] + cache: MessageCache + + @pytest.fixture(autouse=True) + def class_setup(self, messages): + self.messages = messages + self.cache = MessageCache() + self.cache.add(self.messages) + + def class_teardown(self): + del self.cache + + @pytest.mark.asyncio + async def test_addresses(self): + items = (await self.cache.get_posts(addresses=[self.messages[1].sender])).posts + assert items[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_tags(self): + assert ( + len((await self.cache.get_posts(tags=["thistagdoesnotexist"])).posts) == 0 + ) + + @pytest.mark.asyncio + async def test_types(self): + assert ( + len((await self.cache.get_posts(types=["thistypedoesnotexist"])).posts) == 0 + ) + + @pytest.mark.asyncio + async def test_channels(self): + assert (await self.cache.get_posts(channels=[self.messages[1].channel])).posts[ + 0 + ] == self.messages[1] + + @pytest.mark.asyncio + async def test_chains(self): + assert (await self.cache.get_posts(chains=[self.messages[1].chain])).posts[ + 0 + ] == self.messages[1] + + +@pytest.mark.asyncio +async def test_message_cache_listener(): + async def mock_message_stream(): + for i in range(3): + content = PostContent( + content={"hello": f"world{i}"}, + type="test", + address=get_fallback_account().get_address(), + time=0, + ) + message = PostMessage( + sender=get_fallback_account().get_address(), + item_hash=sha256(json.dumps(content.dict()).encode()).hexdigest(), + chain=Chain.ETH.value, + type=MessageType.post.value, + item_type="inline", + time=0, + content=content, + item_content=json.dumps(content.dict()), + ) + yield message + + cache = MessageCache() + # test listener + coro = cache.listen_to(mock_message_stream()) + await coro + assert len(cache) >= 3 + + +@pytest.mark.asyncio +async def test_fetch_aggregate(messages): + cache = MessageCache() + cache.add(messages) + + aggregate = await cache.fetch_aggregate(messages[0].sender, messages[0].content.key) + + assert aggregate == messages[0].content.content + + +@pytest.mark.asyncio +async def test_fetch_aggregates(messages): + cache = MessageCache() + cache.add(messages) + + aggregates = await cache.fetch_aggregates(messages[0].sender) + + assert aggregates == {messages[0].content.key: messages[0].content.content} + + +@pytest.mark.asyncio +async def test_get_message(messages): + cache = MessageCache() + cache.add(messages) + + message: AlephMessage = await cache.get_message(messages[0].item_hash) + + assert message == messages[0] + + +@pytest.mark.asyncio +async def test_get_message_fail(): + cache = MessageCache() + + with pytest.raises(MessageNotFoundError): + await cache.get_message("0x1234567890123456789012345678901234567890")