diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index 58f3faef..d2f64c8e 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -23,6 +23,7 @@ TypeVar, Union, ) +from io import BytesIO import aiohttp from aleph_message.models import ( @@ -31,6 +32,7 @@ AlephMessage, ForgetContent, ForgetMessage, + ItemHash, ItemType, MessageType, PostContent, @@ -43,16 +45,17 @@ ) from aleph_message.models.execution.base import Encoding from aleph_message.status import MessageStatus -from pydantic import ValidationError +from pydantic import ValidationError, BaseModel from aleph.sdk.types import Account, GenericMessage, StorageEnum - +from aleph.sdk.utils import copy_async_readable_to_buffer, Writable, AsyncReadable from .conf import settings from .exceptions import ( BroadcastError, InvalidMessageError, MessageNotFoundError, MultipleMessagesError, + FileTooLarge, ) from .models import MessagesResponse from .utils import check_unix_socket_valid, get_message_type_value @@ -229,6 +232,12 @@ def get_posts( def download_file(self, file_hash: str) -> bytes: return self._wrap(self.async_session.download_file, file_hash=file_hash) + def download_file_ipfs(self, file_hash: str) -> bytes: + return self._wrap( + self.async_session.download_file_ipfs, + file_hash=file_hash, + ) + def watch_messages( self, message_type: Optional[MessageType] = None, @@ -609,6 +618,55 @@ async def get_posts( resp.raise_for_status() return await resp.json() + async def download_file_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Download a file from the storage engine and write it to the specified output buffer. + :param file_hash: The hash of the file to retrieve. + :param output_buffer: Writable binary buffer. The file will be written to this buffer. + """ + + async with self.http_session.get( + f"/api/v0/storage/raw/{file_hash}" + ) as response: + if response.status == 200: + await copy_async_readable_to_buffer( + response.content, output_buffer, chunk_size=16 * 1024 + ) + if response.status == 413: + ipfs_hash = ItemHash(file_hash) + if ipfs_hash.item_type == ItemType.ipfs: + return await self.download_file_ipfs_to_buffer( + file_hash, output_buffer + ) + else: + raise FileTooLarge(f"The file from {file_hash} is too large") + + async def download_file_ipfs_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Download a file from the storage engine and write it to the specified output buffer. + + :param file_hash: The hash of the file to retrieve. + :param output_buffer: The binary output buffer to write the file data to. + """ + async with aiohttp.ClientSession() as session: + async with session.get( + f"https://ipfs.aleph.im/ipfs/{file_hash}" + ) as response: + if response.status == 200: + await copy_async_readable_to_buffer( + response.content, output_buffer, chunk_size=16 * 1024 + ) + else: + response.raise_for_status() + async def download_file( self, file_hash: str, @@ -620,11 +678,24 @@ async def download_file( :param file_hash: The hash of the file to retrieve. """ - async with self.http_session.get( - f"/api/v0/storage/raw/{file_hash}" - ) as response: - response.raise_for_status() - return await response.read() + buffer = BytesIO() + await self.download_file_to_buffer(file_hash, output_buffer=buffer) + return buffer.getvalue() + + async def download_file_ipfs( + self, + file_hash: str, + ) -> bytes: + """ + Get a file from the ipfs storage engine as raw bytes. + + Warning: Downloading large files can be slow. + + :param file_hash: The hash of the file to retrieve. + """ + buffer = BytesIO() + await self.download_file_ipfs_to_buffer(file_hash, output_buffer=buffer) + return buffer.getvalue() async def get_messages( self, diff --git a/src/aleph/sdk/exceptions.py b/src/aleph/sdk/exceptions.py index 0c626548..51762925 100644 --- a/src/aleph/sdk/exceptions.py +++ b/src/aleph/sdk/exceptions.py @@ -42,3 +42,11 @@ class BadSignatureError(Exception): """ pass + + +class FileTooLarge(Exception): + """ + A file is too large + """ + + pass diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index df80261b..91c45c26 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -13,6 +13,13 @@ from aleph.sdk.conf import settings from aleph.sdk.types import GenericMessage +from typing import ( + Tuple, + Type, + TypeVar, + Protocol, +) + logger = logging.getLogger(__name__) try: @@ -47,7 +54,7 @@ def create_archive(path: Path) -> Tuple[Path, Encoding]: return archive_path, Encoding.zip elif os.path.isfile(path): if path.suffix == ".squashfs" or ( - magic and magic.from_file(path).startswith("Squashfs filesystem") + magic and magic.from_file(path).startswith("Squashfs filesystem") ): return path, Encoding.squashfs else: @@ -79,6 +86,30 @@ def check_unix_socket_valid(unix_socket_path: str) -> bool: return True +T = TypeVar("T", str, bytes, covariant=True) +U = TypeVar("U", str, bytes, contravariant=True) + + +class AsyncReadable(Protocol[T]): + async def read(self, n: int = -1) -> T: + ... + + +class Writable(Protocol[U]): + def write(self, buffer: U) -> int: + ... + + +async def copy_async_readable_to_buffer( + readable: AsyncReadable[T], buffer: Writable[T], chunk_size: int +): + while True: + chunk = await readable.read(chunk_size) + if not chunk: + break + buffer.write(chunk) + + def enum_as_str(obj: Union[str, Enum]) -> str: """Returns the value of an Enum, or the string itself when passing a string. diff --git a/tests/unit/test_download.py b/tests/unit/test_download.py new file mode 100644 index 00000000..219ccde5 --- /dev/null +++ b/tests/unit/test_download.py @@ -0,0 +1,33 @@ +import pytest +from aleph.sdk import AlephClient +from aleph.sdk.conf import settings as sdk_settings + + +@pytest.mark.parametrize( + "file_hash,expected_size", + [ + ("QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH", 5), + ("Qmdy5LaAL4eghxE7JD6Ah5o4PJGarjAV9st8az2k52i1vq", 5817703), + ], +) +@pytest.mark.asyncio +async def test_download(file_hash: str, expected_size: int): + async with AlephClient(api_server=sdk_settings.API_HOST) as client: + file_content = await client.download_file(file_hash) # File is 5B + file_size = len(file_content) + assert file_size == expected_size + + +@pytest.mark.parametrize( + "file_hash,expected_size", + [ + ("QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH", 5), + ("Qmdy5LaAL4eghxE7JD6Ah5o4PJGarjAV9st8az2k52i1vq", 5817703), + ], +) +@pytest.mark.asyncio +async def test_download_ipfs(file_hash: str, expected_size: int): + async with AlephClient(api_server=sdk_settings.API_HOST) as client: + file_content = await client.download_file_ipfs(file_hash) ## 5817703 B FILE + file_size = len(file_content) + assert file_size == expected_size