diff --git a/examples/httpgateway.py b/examples/httpgateway.py index 304215c4..da053c95 100644 --- a/examples/httpgateway.py +++ b/examples/httpgateway.py @@ -2,22 +2,14 @@ """ # -*- coding: utf-8 -*- -import os - -# import requests -import platform - -# import socket -import time import asyncio + import click -from aleph_client.asynchronous import create_post +from aiohttp import web -# from aleph_client.chains.nuls1 import NULSAccount, get_fallback_account -from aleph_client.chains.ethereum import ETHAccount from aleph_client.chains.common import get_fallback_private_key - -from aiohttp import web +from aleph_client.chains.ethereum import ETHAccount +from aleph_client.user_session import AuthenticatedUserSession app = web.Application() routes = web.RouteTableDef() @@ -42,13 +34,14 @@ async def source_post(request): return web.json_response( {"status": "error", "message": "unauthorized secret"} ) - message, _status = await create_post( - app["account"], - data, - "event", - channel=app["channel"], - api_server="https://api2.aleph.im", - ) + async with AuthenticatedUserSession( + account=app["account"], api_server="https://api2.aleph.im" + ) as session: + message, _status = await session.create_post( + post_content=data, + post_type="event", + channel=app["channel"], + ) return web.json_response({"status": "success", "item_hash": message.item_hash}) diff --git a/examples/metrics.py b/examples/metrics.py index 1a2733b6..4ff553e1 100644 --- a/examples/metrics.py +++ b/examples/metrics.py @@ -5,11 +5,16 @@ import os import platform import time +from typing import Tuple import psutil +from aleph_message.models import AlephMessage +from aleph_client import AuthenticatedUserSession from aleph_client.chains.ethereum import get_fallback_account -from aleph_client.synchronous import create_aggregate +from aleph_client.conf import settings +from aleph_client.types import MessageStatus +from aleph_client.user_session import AuthenticatedUserSessionSync def get_sysinfo(): @@ -49,8 +54,10 @@ def get_cpu_cores(): return [c._asdict() for c in psutil.cpu_times_percent(0, percpu=True)] -def send_metrics(account, metrics): - return create_aggregate(account, "metrics", metrics, channel="SYSINFO") +def send_metrics( + session: AuthenticatedUserSessionSync, metrics +) -> Tuple[AlephMessage, MessageStatus]: + return session.create_aggregate(key="metrics", content=metrics, channel="SYSINFO") def collect_metrics(): @@ -64,11 +71,14 @@ def collect_metrics(): def main(): account = get_fallback_account() - while True: - metrics = collect_metrics() - message, status = send_metrics(account, metrics) - print("sent", message.item_hash) - time.sleep(10) + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + while True: + metrics = collect_metrics() + message, status = send_metrics(session, metrics) + print("sent", message.item_hash) + time.sleep(10) if __name__ == "__main__": diff --git a/examples/mqtt.py b/examples/mqtt.py index 3bc8bec1..3df400e0 100644 --- a/examples/mqtt.py +++ b/examples/mqtt.py @@ -10,7 +10,8 @@ from aleph_client.chains.common import get_fallback_private_key from aleph_client.chains.ethereum import ETHAccount -from aleph_client.main import create_aggregate +from aleph_client import AuthenticatedUserSession +from aleph_client.conf import settings def get_input_data(value): @@ -26,7 +27,12 @@ def get_input_data(value): def send_metrics(account, metrics): - return create_aggregate(account, "metrics", metrics, channel="SYSINFO") + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + return session.create_aggregate( + key="metrics", content=metrics, channel="SYSINFO" + ) def on_disconnect(client, userdata, rc): @@ -95,10 +101,15 @@ async def gateway( if not userdata["received"]: await client.reconnect() - for key, value in state.items(): - message, status = create_aggregate(account, key, value, channel="IOT_TEST") - print("sent", message.item_hash) - userdata["received"] = False + async with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + for key, value in state.items(): + message, status = await session.create_aggregate( + key=key, content=value, channel="IOT_TEST" + ) + print("sent", message.item_hash) + userdata["received"] = False @click.command() diff --git a/examples/store.py b/examples/store.py index 493292ed..26374f0f 100644 --- a/examples/store.py +++ b/examples/store.py @@ -1,13 +1,13 @@ import asyncio -import aiohttp import click from aleph_message.models import StoreMessage -from aleph_client.asynchronous import create_store from aleph_client.chains.common import get_fallback_private_key from aleph_client.chains.ethereum import ETHAccount +from aleph_client.conf import settings from aleph_client.types import MessageStatus +from aleph_client.user_session import AuthenticatedUserSession DEFAULT_SERVER = "https://api2.aleph.im" @@ -23,7 +23,9 @@ async def print_output_hash(message: StoreMessage, status: MessageStatus): async def do_upload(account, engine, channel, filename=None, file_hash=None): - async with aiohttp.ClientSession() as session: + async with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: print(filename, account.get_address()) if filename: try: @@ -33,24 +35,20 @@ async def do_upload(account, engine, channel, filename=None, file_hash=None): if len(content) > 4 * 1024 * 1024 and engine == "STORAGE": print("File too big for native STORAGE engine") return - message, status = await create_store( - account, + message, status = await session.create_store( file_content=content, channel=channel, storage_engine=engine.lower(), - session=session, ) except IOError: print("File not accessible") raise elif file_hash: - message, status = await create_store( - account, + message, status = await session.create_store( file_hash=file_hash, channel=channel, storage_engine=engine.lower(), - session=session, ) await print_output_hash(message, status) diff --git a/setup.cfg b/setup.cfg index 51c7bed4..8268da61 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,6 +60,7 @@ testing = pytest pytest-cov pytest-asyncio + pytest-mock mypy secp256k1 pynacl @@ -69,6 +70,7 @@ testing = httpx requests aleph-pytezos==0.1.0 + types-certifi types-setuptools black mqtt = diff --git a/src/aleph_client/__init__.py b/src/aleph_client/__init__.py index 34225eca..ff0fb1bf 100644 --- a/src/aleph_client/__init__.py +++ b/src/aleph_client/__init__.py @@ -1,4 +1,5 @@ from pkg_resources import get_distribution, DistributionNotFound +from .user_session import AuthenticatedUserSession, UserSession try: # Change here if project is renamed and does not equal the package name @@ -8,3 +9,6 @@ __version__ = "unknown" finally: del get_distribution, DistributionNotFound + + +__all__ = ["AuthenticatedUserSession", "UserSession"] diff --git a/src/aleph_client/asynchronous.py b/src/aleph_client/asynchronous.py deleted file mode 100644 index 5e441b88..00000000 --- a/src/aleph_client/asynchronous.py +++ /dev/null @@ -1,1111 +0,0 @@ -""" This is the simplest aleph network client available. -""" -import asyncio -import hashlib -import json -import logging -import queue -import threading -import time -from datetime import datetime -from functools import lru_cache -from pathlib import Path -from typing import ( - Optional, - Union, - Any, - Dict, - List, - Iterable, - AsyncIterable, -) -from typing import Type, Mapping, Tuple, NoReturn - -from aleph_message.models import ( - ForgetContent, - MessageType, - AggregateContent, - PostContent, - StoreContent, - PostMessage, - Message, - ForgetMessage, - AlephMessage, - AggregateMessage, - StoreMessage, - ProgramMessage, - ItemType, -) -from pydantic import ValidationError - -from aleph_client.types import Account, StorageEnum, GenericMessage, MessageStatus -from .exceptions import ( - MessageNotFoundError, - MultipleMessagesError, - InvalidMessageError, - BroadcastError, -) -from .models import MessagesResponse -from .utils import get_message_type_value - -logger = logging.getLogger(__name__) - -try: - import magic -except ImportError: - logger.info("Could not import library 'magic', MIME type detection disabled") - magic = None # type:ignore - -from .conf import settings - -import aiohttp -from aiohttp import ClientSession - -from aleph_message.models.program import ProgramContent, Encoding - - -@lru_cache() -def _get_fallback_session(thread_id: Optional[int]) -> ClientSession: - if settings.API_UNIX_SOCKET: - connector = aiohttp.UnixConnector(path=settings.API_UNIX_SOCKET) - return aiohttp.ClientSession(connector=connector) - else: - return aiohttp.ClientSession() - - -def get_fallback_session() -> ClientSession: - thread_id = threading.get_native_id() - return _get_fallback_session(thread_id=thread_id) - - -async def ipfs_push( - content: Mapping, - session: ClientSession, - api_server: str, -) -> str: - """Push arbitrary content as JSON to the IPFS service.""" - url = f"{api_server}/api/v0/ipfs/add_json" - logger.debug(f"Pushing to IPFS on {url}") - - async with session.post(url, json=content) as resp: - resp.raise_for_status() - return (await resp.json()).get("hash") - - -async def storage_push( - content: Mapping, - session: ClientSession, - api_server: str, -) -> str: - """Push arbitrary content as JSON to the storage service.""" - url = f"{api_server}/api/v0/storage/add_json" - logger.debug(f"Pushing to storage on {url}") - - async with session.post(url, json=content) as resp: - resp.raise_for_status() - return (await resp.json()).get("hash") - - -async def ipfs_push_file( - file_content, - session: ClientSession, - api_server: str, -) -> str: - """Push a file to the IPFS service.""" - data = aiohttp.FormData() - data.add_field("file", file_content) - - url = f"{api_server}/api/v0/ipfs/add_file" - logger.debug(f"Pushing file to IPFS on {url}") - - async with session.post(url, data=data) as resp: - resp.raise_for_status() - return (await resp.json()).get("hash") - - -async def storage_push_file( - file_content, - session: ClientSession, - api_server: str, -) -> str: - """Push a file to the storage service.""" - data = aiohttp.FormData() - data.add_field("file", file_content) - - url = f"{api_server}/api/v0/storage/add_file" - logger.debug(f"Posting file on {url}") - - async with session.post(url, data=data) as resp: - resp.raise_for_status() - return (await resp.json()).get("hash") - - -def _log_publication_status(publication_status: Mapping[str, Any]): - status = publication_status.get("status") - failures = publication_status.get("failed") - - if status == "success": - return - elif status == "warning": - logger.warning("Broadcast failed on the following network(s): %s", failures) - elif status == "error": - logger.error( - "Broadcast failed on all protocols. The message was not published." - ) - else: - raise ValueError( - f"Invalid response from server, status in missing or unknown: '{status}'" - ) - - -async def _handle_broadcast_error(response: aiohttp.ClientResponse) -> NoReturn: - if response.status == 500: - # Assume a broadcast error, no need to read the JSON - if response.content_type == "application/json": - error_msg = "Internal error - broadcast failed on all protocols" - else: - error_msg = f"Internal error - the message was not broadcast: {await response.text()}" - - logger.error(error_msg) - raise BroadcastError(error_msg) - elif response.status == 422: - errors = await response.json() - logger.error( - "The message could not be processed because of the following errors: %s", - errors, - ) - raise InvalidMessageError(errors) - else: - error_msg = ( - f"Unexpected HTTP response ({response.status}: {await response.text()})" - ) - logger.error(error_msg) - raise BroadcastError(error_msg) - - -async def _handle_broadcast_deprecated_response( - response: aiohttp.ClientResponse, -) -> None: - if response.status != 200: - await _handle_broadcast_error(response) - else: - publication_status = await response.json() - _log_publication_status(publication_status) - - -async def _broadcast_deprecated( - message_dict: Mapping[str, Any], - session: ClientSession, - api_server: str = settings.API_HOST, -): - - """ - Broadcast a message on the Aleph network using the deprecated - /ipfs/pubsub/pub/ endpoint. - """ - - url = f"{api_server}/api/v0/ipfs/pubsub/pub" - logger.debug(f"Posting message on {url}") - - async with session.post( - url, - json={"topic": "ALEPH-TEST", "data": json.dumps(message_dict)}, - ) as response: - await _handle_broadcast_deprecated_response(response) - - -async def _handle_broadcast_response( - response: aiohttp.ClientResponse, sync: bool -) -> MessageStatus: - if response.status in (200, 202): - status = await response.json() - _log_publication_status(status["publication_status"]) - - if response.status == 202: - if sync: - logger.warning("Timed out while waiting for processing of sync message") - return MessageStatus.PENDING - - return MessageStatus.PROCESSED - - else: - await _handle_broadcast_error(response) - - -BROADCAST_MESSAGE_FIELDS = { - "sender", - "chain", - "signature", - "type", - "item_hash", - "item_type", - "item_content", - "time", - "channel", -} - - -async def _broadcast( - message: AlephMessage, - sync: bool, - session: ClientSession, - api_server: str, -) -> MessageStatus: - """ - Broadcast a message on the Aleph network. - - Uses the POST /messages/ endpoint or the deprecated /ipfs/pubsub/pub/ endpoint - if the first method is not available. - """ - - url = f"{api_server}/api/v0/messages" - logger.debug(f"Posting message on {url}") - - message_dict = message.dict(include=BROADCAST_MESSAGE_FIELDS) - - async with session.post( - url, - json={"sync": sync, "message": message_dict}, - ) as response: - # The endpoint may be unavailable on this node, try the deprecated version. - if response.status == 404: - logger.warning( - "POST /messages/ not found. Defaulting to legacy endpoint..." - ) - await _broadcast_deprecated( - message_dict=message_dict, session=session, api_server=api_server - ) - return MessageStatus.PENDING - else: - message_status = await _handle_broadcast_response( - response=response, sync=sync - ) - return message_status - - -async def create_post( - account: Account, - post_content, - post_type: str, - ref: Optional[str] = None, - address: Optional[str] = None, - channel: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, - inline: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, - sync: bool = False, -) -> Tuple[PostMessage, MessageStatus]: - """ - Create a POST message on the Aleph network. It is associated with a channel and owned by an account. - - :param account: The account that will sign and own the message - :param post_content: The content of the message - :param post_type: An arbitrary content type that helps to describe the post_content - :param ref: A reference to a previous message that it replaces - :param address: The address that will be displayed as the author of the message - :param channel: The channel that the message will be posted on - :param session: An optional aiohttp session to use for the request - :param api_server: An optional API server to use for the request (Default: "https://api2.aleph.im") - :param inline: An optional flag to indicate if the content should be inlined in the message or not - :param storage_engine: An optional storage engine to use for the message, if not inlined (Default: "storage") - :param sync: If true, waits for the message to be processed by the API server (Default: False) - """ - address = address or settings.ADDRESS_TO_USE or account.get_address() - api_server = api_server or settings.API_HOST - - content = PostContent( - type=post_type, - address=address, - content=post_content, - time=time.time(), - ref=ref, - ) - - return await submit( - account=account, - content=content.dict(exclude_none=True), - message_type=MessageType.post, - channel=channel, - api_server=api_server, - session=session, - allow_inlining=inline, - storage_engine=storage_engine, - sync=sync, - ) - - -async def create_aggregate( - account: Account, - key, - content, - address: Optional[str] = None, - channel: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, - inline: bool = True, - sync: bool = False, -) -> Tuple[AggregateMessage, MessageStatus]: - """ - Create an AGGREGATE message. It is meant to be used as a quick access storage associated with an account. - - :param account: Account to use to sign the message - :param key: Key to use to store the content - :param content: Content to store - :param address: Address to use to sign the message - :param channel: Channel to use (Default: "TEST") - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") - :param inline: Whether to write content inside the message (Default: True) - :param sync: If true, waits for the message to be processed by the API server (Default: False) - """ - address = address or settings.ADDRESS_TO_USE or account.get_address() - api_server = api_server or settings.API_HOST - - content_ = AggregateContent( - key=key, - address=address, - content=content, - time=time.time(), - ) - - return await submit( - account=account, - content=content_.dict(exclude_none=True), - message_type=MessageType.aggregate, - channel=channel, - api_server=api_server, - session=session, - allow_inlining=inline, - sync=sync, - ) - - -async def create_store( - account: Account, - address: Optional[str] = None, - file_content: Optional[bytes] = None, - file_path: Optional[Union[str, Path]] = None, - file_hash: Optional[str] = None, - guess_mime_type: bool = False, - ref: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - extra_fields: Optional[dict] = None, - channel: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, - sync: bool = False, -) -> Tuple[StoreMessage, MessageStatus]: - """ - Create a STORE message to store a file on the Aleph network. - - Can be passed either a file path, an IPFS hash or the file's content as raw bytes. - - :param account: Account to use to sign the message - :param address: Address to display as the author of the message (Default: account.get_address()) - :param file_content: Byte stream of the file to store (Default: None) - :param file_path: Path to the file to store (Default: None) - :param file_hash: Hash of the file to store (Default: None) - :param guess_mime_type: Guess the MIME type of the file (Default: False) - :param ref: Reference to a previous message (Default: None) - :param storage_engine: Storage engine to use (Default: "storage") - :param extra_fields: Extra fields to add to the STORE message (Default: None) - :param channel: Channel to post the message to (Default: "TEST") - :param session: aiohttp session to use (Default: get_fallback_session()) - :param api_server: Aleph API server to use (Default: "https://api2.aleph.im") - :param sync: If true, waits for the message to be processed by the API server (Default: False) - """ - address = address or settings.ADDRESS_TO_USE or account.get_address() - api_server = api_server or settings.API_HOST - - extra_fields = extra_fields or {} - session = session or get_fallback_session() - - if file_hash is None: - if file_content is None: - if file_path is None: - raise ValueError( - "Please specify at least a file_content, a file_hash or a file_path" - ) - else: - file_content = open(file_path, "rb").read() - - if storage_engine == StorageEnum.storage: - file_hash = await storage_push_file( - file_content, session=session, api_server=api_server - ) - elif storage_engine == StorageEnum.ipfs: - file_hash = await ipfs_push_file( - file_content, session=session, api_server=api_server - ) - else: - raise ValueError(f"Unknown storage engine: '{storage_engine}'") - - assert file_hash, "File hash should be empty" - - if magic is None: - pass - elif file_content and guess_mime_type and ("mime_type" not in extra_fields): - extra_fields["mime_type"] = magic.from_buffer(file_content, mime=True) - - if ref: - extra_fields["ref"] = ref - - values = { - "address": address, - "item_type": storage_engine, - "item_hash": file_hash, - "time": time.time(), - } - if extra_fields is not None: - values.update(extra_fields) - - content = StoreContent(**values) - - return await submit( - account=account, - content=content.dict(exclude_none=True), - message_type=MessageType.store, - channel=channel, - api_server=api_server, - session=session, - allow_inlining=True, - sync=sync, - ) - - -async def create_program( - account: Account, - program_ref: str, - entrypoint: str, - runtime: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, - vcpus: Optional[int] = None, - timeout_seconds: Optional[float] = None, - persistent: bool = False, - encoding: Encoding = Encoding.zip, - volumes: Optional[List[Mapping]] = None, - subscriptions: Optional[List[Mapping]] = None, -) -> Tuple[ProgramMessage, MessageStatus]: - """ - Post a (create) PROGRAM message. - - :param account: Account to use to sign the message - :param program_ref: Reference to the program to run - :param entrypoint: Entrypoint to run - :param runtime: Runtime to use - :param environment_variables: Environment variables to pass to the program - :param storage_engine: Storage engine to use (Default: "storage") - :param channel: Channel to use (Default: "TEST") - :param address: Address to use (Default: account.get_address()) - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") - :param sync: If true, waits for the message to be processed by the API server - :param memory: Memory in MB for the VM to be allocated (Default: 128) - :param vcpus: Number of vCPUs to allocate (Default: 1) - :param timeout_seconds: Timeout in seconds (Default: 30.0) - :param persistent: Whether the program should be persistent or not (Default: False) - :param encoding: Encoding to use (Default: Encoding.zip) - :param volumes: Volumes to mount - :param subscriptions: Patterns of Aleph messages to forward to the program's event receiver - """ - address = address or settings.ADDRESS_TO_USE or account.get_address() - api_server = api_server or settings.API_HOST - - volumes = volumes if volumes is not None else [] - memory = memory or settings.DEFAULT_VM_MEMORY - vcpus = vcpus or settings.DEFAULT_VM_VCPUS - timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT - - # TODO: Check that program_ref, runtime and data_ref exist - - # Register the different ways to trigger a VM - if subscriptions: - # Trigger on HTTP calls and on Aleph message subscriptions. - triggers = {"http": True, "persistent": persistent, "message": subscriptions} - else: - # Trigger on HTTP calls. - triggers = {"http": True, "persistent": persistent} - - content = ProgramContent( - **{ - "type": "vm-function", - "address": address, - "allow_amend": False, - "code": { - "encoding": encoding, - "entrypoint": entrypoint, - "ref": program_ref, - "use_latest": True, - }, - "on": triggers, - "environment": { - "reproducible": False, - "internet": True, - "aleph_api": True, - }, - "variables": environment_variables, - "resources": { - "vcpus": vcpus, - "memory": memory, - "seconds": timeout_seconds, - }, - "runtime": { - "ref": runtime, - "use_latest": True, - "comment": "Official Aleph runtime" - if runtime == settings.DEFAULT_RUNTIME_ID - else "", - }, - "volumes": volumes, - "time": time.time(), - } - ) - - # Ensure that the version of aleph-message used supports the field. - assert content.on.persistent == persistent - - return await submit( - account=account, - content=content.dict(exclude_none=True), - message_type=MessageType.program, - channel=channel, - api_server=api_server, - storage_engine=storage_engine, - session=session, - sync=sync, - ) - - -async def forget( - account: Account, - hashes: List[str], - reason: Optional[str], - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, - sync: bool = False, -) -> Tuple[ForgetMessage, MessageStatus]: - """ - Post a FORGET message to remove previous messages from the network. - - Targeted messages need to be signed by the same account that is attempting to forget them, - if the creating address did not delegate the access rights to the forgetting account. - - :param account: Account to use to sign the message - :param hashes: Hashes of the messages to forget - :param reason: Reason for forgetting the messages - :param storage_engine: Storage engine to use (Default: "storage") - :param channel: Channel to use (Default: "TEST") - :param address: Address to use (Default: account.get_address()) - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") - :param sync: If true, waits for the message to be processed by the API server (Default: False) - """ - address = address or settings.ADDRESS_TO_USE or account.get_address() - api_server = api_server or settings.API_HOST - - content = ForgetContent( - hashes=hashes, - reason=reason, - address=address, - time=time.time(), - ) - - return await submit( - account, - content=content.dict(exclude_none=True), - message_type=MessageType.forget, - channel=channel, - api_server=api_server, - storage_engine=storage_engine, - session=session, - allow_inlining=True, - sync=sync, - ) - - -def compute_sha256(s: str) -> str: - h = hashlib.sha256() - h.update(s.encode("utf-8")) - return h.hexdigest() - - -async def _prepare_aleph_message( - account: Account, - message_type: MessageType, - content: Dict[str, Any], - channel: Optional[str], - session: aiohttp.ClientSession, - api_server: str, - allow_inlining: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, -) -> AlephMessage: - - message_dict: Dict[str, Any] = { - "sender": account.get_address(), - "chain": account.CHAIN, - "type": message_type, - "content": content, - "time": time.time(), - "channel": channel, - } - - item_content: str = json.dumps(content, separators=(",", ":")) - - if allow_inlining and (len(item_content) < settings.MAX_INLINE_SIZE): - message_dict["item_content"] = item_content - message_dict["item_hash"] = compute_sha256(item_content) - message_dict["item_type"] = ItemType.inline - else: - if storage_engine == StorageEnum.ipfs: - message_dict["item_hash"] = await ipfs_push( - content, session=session, api_server=api_server - ) - message_dict["item_type"] = ItemType.ipfs - else: # storage - assert storage_engine == StorageEnum.storage - message_dict["item_hash"] = await storage_push( - content, session=session, api_server=api_server - ) - message_dict["item_type"] = ItemType.storage - - message_dict = await account.sign_message(message_dict) - return Message(**message_dict) - - -async def submit( - account: Account, - content: Dict[str, Any], - message_type: MessageType, - api_server: str, - channel: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - session: Optional[ClientSession] = None, - allow_inlining: bool = True, - sync: bool = False, -) -> Tuple[AlephMessage, MessageStatus]: - - session = session or get_fallback_session() - - message = await _prepare_aleph_message( - account=account, - message_type=message_type, - content=content, - channel=channel, - session=session, - api_server=api_server, - allow_inlining=allow_inlining, - storage_engine=storage_engine, - ) - message_status = await _broadcast( - message=message, session=session, api_server=api_server, sync=sync - ) - return message, message_status - - -async def fetch_aggregate( - address: str, - key: str, - limit: int = 100, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, -) -> Dict[str, Dict]: - """ - Fetch a value from the aggregate store by owner address and item key. - - :param address: Address of the owner of the aggregate - :param key: Key of the aggregate - :param limit: Maximum number of items to fetch (Default: 100) - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") - """ - session = session or get_fallback_session() - api_server = api_server or settings.API_HOST - - params: Dict[str, Any] = {"keys": key} - if limit: - params["limit"] = limit - - async with session.get( - f"{api_server}/api/v0/aggregates/{address}.json", params=params - ) as resp: - result = await resp.json() - data = result.get("data", dict()) - return data.get(key) - - -async def fetch_aggregates( - address: str, - keys: Optional[Iterable[str]] = None, - limit: int = 100, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, -) -> Dict[str, Dict]: - """ - Fetch key-value pairs from the aggregate store by owner address. - - :param address: Address of the owner of the aggregate - :param keys: Keys of the aggregates to fetch (Default: all items) - :param limit: Maximum number of items to fetch (Default: 100) - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") - """ - session = session or get_fallback_session() - api_server = api_server or settings.API_HOST - - keys_str = ",".join(keys) if keys else "" - params: Dict[str, Any] = {} - if keys_str: - params["keys"] = keys_str - if limit: - params["limit"] = limit - - async with session.get( - f"{api_server}/api/v0/aggregates/{address}.json", - params=params, - ) as resp: - result = await resp.json() - data = result.get("data", dict()) - return data - - -async def get_posts( - pagination: int = 200, - page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, -) -> Dict[str, Dict]: - """ - Fetch a list of posts from the network. - - :param pagination: Number of items to fetch (Default: 200) - :param page: Page to fetch, begins at 1 (Default: 1) - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") - """ - session = session or get_fallback_session() - api_server = api_server or settings.API_HOST - - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if types is not None: - params["types"] = ",".join(types) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - - async with session.get(f"{api_server}/api/v0/posts.json", params=params) as resp: - resp.raise_for_status() - return await resp.json() - - -async def download_file( - file_hash: str, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, -) -> bytes: - """ - Get a file from the storage engine as raw bytes. - - Warning: Downloading large files can be slow and memory intensive. - - :param file_hash: The hash of the file to retrieve. - :param session: The aiohttp session to use. (Default: get_fallback_session()) - :param api_server: The API server to use. (Default: "https://api2.aleph.im") - """ - session = session or get_fallback_session() - api_server = api_server or settings.API_HOST - - async with session.get(f"{api_server}/api/v0/storage/raw/{file_hash}") as response: - response.raise_for_status() - return await response.read() - - -async def get_messages( - pagination: int = 200, - page: int = 1, - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, - ignore_invalid_messages: bool = True, - invalid_messages_log_level: int = logging.NOTSET, -) -> MessagesResponse: - """ - Fetch a list of messages from the network. - - :param pagination: Number of items to fetch (Default: 200) - :param page: Page to fetch, begins at 1 (Default: 1) - :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by content key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") - :param ignore_invalid_messages: Ignore invalid messages (Default: False) - :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) - """ - session = session or get_fallback_session() - api_server = api_server or settings.API_HOST - ignore_invalid_messages = ( - True if ignore_invalid_messages is None else ignore_invalid_messages - ) - invalid_messages_log_level = ( - logging.NOTSET - if invalid_messages_log_level is None - else invalid_messages_log_level - ) - - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if message_type is not None: - params["msgType"] = message_type.value - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - - async with session.get(f"{api_server}/api/v0/messages.json", params=params) as resp: - resp.raise_for_status() - response_json = await resp.json() - messages_raw = response_json["messages"] - - # All messages may not be valid according to the latest specification in - # aleph-message. This allows the user to specify how errors should be handled. - messages: List[AlephMessage] = [] - for message_raw in messages_raw: - try: - message = Message(**message_raw) - messages.append(message) - except KeyError as e: - if not ignore_invalid_messages: - raise e - logger.log( - level=invalid_messages_log_level, - msg=f"KeyError: Field '{e.args[0]}' not found", - ) - except ValidationError as e: - if not ignore_invalid_messages: - raise e - if invalid_messages_log_level: - logger.log(level=invalid_messages_log_level, msg=e) - - return MessagesResponse( - messages=messages, - pagination_page=response_json["pagination_page"], - pagination_total=response_json["pagination_total"], - pagination_per_page=response_json["pagination_per_page"], - pagination_item=response_json["pagination_item"], - ) - - -async def get_message( - item_hash: str, - message_type: Optional[Type[GenericMessage]] = None, - channel: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, -) -> GenericMessage: - """ - Get a single message from its `item_hash` and perform some basic validation. - - :param item_hash: Hash of the message to fetch - :param message_type: Type of message to fetch - :param channel: Channel of the message to fetch - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") - """ - messages_response = await get_messages( - hashes=[item_hash], - session=session, - channels=[channel] if channel else None, - api_server=api_server, - ) - if len(messages_response.messages) < 1: - raise MessageNotFoundError(f"No such hash {item_hash}") - if len(messages_response.messages) != 1: - raise MultipleMessagesError( - f"Multiple messages found for the same item_hash `{item_hash}`" - ) - message: GenericMessage = messages_response.messages[0] - if message_type: - expected_type = get_message_type_value(message_type) - if message.type != expected_type: - raise TypeError( - f"The message type '{message.type}' " - f"does not match the expected type '{expected_type}'" - ) - return message - - -async def watch_messages( - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, -) -> AsyncIterable[AlephMessage]: - """ - Iterate over current and future matching messages asynchronously. - - :param message_type: Type of message to watch - :param content_types: Content types to watch - :param refs: References to watch - :param addresses: Addresses to watch - :param tags: Tags to watch - :param hashes: Hashes to watch - :param channels: Channels to watch - :param chains: Chains to watch - :param start_date: Start date from when to watch - :param end_date: End date until when to watch - :param session: Session to use (Default: get_fallback_session()) - :param api_server: API server to use (Default: "https://api2.aleph.im") - """ - session = session or get_fallback_session() - api_server = api_server or settings.API_HOST - - params: Dict[str, Any] = dict() - - if message_type is not None: - params["msgType"] = message_type.value - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - - async with session.ws_connect( - f"{api_server}/api/ws0/messages", params=params - ) as ws: - logger.debug("Websocket connected") - async for msg in ws: - if msg.type == aiohttp.WSMsgType.TEXT: - if msg.data == "close cmd": - await ws.close() - break - else: - data = json.loads(msg.data) - yield Message(**data) - elif msg.type == aiohttp.WSMsgType.ERROR: - break - - -async def _run_watch_messages(coroutine: AsyncIterable, output_queue: queue.Queue): - """Forward messages from the coroutine to the synchronous queue""" - async for message in coroutine: - output_queue.put(message) - - -def _start_run_watch_messages(output_queue: queue.Queue, args: List, kwargs: Dict): - """Thread entrypoint to run the `watch_messages` asynchronous generator in a thread.""" - watcher = watch_messages(*args, **kwargs) - runner = _run_watch_messages(watcher, output_queue) - asyncio.run(runner) diff --git a/src/aleph_client/commands/aggregate.py b/src/aleph_client/commands/aggregate.py index 94f73484..98b0a350 100644 --- a/src/aleph_client/commands/aggregate.py +++ b/src/aleph_client/commands/aggregate.py @@ -1,10 +1,11 @@ import typer from typing import Optional + +from aleph_client import UserSession from aleph_client.types import AccountFromPrivateKey from aleph_client.account import _load_account from aleph_client.conf import settings from pathlib import Path -from aleph_client import synchronous from aleph_client.commands import help_strings from aleph_client.commands.message import forget_messages @@ -37,10 +38,11 @@ def forget( account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - message_response = synchronous.get_messages( - addresses=[account.get_address()], - message_type=MessageType.aggregate.value, - content_keys=[key], - ) + with UserSession(api_server=settings.API_HOST) as session: + message_response = session.get_messages( + addresses=[account.get_address()], + message_type=MessageType.aggregate.value, + content_keys=[key], + ) hash_list = [message["item_hash"] for message in message_response.messages] forget_messages(account, hash_list, reason, channel) diff --git a/src/aleph_client/commands/files.py b/src/aleph_client/commands/files.py index 2c09b89b..4b3e7931 100644 --- a/src/aleph_client/commands/files.py +++ b/src/aleph_client/commands/files.py @@ -1,4 +1,3 @@ -import asyncio import logging from pathlib import Path from typing import Optional @@ -6,9 +5,8 @@ import typer from aleph_message.models import StoreMessage -from aleph_client import synchronous +from aleph_client import AuthenticatedUserSession from aleph_client.account import _load_account -from aleph_client.asynchronous import get_fallback_session from aleph_client.commands import help_strings from aleph_client.commands.utils import setup_logging from aleph_client.conf import settings @@ -36,20 +34,17 @@ def pin( setup_logging(debug) account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - - try: - result: StoreMessage = synchronous.create_store( - account=account, + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + message, _status = session.create_store( file_hash=hash, storage_engine=StorageEnum.ipfs, channel=channel, ref=ref, ) - logger.debug("Upload finished") - typer.echo(f"{result.json(indent=4)}") - finally: - # Prevent aiohttp unclosed connector warning - asyncio.run(get_fallback_session().close()) + logger.debug("Upload finished") + typer.echo(f"{message.json(indent=4)}") @app.command() @@ -71,31 +66,29 @@ def upload( account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - try: - if not path.is_file(): - typer.echo(f"Error: File not found: '{path}'") - raise typer.Exit(code=1) + if not path.is_file(): + typer.echo(f"Error: File not found: '{path}'") + raise typer.Exit(code=1) - with open(path, "rb") as fd: - logger.debug("Reading file") - # TODO: Read in lazy mode instead of copying everything in memory - file_content = fd.read() - storage_engine = ( - StorageEnum.ipfs - if len(file_content) > 4 * 1024 * 1024 - else StorageEnum.storage - ) - logger.debug("Uploading file") - result: StoreMessage = synchronous.create_store( - account=account, + with open(path, "rb") as fd: + logger.debug("Reading file") + # TODO: Read in lazy mode instead of copying everything in memory + file_content = fd.read() + storage_engine = ( + StorageEnum.ipfs + if len(file_content) > 4 * 1024 * 1024 + else StorageEnum.storage + ) + logger.debug("Uploading file") + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + message, status = session.create_store( file_content=file_content, storage_engine=storage_engine, channel=channel, guess_mime_type=True, ref=ref, ) - logger.debug("Upload finished") - typer.echo(f"{result.json(indent=4)}") - finally: - # Prevent aiohttp unclosed connector warning - asyncio.run(get_fallback_session().close()) + logger.debug("Upload finished") + typer.echo(f"{message.json(indent=4)}") diff --git a/src/aleph_client/commands/message.py b/src/aleph_client/commands/message.py index 6fb5675c..f0aec3d2 100644 --- a/src/aleph_client/commands/message.py +++ b/src/aleph_client/commands/message.py @@ -1,4 +1,3 @@ -import asyncio import json import os.path import subprocess @@ -7,15 +6,10 @@ from typing import Optional, Dict, List import typer -from aleph_message.models import ( - PostMessage, - ForgetMessage, - AlephMessage, -) +from aleph_message.models import AlephMessage -from aleph_client import synchronous +from aleph_client import AuthenticatedUserSession, UserSession from aleph_client.account import _load_account -from aleph_client.asynchronous import get_fallback_session from aleph_client.commands import help_strings from aleph_client.commands.utils import ( setup_logging, @@ -49,7 +43,6 @@ def post( setup_logging(debug) account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - storage_engine: str content: Dict if path: @@ -78,9 +71,10 @@ def post( typer.echo("Not valid JSON") raise typer.Exit(code=2) - try: - result: PostMessage = synchronous.create_post( - account=account, + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + message, status = session.create_post( post_content=content, post_type=type, ref=ref, @@ -88,10 +82,7 @@ def post( inline=True, storage_engine=storage_engine, ) - typer.echo(result.json(indent=4)) - finally: - # Prevent aiohttp unclosed connector warning - asyncio.run(get_fallback_session().close()) + typer.echo(message.json(indent=4)) @app.command() @@ -110,34 +101,35 @@ def amend( setup_logging(debug) account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - - existing_message: AlephMessage = synchronous.get_message(item_hash=hash) - - editor: str = os.getenv("EDITOR", default="nano") - with tempfile.NamedTemporaryFile(suffix="json") as fd: - # Fill in message template - fd.write(existing_message.content.json(indent=4).encode()) - fd.seek(0) - - # Launch editor - subprocess.run([editor, fd.name], check=True) - - # Read new message - fd.seek(0) - new_content_json = fd.read() - - content_type = type(existing_message).__annotations__["content"] - new_content_dict = json.loads(new_content_json) - new_content = content_type(**new_content_dict) - new_content.ref = existing_message.item_hash - typer.echo(new_content) - message, _status = synchronous.submit( - account=account, - content=new_content.dict(), - message_type=existing_message.type, - channel=existing_message.channel, - ) - typer.echo(f"{message.json(indent=4)}") + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + existing_message: AlephMessage = session.get_message(item_hash=hash) + + editor: str = os.getenv("EDITOR", default="nano") + with tempfile.NamedTemporaryFile(suffix="json") as fd: + # Fill in message template + fd.write(existing_message.content.json(indent=4).encode()) + fd.seek(0) + + # Launch editor + subprocess.run([editor, fd.name], check=True) + + # Read new message + fd.seek(0) + new_content_json = fd.read() + + content_type = type(existing_message).__annotations__["content"] + new_content_dict = json.loads(new_content_json) + new_content = content_type(**new_content_dict) + new_content.ref = existing_message.item_hash + typer.echo(new_content) + message, _status = session.submit( + content=new_content.dict(), + message_type=existing_message.type, + channel=existing_message.channel, + ) + typer.echo(f"{message.json(indent=4)}") def forget_messages( @@ -146,17 +138,15 @@ def forget_messages( reason: Optional[str], channel: str, ): - try: - result: ForgetMessage = synchronous.forget( - account=account, + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + message, status = session.forget( hashes=hashes, reason=reason, channel=channel, ) - typer.echo(f"{result.json(indent=4)}") - finally: - # Prevent aiohttp unclosed connector warning - asyncio.run(get_fallback_session().close()) + typer.echo(f"{message.json(indent=4)}") @app.command() @@ -196,9 +186,10 @@ def watch( setup_logging(debug) - original: AlephMessage = synchronous.get_message(item_hash=ref) + with UserSession(api_server=settings.API_HOST) as session: + original: AlephMessage = session.get_message(item_hash=ref) - for message in synchronous.watch_messages( - refs=[ref], addresses=[original.content.address] - ): - typer.echo(f"{message.json(indent=indent)}") + for message in session.watch_messages( + refs=[ref], addresses=[original.content.address] + ): + typer.echo(f"{message.json(indent=indent)}") diff --git a/src/aleph_client/commands/program.py b/src/aleph_client/commands/program.py index 2e5712d0..ea45c35d 100644 --- a/src/aleph_client/commands/program.py +++ b/src/aleph_client/commands/program.py @@ -1,9 +1,8 @@ -import asyncio import json import logging from base64 import b32encode, b16decode from pathlib import Path -from typing import Optional, Dict, List +from typing import Optional, List, Mapping from zipfile import BadZipFile import typer @@ -14,9 +13,8 @@ ProgramContent, ) -from aleph_client import synchronous +from aleph_client import AuthenticatedUserSession from aleph_client.account import _load_account -from aleph_client.asynchronous import get_fallback_session from aleph_client.commands import help_strings from aleph_client.commands.utils import ( setup_logging, @@ -136,7 +134,7 @@ def upload( immutable_volume_dict = volume_to_dict(volume=immutable_volume) volumes.append(immutable_volume_dict) - subscriptions: Optional[List[Dict]] + subscriptions: Optional[List[Mapping]] if beta and yes_no_input("Subscribe to messages ?", default=False): content_raw = input_multiline() try: @@ -147,8 +145,10 @@ def upload( else: subscriptions = None - try: - # Upload the source code + # Upload the source code + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: with open(path_object, "rb") as fd: logger.debug("Reading file") # TODO: Read in lazy mode instead of copying everything in memory @@ -159,8 +159,7 @@ def upload( else StorageEnum.storage ) logger.debug("Uploading file") - user_code: StoreMessage = synchronous.create_store( - account=account, + user_code, _status = session.create_store( file_content=file_content, storage_engine=storage_engine, channel=channel, @@ -173,8 +172,7 @@ def upload( program_ref = user_code.item_hash # Register the program - message, status = synchronous.create_program( - account=account, + message, status = session.create_program( program_ref=program_ref, entrypoint=entrypoint, runtime=runtime, @@ -195,18 +193,14 @@ def upload( hash: str = message.item_hash hash_base32 = b32encode(b16decode(hash.upper())).strip(b"=").lower().decode() - typer.echo( - f"Your program has been uploaded on Aleph .\n\n" - "Available on:\n" - f" {settings.VM_URL_PATH.format(hash=hash)}\n" - f" {settings.VM_URL_HOST.format(hash_base32=hash_base32)}\n" - "Visualise on:\n https://explorer.aleph.im/address/" - f"{message.chain}/{message.sender}/message/PROGRAM/{hash}\n" - ) - - finally: - # Prevent aiohttp unclosed connector warning - asyncio.run(get_fallback_session().close()) + typer.echo( + f"Your program has been uploaded on Aleph .\n\n" + "Available on:\n" + f" {settings.VM_URL_PATH.format(hash=hash)}\n" + f" {settings.VM_URL_HOST.format(hash_base32=hash_base32)}\n" + "Visualise on:\n https://explorer.aleph.im/address/" + f"{message.chain}/{message.sender}/message/PROGRAM/{hash}\n" + ) @app.command() @@ -225,12 +219,14 @@ def update( account = _load_account(private_key, private_key_file) path = path.absolute() - try: - program_message: ProgramMessage = synchronous.get_message( + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + program_message: ProgramMessage = session.get_message( item_hash=hash, message_type=ProgramMessage ) code_ref = program_message.content.code.ref - code_message: StoreMessage = synchronous.get_message( + code_message: StoreMessage = session.get_message( item_hash=code_ref, message_type=StoreMessage ) @@ -256,8 +252,7 @@ def update( # TODO: Read in lazy mode instead of copying everything in memory file_content = fd.read() logger.debug("Uploading file") - message, status = synchronous.create_store( - account=account, + message, status = session.create_store( file_content=file_content, storage_engine=code_message.content.item_type, channel=code_message.channel, @@ -267,9 +262,6 @@ def update( logger.debug("Upload finished") if print_message: typer.echo(f"{message.json(indent=4)}") - finally: - # Prevent aiohttp unclosed connector warning - asyncio.run(get_fallback_session().close()) @app.command() @@ -285,17 +277,19 @@ def unpersist( account = _load_account(private_key, private_key_file) - existing: MessagesResponse = synchronous.get_messages(hashes=[hash]) - message: ProgramMessage = existing.messages[0] - content: ProgramContent = message.content.copy() + with AuthenticatedUserSession( + account=account, api_server=settings.API_HOST + ) as session: + existing: MessagesResponse = session.get_messages(hashes=[hash]) + message: ProgramMessage = existing.messages[0] + content: ProgramContent = message.content.copy() - content.on.persistent = False - content.replaces = message.item_hash + content.on.persistent = False + content.replaces = message.item_hash - message, _status = synchronous.submit( - account=account, - content=content.dict(exclude_none=True), - message_type=message.type, - channel=message.channel, - ) - typer.echo(f"{message.json(indent=4)}") + message, _status = session.submit( + content=content.dict(exclude_none=True), + message_type=message.type, + channel=message.channel, + ) + typer.echo(f"{message.json(indent=4)}") diff --git a/src/aleph_client/main.py b/src/aleph_client/main.py index ef4e59d0..c3ceb427 100644 --- a/src/aleph_client/main.py +++ b/src/aleph_client/main.py @@ -9,4 +9,4 @@ DeprecationWarning, ) -from .synchronous import * +from .user_session import * diff --git a/src/aleph_client/synchronous.py b/src/aleph_client/synchronous.py deleted file mode 100644 index f0fa0c69..00000000 --- a/src/aleph_client/synchronous.py +++ /dev/null @@ -1,150 +0,0 @@ -import asyncio -import queue -import threading -from typing import ( - Any, - Callable, - List, - Optional, - Dict, - Iterable, - Type, - Protocol, - TypeVar, - Awaitable, -) - -from aiohttp import ClientSession -from aleph_message.models import AlephMessage -from aleph_message.models.program import ProgramContent, Encoding - -from . import asynchronous -from .conf import settings -from .types import Account, StorageEnum, GenericMessage - - -T = TypeVar("T") - - -def wrap_async(func: Callable[..., Awaitable[T]]) -> Callable[..., T]: - """Wrap an asynchronous function into a synchronous one, - for easy use in synchronous code. - """ - - def func_caller(*args, **kwargs): - loop = asyncio.get_event_loop() - return loop.run_until_complete(func(*args, **kwargs)) - - # Copy wrapped function interface: - func_caller.__doc__ = func.__doc__ - func_caller.__annotations__ = func.__annotations__ - func_caller.__defaults__ = func.__defaults__ - func_caller.__kwdefaults__ = func.__kwdefaults__ - return func_caller - - -create_post = wrap_async(asynchronous.create_post) -forget = wrap_async(asynchronous.forget) -ipfs_push = wrap_async(asynchronous.ipfs_push) -storage_push = wrap_async(asynchronous.storage_push) -ipfs_push_file = wrap_async(asynchronous.ipfs_push_file) -storage_push_file = wrap_async(asynchronous.storage_push_file) -create_aggregate = wrap_async(asynchronous.create_aggregate) -create_store = wrap_async(asynchronous.create_store) -submit = wrap_async(asynchronous.submit) -fetch_aggregate = wrap_async(asynchronous.fetch_aggregate) -fetch_aggregates = wrap_async(asynchronous.fetch_aggregates) -get_posts = wrap_async(asynchronous.get_posts) -get_messages = wrap_async(asynchronous.get_messages) - - -def get_message( - item_hash: str, - message_type: Optional[Type[GenericMessage]] = None, - channel: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: str = settings.API_HOST, -) -> GenericMessage: - return wrap_async(asynchronous.get_message)( - item_hash=item_hash, - message_type=message_type, - channel=channel, - session=session, - api_server=api_server, - ) - - -def create_program( - account: Account, - program_ref: str, - entrypoint: str, - runtime: str, - environment_variables: Optional[Dict[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - session: Optional[ClientSession] = None, - api_server: Optional[str] = None, - memory: Optional[int] = None, - vcpus: Optional[int] = None, - timeout_seconds: Optional[float] = None, - persistent: bool = False, - encoding: Encoding = Encoding.zip, - volumes: Optional[List[Dict]] = None, - subscriptions: Optional[List[Dict]] = None, -): - """ - Post a (create) PROGRAM message. - - :param account: Account to use to sign the message - :param program_ref: Reference to the program to run - :param entrypoint: Entrypoint to run - :param runtime: Runtime to use - :param environment_variables: Environment variables to pass to the program - :param storage_engine: Storage engine to use (DEFAULT: "storage") - :param channel: Channel to use (DEFAULT: "TEST") - :param address: Address to use (DEFAULT: account.get_address()) - :param session: Session to use (DEFAULT: get_fallback_session()) - :param api_server: API server to use (DEFAULT: "https://api2.aleph.im") - :param memory: Memory in MB for the VM to be allocated (DEFAULT: 128) - :param vcpus: Number of vCPUs to allocate (DEFAULT: 1) - :param timeout_seconds: Timeout in seconds (DEFAULT: 30.0) - :param persistent: Whether the program should be persistent or not (DEFAULT: False) - :param encoding: Encoding to use (DEFAULT: Encoding.zip) - :param volumes: Volumes to mount - :param subscriptions: Patterns of Aleph messages to forward to the program's event receiver - """ - return wrap_async(asynchronous.create_program)( - account=account, - program_ref=program_ref, - entrypoint=entrypoint, - environment_variables=environment_variables, - runtime=runtime, - storage_engine=storage_engine, - channel=channel, - address=address, - session=session, - api_server=api_server, - memory=memory, - vcpus=vcpus, - timeout_seconds=timeout_seconds, - persistent=persistent, - encoding=encoding, - volumes=volumes, - subscriptions=subscriptions, - ) - - -def watch_messages(*args, **kwargs) -> Iterable[AlephMessage]: - """ - Iterate over current and future matching messages synchronously. - - Runs the `watch_messages` asynchronous generator in a thread. - """ - output_queue: queue.Queue[AlephMessage] = queue.Queue() - thread = threading.Thread( - target=asynchronous._start_run_watch_messages, args=(output_queue, args, kwargs) - ) - thread.start() - while True: - yield output_queue.get() diff --git a/src/aleph_client/user_session.py b/src/aleph_client/user_session.py new file mode 100644 index 00000000..10c797c2 --- /dev/null +++ b/src/aleph_client/user_session.py @@ -0,0 +1,1364 @@ +import asyncio +import hashlib +import json +import logging +import queue +import threading +import time +from datetime import datetime +from pathlib import Path +from typing import ( + Optional, + Union, + Any, + Dict, + List, + Iterable, + AsyncIterable, + Awaitable, + Callable, + TypeVar, +) +from typing import Type, Mapping, Tuple, NoReturn + +import aiohttp +from aleph_message.models import ( + ForgetContent, + MessageType, + AggregateContent, + PostContent, + StoreContent, + PostMessage, + Message, + ForgetMessage, + AlephMessage, + AggregateMessage, + StoreMessage, + ProgramMessage, + ItemType, +) +from aleph_message.models.program import ProgramContent, Encoding +from pydantic import ValidationError + +from aleph_client.types import Account, StorageEnum, GenericMessage, MessageStatus +from .conf import settings +from .exceptions import ( + MessageNotFoundError, + MultipleMessagesError, + InvalidMessageError, + BroadcastError, +) +from .models import MessagesResponse +from .utils import get_message_type_value + +logger = logging.getLogger(__name__) + + +try: + import magic +except ImportError: + logger.info("Could not import library 'magic', MIME type detection disabled") + magic = None # type:ignore + + +T = TypeVar("T") + + +def async_wrapper(f): + """ + Copies the docstring of wrapped functions. + """ + + wrapped = getattr(AuthenticatedUserSession, f.__name__) + f.__doc__ = wrapped.__doc__ + + +def wrap_async(func: Callable[..., Awaitable[T]]) -> Callable[..., T]: + """Wrap an asynchronous function into a synchronous one, + for easy use in synchronous code. + """ + + def func_caller(*args, **kwargs): + loop = asyncio.get_event_loop() + return loop.run_until_complete(func(*args, **kwargs)) + + # Copy wrapped function interface: + func_caller.__doc__ = func.__doc__ + func_caller.__annotations__ = func.__annotations__ + func_caller.__defaults__ = func.__defaults__ + func_caller.__kwdefaults__ = func.__kwdefaults__ + return func_caller + + +async def run_async_watcher( + *args, output_queue: queue.Queue, api_server: str, **kwargs +): + async with UserSession(api_server=api_server) as session: + async for message in session.watch_messages(*args, **kwargs): + output_queue.put(message) + + +def watcher_thread(output_queue: queue.Queue, api_server: str, args, kwargs): + asyncio.run( + run_async_watcher( + output_queue=output_queue, api_server=api_server, *args, **kwargs + ) + ) + + +class UserSessionSync: + """ + A sync version of `UserSession`, used in sync code. + + This class is returned by the context manager of `UserSession` and is + intended as a wrapper around the methods of `UserSession` and not as a public class. + The methods are fully typed to enable static type checking, but most (all) methods + should look like this (using args and kwargs for brevity, but the functions should + be fully typed): + + >>> def func(self, *args, **kwargs): + >>> return self._wrap(self.async_session.func)(*args, **kwargs) + """ + + def __init__(self, async_session: "UserSession"): + self.async_session = async_session + + def _wrap(self, method: Callable[..., Awaitable[T]], *args, **kwargs): + return wrap_async(method)(*args, **kwargs) + + def get_messages( + self, + pagination: int = 200, + page: int = 1, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> MessagesResponse: + + return self._wrap( + self.async_session.get_messages, + pagination=pagination, + page=page, + message_type=message_type, + content_types=content_types, + content_keys=content_keys, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ignore_invalid_messages=ignore_invalid_messages, + invalid_messages_log_level=invalid_messages_log_level, + ) + + # @async_wrapper + def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + return self._wrap( + self.async_session.get_message, + item_hash=item_hash, + message_type=message_type, + channel=channel, + ) + + def fetch_aggregate( + self, + address: str, + key: str, + limit: int = 100, + ) -> Dict[str, Dict]: + return self._wrap(self.async_session.fetch_aggregate, address, key, limit) + + def fetch_aggregates( + self, + address: str, + keys: Optional[Iterable[str]] = None, + limit: int = 100, + ) -> Dict[str, Dict]: + return self._wrap(self.async_session.fetch_aggregates, address, keys, limit) + + def get_posts( + self, + pagination: int = 200, + page: int = 1, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> Dict[str, Dict]: + return self._wrap( + self.async_session.get_posts, + pagination=pagination, + page=page, + types=types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + + def download_file(self, file_hash: str) -> bytes: + return self._wrap(self.async_session.download_file, file_hash=file_hash) + + def watch_messages( + self, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> Iterable[AlephMessage]: + """ + Iterate over current and future matching messages synchronously. + + Runs the `watch_messages` asynchronous generator in a thread. + """ + output_queue: queue.Queue[AlephMessage] = queue.Queue() + thread = threading.Thread( + target=watcher_thread, + args=( + output_queue, + self.async_session.api_server, + ( + message_type, + content_types, + refs, + addresses, + tags, + hashes, + channels, + chains, + start_date, + end_date, + ), + {}, + ), + ) + thread.start() + while True: + yield output_queue.get() + + +class AuthenticatedUserSessionSync(UserSessionSync): + async_session: "AuthenticatedUserSession" + + def __init__(self, async_session: "AuthenticatedUserSession"): + super().__init__(async_session=async_session) + + def ipfs_push(self, content: Mapping) -> str: + return self._wrap(self.async_session.ipfs_push, content=content) + + def storage_push(self, content: Mapping) -> str: + return self._wrap(self.async_session.storage_push, content=content) + + def ipfs_push_file(self, file_content: Union[str, bytes]) -> str: + return self._wrap(self.async_session.ipfs_push_file, file_content=file_content) + + def storage_push_file(self, file_content: Union[str, bytes]) -> str: + return self._wrap( + self.async_session.storage_push_file, file_content=file_content + ) + + def create_post( + self, + post_content, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, + ) -> Tuple[PostMessage, MessageStatus]: + return self._wrap( + self.async_session.create_post, + post_content=post_content, + post_type=post_type, + ref=ref, + address=address, + channel=channel, + inline=inline, + storage_engine=storage_engine, + sync=sync, + ) + + def create_aggregate( + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AggregateMessage, MessageStatus]: + return self._wrap( + self.async_session.create_aggregate, + key=key, + content=content, + address=address, + channel=channel, + inline=inline, + sync=sync, + ) + + def create_store( + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, + ) -> Tuple[StoreMessage, MessageStatus]: + return self._wrap( + self.async_session.create_store, + address=address, + file_content=file_content, + file_path=file_path, + file_hash=file_hash, + guess_mime_type=guess_mime_type, + ref=ref, + storage_engine=storage_engine, + extra_fields=extra_fields, + channel=channel, + sync=sync, + ) + + def create_program( + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + ) -> Tuple[ProgramMessage, MessageStatus]: + return self._wrap( + self.async_session.create_program, + program_ref=program_ref, + entrypoint=entrypoint, + runtime=runtime, + environment_variables=environment_variables, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, + persistent=persistent, + encoding=encoding, + volumes=volumes, + subscriptions=subscriptions, + ) + + def forget( + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + ) -> Tuple[ForgetMessage, MessageStatus]: + return self._wrap( + self.async_session.forget, + hashes=hashes, + reason=reason, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + ) + + def submit( + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + return self._wrap( + self.async_session.submit, + content=content, + message_type=message_type, + channel=channel, + storage_engine=storage_engine, + allow_inlining=allow_inlining, + sync=sync, + ) + + +class UserSession: + api_server: str + http_session: aiohttp.ClientSession + + def __init__(self, api_server: str): + self.api_server = api_server + self.http_session = aiohttp.ClientSession(base_url=api_server) + + def __enter__(self) -> UserSessionSync: + return UserSessionSync(async_session=self) + + def __exit__(self, exc_type, exc_val, exc_tb): + close_fut = self.http_session.close() + try: + loop = asyncio.get_running_loop() + loop.run_until_complete(close_fut) + except RuntimeError: + asyncio.run(close_fut) + + async def __aenter__(self) -> "UserSession": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.http_session.close() + + async def fetch_aggregate( + self, + address: str, + key: str, + limit: int = 100, + ) -> Dict[str, Dict]: + """ + Fetch a value from the aggregate store by owner address and item key. + + :param address: Address of the owner of the aggregate + :param key: Key of the aggregate + :param limit: Maximum number of items to fetch (Default: 100) + """ + + params: Dict[str, Any] = {"keys": key} + if limit: + params["limit"] = limit + + async with self.http_session.get( + f"/api/v0/aggregates/{address}.json", params=params + ) as resp: + result = await resp.json() + data = result.get("data", dict()) + return data.get(key) + + async def fetch_aggregates( + self, + address: str, + keys: Optional[Iterable[str]] = None, + limit: int = 100, + ) -> Dict[str, Dict]: + """ + Fetch key-value pairs from the aggregate store by owner address. + + :param address: Address of the owner of the aggregate + :param keys: Keys of the aggregates to fetch (Default: all items) + :param limit: Maximum number of items to fetch (Default: 100) + """ + + keys_str = ",".join(keys) if keys else "" + params: Dict[str, Any] = {} + if keys_str: + params["keys"] = keys_str + if limit: + params["limit"] = limit + + async with self.http_session.get( + f"/api/v0/aggregates/{address}.json", + params=params, + ) as resp: + result = await resp.json() + data = result.get("data", dict()) + return data + + async def get_posts( + self, + pagination: int = 200, + page: int = 1, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> Dict[str, Dict]: + """ + Fetch a list of posts from the network. + + :param pagination: Number of items to fetch (Default: 200) + :param page: Page to fetch, begins at 1 (Default: 1) + :param types: Types of posts to fetch (Default: all types) + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Chains of the posts to fetch (Default: all chains) + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + """ + + params: Dict[str, Any] = dict(pagination=pagination, page=page) + + if types is not None: + params["types"] = ",".join(types) + if refs is not None: + params["refs"] = ",".join(refs) + if addresses is not None: + params["addresses"] = ",".join(addresses) + if tags is not None: + params["tags"] = ",".join(tags) + if hashes is not None: + params["hashes"] = ",".join(hashes) + if channels is not None: + params["channels"] = ",".join(channels) + if chains is not None: + params["chains"] = ",".join(chains) + + if start_date is not None: + if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): + start_date = start_date.timestamp() + params["startDate"] = start_date + if end_date is not None: + if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): + end_date = end_date.timestamp() + params["endDate"] = end_date + + async with self.http_session.get("/api/v0/posts.json", params=params) as resp: + resp.raise_for_status() + return await resp.json() + + async def download_file( + self, + file_hash: str, + ) -> bytes: + """ + Get a file from the storage engine as raw bytes. + + Warning: Downloading large files can be slow and memory intensive. + + :param file_hash: The hash of the file to retrieve. + """ + async with self.http_session.get( + f"/api/v0/storage/raw/{file_hash}" + ) as response: + response.raise_for_status() + return await response.read() + + async def get_messages( + self, + pagination: int = 200, + page: int = 1, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> MessagesResponse: + """ + Fetch a list of messages from the network. + + :param pagination: Number of items to fetch (Default: 200) + :param page: Page to fetch, begins at 1 (Default: 1) + :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" + :param content_types: Filter by content type + :param content_keys: Filter by content key + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Filter by sender address chain + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + :param ignore_invalid_messages: Ignore invalid messages (Default: False) + :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + """ + ignore_invalid_messages = ( + True if ignore_invalid_messages is None else ignore_invalid_messages + ) + invalid_messages_log_level = ( + logging.NOTSET + if invalid_messages_log_level is None + else invalid_messages_log_level + ) + + params: Dict[str, Any] = dict(pagination=pagination, page=page) + + if message_type is not None: + params["msgType"] = message_type.value + if content_types is not None: + params["contentTypes"] = ",".join(content_types) + if content_keys is not None: + params["contentKeys"] = ",".join(content_keys) + if refs is not None: + params["refs"] = ",".join(refs) + if addresses is not None: + params["addresses"] = ",".join(addresses) + if tags is not None: + params["tags"] = ",".join(tags) + if hashes is not None: + params["hashes"] = ",".join(hashes) + if channels is not None: + params["channels"] = ",".join(channels) + if chains is not None: + params["chains"] = ",".join(chains) + + if start_date is not None: + if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): + start_date = start_date.timestamp() + params["startDate"] = start_date + if end_date is not None: + if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): + end_date = end_date.timestamp() + params["endDate"] = end_date + + async with self.http_session.get( + "/api/v0/messages.json", params=params + ) as resp: + resp.raise_for_status() + response_json = await resp.json() + messages_raw = response_json["messages"] + + # All messages may not be valid according to the latest specification in + # aleph-message. This allows the user to specify how errors should be handled. + messages: List[AlephMessage] = [] + for message_raw in messages_raw: + try: + message = Message(**message_raw) + messages.append(message) + except KeyError as e: + if not ignore_invalid_messages: + raise e + logger.log( + level=invalid_messages_log_level, + msg=f"KeyError: Field '{e.args[0]}' not found", + ) + except ValidationError as e: + if not ignore_invalid_messages: + raise e + if invalid_messages_log_level: + logger.log(level=invalid_messages_log_level, msg=e) + + return MessagesResponse( + messages=messages, + pagination_page=response_json["pagination_page"], + pagination_total=response_json["pagination_total"], + pagination_per_page=response_json["pagination_per_page"], + pagination_item=response_json["pagination_item"], + ) + + async def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + """ + Get a single message from its `item_hash` and perform some basic validation. + + :param item_hash: Hash of the message to fetch + :param message_type: Type of message to fetch + :param channel: Channel of the message to fetch + """ + messages_response = await self.get_messages( + hashes=[item_hash], + channels=[channel] if channel else None, + ) + if len(messages_response.messages) < 1: + raise MessageNotFoundError(f"No such hash {item_hash}") + if len(messages_response.messages) != 1: + raise MultipleMessagesError( + f"Multiple messages found for the same item_hash `{item_hash}`" + ) + message: GenericMessage = messages_response.messages[0] + if message_type: + expected_type = get_message_type_value(message_type) + if message.type != expected_type: + raise TypeError( + f"The message type '{message.type}' " + f"does not match the expected type '{expected_type}'" + ) + return message + + async def watch_messages( + self, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> AsyncIterable[AlephMessage]: + """ + Iterate over current and future matching messages asynchronously. + + :param message_type: Type of message to watch + :param content_types: Content types to watch + :param refs: References to watch + :param addresses: Addresses to watch + :param tags: Tags to watch + :param hashes: Hashes to watch + :param channels: Channels to watch + :param chains: Chains to watch + :param start_date: Start date from when to watch + :param end_date: End date until when to watch + """ + params: Dict[str, Any] = dict() + + if message_type is not None: + params["msgType"] = message_type.value + if content_types is not None: + params["contentTypes"] = ",".join(content_types) + if refs is not None: + params["refs"] = ",".join(refs) + if addresses is not None: + params["addresses"] = ",".join(addresses) + if tags is not None: + params["tags"] = ",".join(tags) + if hashes is not None: + params["hashes"] = ",".join(hashes) + if channels is not None: + params["channels"] = ",".join(channels) + if chains is not None: + params["chains"] = ",".join(chains) + + if start_date is not None: + if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): + start_date = start_date.timestamp() + params["startDate"] = start_date + if end_date is not None: + if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): + end_date = end_date.timestamp() + params["endDate"] = end_date + + async with self.http_session.ws_connect( + f"/api/ws0/messages", params=params + ) as ws: + logger.debug("Websocket connected") + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if msg.data == "close cmd": + await ws.close() + break + else: + data = json.loads(msg.data) + yield Message(**data) + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + +class AuthenticatedUserSession(UserSession): + account: Account + + BROADCAST_MESSAGE_FIELDS = { + "sender", + "chain", + "signature", + "type", + "item_hash", + "item_type", + "item_content", + "time", + "channel", + } + + def __init__(self, account: Account, api_server: str): + super().__init__(api_server=api_server) + self.account = account + + def __enter__(self) -> "AuthenticatedUserSessionSync": + return AuthenticatedUserSessionSync(async_session=self) + + async def __aenter__(self) -> "AuthenticatedUserSession": + return self + + async def ipfs_push(self, content: Mapping) -> str: + """Push arbitrary content as JSON to the IPFS service.""" + + url = "/api/v0/ipfs/add_json" + logger.debug(f"Pushing to IPFS on {url}") + + async with self.http_session.post(url, json=content) as resp: + resp.raise_for_status() + return (await resp.json()).get("hash") + + async def storage_push(self, content: Mapping) -> str: + """Push arbitrary content as JSON to the storage service.""" + + url = "/api/v0/storage/add_json" + logger.debug(f"Pushing to storage on {url}") + + async with self.http_session.post(url, json=content) as resp: + resp.raise_for_status() + return (await resp.json()).get("hash") + + async def ipfs_push_file(self, file_content: Union[str, bytes]) -> str: + """Push a file to the IPFS service.""" + data = aiohttp.FormData() + data.add_field("file", file_content) + + url = "/api/v0/ipfs/add_file" + logger.debug(f"Pushing file to IPFS on {url}") + + async with self.http_session.post(url, data=data) as resp: + resp.raise_for_status() + return (await resp.json()).get("hash") + + async def storage_push_file(self, file_content) -> str: + """Push a file to the storage service.""" + data = aiohttp.FormData() + data.add_field("file", file_content) + + url = "/api/v0/storage/add_file" + logger.debug(f"Posting file on {url}") + + async with self.http_session.post(url, data=data) as resp: + resp.raise_for_status() + return (await resp.json()).get("hash") + + @staticmethod + def _log_publication_status(publication_status: Mapping[str, Any]): + status = publication_status.get("status") + failures = publication_status.get("failed") + + if status == "success": + return + elif status == "warning": + logger.warning("Broadcast failed on the following network(s): %s", failures) + elif status == "error": + logger.error( + "Broadcast failed on all protocols. The message was not published." + ) + else: + raise ValueError( + f"Invalid response from server, status in missing or unknown: '{status}'" + ) + + @staticmethod + async def _handle_broadcast_error(response: aiohttp.ClientResponse) -> NoReturn: + if response.status == 500: + # Assume a broadcast error, no need to read the JSON + if response.content_type == "application/json": + error_msg = "Internal error - broadcast failed on all protocols" + else: + error_msg = f"Internal error - the message was not broadcast: {await response.text()}" + + logger.error(error_msg) + raise BroadcastError(error_msg) + elif response.status == 422: + errors = await response.json() + logger.error( + "The message could not be processed because of the following errors: %s", + errors, + ) + raise InvalidMessageError(errors) + else: + error_msg = ( + f"Unexpected HTTP response ({response.status}: {await response.text()})" + ) + logger.error(error_msg) + raise BroadcastError(error_msg) + + async def _handle_broadcast_deprecated_response( + self, + response: aiohttp.ClientResponse, + ) -> None: + if response.status != 200: + await self._handle_broadcast_error(response) + else: + publication_status = await response.json() + self._log_publication_status(publication_status) + + async def _broadcast_deprecated(self, message_dict: Mapping[str, Any]) -> None: + + """ + Broadcast a message on the Aleph network using the deprecated + /ipfs/pubsub/pub/ endpoint. + """ + + url = "/api/v0/ipfs/pubsub/pub" + logger.debug(f"Posting message on {url}") + + async with self.http_session.post( + url, + json={"topic": "ALEPH-TEST", "data": json.dumps(message_dict)}, + ) as response: + await self._handle_broadcast_deprecated_response(response) + + async def _handle_broadcast_response( + self, response: aiohttp.ClientResponse, sync: bool + ) -> MessageStatus: + if response.status in (200, 202): + status = await response.json() + self._log_publication_status(status["publication_status"]) + + if response.status == 202: + if sync: + logger.warning( + "Timed out while waiting for processing of sync message" + ) + return MessageStatus.PENDING + + return MessageStatus.PROCESSED + + else: + await self._handle_broadcast_error(response) + + async def _broadcast( + self, + message: AlephMessage, + sync: bool, + ) -> MessageStatus: + """ + Broadcast a message on the Aleph network. + + Uses the POST /messages/ endpoint or the deprecated /ipfs/pubsub/pub/ endpoint + if the first method is not available. + """ + + url = "/api/v0/messages" + logger.debug(f"Posting message on {url}") + + message_dict = message.dict(include=self.BROADCAST_MESSAGE_FIELDS) + + async with self.http_session.post( + url, + json={"sync": sync, "message": message_dict}, + ) as response: + # The endpoint may be unavailable on this node, try the deprecated version. + if response.status == 404: + logger.warning( + "POST /messages/ not found. Defaulting to legacy endpoint..." + ) + await self._broadcast_deprecated(message_dict=message_dict) + return MessageStatus.PENDING + else: + message_status = await self._handle_broadcast_response( + response=response, sync=sync + ) + return message_status + + async def create_post( + self, + post_content, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, + ) -> Tuple[PostMessage, MessageStatus]: + """ + Create a POST message on the Aleph network. It is associated with a channel and owned by an account. + + :param post_content: The content of the message + :param post_type: An arbitrary content type that helps to describe the post_content + :param ref: A reference to a previous message that it replaces + :param address: The address that will be displayed as the author of the message + :param channel: The channel that the message will be posted on + :param inline: An optional flag to indicate if the content should be inlined in the message or not + :param storage_engine: An optional storage engine to use for the message, if not inlined (Default: "storage") + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + content = PostContent( + type=post_type, + address=address, + content=post_content, + time=time.time(), + ref=ref, + ) + + return await self.submit( + content=content.dict(exclude_none=True), + message_type=MessageType.post, + channel=channel, + allow_inlining=inline, + storage_engine=storage_engine, + sync=sync, + ) + + async def create_aggregate( + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Create an AGGREGATE message. It is meant to be used as a quick access storage associated with an account. + + :param key: Key to use to store the content + :param content: Content to store + :param address: Address to use to sign the message + :param channel: Channel to use (Default: "TEST") + :param inline: Whether to write content inside the message (Default: True) + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + content_ = AggregateContent( + key=key, + address=address, + content=content, + time=time.time(), + ) + + return await self.submit( + content=content_.dict(exclude_none=True), + message_type=MessageType.aggregate, + channel=channel, + allow_inlining=inline, + sync=sync, + ) + + async def create_store( + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, + ) -> Tuple[StoreMessage, MessageStatus]: + """ + Create a STORE message to store a file on the Aleph network. + + Can be passed either a file path, an IPFS hash or the file's content as raw bytes. + + :param address: Address to display as the author of the message (Default: account.get_address()) + :param file_content: Byte stream of the file to store (Default: None) + :param file_path: Path to the file to store (Default: None) + :param file_hash: Hash of the file to store (Default: None) + :param guess_mime_type: Guess the MIME type of the file (Default: False) + :param ref: Reference to a previous message (Default: None) + :param storage_engine: Storage engine to use (Default: "storage") + :param extra_fields: Extra fields to add to the STORE message (Default: None) + :param channel: Channel to post the message to (Default: "TEST") + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + extra_fields = extra_fields or {} + + if file_hash is None: + if file_content is None: + if file_path is None: + raise ValueError( + "Please specify at least a file_content, a file_hash or a file_path" + ) + else: + file_content = open(file_path, "rb").read() + + if storage_engine == StorageEnum.storage: + file_hash = await self.storage_push_file(file_content=file_content) + elif storage_engine == StorageEnum.ipfs: + file_hash = await self.ipfs_push_file(file_content=file_content) + else: + raise ValueError(f"Unknown storage engine: '{storage_engine}'") + + assert file_hash, "File hash should be empty" + + if magic is None: + pass + elif file_content and guess_mime_type and ("mime_type" not in extra_fields): + extra_fields["mime_type"] = magic.from_buffer(file_content, mime=True) + + if ref: + extra_fields["ref"] = ref + + values = { + "address": address, + "item_type": storage_engine, + "item_hash": file_hash, + "time": time.time(), + } + if extra_fields is not None: + values.update(extra_fields) + + content = StoreContent(**values) + + return await self.submit( + content=content.dict(exclude_none=True), + message_type=MessageType.store, + channel=channel, + allow_inlining=True, + sync=sync, + ) + + async def create_program( + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + ) -> Tuple[ProgramMessage, MessageStatus]: + """ + Post a (create) PROGRAM message. + + :param program_ref: Reference to the program to run + :param entrypoint: Entrypoint to run + :param runtime: Runtime to use + :param environment_variables: Environment variables to pass to the program + :param storage_engine: Storage engine to use (Default: "storage") + :param channel: Channel to use (Default: "TEST") + :param address: Address to use (Default: account.get_address()) + :param sync: If true, waits for the message to be processed by the API server + :param memory: Memory in MB for the VM to be allocated (Default: 128) + :param vcpus: Number of vCPUs to allocate (Default: 1) + :param timeout_seconds: Timeout in seconds (Default: 30.0) + :param persistent: Whether the program should be persistent or not (Default: False) + :param encoding: Encoding to use (Default: Encoding.zip) + :param volumes: Volumes to mount + :param subscriptions: Patterns of Aleph messages to forward to the program's event receiver + """ + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + volumes = volumes if volumes is not None else [] + memory = memory or settings.DEFAULT_VM_MEMORY + vcpus = vcpus or settings.DEFAULT_VM_VCPUS + timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT + + # TODO: Check that program_ref, runtime and data_ref exist + + # Register the different ways to trigger a VM + if subscriptions: + # Trigger on HTTP calls and on Aleph message subscriptions. + triggers = { + "http": True, + "persistent": persistent, + "message": subscriptions, + } + else: + # Trigger on HTTP calls. + triggers = {"http": True, "persistent": persistent} + + content = ProgramContent( + **{ + "type": "vm-function", + "address": address, + "allow_amend": False, + "code": { + "encoding": encoding, + "entrypoint": entrypoint, + "ref": program_ref, + "use_latest": True, + }, + "on": triggers, + "environment": { + "reproducible": False, + "internet": True, + "aleph_api": True, + }, + "variables": environment_variables, + "resources": { + "vcpus": vcpus, + "memory": memory, + "seconds": timeout_seconds, + }, + "runtime": { + "ref": runtime, + "use_latest": True, + "comment": "Official Aleph runtime" + if runtime == settings.DEFAULT_RUNTIME_ID + else "", + }, + "volumes": volumes, + "time": time.time(), + } + ) + + # Ensure that the version of aleph-message used supports the field. + assert content.on.persistent == persistent + + return await self.submit( + content=content.dict(exclude_none=True), + message_type=MessageType.program, + channel=channel, + storage_engine=storage_engine, + sync=sync, + ) + + async def forget( + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + ) -> Tuple[ForgetMessage, MessageStatus]: + """ + Post a FORGET message to remove previous messages from the network. + + Targeted messages need to be signed by the same account that is attempting to forget them, + if the creating address did not delegate the access rights to the forgetting account. + + :param hashes: Hashes of the messages to forget + :param reason: Reason for forgetting the messages + :param storage_engine: Storage engine to use (Default: "storage") + :param channel: Channel to use (Default: "TEST") + :param address: Address to use (Default: account.get_address()) + :param sync: If true, waits for the message to be processed by the API server (Default: False) + """ + address = address or settings.ADDRESS_TO_USE or self.account.get_address() + + content = ForgetContent( + hashes=hashes, + reason=reason, + address=address, + time=time.time(), + ) + + return await self.submit( + content=content.dict(exclude_none=True), + message_type=MessageType.forget, + channel=channel, + storage_engine=storage_engine, + allow_inlining=True, + sync=sync, + ) + + @staticmethod + def compute_sha256(s: str) -> str: + h = hashlib.sha256() + h.update(s.encode("utf-8")) + return h.hexdigest() + + async def _prepare_aleph_message( + self, + message_type: MessageType, + content: Dict[str, Any], + channel: Optional[str], + allow_inlining: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + ) -> AlephMessage: + + message_dict: Dict[str, Any] = { + "sender": self.account.get_address(), + "chain": self.account.CHAIN, + "type": message_type, + "content": content, + "time": time.time(), + "channel": channel, + } + + item_content: str = json.dumps(content, separators=(",", ":")) + + if allow_inlining and (len(item_content) < settings.MAX_INLINE_SIZE): + message_dict["item_content"] = item_content + message_dict["item_hash"] = self.compute_sha256(item_content) + message_dict["item_type"] = ItemType.inline + else: + if storage_engine == StorageEnum.ipfs: + message_dict["item_hash"] = await self.ipfs_push( + content=content, + ) + message_dict["item_type"] = ItemType.ipfs + else: # storage + assert storage_engine == StorageEnum.storage + message_dict["item_hash"] = await self.storage_push( + content=content, + ) + message_dict["item_type"] = ItemType.storage + + message_dict = await self.account.sign_message(message_dict) + return Message(**message_dict) + + async def submit( + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + + message = await self._prepare_aleph_message( + message_type=message_type, + content=content, + channel=channel, + allow_inlining=allow_inlining, + storage_engine=storage_engine, + ) + message_status = await self._broadcast(message=message, sync=sync) + return message, message_status diff --git a/src/aleph_client/vm/cache.py b/src/aleph_client/vm/cache.py index 08429ea7..1fceb8f2 100644 --- a/src/aleph_client/vm/cache.py +++ b/src/aleph_client/vm/cache.py @@ -1,14 +1,13 @@ -import re -import fnmatch import abc +import fnmatch +import re from typing import Union, Optional, Any, Dict, List, NewType -from aiohttp import ClientSession +import aiohttp +from pydantic import AnyHttpUrl -from aleph_client.asynchronous import get_fallback_session from ..conf import settings - CacheKey = NewType("CacheKey", str) @@ -45,16 +44,18 @@ async def keys(self, pattern: str = "*") -> List[str]: class VmCache(BaseVmCache): """Virtual Machines can use this cache to store temporary data in memory on the host.""" - session: ClientSession + session: aiohttp.ClientSession cache: Dict[str, bytes] api_host: str def __init__( - self, session: Optional[ClientSession] = None, api_host: Optional[str] = None + self, + session: Optional[aiohttp.ClientSession] = None, + connector_url: Optional[AnyHttpUrl] = None, ): - self.session = session or get_fallback_session() + self.session = session or aiohttp.ClientSession(base_url=connector_url) self.cache = {} - self.api_host = api_host if api_host else settings.API_HOST + self.api_host = connector_url if connector_url else settings.API_HOST async def get(self, key: str) -> Optional[bytes]: sanitized_key = sanitize_cache_key(key) diff --git a/tests/integration/itest_aggregates.py b/tests/integration/itest_aggregates.py index 87b71c66..30614703 100644 --- a/tests/integration/itest_aggregates.py +++ b/tests/integration/itest_aggregates.py @@ -2,10 +2,7 @@ import json from typing import Dict -from aleph_client.asynchronous import ( - create_aggregate, - fetch_aggregate, -) +from aleph_client.user_session import AuthenticatedUserSession from tests.integration.toolkit import try_until from .config import REFERENCE_NODE, TARGET_NODE @@ -20,13 +17,14 @@ async def create_aggregate_on_target( receiver_node: str, channel="INTEGRATION_TESTS", ): - aggregate_message, message_status = await create_aggregate( - account=account, - key=key, - content=content, - channel="INTEGRATION_TESTS", - api_server=emitter_node, - ) + async with AuthenticatedUserSession( + account=account, api_server=emitter_node + ) as tx_session: + aggregate_message, message_status = await tx_session.create_aggregate( + key=key, + content=content, + channel="INTEGRATION_TESTS", + ) assert aggregate_message.sender == account.get_address() assert aggregate_message.channel == channel @@ -39,14 +37,16 @@ async def create_aggregate_on_target( assert aggregate_message.content.address == account.get_address() assert aggregate_message.content.content == content - aggregate_from_receiver = await try_until( - fetch_aggregate, - lambda aggregate: aggregate is not None, - timeout=5, - address=account.get_address(), - key=key, - api_server=receiver_node, - ) + async with AuthenticatedUserSession( + account=account, api_server=receiver_node + ) as rx_session: + aggregate_from_receiver = await try_until( + rx_session.fetch_aggregate, + lambda aggregate: aggregate is not None, + timeout=5, + address=account.get_address(), + key=key, + ) for key, value in content.items(): assert key in aggregate_from_receiver diff --git a/tests/integration/itest_forget.py b/tests/integration/itest_forget.py index 12b3fe5f..09659589 100644 --- a/tests/integration/itest_forget.py +++ b/tests/integration/itest_forget.py @@ -2,8 +2,8 @@ import pytest -from aleph_client.asynchronous import create_post, get_posts, get_messages, forget from aleph_client.types import Account +from aleph_client.user_session import AuthenticatedUserSession from .config import REFERENCE_NODE, TARGET_NODE, TEST_CHANNEL from .toolkit import try_until @@ -12,25 +12,29 @@ async def create_and_forget_post( account: Account, emitter_node: str, receiver_node: str, channel=TEST_CHANNEL ) -> str: async def wait_matching_posts( - item_hash: str, condition: Callable[[Dict], bool], timeout: int = 5 + item_hash: str, + condition: Callable[[Dict], bool], + timeout: int = 5, ): - return await try_until( - get_posts, - condition, - timeout=timeout, - hashes=[item_hash], - api_server=receiver_node, + async with AuthenticatedUserSession( + account=account, api_server=receiver_node + ) as rx_session: + return await try_until( + rx_session.get_posts, + condition, + timeout=timeout, + hashes=[item_hash], + ) + + async with AuthenticatedUserSession( + account=account, api_server=emitter_node + ) as tx_session: + post_message, message_status = await tx_session.create_post( + post_content="A considerate and politically correct post.", + post_type="POST", + channel="INTEGRATION_TESTS", ) - post_message, message_status = await create_post( - account=account, - post_content="A considerate and politically correct post.", - post_type="POST", - channel="INTEGRATION_TESTS", - session=None, - api_server=emitter_node, - ) - # Wait for the message to appear on the receiver. We don't check the values, # they're checked in other integration tests. get_post_response = await wait_matching_posts( @@ -41,13 +45,14 @@ async def wait_matching_posts( post_hash = post_message.item_hash reason = "This well thought-out content offends me!" - forget_message, forget_status = await forget( - account, - hashes=[post_hash], - reason=reason, - channel=channel, - api_server=emitter_node, - ) + async with AuthenticatedUserSession( + account=account, api_server=emitter_node + ) as tx_session: + forget_message, forget_status = await tx_session.forget( + hashes=[post_hash], + reason=reason, + channel=channel, + ) assert forget_message.sender == account.get_address() assert forget_message.content.reason == reason @@ -97,26 +102,28 @@ async def test_forget_a_forget_message(fixture_account): # TODO: this test should be moved to the PyAleph API tests, once a framework is in place. post_hash = await create_and_forget_post(fixture_account, TARGET_NODE, TARGET_NODE) - get_post_response = await get_posts(hashes=[post_hash]) - assert len(get_post_response["posts"]) == 1 - post = get_post_response["posts"][0] - - forget_message_hash = post["forgotten_by"][0] - forget_message, forget_status = await forget( - fixture_account, - hashes=[forget_message_hash], - reason="I want to remember this post. Maybe I can forget I forgot it?", - channel=TEST_CHANNEL, - api_server=TARGET_NODE, - ) + async with AuthenticatedUserSession( + account=fixture_account, api_server=TARGET_NODE + ) as session: + get_post_response = await session.get_posts(hashes=[post_hash]) + assert len(get_post_response["posts"]) == 1 + post = get_post_response["posts"][0] + + forget_message_hash = post["forgotten_by"][0] + forget_message, forget_status = await session.forget( + hashes=[forget_message_hash], + reason="I want to remember this post. Maybe I can forget I forgot it?", + channel=TEST_CHANNEL, + ) - print(forget_message) + print(forget_message) - get_forget_message_response = await get_messages( - hashes=[forget_message_hash], channels=[TEST_CHANNEL], api_server=TARGET_NODE - ) - assert len(get_forget_message_response.messages) == 1 - forget_message = get_forget_message_response.messages[0] - print(forget_message) + get_forget_message_response = await session.get_messages( + hashes=[forget_message_hash], + channels=[TEST_CHANNEL], + ) + assert len(get_forget_message_response.messages) == 1 + forget_message = get_forget_message_response.messages[0] + print(forget_message) - assert "forgotten_by" not in forget_message + assert "forgotten_by" not in forget_message diff --git a/tests/integration/itest_posts.py b/tests/integration/itest_posts.py index 7ee89f3d..7ad0184f 100644 --- a/tests/integration/itest_posts.py +++ b/tests/integration/itest_posts.py @@ -1,11 +1,7 @@ import pytest -from aleph_message import Message -from aleph_message.models import PostMessage, MessagesResponse +from aleph_message.models import MessagesResponse -from aleph_client.asynchronous import ( - create_post, - get_messages, -) +from aleph_client.user_session import AuthenticatedUserSession from tests.integration.toolkit import try_until from .config import REFERENCE_NODE, TARGET_NODE @@ -16,29 +12,29 @@ async def create_message_on_target( """ Create a POST message on the target node, then fetch it from the reference node. """ - - post_message, message_status = await create_post( - account=fixture_account, - post_content=None, - post_type="POST", - channel="INTEGRATION_TESTS", - session=None, - api_server=emitter_node, - ) + async with AuthenticatedUserSession( + account=fixture_account, api_server=emitter_node + ) as tx_session: + post_message, message_status = await tx_session.create_post( + post_content=None, + post_type="POST", + channel="INTEGRATION_TESTS", + ) def response_contains_messages(response: MessagesResponse) -> bool: return len(response.messages) > 0 - # create_message = Message(**created_message_dict) - response_dict = await try_until( - get_messages, - response_contains_messages, - timeout=5, - hashes=[post_message.item_hash], - api_server=receiver_node, - ) - - message_from_target = Message(**response_dict["messages"][0]) + async with AuthenticatedUserSession( + account=fixture_account, api_server=receiver_node + ) as rx_session: + responses = await try_until( + rx_session.get_messages, + response_contains_messages, + timeout=5, + hashes=[post_message.item_hash], + ) + + message_from_target = responses.messages[0] assert post_message.item_hash == message_from_target.item_hash diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 30c3c49c..9a13498c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,20 +1,17 @@ -# -*- coding: utf-8 -*- -""" - Dummy conftest.py for aleph_client. - - If you don't know what this is for, just leave it empty. - Read more about conftest.py under: - https://pytest.org/latest/plugins.html -""" +import os from pathlib import Path from tempfile import NamedTemporaryFile -import pytest +import pytest as pytest -from aleph_client.chains.common import get_fallback_private_key import aleph_client.chains.ethereum as ethereum import aleph_client.chains.sol as solana import aleph_client.chains.tezos as tezos +from aleph_client.chains.common import get_fallback_private_key +from aleph_client.chains.ethereum import ETHAccount +from aleph_client.conf import settings +from aleph_client.types import Account + @pytest.fixture def fallback_private_key() -> bytes: diff --git a/tests/unit/test_asynchronous.py b/tests/unit/test_asynchronous.py index bb689f55..74ae7e19 100644 --- a/tests/unit/test_asynchronous.py +++ b/tests/unit/test_asynchronous.py @@ -1,4 +1,5 @@ -from unittest.mock import MagicMock, patch, AsyncMock +import json +from unittest.mock import patch, AsyncMock, MagicMock import pytest as pytest from aleph_message.models import ( @@ -9,168 +10,148 @@ ForgetMessage, ) -from aleph_client.types import StorageEnum, MessageStatus - -from aleph_client.asynchronous import ( - create_post, - _get_fallback_session, - create_aggregate, - create_store, - create_program, - forget, -) +from aleph_client import AuthenticatedUserSession +from aleph_client.types import StorageEnum, MessageStatus, Account -def new_mock_session_with_post_success(): - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.json.return_value = { - "message_status": "processed", - "publication_status": {"status": "success", "failed": []}, - } +@pytest.fixture +def mock_session_with_post_success( + ethereum_account: Account, +) -> AuthenticatedUserSession: + class MockResponse: + def __init__(self, sync: bool): + self.sync = sync - mock_post = AsyncMock() - mock_post.return_value = mock_response + async def __aenter__(self): + return self - mock_session = MagicMock() - mock_session.post.return_value.__aenter__ = mock_post - return mock_session + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + @property + def status(self): + return 200 if self.sync else 202 -@pytest.mark.asyncio -async def test_create_post(ethereum_account): - _get_fallback_session.cache_clear() + async def json(self): + message_status = "processed" if self.sync else "pending" + return { + "message_status": message_status, + "publication_status": {"status": "success", "failed": []}, + } - content = {"Hello": "World"} + async def text(self): + return json.dumps(await self.json()) - mock_session = new_mock_session_with_post_success() + http_session = AsyncMock() + http_session.post = MagicMock() + http_session.post.side_effect = lambda *args, **kwargs: MockResponse( + sync=kwargs.get("sync", False) + ) - post_message, message_status = await create_post( - account=ethereum_account, - post_content=content, - post_type="TEST", - channel="TEST", - session=mock_session, - api_server="https://example.org", - sync=True, + user_session = AuthenticatedUserSession( + account=ethereum_account, api_server="http://localhost" ) + user_session.http_session = http_session - assert mock_session.post.called - assert isinstance(post_message, PostMessage) - assert message_status == MessageStatus.PROCESSED + return user_session @pytest.mark.asyncio -async def test_create_aggregate(ethereum_account): - _get_fallback_session.cache_clear() +async def test_create_post(mock_session_with_post_success): + + async with mock_session_with_post_success as session: + content = {"Hello": "World"} + + post_message, message_status = await session.create_post( + post_content=content, + post_type="TEST", + channel="TEST", + sync=False, + ) - content = {"Hello": "World"} + assert mock_session_with_post_success.http_session.post.called_once + assert isinstance(post_message, PostMessage) + assert message_status == MessageStatus.PENDING - mock_session = new_mock_session_with_post_success() - _ = await create_aggregate( - account=ethereum_account, - key="hello", - content=content, - channel="TEST", - session=mock_session, - ) +@pytest.mark.asyncio +async def test_create_aggregate(mock_session_with_post_success): - aggregate_message, message_status = await create_aggregate( - account=ethereum_account, - key="hello", - content="world", - channel="TEST", - session=mock_session, - api_server="https://example.org", - ) + async with mock_session_with_post_success as session: + + aggregate_message, message_status = await session.create_aggregate( + key="hello", + content={"Hello": "world"}, + channel="TEST", + ) - assert mock_session.post.called + assert mock_session_with_post_success.http_session.post.called_once assert isinstance(aggregate_message, AggregateMessage) @pytest.mark.asyncio -async def test_create_store(ethereum_account): - _get_fallback_session.cache_clear() - - mock_session = new_mock_session_with_post_success() +async def test_create_store(mock_session_with_post_success): mock_ipfs_push_file = AsyncMock() mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" - with patch("aleph_client.asynchronous.ipfs_push_file", mock_ipfs_push_file): - _ = await create_store( - account=ethereum_account, + mock_session_with_post_success.ipfs_push_file = mock_ipfs_push_file + + async with mock_session_with_post_success as session: + _ = await session.create_store( file_content=b"HELLO", channel="TEST", storage_engine=StorageEnum.ipfs, - session=mock_session, - api_server="https://example.org", ) - _ = await create_store( - account=ethereum_account, + _ = await session.create_store( file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", channel="TEST", storage_engine=StorageEnum.ipfs, - session=mock_session, - api_server="https://example.org", ) mock_storage_push_file = AsyncMock() mock_storage_push_file.return_value = ( "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" ) + mock_session_with_post_success.storage_push_file = mock_storage_push_file + async with mock_session_with_post_success as session: - with patch("aleph_client.asynchronous.storage_push_file", mock_storage_push_file): - - store_message, message_status = await create_store( - account=ethereum_account, + store_message, message_status = await session.create_store( file_content=b"HELLO", channel="TEST", storage_engine=StorageEnum.storage, - session=mock_session, - api_server="https://example.org", ) - assert mock_session.post.called + assert mock_session_with_post_success.http_session.post.called assert isinstance(store_message, StoreMessage) @pytest.mark.asyncio -async def test_create_program(ethereum_account): - _get_fallback_session.cache_clear() - - mock_session = new_mock_session_with_post_success() - - program_message, message_status = await create_program( - account=ethereum_account, - program_ref="FAKE-HASH", - entrypoint="main:app", - runtime="FAKE-HASH", - channel="TEST", - session=mock_session, - api_server="https://example.org", - ) +async def test_create_program(mock_session_with_post_success): + + async with mock_session_with_post_success as session: - assert mock_session.post.called + program_message, message_status = await session.create_program( + program_ref="FAKE-HASH", + entrypoint="main:app", + runtime="FAKE-HASH", + channel="TEST", + ) + + assert mock_session_with_post_success.http_session.post.called_once assert isinstance(program_message, ProgramMessage) @pytest.mark.asyncio -async def test_forget(ethereum_account): - _get_fallback_session.cache_clear() - - mock_session = new_mock_session_with_post_success() - - forget_message, message_status = await forget( - account=ethereum_account, - hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"], - reason="GDPR", - channel="TEST", - session=mock_session, - api_server="https://example.org", - ) +async def test_forget(mock_session_with_post_success): + + async with mock_session_with_post_success as session: + forget_message, message_status = await session.forget( + hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"], + reason="GDPR", + channel="TEST", + ) - assert mock_session.post.called + assert mock_session_with_post_success.http_session.post.called_once assert isinstance(forget_message, ForgetMessage) diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index 2f03a954..958ba6da 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -1,64 +1,94 @@ +import unittest +from typing import Any, Dict +from unittest.mock import AsyncMock + import pytest from aleph_message.models import MessageType, MessagesResponse -import unittest -from aleph_client.asynchronous import ( - get_messages, - fetch_aggregates, - fetch_aggregate, - _get_fallback_session, -) +from aleph_client.conf import settings +from aleph_client.user_session import UserSession + + +def make_mock_session(get_return_value: Dict[str, Any]) -> UserSession: + class MockResponse: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 + + async def json(self): + return get_return_value + + class MockHttpSession(AsyncMock): + def get(self, *_args, **_kwargs): + return MockResponse() + + http_session = MockHttpSession() + + user_session = UserSession(api_server="http://localhost") + user_session.http_session = http_session + + return user_session @pytest.mark.asyncio async def test_fetch_aggregate(): - _get_fallback_session.cache_clear() - - response = await fetch_aggregate( - address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10", key="corechannel" + mock_session = make_mock_session( + {"data": {"corechannel": {"nodes": [], "resource_nodes": []}}} ) + async with mock_session: + + response = await mock_session.fetch_aggregate( + address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10", + key="corechannel", + ) assert response.keys() == {"nodes", "resource_nodes"} @pytest.mark.asyncio async def test_fetch_aggregates(): - _get_fallback_session.cache_clear() - - response = await fetch_aggregates( - address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10" + mock_session = make_mock_session( + {"data": {"corechannel": {"nodes": [], "resource_nodes": []}}} ) - assert response.keys() == {"corechannel"} - assert response["corechannel"].keys() == {"nodes", "resource_nodes"} + + async with mock_session: + response = await mock_session.fetch_aggregates( + address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10" + ) + assert response.keys() == {"corechannel"} + assert response["corechannel"].keys() == {"nodes", "resource_nodes"} @pytest.mark.asyncio async def test_get_posts(): - _get_fallback_session.cache_clear() - - response: MessagesResponse = await get_messages( - pagination=2, - message_type=MessageType.post, - ) + async with UserSession(api_server=settings.API_HOST) as session: + response: MessagesResponse = await session.get_messages( + message_type=MessageType.post, + ) - messages = response.messages - assert len(messages) > 1 - for message in messages: - assert message.type == MessageType.post + messages = response.messages + assert len(messages) > 1 + for message in messages: + assert message.type == MessageType.post @pytest.mark.asyncio async def test_get_messages(): - _get_fallback_session.cache_clear() - - response: MessagesResponse = await get_messages( - pagination=2, - ) + async with UserSession(api_server=settings.API_HOST) as session: + response: MessagesResponse = await session.get_messages( + pagination=2, + ) - messages = response.messages - assert len(messages) > 1 - assert messages[0].type - assert messages[0].sender + messages = response.messages + assert len(messages) > 1 + assert messages[0].type + assert messages[0].sender -if __name__ == '__main __': +if __name__ == "__main __": unittest.main() diff --git a/tests/unit/test_chain_solana.py b/tests/unit/test_chain_solana.py index e62cc7ef..e511f59a 100644 --- a/tests/unit/test_chain_solana.py +++ b/tests/unit/test_chain_solana.py @@ -33,7 +33,9 @@ def test_get_fallback_account(): @pytest.mark.asyncio async def test_SOLAccount(solana_account): - message = asdict(Message("SOL", solana_account.get_address(), "SomeType", "ItemHash")) + message = asdict( + Message("SOL", solana_account.get_address(), "SomeType", "ItemHash") + ) initial_message = message.copy() await solana_account.sign_message(message) assert message["signature"] diff --git a/tests/unit/test_synchronous_get.py b/tests/unit/test_synchronous_get.py new file mode 100644 index 00000000..b6fd1db4 --- /dev/null +++ b/tests/unit/test_synchronous_get.py @@ -0,0 +1,17 @@ +from aleph_message.models import MessageType, MessagesResponse + +from aleph_client.conf import settings +from aleph_client.user_session import UserSession + + +def test_get_posts(): + with UserSession(api_server=settings.API_HOST) as session: + response: MessagesResponse = session.get_messages( + pagination=2, + message_type=MessageType.post, + ) + + messages = response.messages + assert len(messages) > 1 + for message in messages: + assert message.type == MessageType.post diff --git a/tests/unit/test_vm_cache.py b/tests/unit/test_vm_cache.py index d9aa31a7..ddab8362 100644 --- a/tests/unit/test_vm_cache.py +++ b/tests/unit/test_vm_cache.py @@ -1,3 +1,4 @@ +import aiohttp import pytest from aleph_client.vm.cache import TestVmCache, sanitize_cache_key @@ -5,6 +6,8 @@ @pytest.mark.asyncio async def test_local_vm_cache(): + http_session = aiohttp.ClientSession(base_url="http://localhost:8000") + cache = TestVmCache() assert (await cache.get("doesnotexist")) is None assert len(await (cache.keys())) == 0