From 4ed8277adc7c54d8ee13f251710ebe8710a4358e Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Wed, 11 Jan 2023 18:36:21 +0100 Subject: [PATCH 1/3] Internal: user session class Problem: we often pass 2-3 objects to each function that represent information on how to connect to the Aleph network: * the user account * the API server URL * the aiohttp session. This can be simplified to improve the user experience and reduce the complexity of the SDK. Solution: introduce a new state class: `UserSession`. This object is now passed as the first parameter of every public function. It is initialized with the user account and the API server URL. It is in charge of configuring and managing the aiohttp session object. The typical usage of the library now looks like this: ``` account = load_my_account() api_server = "https://aleph.cloud" # example async with UserSession(account=account, api_server=api_server) as session: message, status = await create_post( session=session, ... ) ``` This enables the following improvements: * Less clutter in function signatures: all public SDK functions now have 1 or 2 fewer arguments. * The API server URL is now managed when initializing the aiohttp session inside the user session object. Implementations can simply specify the endpoint URL. Ex: `f"{api_server}/api/v0/messages.json` can now be expressed as `"/api/v0/messages.json"`. Breaking changes: * The signatures of all public functions of the SDK have been modified. The user must now initialize the user session object and pass it as parameter. --- examples/httpgateway.py | 16 +- examples/store.py | 10 +- setup.cfg | 2 + src/aleph_client/asynchronous.py | 284 ++++++++------------------ src/aleph_client/commands/files.py | 75 +++---- src/aleph_client/commands/message.py | 44 ++-- src/aleph_client/commands/program.py | 186 ++++++++--------- src/aleph_client/user_session.py | 33 +++ tests/integration/itest_aggregates.py | 32 +-- tests/integration/itest_forget.py | 91 +++++---- tests/integration/itest_posts.py | 35 ++-- tests/unit/conftest.py | 17 +- tests/unit/test_asynchronous.py | 106 ++++------ tests/unit/test_asynchronous_get.py | 102 ++++++--- tests/unit/test_synchronous_get.py | 20 ++ tests/unit/test_vm_cache.py | 3 + 16 files changed, 491 insertions(+), 565 deletions(-) create mode 100644 src/aleph_client/user_session.py create mode 100644 tests/unit/test_synchronous_get.py diff --git a/examples/httpgateway.py b/examples/httpgateway.py index 304215c4..3ece34df 100644 --- a/examples/httpgateway.py +++ b/examples/httpgateway.py @@ -19,6 +19,8 @@ from aiohttp import web +from aleph_client.user_session import UserSession + app = web.Application() routes = web.RouteTableDef() @@ -42,13 +44,13 @@ async def source_post(request): return web.json_response( {"status": "error", "message": "unauthorized secret"} ) - message, _status = await create_post( - app["account"], - data, - "event", - channel=app["channel"], - api_server="https://api2.aleph.im", - ) + async with UserSession(account=app["account"], api_server="https://api2.aleph.im") as session: + message, _status = await create_post( + session=session, + post_content=data, + post_type="event", + channel=app["channel"], + ) return web.json_response({"status": "success", "item_hash": message.item_hash}) diff --git a/examples/store.py b/examples/store.py index 493292ed..38cf7c02 100644 --- a/examples/store.py +++ b/examples/store.py @@ -7,7 +7,9 @@ from aleph_client.asynchronous import create_store from aleph_client.chains.common import get_fallback_private_key from aleph_client.chains.ethereum import ETHAccount +from aleph_client.conf import settings from aleph_client.types import MessageStatus +from aleph_client.user_session import UserSession DEFAULT_SERVER = "https://api2.aleph.im" @@ -23,7 +25,7 @@ async def print_output_hash(message: StoreMessage, status: MessageStatus): async def do_upload(account, engine, channel, filename=None, file_hash=None): - async with aiohttp.ClientSession() as session: + async with UserSession(account=account, api_server=settings.API_HOST) as session: print(filename, account.get_address()) if filename: try: @@ -34,11 +36,10 @@ async def do_upload(account, engine, channel, filename=None, file_hash=None): print("File too big for native STORAGE engine") return message, status = await create_store( - account, + session=session, file_content=content, channel=channel, storage_engine=engine.lower(), - session=session, ) except IOError: print("File not accessible") @@ -46,11 +47,10 @@ async def do_upload(account, engine, channel, filename=None, file_hash=None): elif file_hash: message, status = await create_store( - account, + session=session, file_hash=file_hash, channel=channel, storage_engine=engine.lower(), - session=session, ) await print_output_hash(message, status) diff --git a/setup.cfg b/setup.cfg index 51c7bed4..8268da61 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,6 +60,7 @@ testing = pytest pytest-cov pytest-asyncio + pytest-mock mypy secp256k1 pynacl @@ -69,6 +70,7 @@ testing = httpx requests aleph-pytezos==0.1.0 + types-certifi types-setuptools black mqtt = diff --git a/src/aleph_client/asynchronous.py b/src/aleph_client/asynchronous.py index 5e441b88..8f409ddd 100644 --- a/src/aleph_client/asynchronous.py +++ b/src/aleph_client/asynchronous.py @@ -5,10 +5,8 @@ import json import logging import queue -import threading import time from datetime import datetime -from functools import lru_cache from pathlib import Path from typing import ( Optional, @@ -21,6 +19,8 @@ ) from typing import Type, Mapping, Tuple, NoReturn +import aiohttp +from aiohttp import ClientSession from aleph_message.models import ( ForgetContent, MessageType, @@ -36,9 +36,11 @@ ProgramMessage, ItemType, ) +from aleph_message.models.program import ProgramContent, Encoding from pydantic import ValidationError from aleph_client.types import Account, StorageEnum, GenericMessage, MessageStatus +from .conf import settings from .exceptions import ( MessageNotFoundError, MultipleMessagesError, @@ -46,6 +48,7 @@ BroadcastError, ) from .models import MessagesResponse +from .user_session import UserSession from .utils import get_message_type_value logger = logging.getLogger(__name__) @@ -56,86 +59,51 @@ logger.info("Could not import library 'magic', MIME type detection disabled") magic = None # type:ignore -from .conf import settings - -import aiohttp -from aiohttp import ClientSession - -from aleph_message.models.program import ProgramContent, Encoding - - -@lru_cache() -def _get_fallback_session(thread_id: Optional[int]) -> ClientSession: - if settings.API_UNIX_SOCKET: - connector = aiohttp.UnixConnector(path=settings.API_UNIX_SOCKET) - return aiohttp.ClientSession(connector=connector) - else: - return aiohttp.ClientSession() - - -def get_fallback_session() -> ClientSession: - thread_id = threading.get_native_id() - return _get_fallback_session(thread_id=thread_id) - -async def ipfs_push( - content: Mapping, - session: ClientSession, - api_server: str, -) -> str: +async def ipfs_push(session: UserSession, content: Mapping) -> str: """Push arbitrary content as JSON to the IPFS service.""" - url = f"{api_server}/api/v0/ipfs/add_json" + + url = "/api/v0/ipfs/add_json" logger.debug(f"Pushing to IPFS on {url}") - async with session.post(url, json=content) as resp: + async with session.http_session.post(url, json=content) as resp: resp.raise_for_status() return (await resp.json()).get("hash") -async def storage_push( - content: Mapping, - session: ClientSession, - api_server: str, -) -> str: +async def storage_push(session: UserSession, content: Mapping) -> str: """Push arbitrary content as JSON to the storage service.""" - url = f"{api_server}/api/v0/storage/add_json" + + url = "/api/v0/storage/add_json" logger.debug(f"Pushing to storage on {url}") - async with session.post(url, json=content) as resp: + async with session.http_session.post(url, json=content) as resp: resp.raise_for_status() return (await resp.json()).get("hash") -async def ipfs_push_file( - file_content, - session: ClientSession, - api_server: str, -) -> str: +async def ipfs_push_file(session: UserSession, file_content: Union[str, bytes]) -> str: """Push a file to the IPFS service.""" data = aiohttp.FormData() data.add_field("file", file_content) - url = f"{api_server}/api/v0/ipfs/add_file" + url = "/api/v0/ipfs/add_file" logger.debug(f"Pushing file to IPFS on {url}") - async with session.post(url, data=data) as resp: + async with session.http_session.post(url, data=data) as resp: resp.raise_for_status() return (await resp.json()).get("hash") -async def storage_push_file( - file_content, - session: ClientSession, - api_server: str, -) -> str: +async def storage_push_file(session: UserSession, file_content) -> str: """Push a file to the storage service.""" data = aiohttp.FormData() data.add_field("file", file_content) - url = f"{api_server}/api/v0/storage/add_file" + url = "/api/v0/storage/add_file" logger.debug(f"Posting file on {url}") - async with session.post(url, data=data) as resp: + async with session.http_session.post(url, data=data) as resp: resp.raise_for_status() return (await resp.json()).get("hash") @@ -194,20 +162,18 @@ async def _handle_broadcast_deprecated_response( async def _broadcast_deprecated( - message_dict: Mapping[str, Any], - session: ClientSession, - api_server: str = settings.API_HOST, -): + session: UserSession, message_dict: Mapping[str, Any] +) -> None: """ Broadcast a message on the Aleph network using the deprecated /ipfs/pubsub/pub/ endpoint. """ - url = f"{api_server}/api/v0/ipfs/pubsub/pub" + url = "/api/v0/ipfs/pubsub/pub" logger.debug(f"Posting message on {url}") - async with session.post( + async with session.http_session.post( url, json={"topic": "ALEPH-TEST", "data": json.dumps(message_dict)}, ) as response: @@ -246,10 +212,9 @@ async def _handle_broadcast_response( async def _broadcast( + session: UserSession, message: AlephMessage, sync: bool, - session: ClientSession, - api_server: str, ) -> MessageStatus: """ Broadcast a message on the Aleph network. @@ -258,12 +223,12 @@ async def _broadcast( if the first method is not available. """ - url = f"{api_server}/api/v0/messages" + url = "/api/v0/messages" logger.debug(f"Posting message on {url}") message_dict = message.dict(include=BROADCAST_MESSAGE_FIELDS) - async with session.post( + async with session.http_session.post( url, json={"sync": sync, "message": message_dict}, ) as response: @@ -272,9 +237,7 @@ async def _broadcast( logger.warning( "POST /messages/ not found. Defaulting to legacy endpoint..." ) - await _broadcast_deprecated( - message_dict=message_dict, session=session, api_server=api_server - ) + await _broadcast_deprecated(message_dict=message_dict, session=session) return MessageStatus.PENDING else: message_status = await _handle_broadcast_response( @@ -284,14 +247,12 @@ async def _broadcast( async def create_post( - account: Account, + session: UserSession, post_content, post_type: str, ref: Optional[str] = None, address: Optional[str] = None, channel: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, inline: bool = True, storage_engine: StorageEnum = StorageEnum.storage, sync: bool = False, @@ -299,20 +260,17 @@ async def create_post( """ Create a POST message on the Aleph network. It is associated with a channel and owned by an account. - :param account: The account that will sign and own the message + :param session: The current user session object :param post_content: The content of the message :param post_type: An arbitrary content type that helps to describe the post_content :param ref: A reference to a previous message that it replaces :param address: The address that will be displayed as the author of the message :param channel: The channel that the message will be posted on - :param session: An optional aiohttp session to use for the request - :param api_server: An optional API server to use for the request (Default: "https://api2.aleph.im") :param inline: An optional flag to indicate if the content should be inlined in the message or not :param storage_engine: An optional storage engine to use for the message, if not inlined (Default: "storage") :param sync: If true, waits for the message to be processed by the API server (Default: False) """ - address = address or settings.ADDRESS_TO_USE or account.get_address() - api_server = api_server or settings.API_HOST + address = address or settings.ADDRESS_TO_USE or session.account.get_address() content = PostContent( type=post_type, @@ -323,12 +281,10 @@ async def create_post( ) return await submit( - account=account, + session=session, content=content.dict(exclude_none=True), message_type=MessageType.post, channel=channel, - api_server=api_server, - session=session, allow_inlining=inline, storage_engine=storage_engine, sync=sync, @@ -336,31 +292,26 @@ async def create_post( async def create_aggregate( - account: Account, - key, - content, + session: UserSession, + key: str, + content: Mapping[str, Any], address: Optional[str] = None, channel: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, inline: bool = True, sync: bool = False, ) -> Tuple[AggregateMessage, MessageStatus]: """ Create an AGGREGATE message. It is meant to be used as a quick access storage associated with an account. - :param account: Account to use to sign the message + :param session: The current user session object :param key: Key to use to store the content :param content: Content to store :param address: Address to use to sign the message :param channel: Channel to use (Default: "TEST") - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") :param inline: Whether to write content inside the message (Default: True) :param sync: If true, waits for the message to be processed by the API server (Default: False) """ - address = address or settings.ADDRESS_TO_USE or account.get_address() - api_server = api_server or settings.API_HOST + address = address or settings.ADDRESS_TO_USE or session.account.get_address() content_ = AggregateContent( key=key, @@ -370,19 +321,17 @@ async def create_aggregate( ) return await submit( - account=account, + session=session, content=content_.dict(exclude_none=True), message_type=MessageType.aggregate, channel=channel, - api_server=api_server, - session=session, allow_inlining=inline, sync=sync, ) async def create_store( - account: Account, + session: UserSession, address: Optional[str] = None, file_content: Optional[bytes] = None, file_path: Optional[Union[str, Path]] = None, @@ -392,8 +341,6 @@ async def create_store( storage_engine: StorageEnum = StorageEnum.storage, extra_fields: Optional[dict] = None, channel: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, sync: bool = False, ) -> Tuple[StoreMessage, MessageStatus]: """ @@ -401,7 +348,7 @@ async def create_store( Can be passed either a file path, an IPFS hash or the file's content as raw bytes. - :param account: Account to use to sign the message + :param session: The current user session object :param address: Address to display as the author of the message (Default: account.get_address()) :param file_content: Byte stream of the file to store (Default: None) :param file_path: Path to the file to store (Default: None) @@ -411,15 +358,11 @@ async def create_store( :param storage_engine: Storage engine to use (Default: "storage") :param extra_fields: Extra fields to add to the STORE message (Default: None) :param channel: Channel to post the message to (Default: "TEST") - :param session: aiohttp session to use (Default: get_fallback_session()) - :param api_server: Aleph API server to use (Default: "https://api2.aleph.im") :param sync: If true, waits for the message to be processed by the API server (Default: False) """ - address = address or settings.ADDRESS_TO_USE or account.get_address() - api_server = api_server or settings.API_HOST + address = address or settings.ADDRESS_TO_USE or session.account.get_address() extra_fields = extra_fields or {} - session = session or get_fallback_session() if file_hash is None: if file_content is None: @@ -432,12 +375,10 @@ async def create_store( if storage_engine == StorageEnum.storage: file_hash = await storage_push_file( - file_content, session=session, api_server=api_server + session=session, file_content=file_content ) elif storage_engine == StorageEnum.ipfs: - file_hash = await ipfs_push_file( - file_content, session=session, api_server=api_server - ) + file_hash = await ipfs_push_file(session=session, file_content=file_content) else: raise ValueError(f"Unknown storage engine: '{storage_engine}'") @@ -463,19 +404,17 @@ async def create_store( content = StoreContent(**values) return await submit( - account=account, + session=session, content=content.dict(exclude_none=True), message_type=MessageType.store, channel=channel, - api_server=api_server, - session=session, allow_inlining=True, sync=sync, ) async def create_program( - account: Account, + session: UserSession, program_ref: str, entrypoint: str, runtime: str, @@ -483,8 +422,6 @@ async def create_program( storage_engine: StorageEnum = StorageEnum.storage, channel: Optional[str] = None, address: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, sync: bool = False, memory: Optional[int] = None, vcpus: Optional[int] = None, @@ -497,7 +434,7 @@ async def create_program( """ Post a (create) PROGRAM message. - :param account: Account to use to sign the message + :param session: The current user session object :param program_ref: Reference to the program to run :param entrypoint: Entrypoint to run :param runtime: Runtime to use @@ -505,8 +442,6 @@ async def create_program( :param storage_engine: Storage engine to use (Default: "storage") :param channel: Channel to use (Default: "TEST") :param address: Address to use (Default: account.get_address()) - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") :param sync: If true, waits for the message to be processed by the API server :param memory: Memory in MB for the VM to be allocated (Default: 128) :param vcpus: Number of vCPUs to allocate (Default: 1) @@ -516,8 +451,7 @@ async def create_program( :param volumes: Volumes to mount :param subscriptions: Patterns of Aleph messages to forward to the program's event receiver """ - address = address or settings.ADDRESS_TO_USE or account.get_address() - api_server = api_server or settings.API_HOST + address = address or settings.ADDRESS_TO_USE or session.account.get_address() volumes = volumes if volumes is not None else [] memory = memory or settings.DEFAULT_VM_MEMORY @@ -573,26 +507,22 @@ async def create_program( assert content.on.persistent == persistent return await submit( - account=account, + session=session, content=content.dict(exclude_none=True), message_type=MessageType.program, channel=channel, - api_server=api_server, storage_engine=storage_engine, - session=session, sync=sync, ) async def forget( - account: Account, + session: UserSession, hashes: List[str], reason: Optional[str], storage_engine: StorageEnum = StorageEnum.storage, channel: Optional[str] = None, address: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, sync: bool = False, ) -> Tuple[ForgetMessage, MessageStatus]: """ @@ -601,18 +531,15 @@ async def forget( Targeted messages need to be signed by the same account that is attempting to forget them, if the creating address did not delegate the access rights to the forgetting account. - :param account: Account to use to sign the message + :param session: The current user session object :param hashes: Hashes of the messages to forget :param reason: Reason for forgetting the messages :param storage_engine: Storage engine to use (Default: "storage") :param channel: Channel to use (Default: "TEST") :param address: Address to use (Default: account.get_address()) - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") :param sync: If true, waits for the message to be processed by the API server (Default: False) """ - address = address or settings.ADDRESS_TO_USE or account.get_address() - api_server = api_server or settings.API_HOST + address = address or settings.ADDRESS_TO_USE or session.account.get_address() content = ForgetContent( hashes=hashes, @@ -622,13 +549,11 @@ async def forget( ) return await submit( - account, + session=session, content=content.dict(exclude_none=True), message_type=MessageType.forget, channel=channel, - api_server=api_server, storage_engine=storage_engine, - session=session, allow_inlining=True, sync=sync, ) @@ -641,19 +566,17 @@ def compute_sha256(s: str) -> str: async def _prepare_aleph_message( - account: Account, + session: UserSession, message_type: MessageType, content: Dict[str, Any], channel: Optional[str], - session: aiohttp.ClientSession, - api_server: str, allow_inlining: bool = True, storage_engine: StorageEnum = StorageEnum.storage, ) -> AlephMessage: message_dict: Dict[str, Any] = { - "sender": account.get_address(), - "chain": account.CHAIN, + "sender": session.account.get_address(), + "chain": session.account.CHAIN, "type": message_type, "content": content, "time": time.time(), @@ -669,75 +592,65 @@ async def _prepare_aleph_message( else: if storage_engine == StorageEnum.ipfs: message_dict["item_hash"] = await ipfs_push( - content, session=session, api_server=api_server + session=session, + content=content, ) message_dict["item_type"] = ItemType.ipfs else: # storage assert storage_engine == StorageEnum.storage message_dict["item_hash"] = await storage_push( - content, session=session, api_server=api_server + session=session, + content=content, ) message_dict["item_type"] = ItemType.storage - message_dict = await account.sign_message(message_dict) + message_dict = await session.account.sign_message(message_dict) return Message(**message_dict) async def submit( - account: Account, + session: UserSession, content: Dict[str, Any], message_type: MessageType, - api_server: str, channel: Optional[str] = None, storage_engine: StorageEnum = StorageEnum.storage, - session: Optional[ClientSession] = None, allow_inlining: bool = True, sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: - session = session or get_fallback_session() - message = await _prepare_aleph_message( - account=account, + session=session, message_type=message_type, content=content, channel=channel, - session=session, - api_server=api_server, allow_inlining=allow_inlining, storage_engine=storage_engine, ) - message_status = await _broadcast( - message=message, session=session, api_server=api_server, sync=sync - ) + message_status = await _broadcast(session=session, message=message, sync=sync) return message, message_status async def fetch_aggregate( + session: UserSession, address: str, key: str, limit: int = 100, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, ) -> Dict[str, Dict]: """ Fetch a value from the aggregate store by owner address and item key. + :param session: The current user session object :param address: Address of the owner of the aggregate :param key: Key of the aggregate :param limit: Maximum number of items to fetch (Default: 100) - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") """ - session = session or get_fallback_session() - api_server = api_server or settings.API_HOST params: Dict[str, Any] = {"keys": key} if limit: params["limit"] = limit - async with session.get( - f"{api_server}/api/v0/aggregates/{address}.json", params=params + async with session.http_session.get( + f"/api/v0/aggregates/{address}.json", params=params ) as resp: result = await resp.json() data = result.get("data", dict()) @@ -745,23 +658,19 @@ async def fetch_aggregate( async def fetch_aggregates( + session: UserSession, address: str, keys: Optional[Iterable[str]] = None, limit: int = 100, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, ) -> Dict[str, Dict]: """ Fetch key-value pairs from the aggregate store by owner address. + :param session: The current user session object :param address: Address of the owner of the aggregate :param keys: Keys of the aggregates to fetch (Default: all items) :param limit: Maximum number of items to fetch (Default: 100) - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") """ - session = session or get_fallback_session() - api_server = api_server or settings.API_HOST keys_str = ",".join(keys) if keys else "" params: Dict[str, Any] = {} @@ -770,8 +679,8 @@ async def fetch_aggregates( if limit: params["limit"] = limit - async with session.get( - f"{api_server}/api/v0/aggregates/{address}.json", + async with session.http_session.get( + f"/api/v0/aggregates/{address}.json", params=params, ) as resp: result = await resp.json() @@ -780,6 +689,7 @@ async def fetch_aggregates( async def get_posts( + session: UserSession, pagination: int = 200, page: int = 1, types: Optional[Iterable[str]] = None, @@ -791,12 +701,11 @@ async def get_posts( chains: Optional[Iterable[str]] = None, start_date: Optional[Union[datetime, float]] = None, end_date: Optional[Union[datetime, float]] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, ) -> Dict[str, Dict]: """ Fetch a list of posts from the network. + :param session: The current user session object :param pagination: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) :param types: Types of posts to fetch (Default: all types) @@ -808,11 +717,7 @@ async def get_posts( :param chains: Chains of the posts to fetch (Default: all chains) :param start_date: Earliest date to fetch messages from :param end_date: Latest date to fetch messages from - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") """ - session = session or get_fallback_session() - api_server = api_server or settings.API_HOST params: Dict[str, Any] = dict(pagination=pagination, page=page) @@ -840,34 +745,30 @@ async def get_posts( end_date = end_date.timestamp() params["endDate"] = end_date - async with session.get(f"{api_server}/api/v0/posts.json", params=params) as resp: + async with session.http_session.get("/api/v0/posts.json", params=params) as resp: resp.raise_for_status() return await resp.json() async def download_file( + session: UserSession, file_hash: str, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, ) -> bytes: """ Get a file from the storage engine as raw bytes. Warning: Downloading large files can be slow and memory intensive. + :param session: The current user session object :param file_hash: The hash of the file to retrieve. - :param session: The aiohttp session to use. (Default: get_fallback_session()) - :param api_server: The API server to use. (Default: "https://api2.aleph.im") """ - session = session or get_fallback_session() - api_server = api_server or settings.API_HOST - - async with session.get(f"{api_server}/api/v0/storage/raw/{file_hash}") as response: + async with session.http_session.get(f"/api/v0/storage/raw/{file_hash}") as response: response.raise_for_status() return await response.read() async def get_messages( + session: UserSession, pagination: int = 200, page: int = 1, message_type: Optional[MessageType] = None, @@ -881,14 +782,13 @@ async def get_messages( chains: Optional[Iterable[str]] = None, start_date: Optional[Union[datetime, float]] = None, end_date: Optional[Union[datetime, float]] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, ignore_invalid_messages: bool = True, invalid_messages_log_level: int = logging.NOTSET, ) -> MessagesResponse: """ Fetch a list of messages from the network. + :param session: The current user session object :param pagination: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" @@ -902,13 +802,9 @@ async def get_messages( :param chains: Filter by sender address chain :param start_date: Earliest date to fetch messages from :param end_date: Latest date to fetch messages from - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") :param ignore_invalid_messages: Ignore invalid messages (Default: False) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ - session = session or get_fallback_session() - api_server = api_server or settings.API_HOST ignore_invalid_messages = ( True if ignore_invalid_messages is None else ignore_invalid_messages ) @@ -948,7 +844,7 @@ async def get_messages( end_date = end_date.timestamp() params["endDate"] = end_date - async with session.get(f"{api_server}/api/v0/messages.json", params=params) as resp: + async with session.http_session.get("/api/v0/messages.json", params=params) as resp: resp.raise_for_status() response_json = await resp.json() messages_raw = response_json["messages"] @@ -983,26 +879,23 @@ async def get_messages( async def get_message( + session: UserSession, item_hash: str, message_type: Optional[Type[GenericMessage]] = None, channel: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, ) -> GenericMessage: """ Get a single message from its `item_hash` and perform some basic validation. + :param session: The current user session object :param item_hash: Hash of the message to fetch :param message_type: Type of message to fetch :param channel: Channel of the message to fetch - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") """ messages_response = await get_messages( - hashes=[item_hash], session=session, + hashes=[item_hash], channels=[channel] if channel else None, - api_server=api_server, ) if len(messages_response.messages) < 1: raise MessageNotFoundError(f"No such hash {item_hash}") @@ -1022,6 +915,7 @@ async def get_message( async def watch_messages( + session: UserSession, message_type: Optional[MessageType] = None, content_types: Optional[Iterable[str]] = None, refs: Optional[Iterable[str]] = None, @@ -1032,12 +926,11 @@ async def watch_messages( chains: Optional[Iterable[str]] = None, start_date: Optional[Union[datetime, float]] = None, end_date: Optional[Union[datetime, float]] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, ) -> AsyncIterable[AlephMessage]: """ Iterate over current and future matching messages asynchronously. + :param session: The current user session object :param message_type: Type of message to watch :param content_types: Content types to watch :param refs: References to watch @@ -1048,12 +941,7 @@ async def watch_messages( :param chains: Chains to watch :param start_date: Start date from when to watch :param end_date: End date until when to watch - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") """ - session = session or get_fallback_session() - api_server = api_server or settings.API_HOST - params: Dict[str, Any] = dict() if message_type is not None: @@ -1082,8 +970,8 @@ async def watch_messages( end_date = end_date.timestamp() params["endDate"] = end_date - async with session.ws_connect( - f"{api_server}/api/ws0/messages", params=params + async with session.http_session.ws_connect( + f"/api/ws0/messages", params=params ) as ws: logger.debug("Websocket connected") async for msg in ws: diff --git a/src/aleph_client/commands/files.py b/src/aleph_client/commands/files.py index 2c09b89b..5b945aca 100644 --- a/src/aleph_client/commands/files.py +++ b/src/aleph_client/commands/files.py @@ -1,4 +1,3 @@ -import asyncio import logging from pathlib import Path from typing import Optional @@ -8,7 +7,6 @@ from aleph_client import synchronous from aleph_client.account import _load_account -from aleph_client.asynchronous import get_fallback_session from aleph_client.commands import help_strings from aleph_client.commands.utils import setup_logging from aleph_client.conf import settings @@ -37,20 +35,15 @@ def pin( account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - try: - result: StoreMessage = synchronous.create_store( - account=account, - file_hash=hash, - storage_engine=StorageEnum.ipfs, - channel=channel, - ref=ref, - ) - logger.debug("Upload finished") - typer.echo(f"{result.json(indent=4)}") - finally: - # Prevent aiohttp unclosed connector warning - asyncio.run(get_fallback_session().close()) - + result: StoreMessage = synchronous.create_store( + account=account, + file_hash=hash, + storage_engine=StorageEnum.ipfs, + channel=channel, + ref=ref, + ) + logger.debug("Upload finished") + typer.echo(f"{result.json(indent=4)}") @app.command() def upload( @@ -71,31 +64,27 @@ def upload( account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - try: - if not path.is_file(): - typer.echo(f"Error: File not found: '{path}'") - raise typer.Exit(code=1) + if not path.is_file(): + typer.echo(f"Error: File not found: '{path}'") + raise typer.Exit(code=1) - with open(path, "rb") as fd: - logger.debug("Reading file") - # TODO: Read in lazy mode instead of copying everything in memory - file_content = fd.read() - storage_engine = ( - StorageEnum.ipfs - if len(file_content) > 4 * 1024 * 1024 - else StorageEnum.storage - ) - logger.debug("Uploading file") - result: StoreMessage = synchronous.create_store( - account=account, - file_content=file_content, - storage_engine=storage_engine, - channel=channel, - guess_mime_type=True, - ref=ref, - ) - logger.debug("Upload finished") - typer.echo(f"{result.json(indent=4)}") - finally: - # Prevent aiohttp unclosed connector warning - asyncio.run(get_fallback_session().close()) + with open(path, "rb") as fd: + logger.debug("Reading file") + # TODO: Read in lazy mode instead of copying everything in memory + file_content = fd.read() + storage_engine = ( + StorageEnum.ipfs + if len(file_content) > 4 * 1024 * 1024 + else StorageEnum.storage + ) + logger.debug("Uploading file") + result: StoreMessage = synchronous.create_store( + account=account, + file_content=file_content, + storage_engine=storage_engine, + channel=channel, + guess_mime_type=True, + ref=ref, + ) + logger.debug("Upload finished") + typer.echo(f"{result.json(indent=4)}") diff --git a/src/aleph_client/commands/message.py b/src/aleph_client/commands/message.py index 6fb5675c..09de29da 100644 --- a/src/aleph_client/commands/message.py +++ b/src/aleph_client/commands/message.py @@ -1,4 +1,3 @@ -import asyncio import json import os.path import subprocess @@ -15,7 +14,6 @@ from aleph_client import synchronous from aleph_client.account import _load_account -from aleph_client.asynchronous import get_fallback_session from aleph_client.commands import help_strings from aleph_client.commands.utils import ( setup_logging, @@ -78,20 +76,16 @@ def post( typer.echo("Not valid JSON") raise typer.Exit(code=2) - try: - result: PostMessage = synchronous.create_post( - account=account, - post_content=content, - post_type=type, - ref=ref, - channel=channel, - inline=True, - storage_engine=storage_engine, - ) - typer.echo(result.json(indent=4)) - finally: - # Prevent aiohttp unclosed connector warning - asyncio.run(get_fallback_session().close()) + result: PostMessage = synchronous.create_post( + account=account, + post_content=content, + post_type=type, + ref=ref, + channel=channel, + inline=True, + storage_engine=storage_engine, + ) + typer.echo(result.json(indent=4)) @app.command() @@ -146,17 +140,13 @@ def forget_messages( reason: Optional[str], channel: str, ): - try: - result: ForgetMessage = synchronous.forget( - account=account, - hashes=hashes, - reason=reason, - channel=channel, - ) - typer.echo(f"{result.json(indent=4)}") - finally: - # Prevent aiohttp unclosed connector warning - asyncio.run(get_fallback_session().close()) + result: ForgetMessage = synchronous.forget( + account=account, + hashes=hashes, + reason=reason, + channel=channel, + ) + typer.echo(f"{result.json(indent=4)}") @app.command() diff --git a/src/aleph_client/commands/program.py b/src/aleph_client/commands/program.py index 2e5712d0..2337076e 100644 --- a/src/aleph_client/commands/program.py +++ b/src/aleph_client/commands/program.py @@ -16,7 +16,6 @@ from aleph_client import synchronous from aleph_client.account import _load_account -from aleph_client.asynchronous import get_fallback_session from aleph_client.commands import help_strings from aleph_client.commands.utils import ( setup_logging, @@ -147,66 +146,61 @@ def upload( else: subscriptions = None - try: - # Upload the source code - with open(path_object, "rb") as fd: - logger.debug("Reading file") - # TODO: Read in lazy mode instead of copying everything in memory - file_content = fd.read() - storage_engine = ( - StorageEnum.ipfs - if len(file_content) > 4 * 1024 * 1024 - else StorageEnum.storage - ) - logger.debug("Uploading file") - user_code: StoreMessage = synchronous.create_store( - account=account, - file_content=file_content, - storage_engine=storage_engine, - channel=channel, - guess_mime_type=True, - ref=None, - ) - logger.debug("Upload finished") - if print_messages or print_code_message: - typer.echo(f"{user_code.json(indent=4)}") - program_ref = user_code.item_hash - - # Register the program - message, status = synchronous.create_program( + # Upload the source code + with open(path_object, "rb") as fd: + logger.debug("Reading file") + # TODO: Read in lazy mode instead of copying everything in memory + file_content = fd.read() + storage_engine = ( + StorageEnum.ipfs + if len(file_content) > 4 * 1024 * 1024 + else StorageEnum.storage + ) + logger.debug("Uploading file") + user_code: StoreMessage = synchronous.create_store( account=account, - program_ref=program_ref, - entrypoint=entrypoint, - runtime=runtime, - storage_engine=StorageEnum.storage, + file_content=file_content, + storage_engine=storage_engine, channel=channel, - memory=memory, - vcpus=vcpus, - timeout_seconds=timeout_seconds, - persistent=persistent, - encoding=encoding, - volumes=volumes, - subscriptions=subscriptions, + guess_mime_type=True, + ref=None, ) logger.debug("Upload finished") - if print_messages or print_program_message: - typer.echo(f"{message.json(indent=4)}") - - hash: str = message.item_hash - hash_base32 = b32encode(b16decode(hash.upper())).strip(b"=").lower().decode() - - typer.echo( - f"Your program has been uploaded on Aleph .\n\n" - "Available on:\n" - f" {settings.VM_URL_PATH.format(hash=hash)}\n" - f" {settings.VM_URL_HOST.format(hash_base32=hash_base32)}\n" - "Visualise on:\n https://explorer.aleph.im/address/" - f"{message.chain}/{message.sender}/message/PROGRAM/{hash}\n" - ) + if print_messages or print_code_message: + typer.echo(f"{user_code.json(indent=4)}") + program_ref = user_code.item_hash - finally: - # Prevent aiohttp unclosed connector warning - asyncio.run(get_fallback_session().close()) + # Register the program + message, status = synchronous.create_program( + account=account, + program_ref=program_ref, + entrypoint=entrypoint, + runtime=runtime, + storage_engine=StorageEnum.storage, + channel=channel, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, + persistent=persistent, + encoding=encoding, + volumes=volumes, + subscriptions=subscriptions, + ) + logger.debug("Upload finished") + if print_messages or print_program_message: + typer.echo(f"{message.json(indent=4)}") + + hash: str = message.item_hash + hash_base32 = b32encode(b16decode(hash.upper())).strip(b"=").lower().decode() + + typer.echo( + f"Your program has been uploaded on Aleph .\n\n" + "Available on:\n" + f" {settings.VM_URL_PATH.format(hash=hash)}\n" + f" {settings.VM_URL_HOST.format(hash_base32=hash_base32)}\n" + "Visualise on:\n https://explorer.aleph.im/address/" + f"{message.chain}/{message.sender}/message/PROGRAM/{hash}\n" + ) @app.command() @@ -225,51 +219,47 @@ def update( account = _load_account(private_key, private_key_file) path = path.absolute() + program_message: ProgramMessage = synchronous.get_message( + item_hash=hash, message_type=ProgramMessage + ) + code_ref = program_message.content.code.ref + code_message: StoreMessage = synchronous.get_message( + item_hash=code_ref, message_type=StoreMessage + ) + try: - program_message: ProgramMessage = synchronous.get_message( - item_hash=hash, message_type=ProgramMessage + path, encoding = create_archive(path) + except BadZipFile: + typer.echo("Invalid zip archive") + raise typer.Exit(3) + except FileNotFoundError: + typer.echo("No such file or directory") + raise typer.Exit(4) + + if encoding != program_message.content.code.encoding: + logger.error( + f"Code must be encoded with the same encoding as the previous version " + f"('{encoding}' vs '{program_message.content.code.encoding}'" ) - code_ref = program_message.content.code.ref - code_message: StoreMessage = synchronous.get_message( - item_hash=code_ref, message_type=StoreMessage + raise typer.Exit(1) + + # Upload the source code + with open(path, "rb") as fd: + logger.debug("Reading file") + # TODO: Read in lazy mode instead of copying everything in memory + file_content = fd.read() + logger.debug("Uploading file") + message, status = synchronous.create_store( + account=account, + file_content=file_content, + storage_engine=code_message.content.item_type, + channel=code_message.channel, + guess_mime_type=True, + ref=code_message.item_hash, ) - - try: - path, encoding = create_archive(path) - except BadZipFile: - typer.echo("Invalid zip archive") - raise typer.Exit(3) - except FileNotFoundError: - typer.echo("No such file or directory") - raise typer.Exit(4) - - if encoding != program_message.content.code.encoding: - logger.error( - f"Code must be encoded with the same encoding as the previous version " - f"('{encoding}' vs '{program_message.content.code.encoding}'" - ) - raise typer.Exit(1) - - # Upload the source code - with open(path, "rb") as fd: - logger.debug("Reading file") - # TODO: Read in lazy mode instead of copying everything in memory - file_content = fd.read() - logger.debug("Uploading file") - message, status = synchronous.create_store( - account=account, - file_content=file_content, - storage_engine=code_message.content.item_type, - channel=code_message.channel, - guess_mime_type=True, - ref=code_message.item_hash, - ) - logger.debug("Upload finished") - if print_message: - typer.echo(f"{message.json(indent=4)}") - finally: - # Prevent aiohttp unclosed connector warning - asyncio.run(get_fallback_session().close()) + logger.debug("Upload finished") + if print_message: + typer.echo(f"{message.json(indent=4)}") @app.command() diff --git a/src/aleph_client/user_session.py b/src/aleph_client/user_session.py new file mode 100644 index 00000000..c72c5082 --- /dev/null +++ b/src/aleph_client/user_session.py @@ -0,0 +1,33 @@ +import asyncio + +import aiohttp + +from aleph_client.types import Account + + +class UserSession: + account: Account + api_server: str + http_session: aiohttp.ClientSession + + def __init__(self, account: Account, api_server: str): + self.account = account + self.api_server = api_server + self.http_session = aiohttp.ClientSession(base_url=api_server) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + close_fut = self.http_session.close() + try: + loop = asyncio.get_running_loop() + loop.run_until_complete(close_fut) + except RuntimeError: + asyncio.run(close_fut) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.http_session.close() diff --git a/tests/integration/itest_aggregates.py b/tests/integration/itest_aggregates.py index 87b71c66..1ceda2e2 100644 --- a/tests/integration/itest_aggregates.py +++ b/tests/integration/itest_aggregates.py @@ -6,6 +6,7 @@ create_aggregate, fetch_aggregate, ) +from aleph_client.user_session import UserSession from tests.integration.toolkit import try_until from .config import REFERENCE_NODE, TARGET_NODE @@ -20,13 +21,13 @@ async def create_aggregate_on_target( receiver_node: str, channel="INTEGRATION_TESTS", ): - aggregate_message, message_status = await create_aggregate( - account=account, - key=key, - content=content, - channel="INTEGRATION_TESTS", - api_server=emitter_node, - ) + async with UserSession(account=account, api_server=emitter_node) as tx_session: + aggregate_message, message_status = await create_aggregate( + session=tx_session, + key=key, + content=content, + channel="INTEGRATION_TESTS", + ) assert aggregate_message.sender == account.get_address() assert aggregate_message.channel == channel @@ -39,14 +40,15 @@ async def create_aggregate_on_target( assert aggregate_message.content.address == account.get_address() assert aggregate_message.content.content == content - aggregate_from_receiver = await try_until( - fetch_aggregate, - lambda aggregate: aggregate is not None, - timeout=5, - address=account.get_address(), - key=key, - api_server=receiver_node, - ) + async with UserSession(account=account, api_server=receiver_node) as rx_session: + aggregate_from_receiver = await try_until( + fetch_aggregate, + lambda aggregate: aggregate is not None, + session=rx_session, + timeout=5, + address=account.get_address(), + key=key, + ) for key, value in content.items(): assert key in aggregate_from_receiver diff --git a/tests/integration/itest_forget.py b/tests/integration/itest_forget.py index 12b3fe5f..b9f2a51e 100644 --- a/tests/integration/itest_forget.py +++ b/tests/integration/itest_forget.py @@ -4,6 +4,7 @@ from aleph_client.asynchronous import create_post, get_posts, get_messages, forget from aleph_client.types import Account +from aleph_client.user_session import UserSession from .config import REFERENCE_NODE, TARGET_NODE, TEST_CHANNEL from .toolkit import try_until @@ -12,25 +13,27 @@ async def create_and_forget_post( account: Account, emitter_node: str, receiver_node: str, channel=TEST_CHANNEL ) -> str: async def wait_matching_posts( - item_hash: str, condition: Callable[[Dict], bool], timeout: int = 5 + item_hash: str, + condition: Callable[[Dict], bool], + timeout: int = 5, ): - return await try_until( - get_posts, - condition, - timeout=timeout, - hashes=[item_hash], - api_server=receiver_node, + async with UserSession(account=account, api_server=receiver_node) as rx_session: + return await try_until( + get_posts, + condition, + session=rx_session, + timeout=timeout, + hashes=[item_hash], + ) + + async with UserSession(account=account, api_server=emitter_node) as tx_session: + post_message, message_status = await create_post( + session=tx_session, + post_content="A considerate and politically correct post.", + post_type="POST", + channel="INTEGRATION_TESTS", ) - post_message, message_status = await create_post( - account=account, - post_content="A considerate and politically correct post.", - post_type="POST", - channel="INTEGRATION_TESTS", - session=None, - api_server=emitter_node, - ) - # Wait for the message to appear on the receiver. We don't check the values, # they're checked in other integration tests. get_post_response = await wait_matching_posts( @@ -41,13 +44,13 @@ async def wait_matching_posts( post_hash = post_message.item_hash reason = "This well thought-out content offends me!" - forget_message, forget_status = await forget( - account, - hashes=[post_hash], - reason=reason, - channel=channel, - api_server=emitter_node, - ) + async with UserSession(account=account, api_server=emitter_node) as tx_session: + forget_message, forget_status = await forget( + session=tx_session, + hashes=[post_hash], + reason=reason, + channel=channel, + ) assert forget_message.sender == account.get_address() assert forget_message.content.reason == reason @@ -97,26 +100,28 @@ async def test_forget_a_forget_message(fixture_account): # TODO: this test should be moved to the PyAleph API tests, once a framework is in place. post_hash = await create_and_forget_post(fixture_account, TARGET_NODE, TARGET_NODE) - get_post_response = await get_posts(hashes=[post_hash]) - assert len(get_post_response["posts"]) == 1 - post = get_post_response["posts"][0] - - forget_message_hash = post["forgotten_by"][0] - forget_message, forget_status = await forget( - fixture_account, - hashes=[forget_message_hash], - reason="I want to remember this post. Maybe I can forget I forgot it?", - channel=TEST_CHANNEL, - api_server=TARGET_NODE, - ) + async with UserSession(account=fixture_account, api_server=TARGET_NODE) as session: + get_post_response = await get_posts(session=session, hashes=[post_hash]) + assert len(get_post_response["posts"]) == 1 + post = get_post_response["posts"][0] + + forget_message_hash = post["forgotten_by"][0] + forget_message, forget_status = await forget( + session=session, + hashes=[forget_message_hash], + reason="I want to remember this post. Maybe I can forget I forgot it?", + channel=TEST_CHANNEL, + ) - print(forget_message) + print(forget_message) - get_forget_message_response = await get_messages( - hashes=[forget_message_hash], channels=[TEST_CHANNEL], api_server=TARGET_NODE - ) - assert len(get_forget_message_response.messages) == 1 - forget_message = get_forget_message_response.messages[0] - print(forget_message) + get_forget_message_response = await get_messages( + session=session, + hashes=[forget_message_hash], + channels=[TEST_CHANNEL], + ) + assert len(get_forget_message_response.messages) == 1 + forget_message = get_forget_message_response.messages[0] + print(forget_message) - assert "forgotten_by" not in forget_message + assert "forgotten_by" not in forget_message diff --git a/tests/integration/itest_posts.py b/tests/integration/itest_posts.py index 7ee89f3d..dabcdca3 100644 --- a/tests/integration/itest_posts.py +++ b/tests/integration/itest_posts.py @@ -6,6 +6,7 @@ create_post, get_messages, ) +from aleph_client.user_session import UserSession from tests.integration.toolkit import try_until from .config import REFERENCE_NODE, TARGET_NODE @@ -16,29 +17,27 @@ async def create_message_on_target( """ Create a POST message on the target node, then fetch it from the reference node. """ - - post_message, message_status = await create_post( - account=fixture_account, - post_content=None, - post_type="POST", - channel="INTEGRATION_TESTS", - session=None, - api_server=emitter_node, - ) + async with UserSession(account=fixture_account, api_server=emitter_node) as tx_session: + post_message, message_status = await create_post( + session=tx_session, + post_content=None, + post_type="POST", + channel="INTEGRATION_TESTS", + ) def response_contains_messages(response: MessagesResponse) -> bool: return len(response.messages) > 0 - # create_message = Message(**created_message_dict) - response_dict = await try_until( - get_messages, - response_contains_messages, - timeout=5, - hashes=[post_message.item_hash], - api_server=receiver_node, - ) + async with UserSession(account=fixture_account, api_server=receiver_node) as rx_session: + responses = await try_until( + get_messages, + response_contains_messages, + session=rx_session, + timeout=5, + hashes=[post_message.item_hash], + ) - message_from_target = Message(**response_dict["messages"][0]) + message_from_target = responses.messages[0] assert post_message.item_hash == message_from_target.item_hash diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 30c3c49c..9a13498c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,20 +1,17 @@ -# -*- coding: utf-8 -*- -""" - Dummy conftest.py for aleph_client. - - If you don't know what this is for, just leave it empty. - Read more about conftest.py under: - https://pytest.org/latest/plugins.html -""" +import os from pathlib import Path from tempfile import NamedTemporaryFile -import pytest +import pytest as pytest -from aleph_client.chains.common import get_fallback_private_key import aleph_client.chains.ethereum as ethereum import aleph_client.chains.sol as solana import aleph_client.chains.tezos as tezos +from aleph_client.chains.common import get_fallback_private_key +from aleph_client.chains.ethereum import ETHAccount +from aleph_client.conf import settings +from aleph_client.types import Account + @pytest.fixture def fallback_private_key() -> bytes: diff --git a/tests/unit/test_asynchronous.py b/tests/unit/test_asynchronous.py index bb689f55..b5ef2e36 100644 --- a/tests/unit/test_asynchronous.py +++ b/tests/unit/test_asynchronous.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch, AsyncMock +from unittest.mock import patch, AsyncMock import pytest as pytest from aleph_message.models import ( @@ -9,112 +9,93 @@ ForgetMessage, ) -from aleph_client.types import StorageEnum, MessageStatus - from aleph_client.asynchronous import ( create_post, - _get_fallback_session, create_aggregate, create_store, create_program, forget, ) +from aleph_client.types import StorageEnum, MessageStatus, Account -def new_mock_session_with_post_success(): - mock_response = AsyncMock() - mock_response.status = 200 +@pytest.fixture +def mock_session_with_post_success(mocker, ethereum_account: Account): + mock_response = mocker.AsyncMock() + mock_response.status = 202 mock_response.json.return_value = { - "message_status": "processed", + "message_status": "pending", "publication_status": {"status": "success", "failed": []}, } - mock_post = AsyncMock() + mock_post = mocker.AsyncMock() mock_post.return_value = mock_response - mock_session = MagicMock() + mock_session = mocker.MagicMock() mock_session.post.return_value.__aenter__ = mock_post - return mock_session + + user_session = mocker.AsyncMock() + user_session.http_session = mock_session + user_session.account = ethereum_account + + return user_session @pytest.mark.asyncio -async def test_create_post(ethereum_account): - _get_fallback_session.cache_clear() +async def test_create_post(mock_session_with_post_success): + mock_session = mock_session_with_post_success content = {"Hello": "World"} - mock_session = new_mock_session_with_post_success() - post_message, message_status = await create_post( - account=ethereum_account, + session=mock_session, post_content=content, post_type="TEST", channel="TEST", - session=mock_session, - api_server="https://example.org", - sync=True, ) - assert mock_session.post.called + assert mock_session.http_session.post.called assert isinstance(post_message, PostMessage) - assert message_status == MessageStatus.PROCESSED + assert message_status == MessageStatus.PENDING @pytest.mark.asyncio -async def test_create_aggregate(ethereum_account): - _get_fallback_session.cache_clear() +async def test_create_aggregate(mock_session_with_post_success): - content = {"Hello": "World"} - - mock_session = new_mock_session_with_post_success() - - _ = await create_aggregate( - account=ethereum_account, - key="hello", - content=content, - channel="TEST", - session=mock_session, - ) + mock_session = mock_session_with_post_success aggregate_message, message_status = await create_aggregate( - account=ethereum_account, + session=mock_session, key="hello", - content="world", + content={"Hello": "world"}, channel="TEST", - session=mock_session, - api_server="https://example.org", ) - assert mock_session.post.called + assert mock_session.http_session.post.called assert isinstance(aggregate_message, AggregateMessage) @pytest.mark.asyncio -async def test_create_store(ethereum_account): - _get_fallback_session.cache_clear() +async def test_create_store(mock_session_with_post_success): - mock_session = new_mock_session_with_post_success() + mock_session = mock_session_with_post_success mock_ipfs_push_file = AsyncMock() mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" with patch("aleph_client.asynchronous.ipfs_push_file", mock_ipfs_push_file): _ = await create_store( - account=ethereum_account, + session=mock_session, file_content=b"HELLO", channel="TEST", storage_engine=StorageEnum.ipfs, - session=mock_session, - api_server="https://example.org", ) _ = await create_store( - account=ethereum_account, + session=mock_session, file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", channel="TEST", storage_engine=StorageEnum.ipfs, - session=mock_session, - api_server="https://example.org", ) mock_storage_push_file = AsyncMock() @@ -123,54 +104,45 @@ async def test_create_store(ethereum_account): ) with patch("aleph_client.asynchronous.storage_push_file", mock_storage_push_file): - store_message, message_status = await create_store( - account=ethereum_account, + session=mock_session, file_content=b"HELLO", channel="TEST", storage_engine=StorageEnum.storage, - session=mock_session, - api_server="https://example.org", ) - assert mock_session.post.called + assert mock_session.http_session.post.called assert isinstance(store_message, StoreMessage) @pytest.mark.asyncio -async def test_create_program(ethereum_account): - _get_fallback_session.cache_clear() +async def test_create_program(mock_session_with_post_success): - mock_session = new_mock_session_with_post_success() + mock_session = mock_session_with_post_success program_message, message_status = await create_program( - account=ethereum_account, + session=mock_session, program_ref="FAKE-HASH", entrypoint="main:app", runtime="FAKE-HASH", channel="TEST", - session=mock_session, - api_server="https://example.org", ) - assert mock_session.post.called + assert mock_session.http_session.post.called assert isinstance(program_message, ProgramMessage) @pytest.mark.asyncio -async def test_forget(ethereum_account): - _get_fallback_session.cache_clear() +async def test_forget(mock_session_with_post_success): - mock_session = new_mock_session_with_post_success() + mock_session = mock_session_with_post_success forget_message, message_status = await forget( - account=ethereum_account, + session=mock_session, hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"], reason="GDPR", channel="TEST", - session=mock_session, - api_server="https://example.org", ) - assert mock_session.post.called + assert mock_session.http_session.post.called assert isinstance(forget_message, ForgetMessage) diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index 2f03a954..7703298e 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -1,64 +1,98 @@ +import unittest +from typing import Any, Dict +from unittest.mock import AsyncMock, MagicMock + import pytest from aleph_message.models import MessageType, MessagesResponse -import unittest from aleph_client.asynchronous import ( get_messages, fetch_aggregates, fetch_aggregate, - _get_fallback_session, ) +from aleph_client.conf import settings +from aleph_client.types import Account +from aleph_client.user_session import UserSession + + +def make_mock_session(mock_account: Account, get_return_value: Dict[str, Any]): + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(side_effect=lambda: get_return_value) + + mock_get = AsyncMock() + mock_get.return_value = mock_response + + mock_session = MagicMock() + mock_session.get.return_value.__aenter__ = mock_get + + user_session = AsyncMock() + user_session.http_session = mock_session + user_session.account = mock_account + + return user_session @pytest.mark.asyncio -async def test_fetch_aggregate(): - _get_fallback_session.cache_clear() +async def test_fetch_aggregate(ethereum_account: Account): + mock_session = make_mock_session( + ethereum_account, {"data": {"corechannel": {"nodes": [], "resource_nodes": []}}} + ) response = await fetch_aggregate( - address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10", key="corechannel" + session=mock_session, + address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10", + key="corechannel", ) assert response.keys() == {"nodes", "resource_nodes"} @pytest.mark.asyncio -async def test_fetch_aggregates(): - _get_fallback_session.cache_clear() +async def test_fetch_aggregates(ethereum_account: Account): + mock_session = make_mock_session( + ethereum_account, {"data": {"corechannel": {"nodes": [], "resource_nodes": []}}} + ) response = await fetch_aggregates( - address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10" + session=mock_session, address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10" ) assert response.keys() == {"corechannel"} assert response["corechannel"].keys() == {"nodes", "resource_nodes"} @pytest.mark.asyncio -async def test_get_posts(): - _get_fallback_session.cache_clear() - - response: MessagesResponse = await get_messages( - pagination=2, - message_type=MessageType.post, - ) - - messages = response.messages - assert len(messages) > 1 - for message in messages: - assert message.type == MessageType.post +async def test_get_posts(ethereum_account: Account): + async with UserSession( + account=ethereum_account, api_server=settings.API_HOST + ) as session: + response: MessagesResponse = await get_messages( + session=session, + pagination=2, + message_type=MessageType.post, + ) + + messages = response.messages + assert len(messages) > 1 + for message in messages: + assert message.type == MessageType.post @pytest.mark.asyncio -async def test_get_messages(): - _get_fallback_session.cache_clear() - - response: MessagesResponse = await get_messages( - pagination=2, - ) - - messages = response.messages - assert len(messages) > 1 - assert messages[0].type - assert messages[0].sender - - -if __name__ == '__main __': +async def test_get_messages(ethereum_account: Account): + async with UserSession( + account=ethereum_account, api_server=settings.API_HOST + ) as session: + response: MessagesResponse = await get_messages( + session=session, + pagination=2, + ) + + messages = response.messages + assert len(messages) > 1 + assert messages[0].type + assert messages[0].sender + + +if __name__ == "__main __": unittest.main() diff --git a/tests/unit/test_synchronous_get.py b/tests/unit/test_synchronous_get.py new file mode 100644 index 00000000..22cf3812 --- /dev/null +++ b/tests/unit/test_synchronous_get.py @@ -0,0 +1,20 @@ +from aleph_message.models import MessageType, MessagesResponse + +from aleph_client.conf import settings +from aleph_client.synchronous import get_messages +from aleph_client.types import Account +from aleph_client.user_session import UserSession + + +def test_get_posts(ethereum_account: Account): + with UserSession(account=ethereum_account, api_server=settings.API_HOST) as session: + response: MessagesResponse = get_messages( + session=session, + pagination=2, + message_type=MessageType.post, + ) + + messages = response.messages + assert len(messages) > 1 + for message in messages: + assert message.type == MessageType.post diff --git a/tests/unit/test_vm_cache.py b/tests/unit/test_vm_cache.py index d9aa31a7..ddab8362 100644 --- a/tests/unit/test_vm_cache.py +++ b/tests/unit/test_vm_cache.py @@ -1,3 +1,4 @@ +import aiohttp import pytest from aleph_client.vm.cache import TestVmCache, sanitize_cache_key @@ -5,6 +6,8 @@ @pytest.mark.asyncio async def test_local_vm_cache(): + http_session = aiohttp.ClientSession(base_url="http://localhost:8000") + cache = TestVmCache() assert (await cache.get("doesnotexist")) is None assert len(await (cache.keys())) == 0 From 44d8c0ee759a8a2467d5377cd17611d658887114 Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Thu, 12 Jan 2023 16:58:48 +0100 Subject: [PATCH 2/3] UserSession + AuthenticatedUserSession --- examples/httpgateway.py | 4 ++-- examples/store.py | 4 ++-- src/aleph_client/asynchronous.py | 30 ++++++++++++++------------- src/aleph_client/user_session.py | 12 ++++++++--- tests/integration/itest_aggregates.py | 6 +++--- tests/integration/itest_forget.py | 10 ++++----- tests/integration/itest_posts.py | 6 +++--- tests/unit/test_asynchronous_get.py | 24 ++++++++------------- tests/unit/test_synchronous_get.py | 5 ++--- 9 files changed, 51 insertions(+), 50 deletions(-) diff --git a/examples/httpgateway.py b/examples/httpgateway.py index 3ece34df..821cf6f5 100644 --- a/examples/httpgateway.py +++ b/examples/httpgateway.py @@ -19,7 +19,7 @@ from aiohttp import web -from aleph_client.user_session import UserSession +from aleph_client.user_session import AuthenticatedUserSession app = web.Application() routes = web.RouteTableDef() @@ -44,7 +44,7 @@ async def source_post(request): return web.json_response( {"status": "error", "message": "unauthorized secret"} ) - async with UserSession(account=app["account"], api_server="https://api2.aleph.im") as session: + async with AuthenticatedUserSession(account=app["account"], api_server="https://api2.aleph.im") as session: message, _status = await create_post( session=session, post_content=data, diff --git a/examples/store.py b/examples/store.py index 38cf7c02..7d9bc9f3 100644 --- a/examples/store.py +++ b/examples/store.py @@ -9,7 +9,7 @@ from aleph_client.chains.ethereum import ETHAccount from aleph_client.conf import settings from aleph_client.types import MessageStatus -from aleph_client.user_session import UserSession +from aleph_client.user_session import AuthenticatedUserSession DEFAULT_SERVER = "https://api2.aleph.im" @@ -25,7 +25,7 @@ async def print_output_hash(message: StoreMessage, status: MessageStatus): async def do_upload(account, engine, channel, filename=None, file_hash=None): - async with UserSession(account=account, api_server=settings.API_HOST) as session: + async with AuthenticatedUserSession(account=account, api_server=settings.API_HOST) as session: print(filename, account.get_address()) if filename: try: diff --git a/src/aleph_client/asynchronous.py b/src/aleph_client/asynchronous.py index 8f409ddd..e8c33a7f 100644 --- a/src/aleph_client/asynchronous.py +++ b/src/aleph_client/asynchronous.py @@ -48,7 +48,7 @@ BroadcastError, ) from .models import MessagesResponse -from .user_session import UserSession +from .user_session import AuthenticatedUserSession, UserSession from .utils import get_message_type_value logger = logging.getLogger(__name__) @@ -60,7 +60,7 @@ magic = None # type:ignore -async def ipfs_push(session: UserSession, content: Mapping) -> str: +async def ipfs_push(session: AuthenticatedUserSession, content: Mapping) -> str: """Push arbitrary content as JSON to the IPFS service.""" url = "/api/v0/ipfs/add_json" @@ -71,7 +71,7 @@ async def ipfs_push(session: UserSession, content: Mapping) -> str: return (await resp.json()).get("hash") -async def storage_push(session: UserSession, content: Mapping) -> str: +async def storage_push(session: AuthenticatedUserSession, content: Mapping) -> str: """Push arbitrary content as JSON to the storage service.""" url = "/api/v0/storage/add_json" @@ -82,7 +82,9 @@ async def storage_push(session: UserSession, content: Mapping) -> str: return (await resp.json()).get("hash") -async def ipfs_push_file(session: UserSession, file_content: Union[str, bytes]) -> str: +async def ipfs_push_file( + session: AuthenticatedUserSession, file_content: Union[str, bytes] +) -> str: """Push a file to the IPFS service.""" data = aiohttp.FormData() data.add_field("file", file_content) @@ -95,7 +97,7 @@ async def ipfs_push_file(session: UserSession, file_content: Union[str, bytes]) return (await resp.json()).get("hash") -async def storage_push_file(session: UserSession, file_content) -> str: +async def storage_push_file(session: AuthenticatedUserSession, file_content) -> str: """Push a file to the storage service.""" data = aiohttp.FormData() data.add_field("file", file_content) @@ -162,7 +164,7 @@ async def _handle_broadcast_deprecated_response( async def _broadcast_deprecated( - session: UserSession, message_dict: Mapping[str, Any] + session: AuthenticatedUserSession, message_dict: Mapping[str, Any] ) -> None: """ @@ -212,7 +214,7 @@ async def _handle_broadcast_response( async def _broadcast( - session: UserSession, + session: AuthenticatedUserSession, message: AlephMessage, sync: bool, ) -> MessageStatus: @@ -247,7 +249,7 @@ async def _broadcast( async def create_post( - session: UserSession, + session: AuthenticatedUserSession, post_content, post_type: str, ref: Optional[str] = None, @@ -292,7 +294,7 @@ async def create_post( async def create_aggregate( - session: UserSession, + session: AuthenticatedUserSession, key: str, content: Mapping[str, Any], address: Optional[str] = None, @@ -331,7 +333,7 @@ async def create_aggregate( async def create_store( - session: UserSession, + session: AuthenticatedUserSession, address: Optional[str] = None, file_content: Optional[bytes] = None, file_path: Optional[Union[str, Path]] = None, @@ -414,7 +416,7 @@ async def create_store( async def create_program( - session: UserSession, + session: AuthenticatedUserSession, program_ref: str, entrypoint: str, runtime: str, @@ -517,7 +519,7 @@ async def create_program( async def forget( - session: UserSession, + session: AuthenticatedUserSession, hashes: List[str], reason: Optional[str], storage_engine: StorageEnum = StorageEnum.storage, @@ -566,7 +568,7 @@ def compute_sha256(s: str) -> str: async def _prepare_aleph_message( - session: UserSession, + session: AuthenticatedUserSession, message_type: MessageType, content: Dict[str, Any], channel: Optional[str], @@ -609,7 +611,7 @@ async def _prepare_aleph_message( async def submit( - session: UserSession, + session: AuthenticatedUserSession, content: Dict[str, Any], message_type: MessageType, channel: Optional[str] = None, diff --git a/src/aleph_client/user_session.py b/src/aleph_client/user_session.py index c72c5082..db366f91 100644 --- a/src/aleph_client/user_session.py +++ b/src/aleph_client/user_session.py @@ -6,12 +6,10 @@ class UserSession: - account: Account api_server: str http_session: aiohttp.ClientSession - def __init__(self, account: Account, api_server: str): - self.account = account + def __init__(self, api_server: str): self.api_server = api_server self.http_session = aiohttp.ClientSession(base_url=api_server) @@ -31,3 +29,11 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.http_session.close() + + +class AuthenticatedUserSession(UserSession): + account: Account + + def __init__(self, account: Account, api_server: str): + super().__init__(api_server=api_server) + self.account = account diff --git a/tests/integration/itest_aggregates.py b/tests/integration/itest_aggregates.py index 1ceda2e2..7a7acb20 100644 --- a/tests/integration/itest_aggregates.py +++ b/tests/integration/itest_aggregates.py @@ -6,7 +6,7 @@ create_aggregate, fetch_aggregate, ) -from aleph_client.user_session import UserSession +from aleph_client.user_session import AuthenticatedUserSession from tests.integration.toolkit import try_until from .config import REFERENCE_NODE, TARGET_NODE @@ -21,7 +21,7 @@ async def create_aggregate_on_target( receiver_node: str, channel="INTEGRATION_TESTS", ): - async with UserSession(account=account, api_server=emitter_node) as tx_session: + async with AuthenticatedUserSession(account=account, api_server=emitter_node) as tx_session: aggregate_message, message_status = await create_aggregate( session=tx_session, key=key, @@ -40,7 +40,7 @@ async def create_aggregate_on_target( assert aggregate_message.content.address == account.get_address() assert aggregate_message.content.content == content - async with UserSession(account=account, api_server=receiver_node) as rx_session: + async with AuthenticatedUserSession(account=account, api_server=receiver_node) as rx_session: aggregate_from_receiver = await try_until( fetch_aggregate, lambda aggregate: aggregate is not None, diff --git a/tests/integration/itest_forget.py b/tests/integration/itest_forget.py index b9f2a51e..2153283b 100644 --- a/tests/integration/itest_forget.py +++ b/tests/integration/itest_forget.py @@ -4,7 +4,7 @@ from aleph_client.asynchronous import create_post, get_posts, get_messages, forget from aleph_client.types import Account -from aleph_client.user_session import UserSession +from aleph_client.user_session import AuthenticatedUserSession from .config import REFERENCE_NODE, TARGET_NODE, TEST_CHANNEL from .toolkit import try_until @@ -17,7 +17,7 @@ async def wait_matching_posts( condition: Callable[[Dict], bool], timeout: int = 5, ): - async with UserSession(account=account, api_server=receiver_node) as rx_session: + async with AuthenticatedUserSession(account=account, api_server=receiver_node) as rx_session: return await try_until( get_posts, condition, @@ -26,7 +26,7 @@ async def wait_matching_posts( hashes=[item_hash], ) - async with UserSession(account=account, api_server=emitter_node) as tx_session: + async with AuthenticatedUserSession(account=account, api_server=emitter_node) as tx_session: post_message, message_status = await create_post( session=tx_session, post_content="A considerate and politically correct post.", @@ -44,7 +44,7 @@ async def wait_matching_posts( post_hash = post_message.item_hash reason = "This well thought-out content offends me!" - async with UserSession(account=account, api_server=emitter_node) as tx_session: + async with AuthenticatedUserSession(account=account, api_server=emitter_node) as tx_session: forget_message, forget_status = await forget( session=tx_session, hashes=[post_hash], @@ -100,7 +100,7 @@ async def test_forget_a_forget_message(fixture_account): # TODO: this test should be moved to the PyAleph API tests, once a framework is in place. post_hash = await create_and_forget_post(fixture_account, TARGET_NODE, TARGET_NODE) - async with UserSession(account=fixture_account, api_server=TARGET_NODE) as session: + async with AuthenticatedUserSession(account=fixture_account, api_server=TARGET_NODE) as session: get_post_response = await get_posts(session=session, hashes=[post_hash]) assert len(get_post_response["posts"]) == 1 post = get_post_response["posts"][0] diff --git a/tests/integration/itest_posts.py b/tests/integration/itest_posts.py index dabcdca3..1344ad92 100644 --- a/tests/integration/itest_posts.py +++ b/tests/integration/itest_posts.py @@ -6,7 +6,7 @@ create_post, get_messages, ) -from aleph_client.user_session import UserSession +from aleph_client.user_session import AuthenticatedUserSession from tests.integration.toolkit import try_until from .config import REFERENCE_NODE, TARGET_NODE @@ -17,7 +17,7 @@ async def create_message_on_target( """ Create a POST message on the target node, then fetch it from the reference node. """ - async with UserSession(account=fixture_account, api_server=emitter_node) as tx_session: + async with AuthenticatedUserSession(account=fixture_account, api_server=emitter_node) as tx_session: post_message, message_status = await create_post( session=tx_session, post_content=None, @@ -28,7 +28,7 @@ async def create_message_on_target( def response_contains_messages(response: MessagesResponse) -> bool: return len(response.messages) > 0 - async with UserSession(account=fixture_account, api_server=receiver_node) as rx_session: + async with AuthenticatedUserSession(account=fixture_account, api_server=receiver_node) as rx_session: responses = await try_until( get_messages, response_contains_messages, diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index 7703298e..f69a8ee0 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -11,11 +11,10 @@ fetch_aggregate, ) from aleph_client.conf import settings -from aleph_client.types import Account from aleph_client.user_session import UserSession -def make_mock_session(mock_account: Account, get_return_value: Dict[str, Any]): +def make_mock_session(get_return_value: Dict[str, Any]): mock_response = AsyncMock() mock_response.status = 200 @@ -29,15 +28,14 @@ def make_mock_session(mock_account: Account, get_return_value: Dict[str, Any]): user_session = AsyncMock() user_session.http_session = mock_session - user_session.account = mock_account return user_session @pytest.mark.asyncio -async def test_fetch_aggregate(ethereum_account: Account): +async def test_fetch_aggregate(): mock_session = make_mock_session( - ethereum_account, {"data": {"corechannel": {"nodes": [], "resource_nodes": []}}} + {"data": {"corechannel": {"nodes": [], "resource_nodes": []}}} ) response = await fetch_aggregate( @@ -49,9 +47,9 @@ async def test_fetch_aggregate(ethereum_account: Account): @pytest.mark.asyncio -async def test_fetch_aggregates(ethereum_account: Account): +async def test_fetch_aggregates(): mock_session = make_mock_session( - ethereum_account, {"data": {"corechannel": {"nodes": [], "resource_nodes": []}}} + {"data": {"corechannel": {"nodes": [], "resource_nodes": []}}} ) response = await fetch_aggregates( @@ -62,10 +60,8 @@ async def test_fetch_aggregates(ethereum_account: Account): @pytest.mark.asyncio -async def test_get_posts(ethereum_account: Account): - async with UserSession( - account=ethereum_account, api_server=settings.API_HOST - ) as session: +async def test_get_posts(): + async with UserSession(api_server=settings.API_HOST) as session: response: MessagesResponse = await get_messages( session=session, pagination=2, @@ -79,10 +75,8 @@ async def test_get_posts(ethereum_account: Account): @pytest.mark.asyncio -async def test_get_messages(ethereum_account: Account): - async with UserSession( - account=ethereum_account, api_server=settings.API_HOST - ) as session: +async def test_get_messages(): + async with UserSession(api_server=settings.API_HOST) as session: response: MessagesResponse = await get_messages( session=session, pagination=2, diff --git a/tests/unit/test_synchronous_get.py b/tests/unit/test_synchronous_get.py index 22cf3812..a8a281a3 100644 --- a/tests/unit/test_synchronous_get.py +++ b/tests/unit/test_synchronous_get.py @@ -2,12 +2,11 @@ from aleph_client.conf import settings from aleph_client.synchronous import get_messages -from aleph_client.types import Account from aleph_client.user_session import UserSession -def test_get_posts(ethereum_account: Account): - with UserSession(account=ethereum_account, api_server=settings.API_HOST) as session: +def test_get_posts(): + with UserSession(api_server=settings.API_HOST) as session: response: MessagesResponse = get_messages( session=session, pagination=2, From fac315890d14cc3735782369827c212c596f3563 Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Fri, 13 Jan 2023 16:45:11 +0100 Subject: [PATCH 3/3] Feature: full API refactoring The session classes now provide all the SDK functionalities. The `synchronous` and `asynchronous` modules are removed in favor of the `UserSession` and `AuthenticatedUserSession` classes. `UserSession` is intended as a read-only session to access public API endpoints, while `AuthenticatedUserSession` requires a chain account and allows the user to post messages to the Aleph network. Example usage: ``` from aleph_client import AuthenticatedUserSession async def post_message(): async with AuthenticatedUserSession( account=fixture_account, api_server=emitter_node ) as session: post_message, message_status = await session.create_post( post_content={"Hello": "World"}, ) ``` Both classes provide a synchronous context manager and an equivalent sync class for non-async code. Breaking changes: - Everything: the SDK API is entirely modified. --- examples/httpgateway.py | 23 +- examples/metrics.py | 26 +- examples/mqtt.py | 23 +- examples/store.py | 12 +- src/aleph_client/__init__.py | 4 + src/aleph_client/asynchronous.py | 1001 ------------------ src/aleph_client/commands/aggregate.py | 14 +- src/aleph_client/commands/files.py | 42 +- src/aleph_client/commands/message.py | 115 +- src/aleph_client/commands/program.py | 196 ++-- src/aleph_client/main.py | 2 +- src/aleph_client/synchronous.py | 150 --- src/aleph_client/user_session.py | 1333 +++++++++++++++++++++++- src/aleph_client/vm/cache.py | 19 +- tests/integration/itest_aggregates.py | 18 +- tests/integration/itest_forget.py | 34 +- tests/integration/itest_posts.py | 21 +- tests/unit/test_asynchronous.py | 149 +-- tests/unit/test_asynchronous_get.py | 64 +- tests/unit/test_chain_solana.py | 4 +- tests/unit/test_synchronous_get.py | 4 +- 21 files changed, 1731 insertions(+), 1523 deletions(-) delete mode 100644 src/aleph_client/asynchronous.py delete mode 100644 src/aleph_client/synchronous.py diff --git a/examples/httpgateway.py b/examples/httpgateway.py index 821cf6f5..da053c95 100644 --- a/examples/httpgateway.py +++ b/examples/httpgateway.py @@ -2,23 +2,13 @@ """ # -*- coding: utf-8 -*- -import os - -# import requests -import platform - -# import socket -import time import asyncio -import click -from aleph_client.asynchronous import create_post - -# from aleph_client.chains.nuls1 import NULSAccount, get_fallback_account -from aleph_client.chains.ethereum import ETHAccount -from aleph_client.chains.common import get_fallback_private_key +import click from aiohttp import web +from aleph_client.chains.common import get_fallback_private_key +from aleph_client.chains.ethereum import ETHAccount from aleph_client.user_session import AuthenticatedUserSession app = web.Application() @@ -44,9 +34,10 @@ async def source_post(request): return web.json_response( {"status": "error", "message": "unauthorized secret"} ) - async with AuthenticatedUserSession(account=app["account"], api_server="https://api2.aleph.im") as session: - message, _status = await create_post( - session=session, + async with AuthenticatedUserSession( + account=app["account"], api_server="https://api2.aleph.im" + ) as session: + message, _status = await session.create_post( post_content=data, post_type="event", channel=app["channel"], diff --git a/examples/metrics.py b/examples/metrics.py index 1a2733b6..4ff553e1 100644 --- a/examples/metrics.py +++ b/examples/metrics.py @@ -5,11 +5,16 @@ import os import platform import time +from typing import Tuple import psutil +from aleph_message.models import AlephMessage +from aleph_client import AuthenticatedUserSession from aleph_client.chains.ethereum import get_fallback_account -from aleph_client.synchronous import create_aggregate +from aleph_client.conf import settings +from aleph_client.types import MessageStatus +from aleph_client.user_session import AuthenticatedUserSessionSync def get_sysinfo(): @@ -49,8 +54,10 @@ def get_cpu_cores(): return [c._asdict() for c in psutil.cpu_times_percent(0, percpu=True)] -def send_metrics(account, metrics): - return create_aggregate(account, "metrics", metrics, channel="SYSINFO") +def send_metrics( + session: AuthenticatedUserSessionSync, metrics +) -> Tuple[AlephMessage, MessageStatus]: + return session.create_aggregate(key="metrics", content=metrics, channel="SYSINFO") def collect_metrics(): @@ -64,11 +71,14 @@ def collect_metrics(): def main(): account = get_fallback_account() - while True: - metrics = collect_metrics() - message, status = send_metrics(account, metrics) - print("sent", message.item_hash) - time.sleep(10) + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + while True: + metrics = collect_metrics() + message, status = send_metrics(session, metrics) + print("sent", message.item_hash) + time.sleep(10) if __name__ == "__main__": diff --git a/examples/mqtt.py b/examples/mqtt.py index 3bc8bec1..3df400e0 100644 --- a/examples/mqtt.py +++ b/examples/mqtt.py @@ -10,7 +10,8 @@ from aleph_client.chains.common import get_fallback_private_key from aleph_client.chains.ethereum import ETHAccount -from aleph_client.main import create_aggregate +from aleph_client import AuthenticatedUserSession +from aleph_client.conf import settings def get_input_data(value): @@ -26,7 +27,12 @@ def get_input_data(value): def send_metrics(account, metrics): - return create_aggregate(account, "metrics", metrics, channel="SYSINFO") + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + return session.create_aggregate( + key="metrics", content=metrics, channel="SYSINFO" + ) def on_disconnect(client, userdata, rc): @@ -95,10 +101,15 @@ async def gateway( if not userdata["received"]: await client.reconnect() - for key, value in state.items(): - message, status = create_aggregate(account, key, value, channel="IOT_TEST") - print("sent", message.item_hash) - userdata["received"] = False + async with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + for key, value in state.items(): + message, status = await session.create_aggregate( + key=key, content=value, channel="IOT_TEST" + ) + print("sent", message.item_hash) + userdata["received"] = False @click.command() diff --git a/examples/store.py b/examples/store.py index 7d9bc9f3..26374f0f 100644 --- a/examples/store.py +++ b/examples/store.py @@ -1,10 +1,8 @@ import asyncio -import aiohttp import click from aleph_message.models import StoreMessage -from aleph_client.asynchronous import create_store from aleph_client.chains.common import get_fallback_private_key from aleph_client.chains.ethereum import ETHAccount from aleph_client.conf import settings @@ -25,7 +23,9 @@ async def print_output_hash(message: StoreMessage, status: MessageStatus): async def do_upload(account, engine, channel, filename=None, file_hash=None): - async with AuthenticatedUserSession(account=account, api_server=settings.API_HOST) as session: + async with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: print(filename, account.get_address()) if filename: try: @@ -35,8 +35,7 @@ async def do_upload(account, engine, channel, filename=None, file_hash=None): if len(content) > 4 * 1024 * 1024 and engine == "STORAGE": print("File too big for native STORAGE engine") return - message, status = await create_store( - session=session, + message, status = await session.create_store( file_content=content, channel=channel, storage_engine=engine.lower(), @@ -46,8 +45,7 @@ async def do_upload(account, engine, channel, filename=None, file_hash=None): raise elif file_hash: - message, status = await create_store( - session=session, + message, status = await session.create_store( file_hash=file_hash, channel=channel, storage_engine=engine.lower(), diff --git a/src/aleph_client/__init__.py b/src/aleph_client/__init__.py index 34225eca..ff0fb1bf 100644 --- a/src/aleph_client/__init__.py +++ b/src/aleph_client/__init__.py @@ -1,4 +1,5 @@ from pkg_resources import get_distribution, DistributionNotFound +from .user_session import AuthenticatedUserSession, UserSession try: # Change here if project is renamed and does not equal the package name @@ -8,3 +9,6 @@ __version__ = "unknown" finally: del get_distribution, DistributionNotFound + + +__all__ = ["AuthenticatedUserSession", "UserSession"] diff --git a/src/aleph_client/asynchronous.py b/src/aleph_client/asynchronous.py deleted file mode 100644 index e8c33a7f..00000000 --- a/src/aleph_client/asynchronous.py +++ /dev/null @@ -1,1001 +0,0 @@ -""" This is the simplest aleph network client available. -""" -import asyncio -import hashlib -import json -import logging -import queue -import time -from datetime import datetime -from pathlib import Path -from typing import ( - Optional, - Union, - Any, - Dict, - List, - Iterable, - AsyncIterable, -) -from typing import Type, Mapping, Tuple, NoReturn - -import aiohttp -from aiohttp import ClientSession -from aleph_message.models import ( - ForgetContent, - MessageType, - AggregateContent, - PostContent, - StoreContent, - PostMessage, - Message, - ForgetMessage, - AlephMessage, - AggregateMessage, - StoreMessage, - ProgramMessage, - ItemType, -) -from aleph_message.models.program import ProgramContent, Encoding -from pydantic import ValidationError - -from aleph_client.types import Account, StorageEnum, GenericMessage, MessageStatus -from .conf import settings -from .exceptions import ( - MessageNotFoundError, - MultipleMessagesError, - InvalidMessageError, - BroadcastError, -) -from .models import MessagesResponse -from .user_session import AuthenticatedUserSession, UserSession -from .utils import get_message_type_value - -logger = logging.getLogger(__name__) - -try: - import magic -except ImportError: - logger.info("Could not import library 'magic', MIME type detection disabled") - magic = None # type:ignore - - -async def ipfs_push(session: AuthenticatedUserSession, content: Mapping) -> str: - """Push arbitrary content as JSON to the IPFS service.""" - - url = "/api/v0/ipfs/add_json" - logger.debug(f"Pushing to IPFS on {url}") - - async with session.http_session.post(url, json=content) as resp: - resp.raise_for_status() - return (await resp.json()).get("hash") - - -async def storage_push(session: AuthenticatedUserSession, content: Mapping) -> str: - """Push arbitrary content as JSON to the storage service.""" - - url = "/api/v0/storage/add_json" - logger.debug(f"Pushing to storage on {url}") - - async with session.http_session.post(url, json=content) as resp: - resp.raise_for_status() - return (await resp.json()).get("hash") - - -async def ipfs_push_file( - session: AuthenticatedUserSession, file_content: Union[str, bytes] -) -> str: - """Push a file to the IPFS service.""" - data = aiohttp.FormData() - data.add_field("file", file_content) - - url = "/api/v0/ipfs/add_file" - logger.debug(f"Pushing file to IPFS on {url}") - - async with session.http_session.post(url, data=data) as resp: - resp.raise_for_status() - return (await resp.json()).get("hash") - - -async def storage_push_file(session: AuthenticatedUserSession, file_content) -> str: - """Push a file to the storage service.""" - data = aiohttp.FormData() - data.add_field("file", file_content) - - url = "/api/v0/storage/add_file" - logger.debug(f"Posting file on {url}") - - async with session.http_session.post(url, data=data) as resp: - resp.raise_for_status() - return (await resp.json()).get("hash") - - -def _log_publication_status(publication_status: Mapping[str, Any]): - status = publication_status.get("status") - failures = publication_status.get("failed") - - if status == "success": - return - elif status == "warning": - logger.warning("Broadcast failed on the following network(s): %s", failures) - elif status == "error": - logger.error( - "Broadcast failed on all protocols. The message was not published." - ) - else: - raise ValueError( - f"Invalid response from server, status in missing or unknown: '{status}'" - ) - - -async def _handle_broadcast_error(response: aiohttp.ClientResponse) -> NoReturn: - if response.status == 500: - # Assume a broadcast error, no need to read the JSON - if response.content_type == "application/json": - error_msg = "Internal error - broadcast failed on all protocols" - else: - error_msg = f"Internal error - the message was not broadcast: {await response.text()}" - - logger.error(error_msg) - raise BroadcastError(error_msg) - elif response.status == 422: - errors = await response.json() - logger.error( - "The message could not be processed because of the following errors: %s", - errors, - ) - raise InvalidMessageError(errors) - else: - error_msg = ( - f"Unexpected HTTP response ({response.status}: {await response.text()})" - ) - logger.error(error_msg) - raise BroadcastError(error_msg) - - -async def _handle_broadcast_deprecated_response( - response: aiohttp.ClientResponse, -) -> None: - if response.status != 200: - await _handle_broadcast_error(response) - else: - publication_status = await response.json() - _log_publication_status(publication_status) - - -async def _broadcast_deprecated( - session: AuthenticatedUserSession, message_dict: Mapping[str, Any] -) -> None: - - """ - Broadcast a message on the Aleph network using the deprecated - /ipfs/pubsub/pub/ endpoint. - """ - - url = "/api/v0/ipfs/pubsub/pub" - logger.debug(f"Posting message on {url}") - - async with session.http_session.post( - url, - json={"topic": "ALEPH-TEST", "data": json.dumps(message_dict)}, - ) as response: - await _handle_broadcast_deprecated_response(response) - - -async def _handle_broadcast_response( - response: aiohttp.ClientResponse, sync: bool -) -> MessageStatus: - if response.status in (200, 202): - status = await response.json() - _log_publication_status(status["publication_status"]) - - if response.status == 202: - if sync: - logger.warning("Timed out while waiting for processing of sync message") - return MessageStatus.PENDING - - return MessageStatus.PROCESSED - - else: - await _handle_broadcast_error(response) - - -BROADCAST_MESSAGE_FIELDS = { - "sender", - "chain", - "signature", - "type", - "item_hash", - "item_type", - "item_content", - "time", - "channel", -} - - -async def _broadcast( - session: AuthenticatedUserSession, - message: AlephMessage, - sync: bool, -) -> MessageStatus: - """ - Broadcast a message on the Aleph network. - - Uses the POST /messages/ endpoint or the deprecated /ipfs/pubsub/pub/ endpoint - if the first method is not available. - """ - - url = "/api/v0/messages" - logger.debug(f"Posting message on {url}") - - message_dict = message.dict(include=BROADCAST_MESSAGE_FIELDS) - - async with session.http_session.post( - url, - json={"sync": sync, "message": message_dict}, - ) as response: - # The endpoint may be unavailable on this node, try the deprecated version. - if response.status == 404: - logger.warning( - "POST /messages/ not found. Defaulting to legacy endpoint..." - ) - await _broadcast_deprecated(message_dict=message_dict, session=session) - return MessageStatus.PENDING - else: - message_status = await _handle_broadcast_response( - response=response, sync=sync - ) - return message_status - - -async def create_post( - session: AuthenticatedUserSession, - post_content, - post_type: str, - ref: Optional[str] = None, - address: Optional[str] = None, - channel: Optional[str] = None, - inline: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, - sync: bool = False, -) -> Tuple[PostMessage, MessageStatus]: - """ - Create a POST message on the Aleph network. It is associated with a channel and owned by an account. - - :param session: The current user session object - :param post_content: The content of the message - :param post_type: An arbitrary content type that helps to describe the post_content - :param ref: A reference to a previous message that it replaces - :param address: The address that will be displayed as the author of the message - :param channel: The channel that the message will be posted on - :param inline: An optional flag to indicate if the content should be inlined in the message or not - :param storage_engine: An optional storage engine to use for the message, if not inlined (Default: "storage") - :param sync: If true, waits for the message to be processed by the API server (Default: False) - """ - address = address or settings.ADDRESS_TO_USE or session.account.get_address() - - content = PostContent( - type=post_type, - address=address, - content=post_content, - time=time.time(), - ref=ref, - ) - - return await submit( - session=session, - content=content.dict(exclude_none=True), - message_type=MessageType.post, - channel=channel, - allow_inlining=inline, - storage_engine=storage_engine, - sync=sync, - ) - - -async def create_aggregate( - session: AuthenticatedUserSession, - key: str, - content: Mapping[str, Any], - address: Optional[str] = None, - channel: Optional[str] = None, - inline: bool = True, - sync: bool = False, -) -> Tuple[AggregateMessage, MessageStatus]: - """ - Create an AGGREGATE message. It is meant to be used as a quick access storage associated with an account. - - :param session: The current user session object - :param key: Key to use to store the content - :param content: Content to store - :param address: Address to use to sign the message - :param channel: Channel to use (Default: "TEST") - :param inline: Whether to write content inside the message (Default: True) - :param sync: If true, waits for the message to be processed by the API server (Default: False) - """ - address = address or settings.ADDRESS_TO_USE or session.account.get_address() - - content_ = AggregateContent( - key=key, - address=address, - content=content, - time=time.time(), - ) - - return await submit( - session=session, - content=content_.dict(exclude_none=True), - message_type=MessageType.aggregate, - channel=channel, - allow_inlining=inline, - sync=sync, - ) - - -async def create_store( - session: AuthenticatedUserSession, - address: Optional[str] = None, - file_content: Optional[bytes] = None, - file_path: Optional[Union[str, Path]] = None, - file_hash: Optional[str] = None, - guess_mime_type: bool = False, - ref: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - extra_fields: Optional[dict] = None, - channel: Optional[str] = None, - sync: bool = False, -) -> Tuple[StoreMessage, MessageStatus]: - """ - Create a STORE message to store a file on the Aleph network. - - Can be passed either a file path, an IPFS hash or the file's content as raw bytes. - - :param session: The current user session object - :param address: Address to display as the author of the message (Default: account.get_address()) - :param file_content: Byte stream of the file to store (Default: None) - :param file_path: Path to the file to store (Default: None) - :param file_hash: Hash of the file to store (Default: None) - :param guess_mime_type: Guess the MIME type of the file (Default: False) - :param ref: Reference to a previous message (Default: None) - :param storage_engine: Storage engine to use (Default: "storage") - :param extra_fields: Extra fields to add to the STORE message (Default: None) - :param channel: Channel to post the message to (Default: "TEST") - :param sync: If true, waits for the message to be processed by the API server (Default: False) - """ - address = address or settings.ADDRESS_TO_USE or session.account.get_address() - - extra_fields = extra_fields or {} - - if file_hash is None: - if file_content is None: - if file_path is None: - raise ValueError( - "Please specify at least a file_content, a file_hash or a file_path" - ) - else: - file_content = open(file_path, "rb").read() - - if storage_engine == StorageEnum.storage: - file_hash = await storage_push_file( - session=session, file_content=file_content - ) - elif storage_engine == StorageEnum.ipfs: - file_hash = await ipfs_push_file(session=session, file_content=file_content) - else: - raise ValueError(f"Unknown storage engine: '{storage_engine}'") - - assert file_hash, "File hash should be empty" - - if magic is None: - pass - elif file_content and guess_mime_type and ("mime_type" not in extra_fields): - extra_fields["mime_type"] = magic.from_buffer(file_content, mime=True) - - if ref: - extra_fields["ref"] = ref - - values = { - "address": address, - "item_type": storage_engine, - "item_hash": file_hash, - "time": time.time(), - } - if extra_fields is not None: - values.update(extra_fields) - - content = StoreContent(**values) - - return await submit( - session=session, - content=content.dict(exclude_none=True), - message_type=MessageType.store, - channel=channel, - allow_inlining=True, - sync=sync, - ) - - -async def create_program( - session: AuthenticatedUserSession, - program_ref: str, - entrypoint: str, - runtime: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, - vcpus: Optional[int] = None, - timeout_seconds: Optional[float] = None, - persistent: bool = False, - encoding: Encoding = Encoding.zip, - volumes: Optional[List[Mapping]] = None, - subscriptions: Optional[List[Mapping]] = None, -) -> Tuple[ProgramMessage, MessageStatus]: - """ - Post a (create) PROGRAM message. - - :param session: The current user session object - :param program_ref: Reference to the program to run - :param entrypoint: Entrypoint to run - :param runtime: Runtime to use - :param environment_variables: Environment variables to pass to the program - :param storage_engine: Storage engine to use (Default: "storage") - :param channel: Channel to use (Default: "TEST") - :param address: Address to use (Default: account.get_address()) - :param sync: If true, waits for the message to be processed by the API server - :param memory: Memory in MB for the VM to be allocated (Default: 128) - :param vcpus: Number of vCPUs to allocate (Default: 1) - :param timeout_seconds: Timeout in seconds (Default: 30.0) - :param persistent: Whether the program should be persistent or not (Default: False) - :param encoding: Encoding to use (Default: Encoding.zip) - :param volumes: Volumes to mount - :param subscriptions: Patterns of Aleph messages to forward to the program's event receiver - """ - address = address or settings.ADDRESS_TO_USE or session.account.get_address() - - volumes = volumes if volumes is not None else [] - memory = memory or settings.DEFAULT_VM_MEMORY - vcpus = vcpus or settings.DEFAULT_VM_VCPUS - timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT - - # TODO: Check that program_ref, runtime and data_ref exist - - # Register the different ways to trigger a VM - if subscriptions: - # Trigger on HTTP calls and on Aleph message subscriptions. - triggers = {"http": True, "persistent": persistent, "message": subscriptions} - else: - # Trigger on HTTP calls. - triggers = {"http": True, "persistent": persistent} - - content = ProgramContent( - **{ - "type": "vm-function", - "address": address, - "allow_amend": False, - "code": { - "encoding": encoding, - "entrypoint": entrypoint, - "ref": program_ref, - "use_latest": True, - }, - "on": triggers, - "environment": { - "reproducible": False, - "internet": True, - "aleph_api": True, - }, - "variables": environment_variables, - "resources": { - "vcpus": vcpus, - "memory": memory, - "seconds": timeout_seconds, - }, - "runtime": { - "ref": runtime, - "use_latest": True, - "comment": "Official Aleph runtime" - if runtime == settings.DEFAULT_RUNTIME_ID - else "", - }, - "volumes": volumes, - "time": time.time(), - } - ) - - # Ensure that the version of aleph-message used supports the field. - assert content.on.persistent == persistent - - return await submit( - session=session, - content=content.dict(exclude_none=True), - message_type=MessageType.program, - channel=channel, - storage_engine=storage_engine, - sync=sync, - ) - - -async def forget( - session: AuthenticatedUserSession, - hashes: List[str], - reason: Optional[str], - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, -) -> Tuple[ForgetMessage, MessageStatus]: - """ - Post a FORGET message to remove previous messages from the network. - - Targeted messages need to be signed by the same account that is attempting to forget them, - if the creating address did not delegate the access rights to the forgetting account. - - :param session: The current user session object - :param hashes: Hashes of the messages to forget - :param reason: Reason for forgetting the messages - :param storage_engine: Storage engine to use (Default: "storage") - :param channel: Channel to use (Default: "TEST") - :param address: Address to use (Default: account.get_address()) - :param sync: If true, waits for the message to be processed by the API server (Default: False) - """ - address = address or settings.ADDRESS_TO_USE or session.account.get_address() - - content = ForgetContent( - hashes=hashes, - reason=reason, - address=address, - time=time.time(), - ) - - return await submit( - session=session, - content=content.dict(exclude_none=True), - message_type=MessageType.forget, - channel=channel, - storage_engine=storage_engine, - allow_inlining=True, - sync=sync, - ) - - -def compute_sha256(s: str) -> str: - h = hashlib.sha256() - h.update(s.encode("utf-8")) - return h.hexdigest() - - -async def _prepare_aleph_message( - session: AuthenticatedUserSession, - message_type: MessageType, - content: Dict[str, Any], - channel: Optional[str], - allow_inlining: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, -) -> AlephMessage: - - message_dict: Dict[str, Any] = { - "sender": session.account.get_address(), - "chain": session.account.CHAIN, - "type": message_type, - "content": content, - "time": time.time(), - "channel": channel, - } - - item_content: str = json.dumps(content, separators=(",", ":")) - - if allow_inlining and (len(item_content) < settings.MAX_INLINE_SIZE): - message_dict["item_content"] = item_content - message_dict["item_hash"] = compute_sha256(item_content) - message_dict["item_type"] = ItemType.inline - else: - if storage_engine == StorageEnum.ipfs: - message_dict["item_hash"] = await ipfs_push( - session=session, - content=content, - ) - message_dict["item_type"] = ItemType.ipfs - else: # storage - assert storage_engine == StorageEnum.storage - message_dict["item_hash"] = await storage_push( - session=session, - content=content, - ) - message_dict["item_type"] = ItemType.storage - - message_dict = await session.account.sign_message(message_dict) - return Message(**message_dict) - - -async def submit( - session: AuthenticatedUserSession, - content: Dict[str, Any], - message_type: MessageType, - channel: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - allow_inlining: bool = True, - sync: bool = False, -) -> Tuple[AlephMessage, MessageStatus]: - - message = await _prepare_aleph_message( - session=session, - message_type=message_type, - content=content, - channel=channel, - allow_inlining=allow_inlining, - storage_engine=storage_engine, - ) - message_status = await _broadcast(session=session, message=message, sync=sync) - return message, message_status - - -async def fetch_aggregate( - session: UserSession, - address: str, - key: str, - limit: int = 100, -) -> Dict[str, Dict]: - """ - Fetch a value from the aggregate store by owner address and item key. - - :param session: The current user session object - :param address: Address of the owner of the aggregate - :param key: Key of the aggregate - :param limit: Maximum number of items to fetch (Default: 100) - """ - - params: Dict[str, Any] = {"keys": key} - if limit: - params["limit"] = limit - - async with session.http_session.get( - f"/api/v0/aggregates/{address}.json", params=params - ) as resp: - result = await resp.json() - data = result.get("data", dict()) - return data.get(key) - - -async def fetch_aggregates( - session: UserSession, - address: str, - keys: Optional[Iterable[str]] = None, - limit: int = 100, -) -> Dict[str, Dict]: - """ - Fetch key-value pairs from the aggregate store by owner address. - - :param session: The current user session object - :param address: Address of the owner of the aggregate - :param keys: Keys of the aggregates to fetch (Default: all items) - :param limit: Maximum number of items to fetch (Default: 100) - """ - - keys_str = ",".join(keys) if keys else "" - params: Dict[str, Any] = {} - if keys_str: - params["keys"] = keys_str - if limit: - params["limit"] = limit - - async with session.http_session.get( - f"/api/v0/aggregates/{address}.json", - params=params, - ) as resp: - result = await resp.json() - data = result.get("data", dict()) - return data - - -async def get_posts( - session: UserSession, - pagination: int = 200, - page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, -) -> Dict[str, Dict]: - """ - Fetch a list of posts from the network. - - :param session: The current user session object - :param pagination: Number of items to fetch (Default: 200) - :param page: Page to fetch, begins at 1 (Default: 1) - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from - """ - - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if types is not None: - params["types"] = ",".join(types) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - - async with session.http_session.get("/api/v0/posts.json", params=params) as resp: - resp.raise_for_status() - return await resp.json() - - -async def download_file( - session: UserSession, - file_hash: str, -) -> bytes: - """ - Get a file from the storage engine as raw bytes. - - Warning: Downloading large files can be slow and memory intensive. - - :param session: The current user session object - :param file_hash: The hash of the file to retrieve. - """ - async with session.http_session.get(f"/api/v0/storage/raw/{file_hash}") as response: - response.raise_for_status() - return await response.read() - - -async def get_messages( - session: UserSession, - pagination: int = 200, - page: int = 1, - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, - ignore_invalid_messages: bool = True, - invalid_messages_log_level: int = logging.NOTSET, -) -> MessagesResponse: - """ - Fetch a list of messages from the network. - - :param session: The current user session object - :param pagination: Number of items to fetch (Default: 200) - :param page: Page to fetch, begins at 1 (Default: 1) - :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by content key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from - :param ignore_invalid_messages: Ignore invalid messages (Default: False) - :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) - """ - ignore_invalid_messages = ( - True if ignore_invalid_messages is None else ignore_invalid_messages - ) - invalid_messages_log_level = ( - logging.NOTSET - if invalid_messages_log_level is None - else invalid_messages_log_level - ) - - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if message_type is not None: - params["msgType"] = message_type.value - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - - async with session.http_session.get("/api/v0/messages.json", params=params) as resp: - resp.raise_for_status() - response_json = await resp.json() - messages_raw = response_json["messages"] - - # All messages may not be valid according to the latest specification in - # aleph-message. This allows the user to specify how errors should be handled. - messages: List[AlephMessage] = [] - for message_raw in messages_raw: - try: - message = Message(**message_raw) - messages.append(message) - except KeyError as e: - if not ignore_invalid_messages: - raise e - logger.log( - level=invalid_messages_log_level, - msg=f"KeyError: Field '{e.args[0]}' not found", - ) - except ValidationError as e: - if not ignore_invalid_messages: - raise e - if invalid_messages_log_level: - logger.log(level=invalid_messages_log_level, msg=e) - - return MessagesResponse( - messages=messages, - pagination_page=response_json["pagination_page"], - pagination_total=response_json["pagination_total"], - pagination_per_page=response_json["pagination_per_page"], - pagination_item=response_json["pagination_item"], - ) - - -async def get_message( - session: UserSession, - item_hash: str, - message_type: Optional[Type[GenericMessage]] = None, - channel: Optional[str] = None, -) -> GenericMessage: - """ - Get a single message from its `item_hash` and perform some basic validation. - - :param session: The current user session object - :param item_hash: Hash of the message to fetch - :param message_type: Type of message to fetch - :param channel: Channel of the message to fetch - """ - messages_response = await get_messages( - session=session, - hashes=[item_hash], - channels=[channel] if channel else None, - ) - if len(messages_response.messages) < 1: - raise MessageNotFoundError(f"No such hash {item_hash}") - if len(messages_response.messages) != 1: - raise MultipleMessagesError( - f"Multiple messages found for the same item_hash `{item_hash}`" - ) - message: GenericMessage = messages_response.messages[0] - if message_type: - expected_type = get_message_type_value(message_type) - if message.type != expected_type: - raise TypeError( - f"The message type '{message.type}' " - f"does not match the expected type '{expected_type}'" - ) - return message - - -async def watch_messages( - session: UserSession, - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, -) -> AsyncIterable[AlephMessage]: - """ - Iterate over current and future matching messages asynchronously. - - :param session: The current user session object - :param message_type: Type of message to watch - :param content_types: Content types to watch - :param refs: References to watch - :param addresses: Addresses to watch - :param tags: Tags to watch - :param hashes: Hashes to watch - :param channels: Channels to watch - :param chains: Chains to watch - :param start_date: Start date from when to watch - :param end_date: End date until when to watch - """ - params: Dict[str, Any] = dict() - - if message_type is not None: - params["msgType"] = message_type.value - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - - async with session.http_session.ws_connect( - f"/api/ws0/messages", params=params - ) as ws: - logger.debug("Websocket connected") - async for msg in ws: - if msg.type == aiohttp.WSMsgType.TEXT: - if msg.data == "close cmd": - await ws.close() - break - else: - data = json.loads(msg.data) - yield Message(**data) - elif msg.type == aiohttp.WSMsgType.ERROR: - break - - -async def _run_watch_messages(coroutine: AsyncIterable, output_queue: queue.Queue): - """Forward messages from the coroutine to the synchronous queue""" - async for message in coroutine: - output_queue.put(message) - - -def _start_run_watch_messages(output_queue: queue.Queue, args: List, kwargs: Dict): - """Thread entrypoint to run the `watch_messages` asynchronous generator in a thread.""" - watcher = watch_messages(*args, **kwargs) - runner = _run_watch_messages(watcher, output_queue) - asyncio.run(runner) diff --git a/src/aleph_client/commands/aggregate.py b/src/aleph_client/commands/aggregate.py index 94f73484..98b0a350 100644 --- a/src/aleph_client/commands/aggregate.py +++ b/src/aleph_client/commands/aggregate.py @@ -1,10 +1,11 @@ import typer from typing import Optional + +from aleph_client import UserSession from aleph_client.types import AccountFromPrivateKey from aleph_client.account import _load_account from aleph_client.conf import settings from pathlib import Path -from aleph_client import synchronous from aleph_client.commands import help_strings from aleph_client.commands.message import forget_messages @@ -37,10 +38,11 @@ def forget( account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - message_response = synchronous.get_messages( - addresses=[account.get_address()], - message_type=MessageType.aggregate.value, - content_keys=[key], - ) + with UserSession(api_server=settings.API_HOST) as session: + message_response = session.get_messages( + addresses=[account.get_address()], + message_type=MessageType.aggregate.value, + content_keys=[key], + ) hash_list = [message["item_hash"] for message in message_response.messages] forget_messages(account, hash_list, reason, channel) diff --git a/src/aleph_client/commands/files.py b/src/aleph_client/commands/files.py index 5b945aca..4b3e7931 100644 --- a/src/aleph_client/commands/files.py +++ b/src/aleph_client/commands/files.py @@ -5,7 +5,7 @@ import typer from aleph_message.models import StoreMessage -from aleph_client import synchronous +from aleph_client import AuthenticatedUserSession from aleph_client.account import _load_account from aleph_client.commands import help_strings from aleph_client.commands.utils import setup_logging @@ -34,16 +34,18 @@ def pin( setup_logging(debug) account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - - result: StoreMessage = synchronous.create_store( - account=account, - file_hash=hash, - storage_engine=StorageEnum.ipfs, - channel=channel, - ref=ref, - ) + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + message, _status = session.create_store( + file_hash=hash, + storage_engine=StorageEnum.ipfs, + channel=channel, + ref=ref, + ) logger.debug("Upload finished") - typer.echo(f"{result.json(indent=4)}") + typer.echo(f"{message.json(indent=4)}") + @app.command() def upload( @@ -78,13 +80,15 @@ def upload( else StorageEnum.storage ) logger.debug("Uploading file") - result: StoreMessage = synchronous.create_store( - account=account, - file_content=file_content, - storage_engine=storage_engine, - channel=channel, - guess_mime_type=True, - ref=ref, - ) + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + message, status = session.create_store( + file_content=file_content, + storage_engine=storage_engine, + channel=channel, + guess_mime_type=True, + ref=ref, + ) logger.debug("Upload finished") - typer.echo(f"{result.json(indent=4)}") + typer.echo(f"{message.json(indent=4)}") diff --git a/src/aleph_client/commands/message.py b/src/aleph_client/commands/message.py index 09de29da..f0aec3d2 100644 --- a/src/aleph_client/commands/message.py +++ b/src/aleph_client/commands/message.py @@ -6,13 +6,9 @@ from typing import Optional, Dict, List import typer -from aleph_message.models import ( - PostMessage, - ForgetMessage, - AlephMessage, -) +from aleph_message.models import AlephMessage -from aleph_client import synchronous +from aleph_client import AuthenticatedUserSession, UserSession from aleph_client.account import _load_account from aleph_client.commands import help_strings from aleph_client.commands.utils import ( @@ -47,7 +43,6 @@ def post( setup_logging(debug) account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - storage_engine: str content: Dict if path: @@ -76,16 +71,18 @@ def post( typer.echo("Not valid JSON") raise typer.Exit(code=2) - result: PostMessage = synchronous.create_post( - account=account, - post_content=content, - post_type=type, - ref=ref, - channel=channel, - inline=True, - storage_engine=storage_engine, - ) - typer.echo(result.json(indent=4)) + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + message, status = session.create_post( + post_content=content, + post_type=type, + ref=ref, + channel=channel, + inline=True, + storage_engine=storage_engine, + ) + typer.echo(message.json(indent=4)) @app.command() @@ -104,34 +101,35 @@ def amend( setup_logging(debug) account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - - existing_message: AlephMessage = synchronous.get_message(item_hash=hash) - - editor: str = os.getenv("EDITOR", default="nano") - with tempfile.NamedTemporaryFile(suffix="json") as fd: - # Fill in message template - fd.write(existing_message.content.json(indent=4).encode()) - fd.seek(0) - - # Launch editor - subprocess.run([editor, fd.name], check=True) - - # Read new message - fd.seek(0) - new_content_json = fd.read() - - content_type = type(existing_message).__annotations__["content"] - new_content_dict = json.loads(new_content_json) - new_content = content_type(**new_content_dict) - new_content.ref = existing_message.item_hash - typer.echo(new_content) - message, _status = synchronous.submit( - account=account, - content=new_content.dict(), - message_type=existing_message.type, - channel=existing_message.channel, - ) - typer.echo(f"{message.json(indent=4)}") + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + existing_message: AlephMessage = session.get_message(item_hash=hash) + + editor: str = os.getenv("EDITOR", default="nano") + with tempfile.NamedTemporaryFile(suffix="json") as fd: + # Fill in message template + fd.write(existing_message.content.json(indent=4).encode()) + fd.seek(0) + + # Launch editor + subprocess.run([editor, fd.name], check=True) + + # Read new message + fd.seek(0) + new_content_json = fd.read() + + content_type = type(existing_message).__annotations__["content"] + new_content_dict = json.loads(new_content_json) + new_content = content_type(**new_content_dict) + new_content.ref = existing_message.item_hash + typer.echo(new_content) + message, _status = session.submit( + content=new_content.dict(), + message_type=existing_message.type, + channel=existing_message.channel, + ) + typer.echo(f"{message.json(indent=4)}") def forget_messages( @@ -140,13 +138,15 @@ def forget_messages( reason: Optional[str], channel: str, ): - result: ForgetMessage = synchronous.forget( - account=account, - hashes=hashes, - reason=reason, - channel=channel, - ) - typer.echo(f"{result.json(indent=4)}") + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + message, status = session.forget( + hashes=hashes, + reason=reason, + channel=channel, + ) + typer.echo(f"{message.json(indent=4)}") @app.command() @@ -186,9 +186,10 @@ def watch( setup_logging(debug) - original: AlephMessage = synchronous.get_message(item_hash=ref) + with UserSession(api_server=settings.API_HOST) as session: + original: AlephMessage = session.get_message(item_hash=ref) - for message in synchronous.watch_messages( - refs=[ref], addresses=[original.content.address] - ): - typer.echo(f"{message.json(indent=indent)}") + for message in session.watch_messages( + refs=[ref], addresses=[original.content.address] + ): + typer.echo(f"{message.json(indent=indent)}") diff --git a/src/aleph_client/commands/program.py b/src/aleph_client/commands/program.py index 2337076e..ea45c35d 100644 --- a/src/aleph_client/commands/program.py +++ b/src/aleph_client/commands/program.py @@ -1,9 +1,8 @@ -import asyncio import json import logging from base64 import b32encode, b16decode from pathlib import Path -from typing import Optional, Dict, List +from typing import Optional, List, Mapping from zipfile import BadZipFile import typer @@ -14,7 +13,7 @@ ProgramContent, ) -from aleph_client import synchronous +from aleph_client import AuthenticatedUserSession from aleph_client.account import _load_account from aleph_client.commands import help_strings from aleph_client.commands.utils import ( @@ -135,7 +134,7 @@ def upload( immutable_volume_dict = volume_to_dict(volume=immutable_volume) volumes.append(immutable_volume_dict) - subscriptions: Optional[List[Dict]] + subscriptions: Optional[List[Mapping]] if beta and yes_no_input("Subscribe to messages ?", default=False): content_raw = input_multiline() try: @@ -147,51 +146,52 @@ def upload( subscriptions = None # Upload the source code - with open(path_object, "rb") as fd: - logger.debug("Reading file") - # TODO: Read in lazy mode instead of copying everything in memory - file_content = fd.read() - storage_engine = ( - StorageEnum.ipfs - if len(file_content) > 4 * 1024 * 1024 - else StorageEnum.storage - ) - logger.debug("Uploading file") - user_code: StoreMessage = synchronous.create_store( - account=account, - file_content=file_content, - storage_engine=storage_engine, + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + with open(path_object, "rb") as fd: + logger.debug("Reading file") + # TODO: Read in lazy mode instead of copying everything in memory + file_content = fd.read() + storage_engine = ( + StorageEnum.ipfs + if len(file_content) > 4 * 1024 * 1024 + else StorageEnum.storage + ) + logger.debug("Uploading file") + user_code, _status = session.create_store( + file_content=file_content, + storage_engine=storage_engine, + channel=channel, + guess_mime_type=True, + ref=None, + ) + logger.debug("Upload finished") + if print_messages or print_code_message: + typer.echo(f"{user_code.json(indent=4)}") + program_ref = user_code.item_hash + + # Register the program + message, status = session.create_program( + program_ref=program_ref, + entrypoint=entrypoint, + runtime=runtime, + storage_engine=StorageEnum.storage, channel=channel, - guess_mime_type=True, - ref=None, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, + persistent=persistent, + encoding=encoding, + volumes=volumes, + subscriptions=subscriptions, ) logger.debug("Upload finished") - if print_messages or print_code_message: - typer.echo(f"{user_code.json(indent=4)}") - program_ref = user_code.item_hash - - # Register the program - message, status = synchronous.create_program( - account=account, - program_ref=program_ref, - entrypoint=entrypoint, - runtime=runtime, - storage_engine=StorageEnum.storage, - channel=channel, - memory=memory, - vcpus=vcpus, - timeout_seconds=timeout_seconds, - persistent=persistent, - encoding=encoding, - volumes=volumes, - subscriptions=subscriptions, - ) - logger.debug("Upload finished") - if print_messages or print_program_message: - typer.echo(f"{message.json(indent=4)}") + if print_messages or print_program_message: + typer.echo(f"{message.json(indent=4)}") - hash: str = message.item_hash - hash_base32 = b32encode(b16decode(hash.upper())).strip(b"=").lower().decode() + hash: str = message.item_hash + hash_base32 = b32encode(b16decode(hash.upper())).strip(b"=").lower().decode() typer.echo( f"Your program has been uploaded on Aleph .\n\n" @@ -219,47 +219,49 @@ def update( account = _load_account(private_key, private_key_file) path = path.absolute() - program_message: ProgramMessage = synchronous.get_message( - item_hash=hash, message_type=ProgramMessage - ) - code_ref = program_message.content.code.ref - code_message: StoreMessage = synchronous.get_message( - item_hash=code_ref, message_type=StoreMessage - ) - - try: - path, encoding = create_archive(path) - except BadZipFile: - typer.echo("Invalid zip archive") - raise typer.Exit(3) - except FileNotFoundError: - typer.echo("No such file or directory") - raise typer.Exit(4) - - if encoding != program_message.content.code.encoding: - logger.error( - f"Code must be encoded with the same encoding as the previous version " - f"('{encoding}' vs '{program_message.content.code.encoding}'" + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + program_message: ProgramMessage = session.get_message( + item_hash=hash, message_type=ProgramMessage ) - raise typer.Exit(1) - - # Upload the source code - with open(path, "rb") as fd: - logger.debug("Reading file") - # TODO: Read in lazy mode instead of copying everything in memory - file_content = fd.read() - logger.debug("Uploading file") - message, status = synchronous.create_store( - account=account, - file_content=file_content, - storage_engine=code_message.content.item_type, - channel=code_message.channel, - guess_mime_type=True, - ref=code_message.item_hash, + code_ref = program_message.content.code.ref + code_message: StoreMessage = session.get_message( + item_hash=code_ref, message_type=StoreMessage ) - logger.debug("Upload finished") - if print_message: - typer.echo(f"{message.json(indent=4)}") + + try: + path, encoding = create_archive(path) + except BadZipFile: + typer.echo("Invalid zip archive") + raise typer.Exit(3) + except FileNotFoundError: + typer.echo("No such file or directory") + raise typer.Exit(4) + + if encoding != program_message.content.code.encoding: + logger.error( + f"Code must be encoded with the same encoding as the previous version " + f"('{encoding}' vs '{program_message.content.code.encoding}'" + ) + raise typer.Exit(1) + + # Upload the source code + with open(path, "rb") as fd: + logger.debug("Reading file") + # TODO: Read in lazy mode instead of copying everything in memory + file_content = fd.read() + logger.debug("Uploading file") + message, status = session.create_store( + file_content=file_content, + storage_engine=code_message.content.item_type, + channel=code_message.channel, + guess_mime_type=True, + ref=code_message.item_hash, + ) + logger.debug("Upload finished") + if print_message: + typer.echo(f"{message.json(indent=4)}") @app.command() @@ -275,17 +277,19 @@ def unpersist( account = _load_account(private_key, private_key_file) - existing: MessagesResponse = synchronous.get_messages(hashes=[hash]) - message: ProgramMessage = existing.messages[0] - content: ProgramContent = message.content.copy() + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + existing: MessagesResponse = session.get_messages(hashes=[hash]) + message: ProgramMessage = existing.messages[0] + content: ProgramContent = message.content.copy() - content.on.persistent = False - content.replaces = message.item_hash + content.on.persistent = False + content.replaces = message.item_hash - message, _status = synchronous.submit( - account=account, - content=content.dict(exclude_none=True), - message_type=message.type, - channel=message.channel, - ) - typer.echo(f"{message.json(indent=4)}") + message, _status = session.submit( + content=content.dict(exclude_none=True), + message_type=message.type, + channel=message.channel, + ) + typer.echo(f"{message.json(indent=4)}") diff --git a/src/aleph_client/main.py b/src/aleph_client/main.py index ef4e59d0..c3ceb427 100644 --- a/src/aleph_client/main.py +++ b/src/aleph_client/main.py @@ -9,4 +9,4 @@ DeprecationWarning, ) -from .synchronous import * +from .user_session import * diff --git a/src/aleph_client/synchronous.py b/src/aleph_client/synchronous.py deleted file mode 100644 index f0fa0c69..00000000 --- a/src/aleph_client/synchronous.py +++ /dev/null @@ -1,150 +0,0 @@ -import asyncio -import queue -import threading -from typing import ( - Any, - Callable, - List, - Optional, - Dict, - Iterable, - Type, - Protocol, - TypeVar, - Awaitable, -) - -from aiohttp import ClientSession -from aleph_message.models import AlephMessage -from aleph_message.models.program import ProgramContent, Encoding - -from . import asynchronous -from .conf import settings -from .types import Account, StorageEnum, GenericMessage - - -T = TypeVar("T") - - -def wrap_async(func: Callable[..., Awaitable[T]]) -> Callable[..., T]: - """Wrap an asynchronous function into a synchronous one, - for easy use in synchronous code. - """ - - def func_caller(*args, **kwargs): - loop = asyncio.get_event_loop() - return loop.run_until_complete(func(*args, **kwargs)) - - # Copy wrapped function interface: - func_caller.__doc__ = func.__doc__ - func_caller.__annotations__ = func.__annotations__ - func_caller.__defaults__ = func.__defaults__ - func_caller.__kwdefaults__ = func.__kwdefaults__ - return func_caller - - -create_post = wrap_async(asynchronous.create_post) -forget = wrap_async(asynchronous.forget) -ipfs_push = wrap_async(asynchronous.ipfs_push) -storage_push = wrap_async(asynchronous.storage_push) -ipfs_push_file = wrap_async(asynchronous.ipfs_push_file) -storage_push_file = wrap_async(asynchronous.storage_push_file) -create_aggregate = wrap_async(asynchronous.create_aggregate) -create_store = wrap_async(asynchronous.create_store) -submit = wrap_async(asynchronous.submit) -fetch_aggregate = wrap_async(asynchronous.fetch_aggregate) -fetch_aggregates = wrap_async(asynchronous.fetch_aggregates) -get_posts = wrap_async(asynchronous.get_posts) -get_messages = wrap_async(asynchronous.get_messages) - - -def get_message( - item_hash: str, - message_type: Optional[Type[GenericMessage]] = None, - channel: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: str = settings.API_HOST, -) -> GenericMessage: - return wrap_async(asynchronous.get_message)( - item_hash=item_hash, - message_type=message_type, - channel=channel, - session=session, - api_server=api_server, - ) - - -def create_program( - account: Account, - program_ref: str, - entrypoint: str, - runtime: str, - environment_variables: Optional[Dict[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, - memory: Optional[int] = None, - vcpus: Optional[int] = None, - timeout_seconds: Optional[float] = None, - persistent: bool = False, - encoding: Encoding = Encoding.zip, - volumes: Optional[List[Dict]] = None, - subscriptions: Optional[List[Dict]] = None, -): - """ - Post a (create) PROGRAM message. - - :param account: Account to use to sign the message - :param program_ref: Reference to the program to run - :param entrypoint: Entrypoint to run - :param runtime: Runtime to use - :param environment_variables: Environment variables to pass to the program - :param storage_engine: Storage engine to use (DEFAULT: "storage") - :param channel: Channel to use (DEFAULT: "TEST") - :param address: Address to use (DEFAULT: account.get_address()) - :param session: Session to use (DEFAULT: get_fallback_session()) - :param api_server: API server to use (DEFAULT: "https://api2.aleph.im") - :param memory: Memory in MB for the VM to be allocated (DEFAULT: 128) - :param vcpus: Number of vCPUs to allocate (DEFAULT: 1) - :param timeout_seconds: Timeout in seconds (DEFAULT: 30.0) - :param persistent: Whether the program should be persistent or not (DEFAULT: False) - :param encoding: Encoding to use (DEFAULT: Encoding.zip) - :param volumes: Volumes to mount - :param subscriptions: Patterns of Aleph messages to forward to the program's event receiver - """ - return wrap_async(asynchronous.create_program)( - account=account, - program_ref=program_ref, - entrypoint=entrypoint, - environment_variables=environment_variables, - runtime=runtime, - storage_engine=storage_engine, - channel=channel, - address=address, - session=session, - api_server=api_server, - memory=memory, - vcpus=vcpus, - timeout_seconds=timeout_seconds, - persistent=persistent, - encoding=encoding, - volumes=volumes, - subscriptions=subscriptions, - ) - - -def watch_messages(*args, **kwargs) -> Iterable[AlephMessage]: - """ - Iterate over current and future matching messages synchronously. - - Runs the `watch_messages` asynchronous generator in a thread. - """ - output_queue: queue.Queue[AlephMessage] = queue.Queue() - thread = threading.Thread( - target=asynchronous._start_run_watch_messages, args=(output_queue, args, kwargs) - ) - thread.start() - while True: - yield output_queue.get() diff --git a/src/aleph_client/user_session.py b/src/aleph_client/user_session.py index db366f91..10c797c2 100644 --- a/src/aleph_client/user_session.py +++ b/src/aleph_client/user_session.py @@ -1,8 +1,438 @@ import asyncio +import hashlib +import json +import logging +import queue +import threading +import time +from datetime import datetime +from pathlib import Path +from typing import ( + Optional, + Union, + Any, + Dict, + List, + Iterable, + AsyncIterable, + Awaitable, + Callable, + TypeVar, +) +from typing import Type, Mapping, Tuple, NoReturn import aiohttp +from aleph_message.models import ( + ForgetContent, + MessageType, + AggregateContent, + PostContent, + StoreContent, + PostMessage, + Message, + ForgetMessage, + AlephMessage, + AggregateMessage, + StoreMessage, + ProgramMessage, + ItemType, +) +from aleph_message.models.program import ProgramContent, Encoding +from pydantic import ValidationError -from aleph_client.types import Account +from aleph_client.types import Account, StorageEnum, GenericMessage, MessageStatus +from .conf import settings +from .exceptions import ( + MessageNotFoundError, + MultipleMessagesError, + InvalidMessageError, + BroadcastError, +) +from .models import MessagesResponse +from .utils import get_message_type_value + +logger = logging.getLogger(__name__) + + +try: + import magic +except ImportError: + logger.info("Could not import library 'magic', MIME type detection disabled") + magic = None # type:ignore + + +T = TypeVar("T") + + +def async_wrapper(f): + """ + Copies the docstring of wrapped functions. + """ + + wrapped = getattr(AuthenticatedUserSession, f.__name__) + f.__doc__ = wrapped.__doc__ + + +def wrap_async(func: Callable[..., Awaitable[T]]) -> Callable[..., T]: + """Wrap an asynchronous function into a synchronous one, + for easy use in synchronous code. + """ + + def func_caller(*args, **kwargs): + loop = asyncio.get_event_loop() + return loop.run_until_complete(func(*args, **kwargs)) + + # Copy wrapped function interface: + func_caller.__doc__ = func.__doc__ + func_caller.__annotations__ = func.__annotations__ + func_caller.__defaults__ = func.__defaults__ + func_caller.__kwdefaults__ = func.__kwdefaults__ + return func_caller + + +async def run_async_watcher( + *args, output_queue: queue.Queue, api_server: str, **kwargs +): + async with UserSession(api_server=api_server) as session: + async for message in session.watch_messages(*args, **kwargs): + output_queue.put(message) + + +def watcher_thread(output_queue: queue.Queue, api_server: str, args, kwargs): + asyncio.run( + run_async_watcher( + output_queue=output_queue, api_server=api_server, *args, **kwargs + ) + ) + + +class UserSessionSync: + """ + A sync version of `UserSession`, used in sync code. + + This class is returned by the context manager of `UserSession` and is + intended as a wrapper around the methods of `UserSession` and not as a public class. + The methods are fully typed to enable static type checking, but most (all) methods + should look like this (using args and kwargs for brevity, but the functions should + be fully typed): + + >>> def func(self, *args, **kwargs): + >>> return self._wrap(self.async_session.func)(*args, **kwargs) + """ + + def __init__(self, async_session: "UserSession"): + self.async_session = async_session + + def _wrap(self, method: Callable[..., Awaitable[T]], *args, **kwargs): + return wrap_async(method)(*args, **kwargs) + + def get_messages( + self, + pagination: int = 200, + page: int = 1, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> MessagesResponse: + + return self._wrap( + self.async_session.get_messages, + pagination=pagination, + page=page, + message_type=message_type, + content_types=content_types, + content_keys=content_keys, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ignore_invalid_messages=ignore_invalid_messages, + invalid_messages_log_level=invalid_messages_log_level, + ) + + # @async_wrapper + def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + return self._wrap( + self.async_session.get_message, + item_hash=item_hash, + message_type=message_type, + channel=channel, + ) + + def fetch_aggregate( + self, + address: str, + key: str, + limit: int = 100, + ) -> Dict[str, Dict]: + return self._wrap(self.async_session.fetch_aggregate, address, key, limit) + + def fetch_aggregates( + self, + address: str, + keys: Optional[Iterable[str]] = None, + limit: int = 100, + ) -> Dict[str, Dict]: + return self._wrap(self.async_session.fetch_aggregates, address, keys, limit) + + def get_posts( + self, + pagination: int = 200, + page: int = 1, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> Dict[str, Dict]: + return self._wrap( + self.async_session.get_posts, + pagination=pagination, + page=page, + types=types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + + def download_file(self, file_hash: str) -> bytes: + return self._wrap(self.async_session.download_file, file_hash=file_hash) + + def watch_messages( + self, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> Iterable[AlephMessage]: + """ + Iterate over current and future matching messages synchronously. + + Runs the `watch_messages` asynchronous generator in a thread. + """ + output_queue: queue.Queue[AlephMessage] = queue.Queue() + thread = threading.Thread( + target=watcher_thread, + args=( + output_queue, + self.async_session.api_server, + ( + message_type, + content_types, + refs, + addresses, + tags, + hashes, + channels, + chains, + start_date, + end_date, + ), + {}, + ), + ) + thread.start() + while True: + yield output_queue.get() + + +class AuthenticatedUserSessionSync(UserSessionSync): + async_session: "AuthenticatedUserSession" + + def __init__(self, async_session: "AuthenticatedUserSession"): + super().__init__(async_session=async_session) + + def ipfs_push(self, content: Mapping) -> str: + return self._wrap(self.async_session.ipfs_push, content=content) + + def storage_push(self, content: Mapping) -> str: + return self._wrap(self.async_session.storage_push, content=content) + + def ipfs_push_file(self, file_content: Union[str, bytes]) -> str: + return self._wrap(self.async_session.ipfs_push_file, file_content=file_content) + + def storage_push_file(self, file_content: Union[str, bytes]) -> str: + return self._wrap( + self.async_session.storage_push_file, file_content=file_content + ) + + def create_post( + self, + post_content, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, + ) -> Tuple[PostMessage, MessageStatus]: + return self._wrap( + self.async_session.create_post, + post_content=post_content, + post_type=post_type, + ref=ref, + address=address, + channel=channel, + inline=inline, + storage_engine=storage_engine, + sync=sync, + ) + + def create_aggregate( + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AggregateMessage, MessageStatus]: + return self._wrap( + self.async_session.create_aggregate, + key=key, + content=content, + address=address, + channel=channel, + inline=inline, + sync=sync, + ) + + def create_store( + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, + ) -> Tuple[StoreMessage, MessageStatus]: + return self._wrap( + self.async_session.create_store, + address=address, + file_content=file_content, + file_path=file_path, + file_hash=file_hash, + guess_mime_type=guess_mime_type, + ref=ref, + storage_engine=storage_engine, + extra_fields=extra_fields, + channel=channel, + sync=sync, + ) + + def create_program( + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + ) -> Tuple[ProgramMessage, MessageStatus]: + return self._wrap( + self.async_session.create_program, + program_ref=program_ref, + entrypoint=entrypoint, + runtime=runtime, + environment_variables=environment_variables, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, + persistent=persistent, + encoding=encoding, + volumes=volumes, + subscriptions=subscriptions, + ) + + def forget( + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + ) -> Tuple[ForgetMessage, MessageStatus]: + return self._wrap( + self.async_session.forget, + hashes=hashes, + reason=reason, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + ) + + def submit( + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + return self._wrap( + self.async_session.submit, + content=content, + message_type=message_type, + channel=channel, + storage_engine=storage_engine, + allow_inlining=allow_inlining, + sync=sync, + ) class UserSession: @@ -13,8 +443,8 @@ def __init__(self, api_server: str): self.api_server = api_server self.http_session = aiohttp.ClientSession(base_url=api_server) - def __enter__(self): - return self + def __enter__(self) -> UserSessionSync: + return UserSessionSync(async_session=self) def __exit__(self, exc_type, exc_val, exc_tb): close_fut = self.http_session.close() @@ -24,16 +454,911 @@ def __exit__(self, exc_type, exc_val, exc_tb): except RuntimeError: asyncio.run(close_fut) - async def __aenter__(self): + async def __aenter__(self) -> "UserSession": return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.http_session.close() + async def fetch_aggregate( + self, + address: str, + key: str, + limit: int = 100, + ) -> Dict[str, Dict]: + """ + Fetch a value from the aggregate store by owner address and item key. + + :param address: Address of the owner of the aggregate + :param key: Key of the aggregate + :param limit: Maximum number of items to fetch (Default: 100) + """ + + params: Dict[str, Any] = {"keys": key} + if limit: + params["limit"] = limit + + async with self.http_session.get( + f"/api/v0/aggregates/{address}.json", params=params + ) as resp: + result = await resp.json() + data = result.get("data", dict()) + return data.get(key) + + async def fetch_aggregates( + self, + address: str, + keys: Optional[Iterable[str]] = None, + limit: int = 100, + ) -> Dict[str, Dict]: + """ + Fetch key-value pairs from the aggregate store by owner address. + + :param address: Address of the owner of the aggregate + :param keys: Keys of the aggregates to fetch (Default: all items) + :param limit: Maximum number of items to fetch (Default: 100) + """ + + keys_str = ",".join(keys) if keys else "" + params: Dict[str, Any] = {} + if keys_str: + params["keys"] = keys_str + if limit: + params["limit"] = limit + + async with self.http_session.get( + f"/api/v0/aggregates/{address}.json", + params=params, + ) as resp: + result = await resp.json() + data = result.get("data", dict()) + return data + + async def get_posts( + self, + pagination: int = 200, + page: int = 1, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> Dict[str, Dict]: + """ + Fetch a list of posts from the network. + + :param pagination: Number of items to fetch (Default: 200) + :param page: Page to fetch, begins at 1 (Default: 1) + :param types: Types of posts to fetch (Default: all types) + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Chains of the posts to fetch (Default: all chains) + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + """ + + params: Dict[str, Any] = dict(pagination=pagination, page=page) + + if types is not None: + params["types"] = ",".join(types) + if refs is not None: + params["refs"] = ",".join(refs) + if addresses is not None: + params["addresses"] = ",".join(addresses) + if tags is not None: + params["tags"] = ",".join(tags) + if hashes is not None: + params["hashes"] = ",".join(hashes) + if channels is not None: + params["channels"] = ",".join(channels) + if chains is not None: + params["chains"] = ",".join(chains) + + if start_date is not None: + if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): + start_date = start_date.timestamp() + params["startDate"] = start_date + if end_date is not None: + if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): + end_date = end_date.timestamp() + params["endDate"] = end_date + + async with self.http_session.get("/api/v0/posts.json", params=params) as resp: + resp.raise_for_status() + return await resp.json() + + async def download_file( + self, + file_hash: str, + ) -> bytes: + """ + Get a file from the storage engine as raw bytes. + + Warning: Downloading large files can be slow and memory intensive. + + :param file_hash: The hash of the file to retrieve. + """ + async with self.http_session.get( + f"/api/v0/storage/raw/{file_hash}" + ) as response: + response.raise_for_status() + return await response.read() + + async def get_messages( + self, + pagination: int = 200, + page: int = 1, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> MessagesResponse: + """ + Fetch a list of messages from the network. + + :param pagination: Number of items to fetch (Default: 200) + :param page: Page to fetch, begins at 1 (Default: 1) + :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" + :param content_types: Filter by content type + :param content_keys: Filter by content key + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Filter by sender address chain + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + :param ignore_invalid_messages: Ignore invalid messages (Default: False) + :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + """ + ignore_invalid_messages = ( + True if ignore_invalid_messages is None else ignore_invalid_messages + ) + invalid_messages_log_level = ( + logging.NOTSET + if invalid_messages_log_level is None + else invalid_messages_log_level + ) + + params: Dict[str, Any] = dict(pagination=pagination, page=page) + + if message_type is not None: + params["msgType"] = message_type.value + if content_types is not None: + params["contentTypes"] = ",".join(content_types) + if content_keys is not None: + params["contentKeys"] = ",".join(content_keys) + if refs is not None: + params["refs"] = ",".join(refs) + if addresses is not None: + params["addresses"] = ",".join(addresses) + if tags is not None: + params["tags"] = ",".join(tags) + if hashes is not None: + params["hashes"] = ",".join(hashes) + if channels is not None: + params["channels"] = ",".join(channels) + if chains is not None: + params["chains"] = ",".join(chains) + + if start_date is not None: + if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): + start_date = start_date.timestamp() + params["startDate"] = start_date + if end_date is not None: + if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): + end_date = end_date.timestamp() + params["endDate"] = end_date + + async with self.http_session.get( + "/api/v0/messages.json", params=params + ) as resp: + resp.raise_for_status() + response_json = await resp.json() + messages_raw = response_json["messages"] + + # All messages may not be valid according to the latest specification in + # aleph-message. This allows the user to specify how errors should be handled. + messages: List[AlephMessage] = [] + for message_raw in messages_raw: + try: + message = Message(**message_raw) + messages.append(message) + except KeyError as e: + if not ignore_invalid_messages: + raise e + logger.log( + level=invalid_messages_log_level, + msg=f"KeyError: Field '{e.args[0]}' not found", + ) + except ValidationError as e: + if not ignore_invalid_messages: + raise e + if invalid_messages_log_level: + logger.log(level=invalid_messages_log_level, msg=e) + + return MessagesResponse( + messages=messages, + pagination_page=response_json["pagination_page"], + pagination_total=response_json["pagination_total"], + pagination_per_page=response_json["pagination_per_page"], + pagination_item=response_json["pagination_item"], + ) + + async def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + """ + Get a single message from its `item_hash` and perform some basic validation. + + :param item_hash: Hash of the message to fetch + :param message_type: Type of message to fetch + :param channel: Channel of the message to fetch + """ + messages_response = await self.get_messages( + hashes=[item_hash], + channels=[channel] if channel else None, + ) + if len(messages_response.messages) < 1: + raise MessageNotFoundError(f"No such hash {item_hash}") + if len(messages_response.messages) != 1: + raise MultipleMessagesError( + f"Multiple messages found for the same item_hash `{item_hash}`" + ) + message: GenericMessage = messages_response.messages[0] + if message_type: + expected_type = get_message_type_value(message_type) + if message.type != expected_type: + raise TypeError( + f"The message type '{message.type}' " + f"does not match the expected type '{expected_type}'" + ) + return message + + async def watch_messages( + self, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> AsyncIterable[AlephMessage]: + """ + Iterate over current and future matching messages asynchronously. + + :param message_type: Type of message to watch + :param content_types: Content types to watch + :param refs: References to watch + :param addresses: Addresses to watch + :param tags: Tags to watch + :param hashes: Hashes to watch + :param channels: Channels to watch + :param chains: Chains to watch + :param start_date: Start date from when to watch + :param end_date: End date until when to watch + """ + params: Dict[str, Any] = dict() + + if message_type is not None: + params["msgType"] = message_type.value + if content_types is not None: + params["contentTypes"] = ",".join(content_types) + if refs is not None: + params["refs"] = ",".join(refs) + if addresses is not None: + params["addresses"] = ",".join(addresses) + if tags is not None: + params["tags"] = ",".join(tags) + if hashes is not None: + params["hashes"] = ",".join(hashes) + if channels is not None: + params["channels"] = ",".join(channels) + if chains is not None: + params["chains"] = ",".join(chains) + + if start_date is not None: + if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): + start_date = start_date.timestamp() + params["startDate"] = start_date + if end_date is not None: + if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): + end_date = end_date.timestamp() + params["endDate"] = end_date + + async with self.http_session.ws_connect( + f"/api/ws0/messages", params=params + ) as ws: + logger.debug("Websocket connected") + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if msg.data == "close cmd": + await ws.close() + break + else: + data = json.loads(msg.data) + yield Message(**data) + elif msg.type == aiohttp.WSMsgType.ERROR: + break + class AuthenticatedUserSession(UserSession): account: Account + BROADCAST_MESSAGE_FIELDS = { + "sender", + "chain", + "signature", + "type", + "item_hash", + "item_type", + "item_content", + "time", + "channel", + } + def __init__(self, account: Account, api_server: str): super().__init__(api_server=api_server) self.account = account + + def __enter__(self) -> "AuthenticatedUserSessionSync": + return AuthenticatedUserSessionSync(async_session=self) + + async def __aenter__(self) -> "AuthenticatedUserSession": + return self + + async def ipfs_push(self, content: Mapping) -> str: + """Push arbitrary content as JSON to the IPFS service.""" + + url = "/api/v0/ipfs/add_json" + logger.debug(f"Pushing to IPFS on {url}") + + async with self.http_session.post(url, json=content) as resp: + resp.raise_for_status() + return (await resp.json()).get("hash") + + async def storage_push(self, content: Mapping) -> str: + """Push arbitrary content as JSON to the storage service.""" + + url = "/api/v0/storage/add_json" + logger.debug(f"Pushing to storage on {url}") + + async with self.http_session.post(url, json=content) as resp: + resp.raise_for_status() + return (await resp.json()).get("hash") + + async def ipfs_push_file(self, file_content: Union[str, bytes]) -> str: + """Push a file to the IPFS service.""" + data = aiohttp.FormData() + data.add_field("file", file_content) + + url = "/api/v0/ipfs/add_file" + logger.debug(f"Pushing file to IPFS on {url}") + + async with self.http_session.post(url, data=data) as resp: + resp.raise_for_status() + return (await resp.json()).get("hash") + + async def storage_push_file(self, file_content) -> str: + """Push a file to the storage service.""" + data = aiohttp.FormData() + data.add_field("file", file_content) + + url = "/api/v0/storage/add_file" + logger.debug(f"Posting file on {url}") + + async with self.http_session.post(url, data=data) as resp: + resp.raise_for_status() + return (await resp.json()).get("hash") + + @staticmethod + def _log_publication_status(publication_status: Mapping[str, Any]): + status = publication_status.get("status") + failures = publication_status.get("failed") + + if status == "success": + return + elif status == "warning": + logger.warning("Broadcast failed on the following network(s): %s", failures) + elif status == "error": + logger.error( + "Broadcast failed on all protocols. The message was not published." + ) + else: + raise ValueError( + f"Invalid response from server, status in missing or unknown: '{status}'" + ) + + @staticmethod + async def _handle_broadcast_error(response: aiohttp.ClientResponse) -> NoReturn: + if response.status == 500: + # Assume a broadcast error, no need to read the JSON + if response.content_type == "application/json": + error_msg = "Internal error - broadcast failed on all protocols" + else: + error_msg = f"Internal error - the message was not broadcast: {await response.text()}" + + logger.error(error_msg) + raise BroadcastError(error_msg) + elif response.status == 422: + errors = await response.json() + logger.error( + "The message could not be processed because of the following errors: %s", + errors, + ) + raise InvalidMessageError(errors) + else: + error_msg = ( + f"Unexpected HTTP response ({response.status}: {await response.text()})" + ) + logger.error(error_msg) + raise BroadcastError(error_msg) + + async def _handle_broadcast_deprecated_response( + self, + response: aiohttp.ClientResponse, + ) -> None: + if response.status != 200: + await self._handle_broadcast_error(response) + else: + publication_status = await response.json() + self._log_publication_status(publication_status) + + async def _broadcast_deprecated(self, message_dict: Mapping[str, Any]) -> None: + + """ + Broadcast a message on the Aleph network using the deprecated + /ipfs/pubsub/pub/ endpoint. + """ + + url = "/api/v0/ipfs/pubsub/pub" + logger.debug(f"Posting message on {url}") + + async with self.http_session.post( + url, + json={"topic": "ALEPH-TEST", "data": json.dumps(message_dict)}, + ) as response: + await self._handle_broadcast_deprecated_response(response) + + async def _handle_broadcast_response( + self, response: aiohttp.ClientResponse, sync: bool + ) -> MessageStatus: + if response.status in (200, 202): + status = await response.json() + self._log_publication_status(status["publication_status"]) + + if response.status == 202: + if sync: + logger.warning( + "Timed out while waiting for processing of sync message" + ) + return MessageStatus.PENDING + + return MessageStatus.PROCESSED + + else: + await self._handle_broadcast_error(response) + + async def _broadcast( + self, + message: AlephMessage, + sync: bool, + ) -> MessageStatus: + """ + Broadcast a message on the Aleph network. + + Uses the POST /messages/ endpoint or the deprecated /ipfs/pubsub/pub/ endpoint + if the first method is not available. + """ + + url = "/api/v0/messages" + logger.debug(f"Posting message on {url}") + + message_dict = message.dict(include=self.BROADCAST_MESSAGE_FIELDS) + + async with self.http_session.post( + url, + json={"sync": sync, "message": message_dict}, + ) as response: + # The endpoint may be unavailable on this node, try the deprecated version. + if response.status == 404: + logger.warning( + "POST /messages/ not found. Defaulting to legacy endpoint..." + ) + await self._broadcast_deprecated(message_dict=message_dict) + return MessageStatus.PENDING + else: + message_status = await self._handle_broadcast_response( + response=response, sync=sync + ) + return message_status + + async def create_post( + self, + post_content, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, + ) -> Tuple[PostMessage, MessageStatus]: + """ + Create a POST message on the Aleph network. It is associated with a channel and owned by an account. + + :param post_content: The content of the message + :param post_type: An arbitrary content type that helps to describe the post_content + :param ref: A reference to a previous message that it replaces + :param address: The address that will be displayed as the author of the message + :param channel: The channel that the message will be posted on + :param inline: An optional flag to indicate if the content should be inlined in the message or not + :param storage_engine: An optional storage engine to use for the message, if not inlined (Default: "storage") + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + content = PostContent( + type=post_type, + address=address, + content=post_content, + time=time.time(), + ref=ref, + ) + + return await self.submit( + content=content.dict(exclude_none=True), + message_type=MessageType.post, + channel=channel, + allow_inlining=inline, + storage_engine=storage_engine, + sync=sync, + ) + + async def create_aggregate( + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Create an AGGREGATE message. It is meant to be used as a quick access storage associated with an account. + + :param key: Key to use to store the content + :param content: Content to store + :param address: Address to use to sign the message + :param channel: Channel to use (Default: "TEST") + :param inline: Whether to write content inside the message (Default: True) + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + content_ = AggregateContent( + key=key, + address=address, + content=content, + time=time.time(), + ) + + return await self.submit( + content=content_.dict(exclude_none=True), + message_type=MessageType.aggregate, + channel=channel, + allow_inlining=inline, + sync=sync, + ) + + async def create_store( + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, + ) -> Tuple[StoreMessage, MessageStatus]: + """ + Create a STORE message to store a file on the Aleph network. + + Can be passed either a file path, an IPFS hash or the file's content as raw bytes. + + :param address: Address to display as the author of the message (Default: account.get_address()) + :param file_content: Byte stream of the file to store (Default: None) + :param file_path: Path to the file to store (Default: None) + :param file_hash: Hash of the file to store (Default: None) + :param guess_mime_type: Guess the MIME type of the file (Default: False) + :param ref: Reference to a previous message (Default: None) + :param storage_engine: Storage engine to use (Default: "storage") + :param extra_fields: Extra fields to add to the STORE message (Default: None) + :param channel: Channel to post the message to (Default: "TEST") + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + extra_fields = extra_fields or {} + + if file_hash is None: + if file_content is None: + if file_path is None: + raise ValueError( + "Please specify at least a file_content, a file_hash or a file_path" + ) + else: + file_content = open(file_path, "rb").read() + + if storage_engine == StorageEnum.storage: + file_hash = await self.storage_push_file(file_content=file_content) + elif storage_engine == StorageEnum.ipfs: + file_hash = await self.ipfs_push_file(file_content=file_content) + else: + raise ValueError(f"Unknown storage engine: '{storage_engine}'") + + assert file_hash, "File hash should be empty" + + if magic is None: + pass + elif file_content and guess_mime_type and ("mime_type" not in extra_fields): + extra_fields["mime_type"] = magic.from_buffer(file_content, mime=True) + + if ref: + extra_fields["ref"] = ref + + values = { + "address": address, + "item_type": storage_engine, + "item_hash": file_hash, + "time": time.time(), + } + if extra_fields is not None: + values.update(extra_fields) + + content = StoreContent(**values) + + return await self.submit( + content=content.dict(exclude_none=True), + message_type=MessageType.store, + channel=channel, + allow_inlining=True, + sync=sync, + ) + + async def create_program( + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + ) -> Tuple[ProgramMessage, MessageStatus]: + """ + Post a (create) PROGRAM message. + + :param program_ref: Reference to the program to run + :param entrypoint: Entrypoint to run + :param runtime: Runtime to use + :param environment_variables: Environment variables to pass to the program + :param storage_engine: Storage engine to use (Default: "storage") + :param channel: Channel to use (Default: "TEST") + :param address: Address to use (Default: account.get_address()) + :param sync: If true, waits for the message to be processed by the API server + :param memory: Memory in MB for the VM to be allocated (Default: 128) + :param vcpus: Number of vCPUs to allocate (Default: 1) + :param timeout_seconds: Timeout in seconds (Default: 30.0) + :param persistent: Whether the program should be persistent or not (Default: False) + :param encoding: Encoding to use (Default: Encoding.zip) + :param volumes: Volumes to mount + :param subscriptions: Patterns of Aleph messages to forward to the program's event receiver + """ + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + volumes = volumes if volumes is not None else [] + memory = memory or settings.DEFAULT_VM_MEMORY + vcpus = vcpus or settings.DEFAULT_VM_VCPUS + timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT + + # TODO: Check that program_ref, runtime and data_ref exist + + # Register the different ways to trigger a VM + if subscriptions: + # Trigger on HTTP calls and on Aleph message subscriptions. + triggers = { + "http": True, + "persistent": persistent, + "message": subscriptions, + } + else: + # Trigger on HTTP calls. + triggers = {"http": True, "persistent": persistent} + + content = ProgramContent( + **{ + "type": "vm-function", + "address": address, + "allow_amend": False, + "code": { + "encoding": encoding, + "entrypoint": entrypoint, + "ref": program_ref, + "use_latest": True, + }, + "on": triggers, + "environment": { + "reproducible": False, + "internet": True, + "aleph_api": True, + }, + "variables": environment_variables, + "resources": { + "vcpus": vcpus, + "memory": memory, + "seconds": timeout_seconds, + }, + "runtime": { + "ref": runtime, + "use_latest": True, + "comment": "Official Aleph runtime" + if runtime == settings.DEFAULT_RUNTIME_ID + else "", + }, + "volumes": volumes, + "time": time.time(), + } + ) + + # Ensure that the version of aleph-message used supports the field. + assert content.on.persistent == persistent + + return await self.submit( + content=content.dict(exclude_none=True), + message_type=MessageType.program, + channel=channel, + storage_engine=storage_engine, + sync=sync, + ) + + async def forget( + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + ) -> Tuple[ForgetMessage, MessageStatus]: + """ + Post a FORGET message to remove previous messages from the network. + + Targeted messages need to be signed by the same account that is attempting to forget them, + if the creating address did not delegate the access rights to the forgetting account. + + :param hashes: Hashes of the messages to forget + :param reason: Reason for forgetting the messages + :param storage_engine: Storage engine to use (Default: "storage") + :param channel: Channel to use (Default: "TEST") + :param address: Address to use (Default: account.get_address()) + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + content = ForgetContent( + hashes=hashes, + reason=reason, + address=address, + time=time.time(), + ) + + return await self.submit( + content=content.dict(exclude_none=True), + message_type=MessageType.forget, + channel=channel, + storage_engine=storage_engine, + allow_inlining=True, + sync=sync, + ) + + @staticmethod + def compute_sha256(s: str) -> str: + h = hashlib.sha256() + h.update(s.encode("utf-8")) + return h.hexdigest() + + async def _prepare_aleph_message( + self, + message_type: MessageType, + content: Dict[str, Any], + channel: Optional[str], + allow_inlining: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + ) -> AlephMessage: + + message_dict: Dict[str, Any] = { + "sender": self.account.get_address(), + "chain": self.account.CHAIN, + "type": message_type, + "content": content, + "time": time.time(), + "channel": channel, + } + + item_content: str = json.dumps(content, separators=(",", ":")) + + if allow_inlining and (len(item_content) < settings.MAX_INLINE_SIZE): + message_dict["item_content"] = item_content + message_dict["item_hash"] = self.compute_sha256(item_content) + message_dict["item_type"] = ItemType.inline + else: + if storage_engine == StorageEnum.ipfs: + message_dict["item_hash"] = await self.ipfs_push( + content=content, + ) + message_dict["item_type"] = ItemType.ipfs + else: # storage + assert storage_engine == StorageEnum.storage + message_dict["item_hash"] = await self.storage_push( + content=content, + ) + message_dict["item_type"] = ItemType.storage + + message_dict = await self.account.sign_message(message_dict) + return Message(**message_dict) + + async def submit( + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + + message = await self._prepare_aleph_message( + message_type=message_type, + content=content, + channel=channel, + allow_inlining=allow_inlining, + storage_engine=storage_engine, + ) + message_status = await self._broadcast(message=message, sync=sync) + return message, message_status diff --git a/src/aleph_client/vm/cache.py b/src/aleph_client/vm/cache.py index 08429ea7..1fceb8f2 100644 --- a/src/aleph_client/vm/cache.py +++ b/src/aleph_client/vm/cache.py @@ -1,14 +1,13 @@ -import re -import fnmatch import abc +import fnmatch +import re from typing import Union, Optional, Any, Dict, List, NewType -from aiohttp import ClientSession +import aiohttp +from pydantic import AnyHttpUrl -from aleph_client.asynchronous import get_fallback_session from ..conf import settings - CacheKey = NewType("CacheKey", str) @@ -45,16 +44,18 @@ async def keys(self, pattern: str = "*") -> List[str]: class VmCache(BaseVmCache): """Virtual Machines can use this cache to store temporary data in memory on the host.""" - session: ClientSession + session: aiohttp.ClientSession cache: Dict[str, bytes] api_host: str def __init__( - self, session: Optional[ClientSession] = None, api_host: Optional[str] = None + self, + session: Optional[aiohttp.ClientSession] = None, + connector_url: Optional[AnyHttpUrl] = None, ): - self.session = session or get_fallback_session() + self.session = session or aiohttp.ClientSession(base_url=connector_url) self.cache = {} - self.api_host = api_host if api_host else settings.API_HOST + self.api_host = connector_url if connector_url else settings.API_HOST async def get(self, key: str) -> Optional[bytes]: sanitized_key = sanitize_cache_key(key) diff --git a/tests/integration/itest_aggregates.py b/tests/integration/itest_aggregates.py index 7a7acb20..30614703 100644 --- a/tests/integration/itest_aggregates.py +++ b/tests/integration/itest_aggregates.py @@ -2,10 +2,6 @@ import json from typing import Dict -from aleph_client.asynchronous import ( - create_aggregate, - fetch_aggregate, -) from aleph_client.user_session import AuthenticatedUserSession from tests.integration.toolkit import try_until from .config import REFERENCE_NODE, TARGET_NODE @@ -21,9 +17,10 @@ async def create_aggregate_on_target( receiver_node: str, channel="INTEGRATION_TESTS", ): - async with AuthenticatedUserSession(account=account, api_server=emitter_node) as tx_session: - aggregate_message, message_status = await create_aggregate( - session=tx_session, + async with AuthenticatedUserSession( + account=account, api_server=emitter_node + ) as tx_session: + aggregate_message, message_status = await tx_session.create_aggregate( key=key, content=content, channel="INTEGRATION_TESTS", @@ -40,11 +37,12 @@ async def create_aggregate_on_target( assert aggregate_message.content.address == account.get_address() assert aggregate_message.content.content == content - async with AuthenticatedUserSession(account=account, api_server=receiver_node) as rx_session: + async with AuthenticatedUserSession( + account=account, api_server=receiver_node + ) as rx_session: aggregate_from_receiver = await try_until( - fetch_aggregate, + rx_session.fetch_aggregate, lambda aggregate: aggregate is not None, - session=rx_session, timeout=5, address=account.get_address(), key=key, diff --git a/tests/integration/itest_forget.py b/tests/integration/itest_forget.py index 2153283b..09659589 100644 --- a/tests/integration/itest_forget.py +++ b/tests/integration/itest_forget.py @@ -2,7 +2,6 @@ import pytest -from aleph_client.asynchronous import create_post, get_posts, get_messages, forget from aleph_client.types import Account from aleph_client.user_session import AuthenticatedUserSession from .config import REFERENCE_NODE, TARGET_NODE, TEST_CHANNEL @@ -17,18 +16,20 @@ async def wait_matching_posts( condition: Callable[[Dict], bool], timeout: int = 5, ): - async with AuthenticatedUserSession(account=account, api_server=receiver_node) as rx_session: + async with AuthenticatedUserSession( + account=account, api_server=receiver_node + ) as rx_session: return await try_until( - get_posts, + rx_session.get_posts, condition, - session=rx_session, timeout=timeout, hashes=[item_hash], ) - async with AuthenticatedUserSession(account=account, api_server=emitter_node) as tx_session: - post_message, message_status = await create_post( - session=tx_session, + async with AuthenticatedUserSession( + account=account, api_server=emitter_node + ) as tx_session: + post_message, message_status = await tx_session.create_post( post_content="A considerate and politically correct post.", post_type="POST", channel="INTEGRATION_TESTS", @@ -44,9 +45,10 @@ async def wait_matching_posts( post_hash = post_message.item_hash reason = "This well thought-out content offends me!" - async with AuthenticatedUserSession(account=account, api_server=emitter_node) as tx_session: - forget_message, forget_status = await forget( - session=tx_session, + async with AuthenticatedUserSession( + account=account, api_server=emitter_node + ) as tx_session: + forget_message, forget_status = await tx_session.forget( hashes=[post_hash], reason=reason, channel=channel, @@ -100,14 +102,15 @@ async def test_forget_a_forget_message(fixture_account): # TODO: this test should be moved to the PyAleph API tests, once a framework is in place. post_hash = await create_and_forget_post(fixture_account, TARGET_NODE, TARGET_NODE) - async with AuthenticatedUserSession(account=fixture_account, api_server=TARGET_NODE) as session: - get_post_response = await get_posts(session=session, hashes=[post_hash]) + async with AuthenticatedUserSession( + account=fixture_account, api_server=TARGET_NODE + ) as session: + get_post_response = await session.get_posts(hashes=[post_hash]) assert len(get_post_response["posts"]) == 1 post = get_post_response["posts"][0] forget_message_hash = post["forgotten_by"][0] - forget_message, forget_status = await forget( - session=session, + forget_message, forget_status = await session.forget( hashes=[forget_message_hash], reason="I want to remember this post. Maybe I can forget I forgot it?", channel=TEST_CHANNEL, @@ -115,8 +118,7 @@ async def test_forget_a_forget_message(fixture_account): print(forget_message) - get_forget_message_response = await get_messages( - session=session, + get_forget_message_response = await session.get_messages( hashes=[forget_message_hash], channels=[TEST_CHANNEL], ) diff --git a/tests/integration/itest_posts.py b/tests/integration/itest_posts.py index 1344ad92..7ad0184f 100644 --- a/tests/integration/itest_posts.py +++ b/tests/integration/itest_posts.py @@ -1,11 +1,6 @@ import pytest -from aleph_message import Message -from aleph_message.models import PostMessage, MessagesResponse +from aleph_message.models import MessagesResponse -from aleph_client.asynchronous import ( - create_post, - get_messages, -) from aleph_client.user_session import AuthenticatedUserSession from tests.integration.toolkit import try_until from .config import REFERENCE_NODE, TARGET_NODE @@ -17,9 +12,10 @@ async def create_message_on_target( """ Create a POST message on the target node, then fetch it from the reference node. """ - async with AuthenticatedUserSession(account=fixture_account, api_server=emitter_node) as tx_session: - post_message, message_status = await create_post( - session=tx_session, + async with AuthenticatedUserSession( + account=fixture_account, api_server=emitter_node + ) as tx_session: + post_message, message_status = await tx_session.create_post( post_content=None, post_type="POST", channel="INTEGRATION_TESTS", @@ -28,11 +24,12 @@ async def create_message_on_target( def response_contains_messages(response: MessagesResponse) -> bool: return len(response.messages) > 0 - async with AuthenticatedUserSession(account=fixture_account, api_server=receiver_node) as rx_session: + async with AuthenticatedUserSession( + account=fixture_account, api_server=receiver_node + ) as rx_session: responses = await try_until( - get_messages, + rx_session.get_messages, response_contains_messages, - session=rx_session, timeout=5, hashes=[post_message.item_hash], ) diff --git a/tests/unit/test_asynchronous.py b/tests/unit/test_asynchronous.py index b5ef2e36..74ae7e19 100644 --- a/tests/unit/test_asynchronous.py +++ b/tests/unit/test_asynchronous.py @@ -1,4 +1,5 @@ -from unittest.mock import patch, AsyncMock +import json +from unittest.mock import patch, AsyncMock, MagicMock import pytest as pytest from aleph_message.models import ( @@ -9,34 +10,48 @@ ForgetMessage, ) -from aleph_client.asynchronous import ( - create_post, - create_aggregate, - create_store, - create_program, - forget, -) +from aleph_client import AuthenticatedUserSession from aleph_client.types import StorageEnum, MessageStatus, Account @pytest.fixture -def mock_session_with_post_success(mocker, ethereum_account: Account): - mock_response = mocker.AsyncMock() - mock_response.status = 202 - mock_response.json.return_value = { - "message_status": "pending", - "publication_status": {"status": "success", "failed": []}, - } - - mock_post = mocker.AsyncMock() - mock_post.return_value = mock_response - - mock_session = mocker.MagicMock() - mock_session.post.return_value.__aenter__ = mock_post +def mock_session_with_post_success( + ethereum_account: Account, +) -> AuthenticatedUserSession: + class MockResponse: + def __init__(self, sync: bool): + self.sync = sync + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 if self.sync else 202 + + async def json(self): + message_status = "processed" if self.sync else "pending" + return { + "message_status": message_status, + "publication_status": {"status": "success", "failed": []}, + } + + async def text(self): + return json.dumps(await self.json()) + + http_session = AsyncMock() + http_session.post = MagicMock() + http_session.post.side_effect = lambda *args, **kwargs: MockResponse( + sync=kwargs.get("sync", False) + ) - user_session = mocker.AsyncMock() - user_session.http_session = mock_session - user_session.account = ethereum_account + user_session = AuthenticatedUserSession( + account=ethereum_account, api_server="http://localhost" + ) + user_session.http_session = http_session return user_session @@ -44,17 +59,17 @@ def mock_session_with_post_success(mocker, ethereum_account: Account): @pytest.mark.asyncio async def test_create_post(mock_session_with_post_success): - mock_session = mock_session_with_post_success - content = {"Hello": "World"} + async with mock_session_with_post_success as session: + content = {"Hello": "World"} - post_message, message_status = await create_post( - session=mock_session, - post_content=content, - post_type="TEST", - channel="TEST", - ) + post_message, message_status = await session.create_post( + post_content=content, + post_type="TEST", + channel="TEST", + sync=False, + ) - assert mock_session.http_session.post.called + assert mock_session_with_post_success.http_session.post.called_once assert isinstance(post_message, PostMessage) assert message_status == MessageStatus.PENDING @@ -62,37 +77,34 @@ async def test_create_post(mock_session_with_post_success): @pytest.mark.asyncio async def test_create_aggregate(mock_session_with_post_success): - mock_session = mock_session_with_post_success + async with mock_session_with_post_success as session: - aggregate_message, message_status = await create_aggregate( - session=mock_session, - key="hello", - content={"Hello": "world"}, - channel="TEST", - ) + aggregate_message, message_status = await session.create_aggregate( + key="hello", + content={"Hello": "world"}, + channel="TEST", + ) - assert mock_session.http_session.post.called + assert mock_session_with_post_success.http_session.post.called_once assert isinstance(aggregate_message, AggregateMessage) @pytest.mark.asyncio async def test_create_store(mock_session_with_post_success): - mock_session = mock_session_with_post_success - mock_ipfs_push_file = AsyncMock() mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" - with patch("aleph_client.asynchronous.ipfs_push_file", mock_ipfs_push_file): - _ = await create_store( - session=mock_session, + mock_session_with_post_success.ipfs_push_file = mock_ipfs_push_file + + async with mock_session_with_post_success as session: + _ = await session.create_store( file_content=b"HELLO", channel="TEST", storage_engine=StorageEnum.ipfs, ) - _ = await create_store( - session=mock_session, + _ = await session.create_store( file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", channel="TEST", storage_engine=StorageEnum.ipfs, @@ -102,47 +114,44 @@ async def test_create_store(mock_session_with_post_success): mock_storage_push_file.return_value = ( "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" ) + mock_session_with_post_success.storage_push_file = mock_storage_push_file + async with mock_session_with_post_success as session: - with patch("aleph_client.asynchronous.storage_push_file", mock_storage_push_file): - store_message, message_status = await create_store( - session=mock_session, + store_message, message_status = await session.create_store( file_content=b"HELLO", channel="TEST", storage_engine=StorageEnum.storage, ) - assert mock_session.http_session.post.called + assert mock_session_with_post_success.http_session.post.called assert isinstance(store_message, StoreMessage) @pytest.mark.asyncio async def test_create_program(mock_session_with_post_success): - mock_session = mock_session_with_post_success + async with mock_session_with_post_success as session: - program_message, message_status = await create_program( - session=mock_session, - program_ref="FAKE-HASH", - entrypoint="main:app", - runtime="FAKE-HASH", - channel="TEST", - ) + program_message, message_status = await session.create_program( + program_ref="FAKE-HASH", + entrypoint="main:app", + runtime="FAKE-HASH", + channel="TEST", + ) - assert mock_session.http_session.post.called + assert mock_session_with_post_success.http_session.post.called_once assert isinstance(program_message, ProgramMessage) @pytest.mark.asyncio async def test_forget(mock_session_with_post_success): - mock_session = mock_session_with_post_success - - forget_message, message_status = await forget( - session=mock_session, - hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"], - reason="GDPR", - channel="TEST", - ) + async with mock_session_with_post_success as session: + forget_message, message_status = await session.forget( + hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"], + reason="GDPR", + channel="TEST", + ) - assert mock_session.http_session.post.called + assert mock_session_with_post_success.http_session.post.called_once assert isinstance(forget_message, ForgetMessage) diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index f69a8ee0..958ba6da 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -1,33 +1,37 @@ import unittest from typing import Any, Dict -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock import pytest from aleph_message.models import MessageType, MessagesResponse -from aleph_client.asynchronous import ( - get_messages, - fetch_aggregates, - fetch_aggregate, -) from aleph_client.conf import settings from aleph_client.user_session import UserSession -def make_mock_session(get_return_value: Dict[str, Any]): +def make_mock_session(get_return_value: Dict[str, Any]) -> UserSession: + class MockResponse: + async def __aenter__(self): + return self - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json = AsyncMock(side_effect=lambda: get_return_value) + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... - mock_get = AsyncMock() - mock_get.return_value = mock_response + @property + def status(self): + return 200 - mock_session = MagicMock() - mock_session.get.return_value.__aenter__ = mock_get + async def json(self): + return get_return_value - user_session = AsyncMock() - user_session.http_session = mock_session + class MockHttpSession(AsyncMock): + def get(self, *_args, **_kwargs): + return MockResponse() + + http_session = MockHttpSession() + + user_session = UserSession(api_server="http://localhost") + user_session.http_session = http_session return user_session @@ -37,12 +41,12 @@ async def test_fetch_aggregate(): mock_session = make_mock_session( {"data": {"corechannel": {"nodes": [], "resource_nodes": []}}} ) + async with mock_session: - response = await fetch_aggregate( - session=mock_session, - address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10", - key="corechannel", - ) + response = await mock_session.fetch_aggregate( + address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10", + key="corechannel", + ) assert response.keys() == {"nodes", "resource_nodes"} @@ -52,19 +56,18 @@ async def test_fetch_aggregates(): {"data": {"corechannel": {"nodes": [], "resource_nodes": []}}} ) - response = await fetch_aggregates( - session=mock_session, address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10" - ) - assert response.keys() == {"corechannel"} - assert response["corechannel"].keys() == {"nodes", "resource_nodes"} + async with mock_session: + response = await mock_session.fetch_aggregates( + address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10" + ) + assert response.keys() == {"corechannel"} + assert response["corechannel"].keys() == {"nodes", "resource_nodes"} @pytest.mark.asyncio async def test_get_posts(): async with UserSession(api_server=settings.API_HOST) as session: - response: MessagesResponse = await get_messages( - session=session, - pagination=2, + response: MessagesResponse = await session.get_messages( message_type=MessageType.post, ) @@ -77,8 +80,7 @@ async def test_get_posts(): @pytest.mark.asyncio async def test_get_messages(): async with UserSession(api_server=settings.API_HOST) as session: - response: MessagesResponse = await get_messages( - session=session, + response: MessagesResponse = await session.get_messages( pagination=2, ) diff --git a/tests/unit/test_chain_solana.py b/tests/unit/test_chain_solana.py index e62cc7ef..e511f59a 100644 --- a/tests/unit/test_chain_solana.py +++ b/tests/unit/test_chain_solana.py @@ -33,7 +33,9 @@ def test_get_fallback_account(): @pytest.mark.asyncio async def test_SOLAccount(solana_account): - message = asdict(Message("SOL", solana_account.get_address(), "SomeType", "ItemHash")) + message = asdict( + Message("SOL", solana_account.get_address(), "SomeType", "ItemHash") + ) initial_message = message.copy() await solana_account.sign_message(message) assert message["signature"] diff --git a/tests/unit/test_synchronous_get.py b/tests/unit/test_synchronous_get.py index a8a281a3..b6fd1db4 100644 --- a/tests/unit/test_synchronous_get.py +++ b/tests/unit/test_synchronous_get.py @@ -1,14 +1,12 @@ from aleph_message.models import MessageType, MessagesResponse from aleph_client.conf import settings -from aleph_client.synchronous import get_messages from aleph_client.user_session import UserSession def test_get_posts(): with UserSession(api_server=settings.API_HOST) as session: - response: MessagesResponse = get_messages( - session=session, + response: MessagesResponse = session.get_messages( pagination=2, message_type=MessageType.post, )