diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index 430ace1f..d48fded9 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -6,6 +6,7 @@ import threading import time from datetime import datetime +from io import BytesIO from pathlib import Path from typing import ( Any, @@ -23,7 +24,6 @@ TypeVar, Union, ) -from io import BytesIO import aiohttp from aleph_message.models import ( @@ -45,17 +45,18 @@ ) from aleph_message.models.execution.base import Encoding from aleph_message.status import MessageStatus -from pydantic import ValidationError, BaseModel +from pydantic import ValidationError from aleph.sdk.types import Account, GenericMessage, StorageEnum -from aleph.sdk.utils import copy_async_readable_to_buffer, Writable, AsyncReadable +from aleph.sdk.utils import Writable, copy_async_readable_to_buffer + from .conf import settings from .exceptions import ( BroadcastError, + FileTooLarge, InvalidMessageError, MessageNotFoundError, MultipleMessagesError, - FileTooLarge, ) from .models import MessagesResponse from .utils import check_unix_socket_valid, get_message_type_value @@ -237,12 +238,24 @@ def download_file_ipfs(self, file_hash: str) -> bytes: self.async_session.download_file_ipfs, file_hash=file_hash, ) - def download_file_to_buffer(self, file_hash: str, output_buffer: Writable[bytes]) -> bytes: - return self._wrap(self.async_session.download_file_to_buffer, file_hash=file_hash, output_buffer=output_buffer) - def download_file_ipfs_to_buffer(self, file_hash: str, output_buffer: Writable[bytes]) -> bytes: - return self._wrap(self.async_session.download_file_ipfs_to_buffer, file_hash=file_hash, output_buffer=output_buffer) + def download_file_to_buffer( + self, file_hash: str, output_buffer: Writable[bytes] + ) -> bytes: + return self._wrap( + self.async_session.download_file_to_buffer, + file_hash=file_hash, + output_buffer=output_buffer, + ) + def download_file_ipfs_to_buffer( + self, file_hash: str, output_buffer: Writable[bytes] + ) -> bytes: + return self._wrap( + self.async_session.download_file_ipfs_to_buffer, + file_hash=file_hash, + output_buffer=output_buffer, + ) def watch_messages( self, diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index 91c45c26..be56cc2c 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -4,7 +4,7 @@ from enum import Enum from pathlib import Path from shutil import make_archive -from typing import Tuple, Type, Union +from typing import Protocol, Tuple, Type, TypeVar, Union from zipfile import BadZipFile, ZipFile from aleph_message.models import MessageType @@ -13,13 +13,6 @@ from aleph.sdk.conf import settings from aleph.sdk.types import GenericMessage -from typing import ( - Tuple, - Type, - TypeVar, - Protocol, -) - logger = logging.getLogger(__name__) try: @@ -54,7 +47,7 @@ def create_archive(path: Path) -> Tuple[Path, Encoding]: return archive_path, Encoding.zip elif os.path.isfile(path): if path.suffix == ".squashfs" or ( - magic and magic.from_file(path).startswith("Squashfs filesystem") + magic and magic.from_file(path).startswith("Squashfs filesystem") ): return path, Encoding.squashfs else: @@ -101,7 +94,7 @@ def write(self, buffer: U) -> int: async def copy_async_readable_to_buffer( - readable: AsyncReadable[T], buffer: Writable[T], chunk_size: int + readable: AsyncReadable[T], buffer: Writable[T], chunk_size: int ): while True: chunk = await readable.read(chunk_size) diff --git a/tests/unit/test_chain_ethereum.py b/tests/unit/test_chain_ethereum.py index 72293664..62e941ec 100644 --- a/tests/unit/test_chain_ethereum.py +++ b/tests/unit/test_chain_ethereum.py @@ -36,11 +36,11 @@ async def test_ETHAccount(ethereum_account): address = account.get_address() assert address - assert type(address) == str + assert isinstance(address, str) assert len(address) == 42 pubkey = account.get_public_key() - assert type(pubkey) == str + assert isinstance(pubkey, str) assert len(pubkey) == 68 @@ -111,9 +111,9 @@ async def test_decrypt_secp256k1(ethereum_account): content = b"SomeContent" encrypted = await account.encrypt(content) - assert type(encrypted) == bytes + assert isinstance(encrypted, bytes) decrypted = await account.decrypt(encrypted) - assert type(decrypted) == bytes + assert isinstance(decrypted, bytes) assert content == decrypted diff --git a/tests/unit/test_chain_nuls1.py b/tests/unit/test_chain_nuls1.py index 6c268b0d..d480e644 100644 --- a/tests/unit/test_chain_nuls1.py +++ b/tests/unit/test_chain_nuls1.py @@ -18,9 +18,9 @@ async def test_sign_data(): ) assert sign - assert type(sign.pub_key) == bytes - assert type(sign.digest_bytes) == bytes - assert type(sign.sig_ser) == bytes + assert isinstance(sign.pub_key, bytes) + assert isinstance(sign.digest_bytes, bytes) + assert isinstance(sign.sig_ser, bytes) assert sign.ecc_type is None @@ -37,9 +37,9 @@ async def test_sign_message(): assert len(sign.sig_ser) == 70 assert sign - assert type(sign.pub_key) == bytes + assert isinstance(sign.pub_key, bytes) assert sign.digest_bytes is None - assert type(sign.sig_ser) == bytes + assert isinstance(sign.sig_ser, bytes) assert sign.ecc_type is None diff --git a/tests/unit/test_chain_solana.py b/tests/unit/test_chain_solana.py index 15f9e12f..7384229f 100644 --- a/tests/unit/test_chain_solana.py +++ b/tests/unit/test_chain_solana.py @@ -27,7 +27,7 @@ def test_get_fallback_account(): assert account.CHAIN == "SOL" assert account.CURVE == "curve25519" assert account._signing_key.verify_key - assert type(account.private_key) == bytes + assert isinstance(account.private_key, bytes) assert len(account.private_key) == 32 @@ -42,12 +42,12 @@ async def test_SOLAccount(solana_account): address = message["sender"] assert address - assert type(address) == str + assert isinstance(address, str) # assert len(address) == 44 # can also be 43? signature = json.loads(message["signature"]) pubkey = base58.b58decode(signature["publicKey"]) - assert type(pubkey) == bytes + assert isinstance(pubkey, bytes) assert len(pubkey) == 32 verify_key = VerifyKey(pubkey) @@ -61,7 +61,7 @@ async def test_SOLAccount(solana_account): assert message["sender"] == signature["publicKey"] pubkey = solana_account.get_public_key() - assert type(pubkey) == str + assert isinstance(pubkey, str) assert len(pubkey) == 64 @@ -71,9 +71,9 @@ async def test_decrypt_curve25516(solana_account): content = b"SomeContent" encrypted = await solana_account.encrypt(content) - assert type(encrypted) == bytes + assert isinstance(encrypted, bytes) decrypted = await solana_account.decrypt(encrypted) - assert type(decrypted) == bytes + assert isinstance(decrypted, bytes) assert content == decrypted @@ -90,7 +90,7 @@ async def test_verify_signature(solana_account): await solana_account.sign_message(message) assert message["signature"] raw_signature = json.loads(message["signature"])["signature"] - assert type(raw_signature) == str + assert isinstance(raw_signature, str) verify_signature(raw_signature, message["sender"], get_verification_buffer(message)) diff --git a/tests/unit/test_download.py b/tests/unit/test_download.py index 219ccde5..b16e0d75 100644 --- a/tests/unit/test_download.py +++ b/tests/unit/test_download.py @@ -1,4 +1,5 @@ import pytest + from aleph.sdk import AlephClient from aleph.sdk.conf import settings as sdk_settings @@ -28,6 +29,6 @@ async def test_download(file_hash: str, expected_size: int): @pytest.mark.asyncio async def test_download_ipfs(file_hash: str, expected_size: int): async with AlephClient(api_server=sdk_settings.API_HOST) as client: - file_content = await client.download_file_ipfs(file_hash) ## 5817703 B FILE + file_content = await client.download_file_ipfs(file_hash) # 5817703 B FILE file_size = len(file_content) assert file_size == expected_size